Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to modify "in_channel" of the firstly layer CNN in the timm model?

Tags:

python

pytorch

everyone. I hope to train a CV model in the timm library on my dataset. Due to the shape of the input data is (batch_size, 15, 224, 224), I need to modify the "in_channel" of the first CNN layer of different CV models. I try different methods but still fail. Could you help me solve this problem? Thanks!

import torch
import torch.nn as nn

import timm

class FrequencyModel(nn.Module):

    def __init__(
        self, 
        in_channels = 6, 
        output = 9, 
        model_name = 'resnet200d', 
        pretrained = False
        ):

        super(FrequencyModel, self).__init__()
        
        self.in_channels = in_channels
        self.output = output
        self.model_name = model_name
        self.pretrained = pretrained

        self.m = timm.create_model(self.model_name, pretrained=self.pretrained, num_classes=output)

        for layer in self.m.modules():
            if(isinstance(layer,nn.Conv2d)):
                layer.in_channels = self.in_channels
                break

    def forward(self,x):
        
        out=self.m(x)

        return out

if __name__ == "__main__":
    
    x = torch.randn((8, 15, 224, 224))

    model=FrequencyModel(
        in_channels = 15, 
        output = 9, 
        model_name = 'resnet200d', 
        pretrained = False
    )
    print(model)
    print(model(x).shape)

The error is:

RuntimeError: Given groups=1, weight of size [32, 3, 3, 3], expected input[8, 15, 224, 224] to have 3 channels, but got 15 channels instead

I hope I can test different CV model easily but not adjust it one by one.

like image 548
Qiang Avatar asked Oct 27 '25 02:10

Qiang


1 Answers

Timm by default has option to alter the input channels in case you want to modify any model. Although the other options can be useful in case you want to alter the input for any other library.

For timm, you can change the input channels as follow:

import timm

# If you want without pretrained weights.
resnet = timm.create_model("resnet200d", pretrained=False, in_chans=16) 

# If you want pretrained models
resnet = timm.create_model("resnet200d", pretrained=True, in_chans=16)
like image 97
thanatoz Avatar answered Oct 29 '25 18:10

thanatoz