Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict

I was trying to train a DCGAN model using MNIST datasets, but I can't load the gen.state_dict() after I finished training.

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import os
from torch.autograd import Variable

workspace_dir = '/content/drive/My Drive/practice'
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)

img_size=64
channel_img=1
lr=2e-4
batch_size=128
z_dim=100
epochs=10
features_gen=64
features_disc=64
save_dir = os.path.join(workspace_dir, 'logs')
os.makedirs(save_dir, exist_ok=True)
import matplotlib.pyplot as plt
transforms=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),std=(0.5,))])
train_data=datasets.MNIST(root='dataset/',train=True,transform=transforms,download=True)
train_loader=torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True)
count=0
for x,y in train_loader:
  if count==5:
    break
  print(x.shape,y.shape)
  count+=1

class Discriminator(nn.Module):
  def __init__(self,channels_img,features_d):
    super(Discriminator,self).__init__()
    
    self.disc=nn.Sequential(
        #input:N * channels_img * 64 *64
        nn.Conv2d(channels_img,features_d,4,2,1),#paper didn't use batchnorm in the early layers in the discriminator features_d* 32 *32
        nn.LeakyReLU(0.2),
        self._block(features_d,features_d*2,4,2,1),#features_d*2 *16 *16
        self._block(features_d*2,features_d*4,4,2,1),#features_d*4 *8 *8
        self._block(features_d*4,features_d*8,4,2,1), #features_d*8 *4 *4
        nn.Conv2d(features_d*8,1,4,2,0),#1 * 1 *1
        nn.Sigmoid()

    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )


  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self,Z_dim,channels_img,features_g):
    super(Generator,self).__init__()
    
    self.gen=nn.Sequential(
        #input :n * z_dim * 1 *1
        self._block(Z_dim,features_g*16,4,1,0),#features_g*16 * 4 * 4
        self._block(features_g*16,features_g*8,4,2,1),#features_g*8 * 8 * 8
        self._block(features_g*8,features_g*4,4,2,1),#features_g*4 * 16 * 16
        self._block(features_g*4,features_g*2,4,2,1),#features_g*2 * 32 * 32
        nn.ConvTranspose2d(features_g*2,channels_img,4,2,1), #
        nn.Tanh()# [-1 to 1] normalize the image
    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
          nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),#w'=(w-1)*s-2p+k
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )
    
  def forward(self,x):
      return self.gen(x)


def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

gen=Generator(z_dim,channel_img,features_gen).to(device)
disc=Discriminator(channel_img,features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
criterion=nn.BCELoss()
#fixed_noise=torch.randn(32,z_dim,1,1).to(device)
#writer_real=SummaryWriter(f"logs/real")
#writer_fake=SummaryWriter(f"logs/fake")
step=0
gen.train()
disc.train()


z_sample = Variable(torch.randn(100, z_dim,1,1)).cuda()
for epoch in range(2):
  for batch_idx,(real,_) in enumerate(train_loader):
    real=real.to(device)
    noise=torch.randn((batch_size,z_dim,1,1)).to(device)
    fake=gen(noise)
    
    #Train Discriminator max log(D(x)) + log(1-D(G(z)))
    disc_real=disc(real).reshape(-1)
    loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
    disc_fake=disc(fake).reshape(-1)
    loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
    loss_disc=(loss_disc_fake+loss_disc_real)/2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    #Train Generator  min log(1-D(G(z))) <--> max log(D(G(z)))
    output=disc(fake).reshape(-1)
    loss_gen=criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    
    print(f'\rEpoch [{epoch+1}/{3}] {batch_idx+1}/{len(train_loader)} Loss_D: {loss_disc.item():.4f} Loss_G: {loss_gen.item():.4f}', end='')
  gen.eval()
  f_imgs_sample = (gen(z_sample).data + 1) / 2.0
  filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')
  torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
  print(f' | Save some samples to {filename}.')
  # show generated image
  grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
  plt.figure(figsize=(10,10))
  plt.imshow(grid_img.permute(1, 2, 0))
  plt.show()
  gen.train()
  
  torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth'))
  torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth'))
  

I can't load the gen state_dict in this step:

# load pretrained model
#gen = Generator(z_dim,1,64)
gen=Generator(z_dim,channel_img,features_gen).to(device)
gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
gen.eval()
gen.cuda()

Here's the error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-4bda27faa444> in <module>()
      5 
      6 #gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
----> 7 gen.load_state_dict(torch.load(os.path.join(workspace_dir, 'dcgan_g.pth')))
      8 #/content/drive/My Drive/practice/dcgan_g.pth
      9 gen.eval()

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1050         if len(error_msgs) > 0:
   1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1054 

***RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict***: "gen.0.0.weight", "gen.0.1.weight", "gen.0.1.bias", "gen.0.1.running_mean", "gen.0.1.running_var", "gen.1.0.weight", "gen.1.1.weight", "gen.1.1.bias", "gen.1.1.running_mean", "gen.1.1.running_var", "gen.2.0.weight", "gen.2.1.weight", "gen.2.1.bias", "gen.2.1.running_mean", "gen.2.1.running_var", "gen.3.0.weight", "gen.3.1.weight", "gen.3.1.bias", "gen.3.1.running_mean", "gen.3.1.running_var", "gen.4.weight", "gen.4.bias". 
    Unexpected key(s) in state_dict: "disc.0.weight", "disc.0.bias", "disc.2.0.weight", "disc.2.1.weight", "disc.2.1.bias", "disc.2.1.running_mean", "disc.2.1.running_var", "disc.2.1.num_batches_tracked", "disc.3.0.weight", "disc.3.1.weight", "disc.3.1.bias", "disc.3.1.running_mean", "disc.3.1.running_var", "disc.3.1.num_batches_tracked", "disc.4.0.weight", "disc.4.1.weight", "disc.4.1.bias", "disc.4.1.running_mean", "disc.4.1.running_var", "disc.4.1.num_batches_tracked", "disc.5.weight", "disc.5.bias".
like image 897
Jacky Lin Avatar asked Oct 27 '25 04:10

Jacky Lin


1 Answers

You saved the weights with the wrong names. That is you saved the generator's weights as dcgan_d.pth and likewise, saved the descriminator's weights as dcgan_g.pth :

  torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth')) # should have been dcgan_g.pth
  torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth')) # should have been dcgan_d.pth

and thus when loading, you try to load the wrong weights :

gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))

dcgan_g.pth contains the descriminators weights not your generators. First fix the wrong names when you save them. and second, simply rename them accordingly you should be fine.

like image 200
Hossein Avatar answered Oct 28 '25 20:10

Hossein