Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch multiple branches of a model

enter image description here

Hi I'm trying to make this model using pytorch.

Each input is consisted of 20 images of size 28 X 28, which is C1 ~ Cp in the image. Each image goes to CNN of same structure, but their outputs are concatenated eventually.

I'm currently struggling with feeding multiple inputs to each of its respective CNN model. Each model in the first box with three convolutional layers will look like this as a code, but I'm not quite sure how I can put 20 different input to separate models of same structure to eventually concatenate.

        self.features = nn.Sequential(
            nn.Conv2d(1,10, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(10, 14, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(14, 18, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(28*28*18, 256)
        )

I've tried out giving a list of inputs as an input to forward function, but it ended up with an error and won't go through. I'll be more than happy to explain further if anything is unclear.

like image 825
Burger Avatar asked Oct 17 '25 11:10

Burger


1 Answers

Simply define forward as taking a list of tensors as input, then process each input with the corresponding CNN (in the example snippet, CNNs share the same structure but don't share parameters, which is what I assume you need. You'll need to fill in the dots ... according to your specifications.

class MyModel(torch.nn.Module):
   def __init__(self, ...):
       ...
       self.cnns = torch.nn.ModuleList([torch.nn.Sequential(...) for _ in range(20)])
   
   def forward(xs: list[Tensor]):
       return torch.cat([cnn(x) for x, cnn in zip(xs, self.cnns)], dim=...)
like image 183
KonstantinosKokos Avatar answered Oct 20 '25 23:10

KonstantinosKokos