Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MMDetection loading from own training checkpoint for inference produces garbage detections

I've trained up a very simple model using the MMDetection colab tutorial and then verifying the result using:

img = mmcv.imread('/content/mmdetection/20210301_145246_123456.jpg')
img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

model.cfg = cfg
result = inference_detector(model, img)
show_result_pyplot(model, img, result)

confirms that it's working great.

I then follow the same steps as for training but instead I load my own training checkpoint, and I don't train. Then running the verification snippet above produces garbage results.

Here's that in code

from mmcv import Config
cfg = Config.fromfile('configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')

from mmdet.apis import set_random_seed

# Modify dataset type and path
cfg.dataset_type = 'SamplesDataset'
cfg.data_root = 'samples_dataset/'

cfg.data.test.type = 'SamplesDataset'
cfg.data.test.data_root = 'samples_dataset/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'o2h'

cfg.data.train.type = 'SamplesDataset'
cfg.data.train.data_root = 'samples_dataset/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'o2h'

cfg.data.val.type = 'SamplesDataset'
cfg.data.val.data_root = 'samples_dataset/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'o2h'

# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 1
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
# cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
cfg.load_from = './experiments/epoch_1.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './experiments'

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
cfg.runner = dict(type='EpochBasedRunner', max_epochs=1)
cfg.total_epochs = 1

# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
# cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 1

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)


# We can initialize the logger for training and have a look
# at the final config used for training
# print(f'Config:\n{cfg.pretty_text}')

from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector

# Build dataset
# datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
# mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# train_detector(model, datasets, cfg, distributed=False, validate=True)

Obviously, I wouldn't normally do all that just for validating my model, but this is one of many debugging steps for me, as my goal is to download and run the model locally. This is what I'm trying to do locally:

import sys
import glob
import time

sys.path.insert(0, '../mmdetection')
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
from mmdet.models import build_detector
import mmcv
import numpy as np

file_paths = glob.glob('samples/o2h/*.jpg')

cfg = mmcv.Config.fromfile('../mmdetection/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
cfg.model.roi_head.bbox_head.num_classes = 1
cfg.load_from = 'models/mmdet_faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.pth' # my own checkpoint
model = build_detector(cfg.model)
model.CLASSES = ('hash',)
model.cfg = cfg

file_path = np.random.choice(file_paths)
print(file_path)

start = time.time()
result = inference_detector(model, file_path)
print(f"Time taken for inference: {time.time() - start:.2f}s")
show_result_pyplot(model, file_path, result)
like image 562
Alexander Soare Avatar asked Nov 05 '25 08:11

Alexander Soare


1 Answers

One of the mistakes in your code is that you have not updated num_classes for mask_head.

Our aim here should be to replicate the same config file that was used for training should also be used for testing/validation. If you have trained the model using 1 num_classes for bbox_head and mask_head in the config file but for validation/testing you are using 80 num_classes as default, then that will cause a mismatch in the testing process, leading to garbage detections and segmentations.

There are 2 solutions for achieving the required result:

  1. Change the num_classes in config file before doing inference
  2. Save the model and config file as pickle, as soon as training is completed.

Note: The first solution is standard but the second solution is more simpler

1. Change the num_classes in config file before doing inference.

First, find the total number of classes in your dataset. Here num_classes is total number of classes in the training dataset.

Locate to this path: mmdetection/configs/model_name (model_name is name used for training)

Here, inside model_name folder, find the ..._config.py that you have used for training. Inside this config file, if you have found model = dict(...) then change the num_classes for each of these keys: bbox_head, mask_head.

bbox_head might be list. so, change num_classes for each keys in the list.

If model = dict(...) is not found, then at the first line there is _base_ = '...' So, open that config file and check whether model=dict(...) is found or not. If not found keep on opening the file location of _base_.

After changing the num_classes, use this code for inference:

Code after changing the num_classes:

from mmdet.apis import init_detector, inference_detector
import mmcv
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
%matplotlib inline

config_file = './configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py' #(I have used SCNet for training)

checkpoint_file = 'tutorial_exps/epoch_40.pth' #(checkpoint saved after training)

model = init_detector(config_file, checkpoint_file, device='cuda:0') #loading the model

img = 'test.png'

result = inference_detector(model, img)

#visualize the results in a new window
im1 = cv2.imread(img)[:,:,::-1]
#im_ones = np.ones(im1.shape, dtype='uint')*255
# model.show_result(im_ones, result, out_file='fine_result6.jpg')
plt.imshow(model.show_result(im1, result))


2. Save the model and config as pickle as soon as training is completed.

Another simple solution is to save both model and config as pickle as soon as the training is completed, irrespective of depending on mmdetection to do it.

Note:
The pickle files should be saved right after training is completed.

Code for saving as pickle:

import pickle

with open('mdl.pkl','wb') as f:
    pickle.dump(model, f)

with open('cfg.pkl','wb') as f:
    pickle.dump(cfg, f)

You can use this model/config wherever and whenever you want. For inference with the saved model, use this:

import pickle, mmcv
from mmdet.apis import inference_detector, show_result_pyplot

model = pickle.load(open('mdl.pkl','rb'))
cfg = pickle.load(open('cfg.pkl','rb'))

img = mmcv.imread('images/test.png')

model.cfg = cfg
result = inference_detector(model, img)
show_result_pyplot(model, img, result)
like image 51
Prakash Dahal Avatar answered Nov 07 '25 01:11

Prakash Dahal