Why do we "pack" the sequences in PyTorch?
I was trying to replicate How to use packing for variable-length sequence inputs for rnn but I guess I first need to understand why we need to "pack" the sequence.
I understand why we need to "pad" them but why is "packing" (through pack_padded_sequence
) necessary?
Any high-level explanation would be appreciated!
Solution 1:
I have stumbled upon this problem too and below is what I figured out.
When training RNN (LSTM or GRU or vanilla-RNN), it is difficult to batch the variable length sequences. For example: if the length of sequences in a size 8 batch is [4,6,8,5,4,3,7,8], you will pad all the sequences and that will result in 8 sequences of length 8. You would end up doing 64 computations (8x8), but you needed to do only 45 computations. Moreover, if you wanted to do something fancy like using a bidirectional-RNN, it would be harder to do batch computations just by padding and you might end up doing more computations than required.
Instead, PyTorch allows us to pack the sequence, internally packed sequence is a tuple of two lists. One contains the elements of sequences. Elements are interleaved by time steps (see example below) and other contains the size of each sequence the batch size at each step. This is helpful in recovering the actual sequences as well as telling RNN what is the batch size at each time step. This has been pointed by @Aerin. This can be passed to RNN and it will internally optimize the computations.
I might have been unclear at some points, so let me know and I can add more explanations.
Here's a code example:
a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
>>>>
tensor([[ 1, 2, 3],
[ 3, 4, 0]])
torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,2])
>>>>PackedSequence(data=tensor([ 1, 3, 2, 4, 3]), batch_sizes=tensor([ 2, 2, 1]))
Solution 2:
Here are some visual explanations1 that might help to develop better intuition for the functionality of pack_padded_sequence()
.
TL;DR: It is performed primarily to save compute. Consequently, the time required for training neural network models is also (drastically) reduced, especially when carried out on very large (a.k.a. web-scale) datasets.
Let's assume we have 6
sequences (of variable lengths) in total. You can also consider this number 6
as the batch_size
hyperparameter. (The batch_size
will vary depending on the length of the sequence (cf. Fig.2 below))
Now, we want to pass these sequences to some recurrent neural network architecture(s). To do so, we have to pad all of the sequences (typically with 0
s) in our batch to the maximum sequence length in our batch (max(sequence_lengths)
), which in the below figure is 9
.
So, the data preparation work should be complete by now, right? Not really.. Because there is still one pressing problem, mainly in terms of how much compute do we have to do when compared to the actually required computations.
For the sake of understanding, let's also assume that we will matrix multiply the above padded_batch_of_sequences
of shape (6, 9)
with a weight matrix W
of shape (9, 3)
.
Thus, we will have to perform 6x9 = 54
multiplication and 6x8 = 48
addition
(nrows x (n-1)_cols
) operations, only to throw away most of the computed results since they would be 0
s (where we have pads). The actual required compute in this case is as follows:
9-mult 8-add
8-mult 7-add
6-mult 5-add
4-mult 3-add
3-mult 2-add
2-mult 1-add
---------------
32-mult 26-add
------------------------------
#savings: 22-mult & 22-add ops
(32-54) (26-48)
That's a LOT more savings even for this very simple (toy) example. You can now imagine how much compute (eventually: cost, energy, time, carbon emission etc.) can be saved using pack_padded_sequence()
for large tensors with millions of entries, and million+ systems all over the world doing that, again and again.
The functionality of pack_padded_sequence()
can be understood from the figure below, with the help of the used color-coding:
As a result of using pack_padded_sequence()
, we will get a tuple of tensors containing (i) the flattened (along axis-1, in the above figure) sequences
, (ii) the corresponding batch sizes, tensor([6,6,5,4,3,3,2,2,1])
for the above example.
The data tensor (i.e. the flattened sequences) could then be passed to objective functions such as CrossEntropy for loss calculations.
1 image credits to @sgrvinod