Reverse engineering the Gumbel Max Trick

Intro

I recently learned about the so-called Gumbel-Max and Gumbel-Softmax tricks. Essentially, the Gumbel-Max trick says that if we have a categorical distribution π=π1,πK and i.i.d. Gumbel(0,1)-distributed random variables Gi,1iK, then kP(Gk+log(πk)=max{Gi+log(πi):1iK})=πk. The Gumbel-Softmax trick is then a continuous relaxation of the above which uses the fact that letting the temperature of a softmax go to 0 gives the (one-hot encoding of the) argmax function.

Note that Gumbel(0,1) is the distribution of log(log(u)), where uU([0,1]) is uniformly distributed. So, instead of sampling from π, we can sample K uniforms, compute their double logarithms, add the logprobs (actually, it’s enough to add the log-potentials but I’ll get to that) and take argmax. This seems like more work than simply sampling from π (which is normally done by sampling a single uU(0,1) and checking between which two numbers in cumsum(π) it lies). And indeed it is, if we want just one sample from π. But this method of sampling has other benefits, like the ability to pre-compute the Gumbel samples or to avoid the normalisation of π at every forward pass, or to perform reservoir sampling. But I don’t want to focus on these comparisons now (if you’re interested, I found some nice concise explanations in this blogpost).

These tricks are really neat but I wondered how people came up with them. So I tried to reverse-engineer the Gumbel Max trick and then (in a future post) write down some of the things I’ve learned about softmax and temperature annealing. All these things are well-known to most people who do probability theory for a living but I thought a couple of blogposts might be useful for me to put my thoughts in order and potentially for readers to follow along, so here goes.

Reverse-Engineering Gumbel-Max

Suppose we are given a categorical distribution π=π1,πK. We want to find independent, continuous real random variables X=X1,,XK which satisfy the condition

kP(Xk=max{Xi:1iK})=πk. One quick observation we can make is that if {X1,XK} satisfy this condition and f is any strictly increasing function, then {f(X1),f(XK)} will also satisfy it. So we already see this is a really flexible condition. But anyway, let’s actually find some X’s that do what we want.

Let Fi denote the CDF of Xi, that is Fi(t)=P(Xit) and let fi=Fi denote the pdf (I will deliberately ignore mathematical subtleties around differentiability, continuity etc. in this post). Then, the above equations can be rewritten as: πk=P(Xk=max{Xi:1iK})=RP(sX1,,sXK|Xk=s)fk(s)ds=Rfk(s)ikFi(s)ds, where in the last equality we used the independence assumption.

Notice that on the right-hand side, we are integrating a positive function of s from to and getting the number πk as a result. If instead we integrated from to some tR, we would get Hk(t)πk, where Hk(t) is some number between 0 and 1 which increases with t (by the positivity of the integrand), approaches 0 as t and approaches 1 as t. In other words, the function Hk is a CDF for some continuous distribution.

At this point, in order to find concrete solutions to the above set of K equations, we make the simplifying assumption that all these CDFs are equal, i.e. H1=H2==HK=:H. So we get the following integral equations: kttikFi(s)fk(s)ds=H(t)πk. Now, summing these over 1kK, we obtain: tdds1iKFi(s)ds=H(t), from which we obtain 1iKFi(t)=H(t) (we can see that the constant of integration is 0 by letting s and using the fact that all functions involved are cdfs).

On the other hand, differentiating both sides of the integral equations with respect to t gives the functional equations: ktikFi(t)fk(t)=h(t)πk.

Substituting the fact that 1iKFi(t)=H(t) into these, we obtain: H(t)fk(t)Fk(t)=h(t)πkfk(t)Fk(t)=h(t)H(t)πkddtlog(Fk(t))=ddtlog(H(t))πk

Integrating both sides we find log(Fk(t))=log(H(t))πkk (constant of integration is again 0 by considerations at ), and so we have: 1kKFk(t)=H(t)πk. Thus, for any choice of CDF H, we obtain independent random variables XkFk=Hπk which satisfy our requirement.

