Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Custom Operations on Multi-dimensional Tensors

I am trying to compute the tensor R (see image) and the only way I could explain what I am trying to compute is by doing it on a paper:

enter image description here

o = torch.tensor([[[1, 3, 2], [7, 9, 8], [13, 15, 14], [19, 21, 20], [25, 27, 26]], [[31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50], [55, 57, 56]]])
p = torch.tensor([[[19, 21, 20], [7, 9, 8], [13, 15, 14], [1, 3, 2], [25, 27, 26]], [[55, 57, 56], [31, 33, 32], [37, 39, 38], [43, 45, 44], [49, 51, 50]]])

# this is O' in image
o_prime = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 0.11]])

# this is P' in image
p_prime = torch.tensor([[1.1, 1.2, 1.3, 1.4, 1.5], [1.6, 1.7, 1.8, 1.9, 1.11]])

# this is R (this is what I need)
r = torch.tensor([[[0, 0, 0, 6.1, 0], [0, 24.2, 0, 0, 0], [0, 0, 42.3, 0, 0], [60.4, 0, 0, 0, 0], [0, 0, 0, 0, 78.5]], [[0, 96.6, 0, 0, 0], [0, 0, 114.7, 0, 0], [0, 0, 0, 132.8, 0], [0, 0, 0, 0, 150.9], [168.11, 0, 0, 0, 0]]])

How do I get R without looping over tensors?

correction: In the image, I forgot to add value of p' along with sum(o) + o'

like image 466
Sam Avatar asked Oct 31 '25 19:10

Sam


1 Answers

You can construct a helper tensor containing the resulting values sum(o) + o' + p':

>>> v = o.sum(2, True) + o_prime[...,None] + o_prime[...,None]
tensor([[[  7.2000],
         [ 25.4000],
         [ 43.6000],
         [ 61.8000],
         [ 80.0000]],

        [[ 98.2000],
         [116.4000],
         [134.6000],
         [152.8000],
         [169.2200]]])

Then you can assemble a mask for the final tensor via broadcasting:

>>> eq = o[:,None] == p[:,:,None]

Ensuring all three elements on the last dimension match:

>>> eq.all(dim=-1)
tensor([[[False, False, False,  True, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [ True, False, False, False, False],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [ True, False, False, False, False],
         [False,  True, False, False, False],
         [False, False,  True, False, False],
         [False, False, False,  True, False]]])

Finally, you can simply multiply both tensor and auto-broadcasting will handle the rest:

>>> R = eq.all(dim=-1) * v
tensor([[[  0.0000,   0.0000,   0.0000,   7.2000,   0.0000],
         [  0.0000,  25.4000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,  43.6000,   0.0000,   0.0000],
         [ 61.8000,   0.0000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,   0.0000,  80.0000]],

        [[  0.0000,   0.0000,   0.0000,   0.0000,  98.2000],
         [116.4000,   0.0000,   0.0000,   0.0000,   0.0000],
         [  0.0000, 134.6000,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000, 152.8000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000, 169.2200,   0.0000]]])

I wanted to know how do you visualize such problems and then come up with a solution? Any pointers would be beneficial.

I was going to say it depends a lot on the problem at hand but this wouldn't get you very far! I believe having a toolbox of functions/tricks and scenarios you've come across (i.e. experience) helps greatly. This is true for problem-solving more generally speaking. I can try to explain how I came up with the solution and my thought process behind it. The initial idea for this problem is to perform an outer equality check between o and p. By that I mean we are trying to construct a structure which evaluates every (o[i], p[j]) relation in batch-wise.

Turns out this is rather a common operation usually seen as an outer summation or outer product. In fact, this type of operation is also applicable to the equality operator: here we are looking to construct a 5x5 matrix of o[i] == p[j]. Keeping in mind throughout the process we have a leading dimension containing three elements, but that doesn't change the process. We just need to account for it by checking that all three checks are indeed True, hence the all(dim=-1) call.

Since the desired result doesn't depend on column position inside the mask, i.e. result = sum(0) + o' + p' whatever the column index, we can just precompute the results for each row beforehand. The final operation is simple multiply the mask (which of course only contains ones at the desired locations) with the vector of dimensions. Intuitively, all columns will get multiplied by the same value but only the 1s will allow for the value to be set.

But most importantly, we have to acknowledge that your figure did all the hard work. This is in my opinion the first step before starting with any reasoning or implementation. So to summarize, I would suggest:

  1. start with a minimal example reducing the number of variables while still making it relevant for the actual problem it is supposed to solve.

  2. think about how you can solve it step by step by trying to get closer and closer to the solution. Iteratively, and trying out with this minimal setup.

Most importantly it comes with practice, with time you will find it easier to reason about your problem and use the right tools to manipulate your data.

like image 170
Ivan Avatar answered Nov 02 '25 08:11

Ivan