Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: how to apply another transform to an existing Dataset?

This is a code example:

dataset = datasets.MNIST(root=root, train=istrain, transform=None)  #preserve raw img

print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample

for ind in range(len(dataset)):
    img, label = dataset[ind] # <class 'PIL.Image.Image'> <class 'int'>/<class 'numpy.int64'>
    img.save(fp=os.path.join(saverawdir, f'{ind:02d}-{int(label):02d}.png'))

dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
#transform for net forwarding

print(type(dataset[0][0]))
# expected <class 'torch.Tensor'>, however it's still <class 'PIL.Image.Image'>

Since dataset is randomly resampled, I don't want to reload a new dataset with transform, but just apply transform to the already existing dataset.

Thanks for your help :D

like image 640
graphitump Avatar asked Dec 07 '25 05:12

graphitump


1 Answers

You can create a small wrapper Dataset that will take care of applying the given transform to the underlying dataset on the fly:

Here's an example that was posted over on the pytorch forums: https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209/4

class MyDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

With your code it could look something like:

dataset = datasets.MNIST(root=root, train=istrain, transform=None)  #preserve raw img

print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample

transformed_dataset = TransformDataset(dataset, transform=transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]))
like image 114
zr0gravity7 Avatar answered Dec 08 '25 17:12

zr0gravity7



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!