I have a mask active
that tracks batches that still have not terminated in a recurrent process. It's dimension is [batch_full,]
, and it's true entries show which elements need to still be used in current step. The recurrent process generates another mask, terminated
, which has as many elements as true values in active
mask. Now, I want to take values from ~terminated
and put them back into active
, but at the right indices. Basically I want to do:
import torch
active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)
terminated = torch.tensor([True, False])
active[active] = ~terminated
print(active) # expected [F, F, F, T]
However, I get error:
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
How can I do the described above operation in an effective way?
There are a few solutions, I will also give their speed as measured by timeit
, 10k repetitions, on 2021 macbook pro.
The simplest solution, taking 0.260s:
active[active.clone()] = ~terminated
We can use masked_scatter_
inplace operation for abt. 2x speedup (0.136s):
active.masked_scatter_(
active,
~terminated,
)
Out of place operation, taking 0.161s, would be:
active = torch.masked_scatter(
active,
active,
~terminated,
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With