Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get each sequence's last item from packed sequence

I am trying to put a packed and padded sequence through a GRU, and retrieve the output of the last item of each sequence. Of course I don't mean the -1 item, but the actual last, not-padded item. We know the lengths of the sequences in advance, so it should be as easy as to extract for each sequence the length-1 item.

I tried the following

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Data
input = torch.Tensor([[[0., 0., 0.],
                       [1., 0., 1.],
                       [1., 1., 0.],
                       [1., 0., 1.],
                       [1., 0., 1.],
                       [1., 1., 0.]],

                      [[1., 1., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[0., 0., 0.],
                       [1., 0., 0.],
                       [1., 1., 1.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[1., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]]])

lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)

# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)

# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])

last_seq_items = torch.index_select(output, 1, last_seq_idxs) 

print(last_seq_items.size())
# torch.Size([4, 4, 12])

But the shape is not what I expect. I had expected to get 4x12, i.e. last item of each individual sequence x hidden.`

I could loop through the whole thing, and build a new tensor containing the items I need, but I was hoping for a built-in approach that took advantage of some smart math. I fear that manually looping and building, will result in very poor performance.

like image 487
Bram Vanroy Avatar asked Oct 27 '25 16:10

Bram Vanroy


2 Answers

Instead of last two operations last_seq_idxs and last_seq_items you could just do last_seq_items=output[torch.arange(4), input_sizes-1].

I don't think index_select is doing the right thing. It will select the whole batch at the index you passed and therefore your output size is [4,4,12].

like image 144
Umang Gupta Avatar answered Oct 30 '25 08:10

Umang Gupta


A more verbose alternative to Umang Gupta's answer:

# ...
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# One per sequence, with its last actual node extracted, and unsqueezed
last_seq = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
# Merge them together all sequences together to get batch
last_seq = torch.cat(last_seq, dim=0)
like image 43
Bram Vanroy Avatar answered Oct 30 '25 08:10

Bram Vanroy