We can now use Inverse Transform Sampling to sample from these distributions. That is, we know that if UkU([0,1]) then Fk1(Uk)=H1(Uk1πk) is distributed according to Fk.

But, actually, notice that we are really free in our choice of H1: all we need is for it to be a strictly increasing function defined on the unit interval! So, we can stop worrying about H or h and we can just say:

Given a categorical distribution π=π1,πK, a strictly increasing function f:(0,1)R, and UkU((0,1)),;1kK, the random variables Xk:=f(U1πk) satisfy P(argmax(X)=k)=πk,1kK.

One obvious choice is to take f to be the identity function, giving us Xk=Uk1πk.

Another option is to take f=log, giving us Xk=log(Uk)πk. Arguably, the most popular choice is to take f(t)=log(log(t)) , which gives: Xk=log(log(U1πk))=log(log(U))+log(πk)=Gk+log(πk), where GkGumbel(0,1).

And so, we’ve “rediscovered” the Gumbel-Max trick….almost. The last piece is to notice that the function f can itself depend on π and this can be quite useful if, for example, the distribution π is computed by some neural network in an unnormalized form. That is, suppose our network outputs the log-potentials =(1,,K)RK and to compute the probabilities πkexp(k), we need to know the normalising constant Z()=i=1Kexp(i). The cool thing is, we don’t actually need to compute the normalising constant to generate variables Xk with the desired property. For example, we could set f(t)=log(log(t))+LSE(), where LSE is “LogSumExp”, i.e. LSE():=log(i=1Kexp(i)). Then we get Xk=Gk+log(πk)+lse()=Gk+k, which is the more general Gumbel-Max trick which allows us to sample from a categorical distribution knowing only its log-potentials.

I’ll end with some Julia code showing a few simulations which I used as a sanity check that it all works. In the next post I’ll talk about the Gumbel-Softmax trick and how I think about temperature annealing. See you then!

import Plots, Random
import Pkg; Pkg.add("StatsPlots")
import StatsPlots
# Functions that operate on u and pi
f_1 = (u, pi) -> u .^ (1 ./ pi) # identity
f_2 = (u, pi) -> log.(u) ./ pi # log
f_3 = (u, pi) -> -log.(-log.(u)) .+ log.(pi) # Gumbel

# Functions that only use the log-potentials 
g_1 = (u, logit) -> u .^ (1 ./ exp.(logit)) # identity
g_2 = (u, logit) -> log.(u) ./ exp.(logit) # log
g_3 = (u, logit) -> -log.(-log.(u)) .+ logit # Gumbel
K = 12
logits = randn(Float64, (1, K))
unnormalized_pi = exp.(logits)
pi = unnormalized_pi ./ sum(unnormalized_pi)
function get_empirical_dist_from_argmaxes(f, u, pi)
    x = f.(u, pi)
    argmaxes = mapslices(argmax, x, dims=2)
    counts = [0 for i in pi]
    for val in argmaxes
        counts[val] += 1
    end
    return counts ./ sum(counts)
end
reps = 10000
dists = Dict("π"=>pi)
for (name_f, name_g, f, g) in zip(["f₁", "f₂", "f₃"], ["g₁", "g₂", "g₃"], [f_1, f_2, f_3], [g_1, g_2, g_3])
    dists[name_f] = get_empirical_dist_from_argmaxes(f, Random.rand(Float64, (reps, K)), pi)
    dists[name_g] = get_empirical_dist_from_argmaxes(g, Random.rand(Float64, (reps, K)), logits)
end
names_to_plot = ["π", "f₁", "f₂", "f₃", "g₁", "g₂", "g₃"]
dists_to_plot = vcat([dists[name] for name in names_to_plot]...)
bar_chart = StatsPlots.groupedbar(
    vec(dists_to_plot),
    groups=repeat(names_to_plot, outer=K),
    xlabel="class"
)
Momchil Konstantinov
Momchil Konstantinov
ML Science/Engineering, Maths PhD