Update of mean and variance of weights

I'm trying to understand the Bayes by Backprop algorithm from the paper Weight Uncertainty in Neural Networks, the idea is to make a NN in which each weight has it's own probability distribution. I get the theory, but I don't undertsand how to update the mean and variance in the learning part. I found a code in Pytorch which simply does:

class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features):
    (...)
    # Weight parameters
    self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
    self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5,-4))
    self.weight = Gaussian(self.weight_mu, self.weight_rho)

How does the optimizer knows how to update mu (the mean) and rho (the variance)? It is the Parameter function which allows it? Or it is happening somewhere in the code? (it looks like a normal neural net from there!).

Topic weight-initialization pytorch bayesian probability neural-network

Category Data Science


The PyTorch remembers a computation graph. It means that for each value you get after the forward pass it remembers which operation you have applied and for all basic operations, it knows how to calculate the gradients. So, that's why in PyTorch you get the optimization for free -- at least, for common neural nets. The nn.Parameter method specifically tells to track gradients and change them during optimization.

The approach works well in the particular case for the most part, but there are two not-so-common steps in bayes by backprop:

  1. For each neuron we sample weights. Technically, we start with sampling from $N(0, 1)$ and then we apply the trainable params. The specific values we get from $N(0, 1)$ are kind of extra inputs and for some operations, they change the gradients.
  2. Loss usually depends only on outputs and true labels, but for the KL term, you need to calculate the closed-form value based on weights themselves. You would probably need to explicitly collect the KL loss term from each layer.

Everything else should be handled fine by pytorch

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.