I am trying to create a transform that shuffles the patches of each image in a batch.
I aim to use it in the same manner as the rest of the transformations in torchvision:
trans = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ShufflePatches(patch_size=(16,16)) # our new transform
])
More specifically, the input is a BxCxHxW tensor. I want to split each image in the batch into non-overlapping patches of size patch_size, shuffle them, and regroup into a single image.
Given the image (of size 224x224):

Using ShufflePatches(patch_size=(112,112)) I would like to produce the output image:

I think the solution has to do with torch.unfold and torch.fold, but didn't manage to get any further.
Any help would be appreciated!
Indeed unfold and fold seem appropriate in this case.
import torch
import torch.nn.functional as nnf
class ShufflePatches(object):
def __init__(self, patch_size):
self.ps = patch_size
def __call__(self, x):
# divide the batch of images into non-overlapping patches
u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
return f
Here's an example with patch size=16:
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With