Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch MNIST autoencoder to learn 10-digit classification

I'm trying to build a simple autoencoder for MNIST, where the middle layer is just 10 neurons. My hope is that it will learn to classify the 10 digits, and I assume that would lead to the lowest error in the end (wrt reproducing the original image).

I have the following code, which I've already played around with a fair amount. If I run it for up-to 100 epochs, the loss doesn't really go below 1.0, and if I evaluate it, it's obviously not working. What am I missing?

Training:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image

num_epochs = 100
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = tv.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            nn.Dropout2d(p=0.2),
            # 4 x 24 x 24
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.Dropout2d(p=0.2),
            # 8 x 20 x 20 = 3200
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(3200, 10),
            nn.ReLU(True),
            # 10
            nn.Softmax(),
            # 10
            )
        self.decoder = nn.Sequential(
            # 10
            nn.Linear(10, 400),
            nn.ReLU(True),
            # 400
            nn.Unflatten(1, (1, 20, 20)),
            # 20 x 20
            nn.Dropout2d(p=0.2),
            nn.ConvTranspose2d(1, 10, kernel_size=5),
            # 24 x 24
            nn.ReLU(True),
            nn.Dropout2d(p=0.2),
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            # 28 x 28
            nn.ReLU(True),
            nn.Sigmoid(),
            )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Autoencoder().cpu()
distance = nn.MSELoss()
#optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = Variable(img).cpu()
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Already the training loss indicates that the thing is not working, but printing out the confusion matrix (which in this case should not necessarily be the identity matrix, since the neurons can be ordered arbitrarily, but should be row-col-reordarable and approximate the identity, if this would work):

import numpy as np

confusion_matrix = np.zeros((10, 10))

batch_size = 20*1000

testset = tv.datasets.MNIST(root='./data',  train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)

for data in dataloader:
    imgs, labels = data
    imgs = Variable(imgs).cpu()
    encs = model.encoder(imgs).detach().numpy()
    for i in range(len(encs)):
        predicted = np.argmax(encs[i])
        actual = labels[i]
        confusion_matrix[actual][predicted] += 1
print(confusion_matrix)
like image 631
Marton Trencseni Avatar asked Oct 17 '25 11:10

Marton Trencseni


1 Answers

Autoencoder is technically not used as a classifier in general. They learn how to encode a given image into a short vector and reconstruct the same image from the encoded vector. It is a way of compressing image into a short vector:

Autoencoder

Since you want to train autoencoder with classification capabilities, we need to make some changes to model. First of all, there will be two different losses:

  1. MSE loss: Current autoencoder reconstruction loss. This will force network to output an image as close as possible to given image by taking the compressed representation.
  2. Classification loss: Classic cross entropy should do the trick. This loss will take compressed representation (C dimensional) and target labels to calculate negative log likelihood loss. This loss will force encoder to output compressed representation such that it aligns well with the target class.

I've done a couple of changes to your code to get the combined model working. Firstly, let's see the code:

 import torch
 import torchvision as tv
 import torchvision.transforms as transforms
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
 from torchvision.utils import save_image

 num_epochs = 10
 batch_size = 64
 transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))
 ])     
 
 trainset = tv.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)
 testset  = tv.datasets.MNIST(root='./data',  train=False, download=True, transform=transform)
 dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
 testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 class Autoencoderv3(nn.Module):
     def __init__(self):
         super(Autoencoderv3,self).__init__()
         self.encoder = nn.Sequential(
             nn.Conv2d(1, 4, kernel_size=5),
             nn.Dropout2d(p=0.1),
             nn.ReLU(True),
             nn.Conv2d(4, 8, kernel_size=5),
             nn.Dropout2d(p=0.1),
             nn.ReLU(True),
             nn.Flatten(),
             nn.Linear(3200, 10)
             )
         self.softmax = nn.Softmax(dim=1)
         self.decoder = nn.Sequential(
             nn.Linear(10, 400),
             nn.ReLU(True),
             nn.Unflatten(1, (1, 20, 20)),
             nn.Dropout2d(p=0.1),
             nn.ConvTranspose2d(1, 10, kernel_size=5),
             nn.ReLU(True),
             nn.Dropout2d(p=0.1),
             nn.ConvTranspose2d(10, 1, kernel_size=5)
             )
         
     def forward(self, x):
         out_en = self.encoder(x)
         out = self.softmax(out_en)
         out = self.decoder(out)
         return out, out_en
 
 model = Autoencoderv3().to(device)
 distance   = nn.MSELoss()
 class_loss = nn.CrossEntropyLoss()
 
 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
 mse_multp = 0.5
 cls_multp = 0.5
 
 model.train()
 
 for epoch in range(num_epochs):
     total_mseloss = 0.0
     total_clsloss = 0.0
     for ind, data in enumerate(dataloader):
         img, labels = data[0].to(device), data[1].to(device) 
         output, output_en = model(img)
         loss_mse = distance(output, img)
         loss_cls = class_loss(output_en, labels)
         loss = (mse_multp * loss_mse) + (cls_multp * loss_cls)  # Combine two losses together
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         # Track this epoch's loss
         total_mseloss += loss_mse.item()
         total_clsloss += loss_cls.item()
 
     # Check accuracy on test set after each epoch:
     model.eval()   # Turn off dropout in evaluation mode
     acc = 0.0
     total_samples = 0
     for data in testloader:
         # We only care about the 10 dimensional encoder output for classification
         img, labels = data[0].to(device), data[1].to(device) 
         _, output_en = model(img)   
         # output_en contains 10 values for each input, apply softmax to calculate class probabilities
         prob = nn.functional.softmax(output_en, dim = 1)
         pred = torch.max(prob, dim=1)[1].detach().cpu().numpy() # Max prob assigned to class 
         acc += (pred == labels.cpu().numpy()).sum()
         total_samples += labels.shape[0]
     model.train()   # Enables dropout back again
     print('epoch [{}/{}], loss_mse: {:.4f}  loss_cls: {:.4f}  Acc on test: {:.4f}'.format(epoch+1, num_epochs, total_mseloss / len(dataloader), total_clsloss / len(dataloader), acc / total_samples))
   

