Tensorflow Probability Implementation of Automatic Differentiation Variational Inference with Mixtures
In this paper, the authors suggest using the following loss instead of the traditional ELBO in order to train what basically is a Variational Autoencoder with a Gaussian Mixture Model instead of a single, normal distribution: $$ \mathcal{L}_{SIWAE}^T(\phi)=\mathbb{E}_{\{z_{kt}\sim q_{k,\phi}(z|x)\}_{k=1,t=1}^{K,T}}\left[\log\frac{1}{T}\sum_{t=1}^T\sum_{k=1}^K\alpha_{k,\phi}(x)\frac{p(x|z_{k,t})r(z_{kt})}{q_\phi(z_{kt}|x)}\right] $$ They also provide the following code which is supposed to be a tensorflow probability implementation:
def siwae(prior, likelihood, posterior, x, T):
q = posterior(x)
z = q.components_dist.sample(T)
z = tf.transpose (z, perm=[2, 0, 1, 3])
loss_n = tf.math.reduce_logsumexp(
(−tf.math.log(T) + tf.math.log_softmax(mixture_dist.logits)[:, None, :]
+ prior.log_prior(z) + likelihood(z).log_prob(x) − q.log_prob(z)), axis=[0, 1])
return tf.math.reduce_mean(loss_n, axis=0)
However, it seems like this doesn't work at all so as someone with nearly no tensorflow knowledge I came up with the following:
def siwae(prior, likelihood, posterior, x, T):
q = posterior(x) # distribution over variables of shape (batch_size, 2)
z = q.components_distribution.sample(T)
z = tf.transpose(z, perm=[2, 0, 1, 3]) # shape (K, T, batch_size, encoded_size)
l1 = -tf.math.log(float(T)) # shape: (), log (1/T)
l2 = tf.math.log_softmax(tf.transpose(q.mixture_distribution.logits))[:, None , :] # shape (K, 1, batch_size), alpha
l3 = prior.log_prob(z) # shape (K, T, batch_size), r(z)
l4 = likelihood(tf.reshape(z, (K*T*x.shape[0], encoded_size)))
l4 = l4.log_prob(tf.repeat(x, repeats=K*T, axis=0)) # shape (K*T*batch_size, )
l4 = tf.reshape(l4, (K, T, x.shape[0])) # shape (K, T, batch_size), p(x|z)
l5 = -q.log_prob(z) # shape (K, T, batch_size), q(z|x)
loss_n = tf.math.reduce_logsumexp(l1 + l2 + l3 + l4 + l5, axis=[0, 1])
return tf.math.reduce_mean(loss_n, axis=0)
There are no errors when I try to use this as
siwae(prior, decoder, encoder, x_test[:100, ...], T)
but after a few training steps I get only nans. I really don't have any idea of this is an due to a wrong implementation or wrong usage of the loss - especially as I don't have much experience with tensorflow. So any help would be greatly appreciated. For a full, minimal example I created this colab.
Topic multivariate-distribution vae tensorflow probability machine-learning
Category Data Science