0%

Recover sorted tensor in Pytorch

Problem

When I use torch.nn.utils.rnn.pad_sequence to Padding words and feed the padded sequence into LSTM/RNN, a input sorted by length is neccessary. But an order-changed sequence will increase the difficulty of evaluation. So here is a way to recover the sorted tensor using Pytorch functions.

Let’s go

1
x = torch.randn(10)

Here x is tensor([-0.4321, 0.3852, 0.6008, 0.8452, -0.4709, 0.7610, -0.9743, -0.9819, -1.1142, -0.1249]) and then we do the sort.

1
sorted_x, idx = torch.sort(x)

Here idx is the index of x, tensor([8, 7, 6, 4, 0, 9, 1, 2, 5, 3]). Then we can get the original order just by sorting the idx.

1
2
_, rev_idx = torch.sort(idx)
sorted_x[rev_idx]

We can see that the script prints the tensor([-0.4321, 0.3852, 0.6008, 0.8452, -0.4709, 0.7610, -0.9743, -0.9819, -1.1142, -0.1249]) which is equals to the original x. It’s amazing, isn’t it? I’ll then show you why it works.

Mathematical Explain

We suppose there is a n-permutation corresponding to our tensor.

Then we do the sort and get a new permutation.

Here idx is corresponding to the vector . Then do the second sort and get and here rev_idx corresponding to vector . The code sorted_x[rev_idx] selected elements with subscriber from the second permutaion, which means it selected the vector .

Mention that the vector is a permutation of . So the vector is also a permutation of . So we have that for all ,. Finally, which is the original tensor.