I'm trying to build a simple autoencoder for MNIST, where the middle layer is just 10 neurons. My hope is that it will learn to classify the 10 digits, and I assume that would lead to the lowest error in the end (wrt reproducing the original image).
I have the following code, which I've already played around with a fair amount. If I run it for up-to 100 epochs, the loss doesn't really go below 1.0, and if I evaluate it, it's obviously not working. What am I missing?
Training:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
num_epochs = 100
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder,self).__init__()
self.encoder = nn.Sequential(
# 28 x 28
nn.Conv2d(1, 4, kernel_size=5),
nn.Dropout2d(p=0.2),
# 4 x 24 x 24
nn.ReLU(True),
nn.Conv2d(4, 8, kernel_size=5),
nn.Dropout2d(p=0.2),
# 8 x 20 x 20 = 3200
nn.ReLU(True),
nn.Flatten(),
nn.Linear(3200, 10),
nn.ReLU(True),
# 10
nn.Softmax(),
# 10
)
self.decoder = nn.Sequential(
# 10
nn.Linear(10, 400),
nn.ReLU(True),
# 400
nn.Unflatten(1, (1, 20, 20)),
# 20 x 20
nn.Dropout2d(p=0.2),
nn.ConvTranspose2d(1, 10, kernel_size=5),
# 24 x 24
nn.ReLU(True),
nn.Dropout2d(p=0.2),
nn.ConvTranspose2d(10, 1, kernel_size=5),
# 28 x 28
nn.ReLU(True),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = Autoencoder().cpu()
distance = nn.MSELoss()
#optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(num_epochs):
for data in dataloader:
img, _ = data
img = Variable(img).cpu()
output = model(img)
loss = distance(output, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
Already the training loss indicates that the thing is not working, but printing out the confusion matrix (which in this case should not necessarily be the identity matrix, since the neurons can be ordered arbitrarily, but should be row-col-reordarable and approximate the identity, if this would work):
import numpy as np
confusion_matrix = np.zeros((10, 10))
batch_size = 20*1000
testset = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)
for data in dataloader:
imgs, labels = data
imgs = Variable(imgs).cpu()
encs = model.encoder(imgs).detach().numpy()
for i in range(len(encs)):
predicted = np.argmax(encs[i])
actual = labels[i]
confusion_matrix[actual][predicted] += 1
print(confusion_matrix)
Autoencoder is technically not used as a classifier in general. They learn how to encode a given image into a short vector and reconstruct the same image from the encoded vector. It is a way of compressing image into a short vector:
Since you want to train autoencoder with classification capabilities, we need to make some changes to model. First of all, there will be two different losses:
I've done a couple of changes to your code to get the combined model working. Firstly, let's see the code:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
num_epochs = 10
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Autoencoderv3(nn.Module):
def __init__(self):
super(Autoencoderv3,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 4, kernel_size=5),
nn.Dropout2d(p=0.1),
nn.ReLU(True),
nn.Conv2d(4, 8, kernel_size=5),
nn.Dropout2d(p=0.1),
nn.ReLU(True),
nn.Flatten(),
nn.Linear(3200, 10)
)
self.softmax = nn.Softmax(dim=1)
self.decoder = nn.Sequential(
nn.Linear(10, 400),
nn.ReLU(True),
nn.Unflatten(1, (1, 20, 20)),
nn.Dropout2d(p=0.1),
nn.ConvTranspose2d(1, 10, kernel_size=5),
nn.ReLU(True),
nn.Dropout2d(p=0.1),
nn.ConvTranspose2d(10, 1, kernel_size=5)
)
def forward(self, x):
out_en = self.encoder(x)
out = self.softmax(out_en)
out = self.decoder(out)
return out, out_en
model = Autoencoderv3().to(device)
distance = nn.MSELoss()
class_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
mse_multp = 0.5
cls_multp = 0.5
model.train()
for epoch in range(num_epochs):
total_mseloss = 0.0
total_clsloss = 0.0
for ind, data in enumerate(dataloader):
img, labels = data[0].to(device), data[1].to(device)
output, output_en = model(img)
loss_mse = distance(output, img)
loss_cls = class_loss(output_en, labels)
loss = (mse_multp * loss_mse) + (cls_multp * loss_cls) # Combine two losses together
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track this epoch's loss
total_mseloss += loss_mse.item()
total_clsloss += loss_cls.item()
# Check accuracy on test set after each epoch:
model.eval() # Turn off dropout in evaluation mode
acc = 0.0
total_samples = 0
for data in testloader:
# We only care about the 10 dimensional encoder output for classification
img, labels = data[0].to(device), data[1].to(device)
_, output_en = model(img)
# output_en contains 10 values for each input, apply softmax to calculate class probabilities
prob = nn.functional.softmax(output_en, dim = 1)
pred = torch.max(prob, dim=1)[1].detach().cpu().numpy() # Max prob assigned to class
acc += (pred == labels.cpu().numpy()).sum()
total_samples += labels.shape[0]
model.train() # Enables dropout back again
print('epoch [{}/{}], loss_mse: {:.4f} loss_cls: {:.4f} Acc on test: {:.4f}'.format(epoch+1, num_epochs, total_mseloss / len(dataloader), total_clsloss / len(dataloader), acc / total_samples))
This code should now train the model both as a classifier and a generative autoencoder. In general though, this type of approach can be a bit tricky to get the model training. In this case, MNIST data is simple enough to get those two complementary losses train together. In more complex cases like Generative Adversarial Networks (GAN), they apply model training switching, freezing one model etc. to get whole model trained. This autoencoder model trains easily on MNIST without doing those types of tricks:
epoch [1/10], loss_mse: 0.8928 loss_cls: 0.4627 Acc on test: 0.9463
epoch [2/10], loss_mse: 0.8287 loss_cls: 0.2105 Acc on test: 0.9639
epoch [3/10], loss_mse: 0.7803 loss_cls: 0.1574 Acc on test: 0.9737
epoch [4/10], loss_mse: 0.7513 loss_cls: 0.1290 Acc on test: 0.9764
epoch [5/10], loss_mse: 0.7298 loss_cls: 0.1117 Acc on test: 0.9762
epoch [6/10], loss_mse: 0.7110 loss_cls: 0.1017 Acc on test: 0.9801
epoch [7/10], loss_mse: 0.6962 loss_cls: 0.0920 Acc on test: 0.9794
epoch [8/10], loss_mse: 0.6824 loss_cls: 0.0859 Acc on test: 0.9806
epoch [9/10], loss_mse: 0.6733 loss_cls: 0.0797 Acc on test: 0.9814
epoch [10/10], loss_mse: 0.6671 loss_cls: 0.0764 Acc on test: 0.9813
As you can see, both mse loss and classification loss is decreasing, and accuracy on test set is increasing. In the code, MSE loss and classification loss are added together. This means respective gradients calculated from each loss are fighting against each other to force the network into their direction. I've added loss multiplier to control the contribution from each loss. If MSE has a higher multiplier, network will have more gradients from MSE loss, meaning it will better learn reconstruction, if CLS loss has a higher multiplier, network will get better classification accuracies. You can play with those multiplier to see how end result is changing, but MNIST is a very easy dataset so differences might be hard to see maybe. Currently, it doesn't do too bad at reconstructing inputs:
import numpy as np
import matplotlib.pyplot as plt
model.eval()
img, labels = list(dataloader)[0]
img = img.to(device)
output, output_en = model(img)
inp = img[0:10, 0, :, :].squeeze().detach().cpu()
out = output[0:10, 0, :, :].squeeze().detach().cpu()
# Just some trick to concatenate first ten images next to each other
inp = inp.permute(1,0,2).reshape(28, -1).numpy()
out = out.permute(1,0,2).reshape(28, -1).numpy()
combined = np.vstack([inp, out])
plt.imshow(combined)
plt.show()
I am sure with more training and fine tuning loss multipliers, you can get better results.
Lastly, decoder receives softmax of encoder output. This mean decoder tries to create output image from 0 - 1 probabilities of the input. So if the softmax probability vector is 0.98 at input location 0 and close to zero elsewhere, decoder should output an image that looks like 0.0. Here I give network input to create 0 to 9 reconstructions:
test_arr = np.zeros([10, 10], dtype = np.float32)
ind = np.arange(0, 10)
test_arr[ind, ind] = 1.0
model.eval()
img = torch.from_numpy(test_arr).to(device)
out = model.decoder(img)
out = out[0:10, 0, :, :].squeeze().detach().cpu()
out = out.permute(1,0,2).reshape(28, -1).numpy()
plt.imshow(out)
plt.show()
I've also done a few small changes in the code, printing epoch average loss etc. which doesn't really change the training logic, so you can see those changes in the code and let me know if anything looks weird.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With