Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python matplotlib, invalid shape for image data

Currently I have this code to show three images:

imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')

And it works fine. But I am trying to put them all three in a row instead of column.

Here is the code I have tried:

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)

It throws

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

If I do

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())

It throws

TypeError: Invalid shape (1, 3, 128, 128) for image data

How should I fix this or is there an easier way to implement it?

like image 815
Hekes Pekes Avatar asked Oct 27 '25 10:10

Hekes Pekes


1 Answers

The matplotlib function 'imshow' gets 3-channel pictures as (h, w, 3) as you can see in the documentation.

It seems that you passed a "batch" of single image (the first dimention) of three channels (second dimention) of the image (h and w are the third and forth dimention).

You need to reshape or view your image (after converting to cpu, try to use:

image1.squeeze().permute(1,2,0)

The result will be an image of the desired shape (128, 128, 3).

The squeeze() function will remove the first dimention. And the premute() function will transpose the dimenstion where the first will shift to the third position and the two other will shift to the beginning.

Also, have a look here for further talk on the GPU and CPU issues: link

Hope that helps.

like image 77
A. Maman Avatar answered Oct 30 '25 01:10

A. Maman



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!