Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 2]

I'm trying to create a custom CNN model using PyTorch for binary image classification of RGB images, but I keep getting a runtime error saying that my original input shape [64,3,128,128] is being output as [64,2]. I've been trying to fix it for 2 days now, but I'm still clueless about what's wrong with the code.

Here's the code of the network:

class MyCNN(nn.Module):
  def __init__(self):
    super(MyCNN, self).__init__()
    self.network = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Flatten(),
        nn.Linear(in_features=25088, out_features=2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 2),
    )

  def forward(self, x):
    return self.network(x)

It's being called here:

for epoch in range(num_epochs):
    for images, labels in data_loader:  
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Here's the stack trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-fb9ee290e1d6> in <module>()
      7 
      8         # Forward pass
----> 9         outputs = model(images)
     10         loss = criterion(outputs, labels)
     11 

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-29-09c58015e865> in forward(self, x)
     27         x = layer(x)
     28         print(x.shape)
---> 29     return self.network(x)
     30 
     31 model = MyCNN()

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    455 
    456     def forward(self, input: Tensor) -> Tensor:
--> 457         return self._conv_forward(input, self.weight, self.bias)
    458 
    459 class Conv3d(_ConvNd):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    452                             _pair(0), self.dilation, self.groups)
    453         return F.conv2d(input, weight, bias, self.stride,
--> 454                         self.padding, self.dilation, self.groups)
    455 
    456     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 2]

I really appreciate the help. I apologize if the solution is simple but I didn't see it easily. Cheers.

like image 831
sphynxo Avatar asked Oct 25 '25 11:10

sphynxo


1 Answers

The data seems to have changed because the size of the images is (64, 3, 512, 512) and the labels are (64,2). And if the shape fits well, it works fine. Here is my code.

Code:

import torch
import torch.nn as nn
import torch.optim as optim

class MyCNN(nn.Module):
  def __init__(self):
    super(MyCNN, self).__init__()
    self.network = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Flatten(),
        nn.Linear(in_features=25088, out_features=2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 2),
    )

  def forward(self, x):
    return self.network(x)

model = MyCNN()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001)

optimizer.zero_grad()

# Forward pass
images = torch.randn(64, 3, 128, 128)
labels = torch.randn(64, 2)
outputs = model(images)
loss = criterion(outputs, labels)
        
# Backward and optimize
loss.backward()
optimizer.step()

I recommend to change this line

for images, labels in data_loader:  
        images, labels = images.to(device), labels.to(device)

to this

for labels, images in data_loader:  
        images, labels = images.to(device), labels.to(device)
like image 100
core_not_dumped Avatar answered Oct 28 '25 00:10

core_not_dumped



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!