Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch mask missing values when calculating rmse

Tags:

pytorch

I'm trying to calculate the rmse error of two torch tensors. I would like to ignore/mask the rows where the labels are 0 (missing values). How could I modify this line to take that restriction into account?

torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()

Thank you in advance.

like image 795
razvanc92 Avatar asked Oct 13 '25 05:10

razvanc92


1 Answers

This can be solved by defining a custom MSE loss function* that masks out the missing values, 0 in your case, from both the input and target tensors:

def mse_loss_with_nans(input, target):

    # Missing data are nan's
    # mask = torch.isnan(target)

    # Missing data are 0's
    mask = target == 0

    out = (input[~mask]-target[~mask])**2
    loss = out.mean()

    return loss

(*) Computing MSE is equivalent to RMSE from an optimisation point of view -- with the advantage of being computationally faster.

like image 135
prl900 Avatar answered Oct 16 '25 12:10

prl900



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!