Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I register a dataset to use with detectron2? We have images and their annotations in COCO JSON format

I am trying to train a model using Detectron2. I am using Grocery image data and I have annotations in COCO format. I am having a problem with model loading. Model is not taking annotations. I am referring to this blog https://gilberttanner.com/blog/detectron2-train-a-instance-segmentation-model.

Facing issue in registering the dataset.

from detectron2.data.datasets import register_coco_instances

for d in ["train", "test"]:
    register_coco_instances(f"microcontroller_{d}", {}, f"Microcontroller Segmentation/{d}.json", f"Microcontroller Segmentation/{d}")

Is there any problem with this code?

like image 284
Nikhil Kumar Pavanam Avatar asked Oct 19 '25 15:10

Nikhil Kumar Pavanam


2 Answers

I think this might help you

from detectron2.data.datasets import register_coco_instances
register_coco_instances("YourTrainDatasetName", {},"path to train.json", "path to train image folder")
register_coco_instances("YourTestDatasetName", {}, "path to test.json", "path to test image folder")

Let me know if it works for you.I have trained detectron2 using this :)

like image 54
yogi Avatar answered Oct 22 '25 19:10

yogi


For anyone (like me) searching for the proof that the dataset got registered correctly.

After you register the dataset (smth like this):

from detectron2.data.datasets import register_coco_instances
register_coco_instances("coco_custom", {}, "./data/annotations/instances.json", "./data/images/")

You can get the metadata:

nuts_metadata = MetadataCatalog.get('coco_custom')
dataset_dicts = DatasetCatalog.get("coco_custom")

And visualize the data:

import random
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt

for d in random.sample(dataset_dicts, 3):
    img = cv2.imread(d["file_name"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=nuts_metadata , scale=0.5)
    vis = visualizer.draw_dataset_dict(d)
    plt.imshow(vis.get_image()[:, :, ::-1])

That way you can see if everything is imported correctly

like image 41
TayJen Avatar answered Oct 22 '25 18:10

TayJen