WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562
WIP: GumbelSoftmax / RelaxedOneHotCategoricalStraightThrough #562daydreamt wants to merge 14 commits intopyro-ppl:masterfrom
Conversation
…more reading on relaxed_categorical and transformations of distributions instead
…pansion where needed
|
Hi @daydreamt , thanks for the PR! I think the main blocker of your work would be to define custom derivative rules for some of your operators. I'll update the repo to the latest JAX version today to unblock your work. |
|
@daydreamt FYI, I think @tbsexton only needs RelaxedOneHotCategorical (or GumbelSoftmax) in his feature request because he wanted to use MCMC (instead of SVI) to draw samples from the relaxed distribution. @tbsexton could you confirm that StraightThrough is not required? |
|
@fehiepsi I was originally only using HMC, though planning to test out SVI as well. As long as not having access to a backward pass doesn't preclude using NUTS for inference of latent variables, should work! This is in practice a work-around for not having discrete latent variables; see my original example problem here. |
|
Thanks, @tbsexton! In your model, you want to infer each ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
x0 = ny.sample("x0", dist.Categorical(ϕ))
infectious, hist = spread_jax(s_ij, x0, 5)by ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
infectious, hist = spread_jax(s_ij, ϕ, 5)Or if you want the prior for ϕ = ny.sample("ϕ", dist.RelaxedOneHotCategorical(temporature, logits=np.ones(n_nodes))))
infectious, hist = spread_jax(s_ij, ϕ, 5)The reason is with If you want something like straight through, you can simply use by defining "straight-through" . You can use |
|
@fehiepsi much appreciated! I think I should update the model there to reflect som local changes, but primarily I think it makes more sense to pull the dirichlet out of the plates: def diff_kg(infections):
n_cascades, n_nodes = infections.shape
n_edges = n_nodes*(n_nodes-1)//2 # complete graph
# node initial infection, relative probability
ϕ = ny.sample("ϕ", dist.Dirichlet(np.ones(n_nodes)))
# beta hyperpriors
u = ny.sample("u", dist.Uniform(np.zeros(n_edges),
np.ones(n_edges)))
v = ny.sample("v", dist.Gamma(np.ones(n_edges),
20*np.ones(n_edges)))
Λ = ny.sample("Λ", dist.Beta(u*v, (1-u)*v))
s_ij = jax_squareform(Λ) # adjacency matrix to recover via inference
with ny.plate("n_cascades", n_cascades):
# infer infection source node
x0 = ny.sample("x0", dist.Categorical(ϕ))
# simulate ode and realize
infectious, hist = spread_jax(s_ij, x0, 5)
numpyro.sample("obs", dist.Bernoulli(probs=infectious),
obs=infections)The main idea being that certain nodes in general have a tendency to be "sources", represented by the dirichlet prior, and those manifest as conditional probabilities that each node was the source (given any individual observed infection cascade). That should be realized as one node for the Maybe that dirichlet prior is unnecessary partial pooling? I will definitely give the new relaxed categorical a try. @daydreamt would it be helpful if I tested things out before the PR gets merged? |
Agree that this makes more sense. With this model, you can define RelaxedOneHotCategorical for |
|
Hey @daydreamt , any progress on this? |
|
Hi @dirmeier, not really, please feel free to take over or supersede with another MR. |
Hi all, since it's been a while I thought I should maybe give a sign of life and continue from here. This tries to implement #559.
There are still some things I haven't figured out myself yet, so I was planning to only request the review when I'm more ready, but of course feel free to take a look if you want already.