What input for a combined model (3 nets)

I have this architecture, made of 3 NNs: In code:

class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):

        super(VGGBlock,self).__init__()

        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1
                        }

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x
class VGG16(nn.Module):

  def __init__(self, input_size, num_classes=1,batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)

  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)
    x = torch.flatten(x,1)

    return x
class VGG16Classifier(nn.Module):

  def __init__(self, num_classes=1,classifier = None,batch_norm=False):
    super(VGG16Classifier, self).__init__()


    self._vgg_a = VGG16((1,32,32),batch_norm=True)
    self._vgg_b = VGG16((1,32,32),batch_norm=True)
    self._vgg_star = VGG16((1,32,32),batch_norm=True)
    self.classifier = classifier

    if (self.classifier is None):
        self.classifier = nn.Sequential(
          nn.Linear(2048, 2048),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(2048, 512),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(512, num_classes)
        )

  def forward(self, x1,x2,x3):
      op1 = self._vgg_a(x1)
      op2 = self._vgg_b(x2)
      op3 = self._vgg_star(x3) 
      
      x1 = self.classifier(op1)
      x2 = self.classifier(op2)
      x3 = self.classifier(op3)

      return x1,x2,x3

      return xc
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model_star = VGG16((1,32,32),batch_norm=True)
model_combo = VGG16Classifier(model1,model2,model_star)

I have to implement this custom loss: In code:

class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, loss_star, _lambda=1.0):
        super().__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.loss_star = loss_star

        self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))


    def forward(self,y_hat,y):

        return (self.loss_a(y_hat[0],y[0]) + 
                self.loss_b(y_hat[1],y[1]) + 
                self.loss_combo(y_hat[2],y[2]) + 
                self._lambda * torch.sum(model_star.weight - torch.pow(torch.cdist(model1.weight+model2.weight), 2)))

Maybe is not well written this loss (i accept advices :D) Now, my question is: considering the fact that i have a model made of 3 networks, and a loss which have 3 labels (y[0, y1 and y2) I have to split my dataset in 3 parts? For example, if I use MNIST, I have to split MNIST in MNIST_A, MNIST_B and MNIST_STAR, creating a loader for each of these 3 datasets, and pass to the model when I have to train?

Topic multi-output vgg16 pytorch classifier neural-network

Category Data Science


No need to split the dataset. You can do this on runtime and using a single data loader. Following is the code for the same and explanations included in comments -

# imports
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader

# This is how I am downloading MNIST dataset, I use the .pt files from the downloaded dataset.
training_data = datasets.MNIST(
            root="data",
            train=True,
            download=True
        )

# A single custom dataset
class MnistImageDataset(Dataset):
    def __init__(self):
        # first the dataset will be loaded
        mnist_train = torch.load("data/MNIST/processed/training.pt")

        self.m_images = mnist_train[0] # all images
        self.m_labels = mnist_train[1] # all labels

        self.no_labels = len(self.m_labels)//3 # len of images a single model will be trained with

    def __len__(self):
        # dataset length will be the total_len/3
        return len(self.m_labels)//3

    def __getitem__(self, idx):
        # create dictionaries
        images = {}
        labels = {}

        # select images from each split for each model
        images['model_a'] = self.m_images[idx]
        images['model_b'] = self.m_images[idx+self.no_labels]
        images['model_star'] = self.m_images[idx+self.no_labels*2]

        # select labels from each split for each model
        labels['model_a'] = self.m_labels[idx]
        labels['model_b'] = self.m_labels[idx+self.no_labels]
        labels['model_star'] = self.m_labels[idx+self.no_labels*2]

        return images, labels

# create dataset and wrap into dataloader
mnist = MnistImageDataset()
train_dataloader = DataLoader(mnist, batch_size=64)

for images, labels in train_dataloader:
        # sample inference
        op1,op2,op3 = model_combo(torch.unsqueeze(images['model_a'],1),torch.unsqueeze(images['model_b'],1),torch.unsqueeze(images['model_star'],1))
        # next comes your loss and other pipeline

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.