Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Use Pytorch SSIM loss function in my model

I am trying out this SSIM loss implement by this repo for image restoration.

For the reference of original sample code on author's GitHub, I tried:

model.train()
for epo in range(epoch):
    for i, data in enumerate(trainloader, 0):
        inputs = data
        inputs = Variable(inputs)
        optimizer.zero_grad()
        inputs = inputs.view(bs, 1, 128, 128)
        top = model.upward(inputs)
        outputs = model.downward(top, shortcut = True)
        outputs = outputs.view(bs, 1, 128, 128)

        if i % 20 == 0:
            out = outputs[0].view(128, 128).detach().numpy() * 255
            cv2.imwrite("/home/tk/Documents/recover/SSIM/" + str(epo) + "_" + str(i) + "_re.png", out)

        loss = - criterion(inputs, outputs)
        ssim_value = - loss.data.item()
        print (ssim_value)
        loss.backward()
        optimizer.step()

However, the results didn't come out as I expected. After first 10 epochs, the printed outcome image were all black.

loss = - criterion(inputs, outputs) is proposed by the author, however, for classical Pytorch training code this will be loss = criterion(y_pred, target), therefore should be loss = criterion(inputs, outputs) here.

However, I tried loss = criterion(inputs, outputs) but the results are still the same.

Can anyone share some thoughts about how to properly utilize SSIM loss? Thanks.

like image 649
sealpuppy Avatar asked Sep 06 '25 20:09

sealpuppy


1 Answers

The usual way to transform a similarity (higher is better) into a loss is to compute 1 - similarity(x, y).

To create this loss you can create a new "function".

def ssim_loss(x, y):
    return 1. - ssim(x, y)

Alternatively, if the similarity is a class (nn.Module), you can overload it to create a new one.

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

Also, there are better implementations of SSIM than the one of this repo. For example, the one of the piqa Python package is faster. The package can be installed with

pip install piqa

For your problem

from piqa import SSIM

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

criterion = SSIMLoss() # .cuda() if you need GPU support

...
loss = criterion(x, y)
...

should work well.

like image 81
Donshel Avatar answered Sep 08 '25 10:09

Donshel