Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch "Unfold" equivalent in Tensorflow [duplicate]

Say I have grayscale images of size (50*50), with a batch size of 2 in this case, and i use the Pytorch Unfold function as follows:

import numpy as np
from torch import nn
from torch import tensor

image1 = np.random.rand(1,50,50)
image2 = np.random.rand(1,50,50)
image = np.stack((image1,image2))
image = tensor(image)

ds = nn.Unfold(kernel_size=(2,2),stride=2)
x = ds(image).numpy()
x.shape

## OUTPUT: (2, 4, 625)

What would be the equivalent tensorflow implementation so that the output of the tensorflow implementation would exactly match 'x'? I've tried using the tf.image.extract_patches function but it seems to not be giving me what I quite want.

The question is then: What is the tensorflow implementation of Unfold?

like image 767
D. Ramsook Avatar asked Oct 17 '25 23:10

D. Ramsook


1 Answers

tf.image.extract_patches() is analogous to torch.nn.Unfold, but you need to rejig the parameters slightly:

tf.image.extract_patches(image, sizes=[1,2,2,1], strides=[1,2,2,1], rates=[1,1,1,1], padding='SAME')
like image 184
iacob Avatar answered Oct 19 '25 13:10

iacob