I followed this tutorial to train a pytorch model for instance segmentation: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
I would not like to train a model on entirely different data and classes, totally unrelated to COCO. What changes do I need to make to retrain the model. From my reading I'm guessing besides have the correct number of classes I just need to train this line:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
to
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
But I notice there is another parameters: pretrained_backbone=True, trainable_backbone_layers=None should they be changed too?
The function signature is
torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs)
Setting pretrained=False will tell PyTorch not to download model pre-trained on COCO train2017. You want it as you're interested in training.
Usually, this is enough if you want to train on a different dataset.
When you set pretrained=False, PyTorch will download pretrained ResNet50 on ImageNet. And by default, it'll freeze first two blocks named conv1 and layer1. This is how it was done in Faster R-CNN paper which frooze the initial layers of pretrained backbone.
(Just print model to check its structure).
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
Now, if you don't even want the first two layers to freeze, you can set trainable_backbone_layers=5 (done automatically, when you set pretrained_backbone=False), which will train the entire resnet backbone from scratch.
Check PR#2160.
From maskrcnn_resnet50_fpn document:
So for training from scratch using:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, trainable_backbone_layers=5, num_classes=your_num_classes)
or:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, num_classes=your_num_classes)
because in source code of maskrcnn_resnet50_fpn:
if not (pretrained or pretrained_backbone):
trainable_backbone_layers = 5
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With