Torch.nn.utils.rnn.pack_padded_sequence ()
Here pack
, the understanding of compression is better. Presses a filled variable-length sequence. (When filling, there will be redundancy, so press tight)
The process of the pack is: (Note the form of pack, not press the row, but press the column pressure)
(The object in the box below is PackedSequence
composed of data and batch_sizes)
The shape you enter can be (txbx*). T
is the longest sequence length, B
Yes batch size
, *
representing any dimension (can be 0). If so batch_first=True
, then the corresponding input size
is (B×T×*)
.
Variable
The sequence to be saved should be sorted by the length of the sequence, long in front, and short in the back. That input[:,0]
represents the longest sequence, the input[:, B-1]
shortest sequence saved.
NOTE:
As long as the dimension is greater than or equal to 2, input
it can be used as a parameter for this function. You can use it to package labels
, and then use RNN
the output and after the packaging labels
to calculate loss
. It PackedSequence
can be obtained through the properties of the object .data
Variable
.
Parameter description:
- Input (Variable) – Batch with variable-length sequence filled
- Lengths (List[int]) –
Variable
the length of each sequence in.
- Batch_first (bool, optional) – If yes
True
, the shape of input should be B*T*size
.
return value:
An PackedSequence
object.
Torch.nn.utils.rnn.pad_packed_sequence ()
Fill packed_sequence
.
The function mentioned above is to press a filled variable-length sequence. This operation is the opposite of pack_padded_sequence (). The pressed sequence is then filled back.
The value of the returned varaible is the length of the longest sequence, which is size
T×B×*
T
B
batch_size, if batch_first=True
so, then the return value is B×T×*
.
The elements in batch will be sorted in reverse order of their length.
Parameter description:
- Sequence (packedsequence) – batch that will be populated
- Batch_first (bool, optional) – If True, the format of the returned data is
B×T×*
.
Return value: A tuple that contains the filled sequence, and a list of the lengths of the series in batch
An example:
Output:
When the Packedsequence object is entered RNN, the output RNN or Packedsequence object
Reference:
Https://www.cnblogs.com/lindaxin/p/8052043.html
Https://pytorch.org/docs/stable/nn.html?highlight=pack_padded_sequence#torch.nn.utils.rnn.pack_padded_sequence
Pytorch in Rnn pack_padded_sequence () and Pad_packed_sequence ()