Reverse engineering the Gumbel Max Trick
Intro
I recently learned about the so-called Gumbel-Max and Gumbel-Softmax tricks. Essentially, the Gumbel-Max trick says that if we have a categorical distribution $\vec{\pi} = {\pi_1, \ldots \pi_K}$ and i.i.d. $\mathrm{Gumbel}(0, 1)$-distributed random variables $G_i, \; 1\le i\le K$, then $$ \forall k \quad \mathbb{P}(G_k + \log(\pi_k) = \max\{G_i + \log(\pi_i) \colon 1 \le i \le K\}) = \pi_k. $$ The Gumbel-Softmax trick is then a continuous relaxation of the above which uses the fact that letting the temperature of a softmax go to 0 gives the (one-hot encoding of the) argmax function.
Note that $\mathrm{Gumbel}(0, 1)$ is the distribution of $-\log(-\log(u))$, where $u \sim \mathcal{U}([0, 1])$ is uniformly distributed. So, instead of sampling from $\vec{\pi}$, we can sample $K$ uniforms, compute their double logarithms, add the logprobs (actually, it’s enough to add the log-potentials but I’ll get to that) and take argmax. This seems like more work than simply sampling from $\pi$ (which is normally done by sampling a single $u \sim \mathcal{U}(0, 1)$ and checking between which two numbers in $\operatorname{cumsum}(\pi)$ it lies). And indeed it is, if we want just one sample from $\pi$. But this method of sampling has other benefits, like the ability to pre-compute the Gumbel samples or to avoid the normalisation of $\pi$ at every forward pass, or to perform reservoir sampling. But I don’t want to focus on these comparisons now (if you’re interested, I found some nice concise explanations in this blogpost).
These tricks are really neat but I wondered how people came up with them. So I tried to reverse-engineer the Gumbel Max trick and then (in a future post) write down some of the things I’ve learned about softmax and temperature annealing. All these things are well-known to most people who do probability theory for a living but I thought a couple of blogposts might be useful for me to put my thoughts in order and potentially for readers to follow along, so here goes.
Reverse-Engineering Gumbel-Max
Suppose we are given a categorical distribution $\vec{\pi} = {\pi_1, \ldots \pi_K}$. We want to find independent, continuous real random variables $\vec{X} = {X_1, \ldots, X_K}$ which satisfy the condition
$$ \forall k \quad \mathbb{P}(X_k = \max\{X_i \colon 1 \le i \le K\}) = \pi_k. $$ One quick observation we can make is that if $\{X_1, \ldots X_K\}$ satisfy this condition and $f$ is any strictly increasing function, then $\{f(X_1), \ldots f(X_K)\}$ will also satisfy it. So we already see this is a really flexible condition. But anyway, let’s actually find some $X$’s that do what we want.
Let $F_i$ denote the CDF of $X_i$, that is $F_i(t) = \mathbb{P}(X_i \le t)$ and let $f_i = F_i’$ denote the pdf (I will deliberately ignore mathematical subtleties around differentiability, continuity etc. in this post). Then, the above equations can be rewritten as: $$ \begin{aligned} \pi_k &= \mathbb{P}(X_k = \max\{X_i \colon 1 \le i \le K\}) \\ &= \int_{\mathbb{R}} \mathbb{P}(s \ge X_1, \ldots, s \ge X_K\vert X_k=s)f_k(s) \,\mathrm{d}s \\ &= \int_{\mathbb{R}}f_k(s)\prod_{i \neq k} F_i(s) \,\mathrm{d}s\,, \end{aligned} $$ where in the last equality we used the independence assumption.
Notice that on the right-hand side, we are integrating a positive function of $s$ from $- \infty$ to $\infty$ and getting the number $\pi_k$ as a result. If instead we integrated from $-\infty$ to some $t \in \mathbb{R}$, we would get $H_k(t)\pi_k$, where $H_k(t)$ is some number between $0$ and $1$ which increases with $t$ (by the positivity of the integrand), approaches $0$ as $t \rightarrow -\infty$ and approaches $1$ as $t \rightarrow \infty$. In other words, the function $H_k$ is a CDF for some continuous distribution.
At this point, in order to find concrete solutions to the above set of $K$ equations, we make the simplifying assumption that all these CDFs are equal, i.e. $H_1 = H_2 = \cdots = H_K =\colon H$. So we get the following integral equations: $$\forall k \quad \forall t \quad \int_{-\infty}^t\prod_{i \neq k} F_i(s)f_k(s) \,\mathrm{d}s = H(t) \pi_k.$$ Now, summing these over $1 \le k \le K$, we obtain: $$\int_{-\infty}^{t}\frac{\mathrm{d}}{\mathrm{d}s}\prod_{1\le i \le K} F_i(s) \,\mathrm{d}s = H(t),$$ from which we obtain $\prod_{1\le i \le K} F_i(t) = H(t)$ (we can see that the constant of integration is $0$ by letting $s \to \infty$ and using the fact that all functions involved are cdfs).
On the other hand, differentiating both sides of the integral equations with respect to $t$ gives the functional equations: $$\forall k \quad \forall t \quad \prod_{i \neq k} F_i(t)f_k(t) = h(t) \pi_k.$$
Substituting the fact that $\prod_{1\le i \le K} F_i(t) = H(t)$ into these, we obtain: $$ \begin{align} \frac{H(t)f_k(t)}{F_k(t)} & = h(t) \pi_k \\ \Leftrightarrow \quad\quad\quad\frac{f_k(t)}{F_k(t)} & = \frac{h(t)}{H(t)} \pi_k \\ \Leftrightarrow \quad\quad\quad\frac{\mathrm{d}}{\mathrm{d}t}\log(F_k(t)) & = \frac{\mathrm{d}}{\mathrm{d}t}\log(H(t)) \pi_k \\ \end{align} $$
Integrating both sides we find $$\log(F_k(t)) = \log(H(t))\pi_k \quad \forall k$$ (constant of integration is again $0$ by considerations at $\infty$), and so we have: $$\forall 1\le k\le K \quad F_k(t) = H(t)^{\pi_k}.$$ Thus, for any choice of CDF $H$, we obtain independent random variables $X_k \sim F_k = H^{\pi_k}$ which satisfy our requirement.
We can now use Inverse Transform Sampling to sample from these distributions. That is, we know that if $U_k \sim \mathcal{U}([0, 1])$ then $F_k^{-1}(U_k) = H^{-1}\left(U_k^{\frac{1}{\pi_k}}\right)$ is distributed according to $F_k$.
But, actually, notice that we are really free in our choice of $H^{-1}$: all we need is for it to be a strictly increasing function defined on the unit interval! So, we can stop worrying about $H$ or $h$ and we can just say:
Given a categorical distribution $\vec{\pi} = {\pi_1, \ldots \pi_K}$, a strictly increasing function $f \colon (0, 1) \rightarrow \mathbb{R}$, and $U_k \sim \mathcal{U}((0, 1)), ; 1\le k \le K$, the random variables $X_k \colon = f\left(U^{\frac{1}{\pi_k}}\right)$ satisfy $$\mathbb{P}\left(\mathrm{argmax}\left(\vec{X}\right) = k\right) = \pi_k, \quad \forall 1\le k \le K.$$
One obvious choice is to take $f$ to be the identity function, giving us $$X_k = U_k^{\frac{1}{\pi_k}}.$$
Another option is to take $f = \log$, giving us $$X_k = \frac{\log(U_k)}{\pi_k}.$$ Arguably, the most popular choice is to take $f(t) = -\log(-\log(t))$ , which gives: $$X_k = -\log\left(-\log\left(U^{\frac{1}{\pi_k}}\right)\right) = -\log(-\log(U)) + \log(\pi_k) = G_k + \log(\pi_k),$$ where $G_k \sim \mathrm{Gumbel}(0, 1)$.
And so, we’ve “rediscovered” the Gumbel-Max trick….almost. The last piece is to notice that the function $f$ can itself depend on $\vec{\pi}$ and this can be quite useful if, for example, the distribution $\vec{\pi}$ is computed by some neural network in an unnormalized form. That is, suppose our network outputs the log-potentials $\vec{\ell} = (\ell_1, \ldots, \ell_K) \in \mathbb{R}^K$ and to compute the probabilities $\pi_k \propto \exp(\ell_k)$, we need to know the normalising constant $Z(\vec{\ell}) = \sum_{i=1}^K \exp(\ell_i)$. The cool thing is, we don’t actually need to compute the normalising constant to generate variables $X_k$ with the desired property. For example, we could set $f(t) = -\log(-\log(t)) + \mathrm{LSE}(\vec{\ell})$, where $\mathrm{LSE}$ is “LogSumExp”, i.e. $\mathrm{LSE}(\vec{\ell}) \colon = \log\left(\sum_{i=1}^K \exp(\ell_i)\right)$. Then we get $$X_k = G_k + \log(\pi_k) + \mathrm{lse}(\vec{\ell}) = G_k + \ell_k,$$ which is the more general Gumbel-Max trick which allows us to sample from a categorical distribution knowing only its log-potentials.
I’ll end with some Julia code showing a few simulations which I used as a sanity check that it all works. In the next post I’ll talk about the Gumbel-Softmax trick and how I think about temperature annealing. See you then!
import Plots, Random
import Pkg; Pkg.add("StatsPlots")
import StatsPlots
# Functions that operate on u and pi
f_1 = (u, pi) -> u .^ (1 ./ pi) # identity
f_2 = (u, pi) -> log.(u) ./ pi # log
f_3 = (u, pi) -> -log.(-log.(u)) .+ log.(pi) # Gumbel
# Functions that only use the log-potentials
g_1 = (u, logit) -> u .^ (1 ./ exp.(logit)) # identity
g_2 = (u, logit) -> log.(u) ./ exp.(logit) # log
g_3 = (u, logit) -> -log.(-log.(u)) .+ logit # Gumbel
K = 12
logits = randn(Float64, (1, K))
unnormalized_pi = exp.(logits)
pi = unnormalized_pi ./ sum(unnormalized_pi)
function get_empirical_dist_from_argmaxes(f, u, pi)
x = f.(u, pi)
argmaxes = mapslices(argmax, x, dims=2)
counts = [0 for i in pi]
for val in argmaxes
counts[val] += 1
end
return counts ./ sum(counts)
end
reps = 10000
dists = Dict("π"=>pi)
for (name_f, name_g, f, g) in zip(["f₁", "f₂", "f₃"], ["g₁", "g₂", "g₃"], [f_1, f_2, f_3], [g_1, g_2, g_3])
dists[name_f] = get_empirical_dist_from_argmaxes(f, Random.rand(Float64, (reps, K)), pi)
dists[name_g] = get_empirical_dist_from_argmaxes(g, Random.rand(Float64, (reps, K)), logits)
end
names_to_plot = ["π", "f₁", "f₂", "f₃", "g₁", "g₂", "g₃"]
dists_to_plot = vcat([dists[name] for name in names_to_plot]...)
bar_chart = StatsPlots.groupedbar(
vec(dists_to_plot),
groups=repeat(names_to_plot, outer=K),
xlabel="class"
)