Vectorize scipy.stats.norm.logpdf

I am tryint to trying to train a Bayesian NN and at some point I need to compute log-likelihoods for some data points, according to a multivariate diagonal gaussian distribution with parameters (mu, sigma). I have 2 problems:

  1. I don't know the size of the values in advance (note that I am guaranteed that 'values', 'mu' and 'rho') are the same size, but they could either be 1D or 2D, which forces me to have an ugly if statement. Ideally I would just iterate over the elements no matter the size of the tensor.

  2. This is painfully slow. I don't see how I could vectorize the logpdf the numpy way, as passing in the values, mu and sigma directly to norm.logpdf seems to implicitely construct a covariance matrix (which is too big and makes the program crash).

     from scipy.stats import norm
    
     ...
    
     mu    = self.mu.detach().numpy()
     sigma = np.log(1 + np.exp(self.rho.detach().numpy()))
     vals  = values.detach().numpy()
     log_likelihood_val = 0
     if len(values.size()) == 2:
         for i in range(values.size()[0]):
             for j in range(values.size()[1]):
                 log_likelihood_val += norm.logpdf(vals[i,j], loc=mu[i,j], scale=sigma[i,j])
     else:
         for i in range(values.size()[0]):
             log_likelihood_val += norm.logpdf(vals[i], loc=mu[i], scale=sigma[i])
     return torch.tensor(log_likelihood_val)
    

How should I implement it instead?

Topic numpy scipy

Category Data Science


I just ended up defining my own logpdf function so that it is easily vectorized, which solved both problems at once:

def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
    def logpdf(x, mu, sigma):
        return -(((x - mu)/sigma)**2)/2 - torch.log(np.sqrt(2*np.pi) * sigma)
        
    sigma = torch.log(1 + torch.exp(self.rho))
    log_likelihood_val = torch.sum(logpdf(values, self.mu, sigma))
    return log_likelihood_val

Hope this might help someone.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.