Chris J. Maddison     About     Archive     Academic     Publications

Gumbel Machinery

Recently, Laurent Dinh wrote a great blogpost asking (and answering) whether it was possible to invert the Gumbel-Max trick: given a sample from a discrete random variable, can we sample from the Gumbels that produced it?

We thought it would be valuable to show an alternative approach to Laurent’s question. Taking this tack we can prove the Gumbel-Max trick, answer Laurent’s question, and more, all in a few short lines. All together this results in four central properties that form a sort of Gumbel machinery of intuitions, which were indispensable during the development of our NIPS 2014 paper A* Sampling.

Implemented in Python, the Gumbel-Max trick looks like this:

alpha = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
uniform = np.random.rand(5)
gumbels = -np.log(-np.log(uniform)) + np.log(alpha)
K = np.argmax(gumbels)

The trick is to show that K sampled in such a way produces a sample from the discrete distribution proportional to alpha[i]. In mathematical notation,

where and . We will get to the proof, but for now Laurent’s question was simply, given K, what is the distribution over gumbels?

The Gumbels

The apparently arbitrary choice of noise in the Gumbel-Max trick is its namesake; If , then

has a Gumbel distribution with location . The cumulative distribution function (CDF) of a Gumbel is

The derivative of this is the density of the Gumbel,

Thus, the joint density of gumbels is

Product of Gumbel CDFs.

The keystone for any understanding of Gumbels is that multiplying their CDFs accumulates the parameters.

The derivation of this property is just some simple algebra,

The Joint

Our strategy will be to write down the joint distribution of the gumbels and K in an intuitive form, then to manipulate it to reveal a structure like

which answers Laurent’s question. This section is a bit tedious, but it only needs to be done once before getting at the valuable properties of Gumbels.

Knowing K restricts the possible Gumbel events to ones in which gumbels[K] > gumbels[i] for i != K. So, intuitively the joint of gumbels and K is

where is one of and is the Iverson bracket notation for the indicator function of set . If you’re still not convinced, then sum over or integrate over . In both cases you will get the correct marginal events.

First, we multiply by a judicious choice of 1,

pull the density of gumbels[K] out of the product,

apply the product of Gumbel CDFs property in reverse,

distribute into the product,

expand the density of gumbels[K] and apply the product of Gumbel CDFs,

finally, multiply by to get:

Four Important Properties of Gumbels

The value of that tedious algebra is what it reveals. We can now simply read off the following properties from our form of the joint density of K and gumbels. Refering back to the code

alpha = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
uniform = np.random.rand(5)
gumbels = -np.log(-np.log(uniform)) + np.log(alpha)
K = np.argmax(gumbels)

1. Gumbel-Max Trick.

K is distributed as , since

2. The max Gumbel integrates over alphas.

gumbels[K] is distributed as a Gumbel with location where , since

3. Argmax and max are independent.

K and gumbels[K] are independent,

4. The remaining Gumbels are still independent but truncated.

Given K and gumbels[K] consider

  remaining_gumbels = gumbels[:K] + gumbels[K+1:]
  remaining_alpha = alpha[:K] + alpha[K+1:]

The remaining Gumbels remaining_gumbels[i] are independent Gumbels with location remaining_alpha[i] truncated at gumbels[K], since

Laurent’s Question and Beyond

Reading off the density, we also get a simple answer to Laurent’s question. To sample from , sample the top Gumbel with location and for sample truncated Gumbels with locations . This code samples a truncated Gumbel,

def truncated_gumbel(alpha, truncation):
    gumbel = np.random.gumbel() + np.log(alpha)
    return -np.log(np.exp(-gumbel) + np.exp(-truncation))

This code samples from the desired posterior,

def topdown(alphas, k):
    topgumbel = np.random.gumbel() + np.log(sum(alphas))
    gumbels = []
    for i in range(len(alphas)):
        if i == k:
            gumbel = topgumbel
            gumbel = truncated_gumbel(alphas[i], topgumbel)
    return gumbels

For reference, here is the rejection version

def rejection(alphas, k):
    log_alphas = np.log(alphas)
    gumbels = np.random.gumbel(size=len(alphas))
    while k != np.argmax(gumbels + log_alphas):
        gumbels = np.random.gumbel(size=len(alphas))
    return (gumbels + log_alphas).tolist()

Note that this code differs slightly from the routine in Laurent’s blogpost, up to a shift of for each Gumbel.

The above alternative factorization of the joint and the properties of Gumbels that it implies have been valuable to us, allowing manipulation of Gumbels at a higher level of abstraction. In fact, the factorization is a special case of the Top-Down construction that is at the core of A* Sampling. By applying these intuitions recursively, we can easily derive algorithms for sampling a set of Gumbels in decreasing order, or even sampling a heap of Gumbels from the top down. We hope to cover the more general version in a future blog post.