This code should now train the model both as a classifier and a generative autoencoder. In general though, this type of approach can be a bit tricky to get the model training. In this case, MNIST data is simple enough to get those two complementary losses train together. In more complex cases like Generative Adversarial Networks (GAN), they apply model training switching, freezing one model etc. to get whole model trained. This autoencoder model trains easily on MNIST without doing those types of tricks:

 epoch [1/10], loss_mse: 0.8928  loss_cls: 0.4627  Acc on test: 0.9463
 epoch [2/10], loss_mse: 0.8287  loss_cls: 0.2105  Acc on test: 0.9639
 epoch [3/10], loss_mse: 0.7803  loss_cls: 0.1574  Acc on test: 0.9737
 epoch [4/10], loss_mse: 0.7513  loss_cls: 0.1290  Acc on test: 0.9764
 epoch [5/10], loss_mse: 0.7298  loss_cls: 0.1117  Acc on test: 0.9762
 epoch [6/10], loss_mse: 0.7110  loss_cls: 0.1017  Acc on test: 0.9801
 epoch [7/10], loss_mse: 0.6962  loss_cls: 0.0920  Acc on test: 0.9794
 epoch [8/10], loss_mse: 0.6824  loss_cls: 0.0859  Acc on test: 0.9806
 epoch [9/10], loss_mse: 0.6733  loss_cls: 0.0797  Acc on test: 0.9814
 epoch [10/10], loss_mse: 0.6671  loss_cls: 0.0764  Acc on test: 0.9813

As you can see, both mse loss and classification loss is decreasing, and accuracy on test set is increasing. In the code, MSE loss and classification loss are added together. This means respective gradients calculated from each loss are fighting against each other to force the network into their direction. I've added loss multiplier to control the contribution from each loss. If MSE has a higher multiplier, network will have more gradients from MSE loss, meaning it will better learn reconstruction, if CLS loss has a higher multiplier, network will get better classification accuracies. You can play with those multiplier to see how end result is changing, but MNIST is a very easy dataset so differences might be hard to see maybe. Currently, it doesn't do too bad at reconstructing inputs:

 import numpy as np
 import matplotlib.pyplot as plt
 
 model.eval()
 img, labels = list(dataloader)[0]
 img = img.to(device)
 output, output_en = model(img)
 inp = img[0:10, 0, :, :].squeeze().detach().cpu()
 out = output[0:10, 0, :, :].squeeze().detach().cpu()
 
 # Just some trick to concatenate first ten images next to each other
 inp = inp.permute(1,0,2).reshape(28, -1).numpy()
 out = out.permute(1,0,2).reshape(28, -1).numpy()
 combined = np.vstack([inp, out])
 
 plt.imshow(combined)
 plt.show()

Reconstrunction

I am sure with more training and fine tuning loss multipliers, you can get better results.

Lastly, decoder receives softmax of encoder output. This mean decoder tries to create output image from 0 - 1 probabilities of the input. So if the softmax probability vector is 0.98 at input location 0 and close to zero elsewhere, decoder should output an image that looks like 0.0. Here I give network input to create 0 to 9 reconstructions:

 test_arr = np.zeros([10, 10], dtype = np.float32)
 ind = np.arange(0, 10)
 test_arr[ind, ind] = 1.0
 
 model.eval()
 img = torch.from_numpy(test_arr).to(device)
 out = model.decoder(img)
 out = out[0:10, 0, :, :].squeeze().detach().cpu()
 out = out.permute(1,0,2).reshape(28, -1).numpy()
 plt.imshow(out)
 plt.show()

0 to 10 reconstruction

I've also done a few small changes in the code, printing epoch average loss etc. which doesn't really change the training logic, so you can see those changes in the code and let me know if anything looks weird.

like image 190
yutasrobot Avatar answered Oct 20 '25 00:10

yutasrobot



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!