How to represent a "switch"-like behavior in a neural network?

I have three input variables $x_1$, $x_2$ and $d$, where $x_1$ and $x_2$ are numerical variables and $d$ is a dummy variable that takes the value of 1 or 2. How to represent the part of a neural network in the black box so that when $d=1$, $x_1$ and $x_2$ are sent to layer $T_1$ for transformation, and when $d=2$, $x_1$ and $x_2$ are sent to layer $T_2$ for transformation?

Topic representation neural-network

Category Data Science


It turns out pytorch provides pretty native support to the kind of "conditional branching". Here is an example:

import torch
import torch.nn.functional as F
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.transformation1 = nn.Linear(2, 10)
        self.transformation2 = nn.Linear(2, 10)
        self.common_layer = nn.Linear(10, 1)
        
    def forward(self, x):
        d = x[:, 2]
        x = x[:, :2]
        idxs1 = d==1
        idxs2 = d==2
        x1 = x[idxs1]
        x2 = x[idxs2]
        x1 = F.relu(self.transformation1(x1))
        x2 = F.relu(self.transformation2(x2))
        x1 = self.common_layer(x1)
        x2 = self.common_layer(x2)
        logits = torch.zeros(d.shape[0], 1)
        logits[idxs1] = x1
        logits[idxs2] = x2
        return torch.sigmoid(logits)[:,0]

This model sends the part of data where $d=1$ to the layer transformation1, and sends the part of where $d=2$ to the layer transformation2. Then it sends the output of either transformation1 or transformation2 to a common layer for probabilistic scoring (for a binary classification task).

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.