converting a KL divergence from torch to pytorch

The following code is the KL divergence between a Gaussian posterior and mixture of Gaussian priors and it is part of the model described in this paper. The published code is written in torch language

function KLDivergence(D, M)
    -- KL  = 1/2(  logvar2 - logvar1 + (var1 + (m1-m2)^2)/var2  - 1 )
    local mean1_in = - nn.Identity()
    local logVar1_in = - nn.Identity()
    local mean2_in = - nn.Identity() -- [(MxN)xD]
    local logVar2_in = - nn.Identity() -- [(MxN)xD]

    local mean1 = mean1_in - nn.Replicate(M)
    local logVar1 = logVar1_in - nn.Replicate(M)
    local mean2 = mean2_in - nn.View(M, -1, D)
    local logVar2 = logVar2_in - nn.View(M, -1, D)

    local var1 = logVar1 - nn.Exp()
    local var2 = logVar2 - nn.Exp()
    local dm2 = {mean1, mean2}
                - nn.CSubTable()
                - nn.Power(2)
    local dm2_v1 = {dm2, var1} - nn.CAddTable()
    local dm2_v1_v2 = {dm2_v1, var2} - nn.CDivTable() - nn.AddConstant(-1)
    local total = {dm2_v1_v2, logVar2} - nn.CAddTable()
    local totals = {total, logVar1}
                    - nn.CSubTable()
                    - nn.MulConstant(0.5) -- [MxNxD]
                    - nn.Sum(1)
                    - nn.MulConstant(1/M)
                    - nn.View(-1, D, 1)

    return nn.gModule({mean1_in, logVar1_in, mean2_in, logVar2_in}, {totals})

end

function KL_Table(K, D, M)
    local KL_table = nn.ConcatTable()
    for k=1, K do
        local mean   = - nn.Identity() -- [NxD]
        local logVar = - nn.Identity() -- [NxD]
        local mean_Mixture = - nn.Identity() -- {[NxD]}k
        local logVar_Mixture = - nn.Identity() -- {[NxD]}k

        local meanK = mean_Mixture - nn.SelectTable(k)
        local logVarK = logVar_Mixture - nn.SelectTable(k)
        local KL = {mean, logVar, meanK, logVarK} - KLDivergence(D, M)
        local KL_module = nn.gModule({mean, logVar, mean_Mixture, logVar_Mixture}, {KL})
        KL_table:add(KL_module)
    end

    return KL_table
end

function ExpectedKLDivergence(K, D, M)

    local q_z    = - nn.Identity() -- [NxK]
    local mean   = - nn.Identity() -- [NxD]
    local logVar = - nn.Identity() -- [NxD]
    local mean_Mixture = - nn.Identity() -- {[NxD]}k
    local logVar_Mixture = - nn.Identity() -- {[NxD]}k

    local KL_List = {mean, logVar, mean_Mixture, logVar_Mixture}
                - KL_Table(K, D, M)  -- {[NxDx1]}k
                - nn.JoinTable(3) -- [NxDxK]

    local weighted_KL = {KL_List, q_z}
                    - nn.MV()  -- [NxDxK]x[NxK] = [NxD]
    return  nn.gModule({q_z, mean,logVar, mean_Mixture, logVar_Mixture},{weighted_KL})
end

I would like to rewrite this code in pytorch. However, I don't fully understand what has been done here. I will appreciate if someone can explain it.

Topic vae pytorch torch

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.