This is a code example:
dataset = datasets.MNIST(root=root, train=istrain, transform=None) #preserve raw img
print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>
dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample
for ind in range(len(dataset)):
img, label = dataset[ind] # <class 'PIL.Image.Image'> <class 'int'>/<class 'numpy.int64'>
img.save(fp=os.path.join(saverawdir, f'{ind:02d}-{int(label):02d}.png'))
dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
#transform for net forwarding
print(type(dataset[0][0]))
# expected <class 'torch.Tensor'>, however it's still <class 'PIL.Image.Image'>
Since dataset is randomly resampled, I don't want to reload a new dataset with transform, but just apply transform to the already existing dataset.
Thanks for your help :D
You can create a small wrapper Dataset that will take care of applying the given transform to the underlying dataset on the fly:
Here's an example that was posted over on the pytorch forums: https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209/4
class MyDataset(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
With your code it could look something like:
dataset = datasets.MNIST(root=root, train=istrain, transform=None) #preserve raw img
print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>
dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample
transformed_dataset = TransformDataset(dataset, transform=transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With