I recently learned about the so-called Gumbel-Max and Gumbel-Softmax tricks. Essentially, the Gumbel-Max trick says that if we have a categorical distribution and i.i.d. -distributed random variables , then
The Gumbel-Softmax trick is then a continuous relaxation of the above which uses the fact that letting the temperature of a softmax go to 0 gives the (one-hot encoding of the) argmax function.
Note that is the distribution of , where is uniformly distributed.
So, instead of sampling from , we can sample uniforms, compute their double logarithms, add the logprobs (actually, it’s enough to add the log-potentials but I’ll get to that) and take argmax. This seems like more work than simply sampling from (which is normally done by sampling a single and checking between which two numbers in it lies). And indeed it is, if we want just one sample from . But this method of sampling has other benefits, like the ability to pre-compute the Gumbel samples or to avoid the normalisation of at every forward pass, or to perform reservoir sampling. But I don’t want to focus on these comparisons now (if you’re interested, I found some nice concise explanations in this blogpost).
These tricks are really neat but I wondered how people came up with them. So I tried to reverse-engineer the Gumbel Max trick and then (in a future post) write down some of the things I’ve learned about softmax and temperature annealing. All these things are well-known to most people who do probability theory for a living but I thought a couple of blogposts might be useful for me to put my thoughts in order and potentially for readers to follow along, so here goes.
Reverse-Engineering Gumbel-Max
Suppose we are given a categorical distribution .
We want to find independent, continuous real random variables which satisfy the condition
One quick observation we can make is that if satisfy this condition and is any strictly increasing function, then will also satisfy it. So we already see this is a really flexible condition. But anyway, let’s actually find some ’s that do what we want.
Let denote the CDF of , that is and let denote the pdf (I will deliberately ignore mathematical subtleties around differentiability, continuity etc. in this post). Then, the above equations can be rewritten as:
where in the last equality we used the independence assumption.
Notice that on the right-hand side, we are integrating a positive function of from to and getting the number as a result. If instead we integrated from to some , we would get , where is some number between and which increases with (by the positivity of the integrand), approaches as and approaches as . In other words, the function is a CDF for some continuous distribution.
At this point, in order to find concrete solutions to the above set of equations, we make the simplifying assumption that all these CDFs are equal, i.e. . So we get the following integral equations:
Now, summing these over , we obtain:
from which we obtain (we can see that the constant of integration is by letting and using the fact that all functions involved are cdfs).
On the other hand, differentiating both sides of the integral equations with respect to gives the functional equations:
Substituting the fact that into these, we obtain:
Integrating both sides we find
(constant of integration is again by considerations at ), and so we have:
Thus, for any choice of CDF , we obtain independent random variables which satisfy our requirement.
We can now use Inverse Transform Sampling to sample from these distributions. That is, we know that if
then is distributed according to .
But, actually, notice that we are really free in our choice of : all we need is for it to be a strictly increasing function defined on the unit interval! So, we can stop worrying about or and we can just say:
Given a categorical distribution , a strictly increasing function , and , the random variables satisfy
One obvious choice is to take to be the identity function, giving us
Another option is to take , giving us
Arguably, the most popular choice is to take , which gives:
where .
And so, we’ve “rediscovered” the Gumbel-Max trick….almost. The last piece is to notice that the function can itself depend on and this can be quite useful if, for example, the distribution is computed by some neural network in an unnormalized form. That is, suppose our network outputs the log-potentials and to compute the probabilities , we need to know the normalising constant . The cool thing is, we don’t actually need to compute the normalising constant to generate variables with the desired property. For example, we could set , where is “LogSumExp”, i.e. . Then we get
which is the more general Gumbel-Max trick which allows us to sample from a categorical distribution knowing only its log-potentials.
I’ll end with some Julia code showing a few simulations which I used as a sanity check that it all works. In the next post I’ll talk about the Gumbel-Softmax trick and how I think about temperature annealing. See you then!
importPlots,Random
importPkg;Pkg.add("StatsPlots")importStatsPlots
# Functions that operate on u and pif_1=(u,pi)->u.^(1./pi)# identityf_2=(u,pi)->log.(u)./pi# logf_3=(u,pi)->-log.(-log.(u)).+log.(pi)# Gumbel# Functions that only use the log-potentials g_1=(u,logit)->u.^(1./exp.(logit))# identityg_2=(u,logit)->log.(u)./exp.(logit)# logg_3=(u,logit)->-log.(-log.(u)).+logit# Gumbel