Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split dataset based on file names in pytorch Dataset

Is there a way to divide the dataset into training and testing based on the filenames. I have a folder containing two folders: input and output. Input folder has the images and output are the labels for that image. The file names in the input folder are something like input01_train.png and input01_test.png like shown below.

                          Dataset
                          /     \
                     Input       Output
                      |             |
           input01_train.png   output01_train.png
                    .                 .
                    .                 .
           input01_test.png    output01_test.png

The code I have only divides the dataset into inputs and labels not test and train.

class CancerDataset(Dataset):
  def __init__(self, dataset_folder):#,label_folder):
    self.dataset_folder = torchvision.datasets.ImageFolder(dataset_folder ,transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]))
    self.label_folder = torchvision.datasets.ImageFolder(dataset_folder ,transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]))

  def __getitem__(self,index):
    img = self.dataset_folder[index]
    label = self.label_folder[index]
    return img,label

  def __len__(self):
    return len(self.dataset_folder)

trainset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/')
trainsetloader = DataLoader(trainset,batch_size = 1, shuffle = True,num_workers = 0,pin_memory = True)

I would like to be able to divide the train and test set by their names if that is possible .

like image 639
shreeya Avatar asked Oct 21 '25 05:10

shreeya


1 Answers

You could load the images yourself in __getitem__, selecting only those that contain '_train.png' or '_test.png'.

class CancerDataset(Dataset):
    def __init__(self, datafolder, datatype='train', transform = transforms.Compose([transforms.Resize(512),transforms.ToTensor()]):
        self.datafolder = datafolder
        self.image_files_list = [s for s in os.listdir(datafolder) if
                                 '_%s.png' % datatype in s]
        # Same for the labels files
        self.label_files_list = ...
        self.transform = transform

    def __len__(self):
        return len(self.image_files_list)

    def __getitem__(self, idx):
        img_name = os.path.join(self.datafolder,
                                self.image_files_list[idx])
        image = Image.open(img_name)
        image = self.transform(image)
        # Same for the labels files
        label = .... # Load in etc
        label = self.transform(label)
        return image, label

Now you could make two datasets (trainset and testset).

trainset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/', datatype='train')
testset = CancerDataset(dataset_folder = '/content/drive/My Drive/cancer_data/', datatype='test')
like image 107
jmsinusa Avatar answered Oct 23 '25 20:10

jmsinusa



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!