Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch: Assign values from one mask to another, masked by itself

Tags:

python

pytorch

I have a mask active that tracks batches that still have not terminated in a recurrent process. It's dimension is [batch_full,], and it's true entries show which elements need to still be used in current step. The recurrent process generates another mask, terminated, which has as many elements as true values in active mask. Now, I want to take values from ~terminated and put them back into active, but at the right indices. Basically I want to do:

import torch

active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)

terminated = torch.tensor([True, False])

active[active] = ~terminated

print(active)  # expected [F, F, F, T]

However, I get error:

RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

How can I do the described above operation in an effective way?

like image 530
Boschie Avatar asked Sep 01 '25 20:09

Boschie


1 Answers

There are a few solutions, I will also give their speed as measured by timeit, 10k repetitions, on 2021 macbook pro.

The simplest solution, taking 0.260s:

active[active.clone()] = ~terminated

We can use masked_scatter_ inplace operation for abt. 2x speedup (0.136s):

active.masked_scatter_(
        active,
        ~terminated,
    )

Out of place operation, taking 0.161s, would be:

active = torch.masked_scatter(
        active,
        active,
        ~terminated,
    )
like image 116
Boschie Avatar answered Sep 03 '25 09:09

Boschie