The Gumbel-max trick

A funky way of sampling from categorical distributions

A categorical distribution is a discrete probability distribution that assigns a probability to each of $K$ classes (or categories). That is, for each class $k \in {1, 2, \ldots K}$ we have some value $\pi_k$ representing the probability of drawing that class. Because we’re dealing with probabilities, each $\pi_k$ must be greater than \(0\) and we must have $\sum_{k}\pi_k = 1$ (in other words, our $\pi_k$ must lie on the simplex).

The obvious way we might think to parameterize such a distribution is to use the vector $\boldsymbol{\pi}$ of probabilities ${\pi_k}$. That is, if we have a categorical random variable $I$, we would write

\[I \sim Cat(\boldsymbol{\pi}).\]

However, in many machine learning problems we may instead prefer to parameterize a discrete distribution in terms of an unconstrained vector of numbers. That is, we may instead wish to parameterize our distribution with some vector $\boldsymbol{\theta} \in \mathbb{R}^{K}$ of values $\theta_{k}$ that can take on arbitrary values (i.e., they may be negative, don’t sum to 1, etc.). By doing so, we can use unconstrained optimization algorithms to optimize $\boldsymbol{\theta}$ rather than restricting ourselves to constrained optimization for $\boldsymbol{\pi}$. How to we get from $\boldsymbol{\theta}$ to probabilities? Typically, we’ll use the softmax transformation:

\[\begin{equation}\label{eq:1} \pi_k = \frac{\exp(\theta_k)}{\sum_{k'=1}^{K}\exp(\theta_{k'})}. \end{equation}\]

After performing the transformation, we could then sample from $Cat(\pi)$. However, what if we don’t want to explicitly construct our distribution using the softmax transform? It turns out that there exists another method for achieving the same effect: the Gumbel-max trick.


The Gumbel distribution is a probability distribution with location and scale parameters $\mu \in \mathbb{R}$ and $\beta \in \mathbb{R}_{\geq 0}$, respectively. Its probability density function (PDF) is

\[f(x; \mu, \beta) = \frac{1}{\beta} \exp\bigg(-(x - \mu)/\beta - \exp(-(x - \mu)/\beta)\bigg)\]

and cumulative density function (CDF) is

\[F(x; \mu, \beta) = \exp(-\exp(-(x - \mu)/\beta)).\]

We can denote a Gumbel distribution with location $\mu$ and scale $\beta$ using the notation $G(\mu, \beta)$ and a random variable following this distribution as $G_{\mu, \beta}$. The notation here can look pretty intimidating at first, but luckily for the rest of this post we’ll only need to think about “standard” Gumbels with $\mu = 0$ and $\beta = 1$. To reduce clutter, for standard Gumbel random variables we’ll omit the subscripts and just write $G$.

Now, given our definition of the Gumbel distribution we will make the following claim:

For a set of unnormalized probabilities ${\theta_k}$, we can draw a sample from the corresponding categorical distribution as follows: for each $\theta_k$ we add a sample $G^{(k)}$ from the standard Gumbel distribution, and then select the index with the maximum sum. That is.

\[I = \underset{k}{\operatorname{argmax}}\{\theta_k + G^{(k)}\} \sim Cat(\boldsymbol{\pi})\]

To prove this, we will show that $P(I = \omega) = \pi_{\omega}$. First, as a shorthand we’ll define

\[G_{\theta_\omega} := \theta_{\omega} + G^{(\omega)}\]

Now starting from the definition of argmax, we know that $ I = \omega $ can only be true if $G_{\theta_k} < G_{\theta_{\omega}}$ for all $k \neq \omega$. That is (using the shorthand $M := G_{\theta_{\omega}}$),

\[P(I = \omega) = \mathbb{E}_{M}\bigg[p(G_{\theta_{k}} < M \quad \forall k \neq \omega)\bigg]\]

Since our Gumbel variables ${G^{(k)}}$ are i.i.d., we can factorize the probability above to get

\[\begin{aligned} P(I = \omega) &= \mathbb{E}_{M}\bigg[\prod_{k \neq \omega} p(G_{\theta_{k}} < M)\bigg] \end{aligned}\]

Letting $f_{\omega}(\cdot)$ denote the PDF of $G_{\theta_{\omega}}$, we then have

\[\begin{aligned} P(I = \omega) &= \int_{-\infty}^{\infty}f_{\omega}(m)\prod_{k \neq \omega}p(G_{\theta_{k}} < m)dm\\ &= \int_{-\infty}^{\infty}f_{\omega}(m)\prod_{k \neq \omega}\exp(-\exp(\theta_{k}- m))dm \\ &= \int_{-\infty}^{\infty}f_{\omega}(m)\exp\bigg(-\sum_{k \neq \omega}\exp(\theta_{k}- m)\bigg)dm \\ &= \int_{-\infty}^{\infty}\exp(\theta_{\omega} - m - \exp(\theta_{\omega} - m))\exp\bigg(-\sum_{k \neq \omega}\exp(\theta_{k}- m)\bigg)dm\\ &= \int_{-\infty}^{\infty}\exp(\theta_{\omega} - m)\exp\bigg(-\sum_{k}\exp(\theta_{k} - m)\bigg)dm \\ &= \int_{-\infty}^{\infty}\exp(\theta_{\omega})\exp(-m)\exp(-\exp(-m)\sum_{k}\exp(\theta_k))dm \end{aligned}\]

Now we define \(Z = \sum_{k}\exp(\theta_k)\). From Equation \ref{eq:1} we must have $\exp(\theta_{\omega}) = \pi_{\omega}Z$. We can then write

\[P(I = \omega) = \pi_{\omega}Z\int_{-\infty}^{\infty}\exp(-m)\exp(-Z\exp(-m))dm.\]

Now, using the identity:

\[\int_{-\infty}^{\infty}\exp(-m)\exp(-Z\exp(-m))dm = \frac{1}{Z},\]

we have $P(I = \omega) = \pi_{\omega}$ as desired $\square$.