When using Pytorch's RNN module, it is sometimes unavoidable to use pack_padded_sequence and pad_packed_sequence, when using two-way RNN, you must use Pack_padded_seque NCE ! Otherwise, the Pytorch is unable to obtain the length of the sequence, and it does not correctly calculate the results of the bidirectional rnn/gru/lstm.
However, there is a problem when using pack_padded_sequence, that is, the length of the input mini-batch sequence must be ordered from long to short, and when the order of the samples in Mini-batch is very important, it is a bit tricky. For example, each sample is a letter-level representation of a word, and a mini-batch preserves the words of a sentence.
In this case, we still want to use pack_padded_sequence, so we need to sort the samples in Mini-batch first, and then revert to the previous order after the Rnn/lstm/gru is done.
The following code will be used to implement this method:
Import torch from torch import nn from Torch.autograd import Variable def rnn_forwarder (RNN, inputs, seq_lengths): "" ":p Aram Rnn:rnn instance:p Aram Inputs:floattensor, shape [Batch, time, Dim] If rnn.batch_first Else [time, bat CH, Dim]:p Aram Seq_lengths:longtensor shape [batch]: return:the result of RNN layer, "" "Batch_first = Rnn.batch_first # Assume seq_lengths = [3, 5, 2] # sort the sequence length (descending), sorted_seq_lengths = [5, 3, 2] # indices for [
1, 0, 2], the value of indices can be expressed in the language # original batch in the 0 position of the value, now on position 1.
# The value of the original batch in the 1 position, now on position 0.
# The value of the original batch in the 2 position, now on position 2. Sorted_seq_lengths, indices = Torch.sort (Seq_lengths, descending=true) # If we want to restore the result of the calculation to the order before sorting, # only need to indices
Sort again (ascending), get [0, 1, 2], # desorted_indices The result is [1, 0, 2] # use Desorted_indices to index the results of the calculation. _, Desorted_indices = Torch.sort (Indices, descending=false) # sort the original sequence if batch_first:inputs = Inputs[i Ndices] ELse:inputs = inputs[:, indices] packed_inputs = nn.utils.rnn.pack_padded_sequence (inputs,
Sorted_seq_lengths.cpu (). NumPy (), Batch_first=batch_first) res, state = RNN (packed_inputs) padded_res, _ = nn.utils.rnn.pad_packed_sequence (res, Batch_first=batch_first) # Restore the sample order before sorting if batch_first:desorted_res = Padded_res[desorted_indices] Els
E:desorted_res = padded_res[:, desorted_indices] return desorted_res if __name__ = = "__main__": bs = 3 Max_time_step = 5 Feat_size = Hidden_size = 7 Seq_lengths = [3, 5, 2] rnn = nn. GRU (Input_size=feat_size, Hidden_size=hidden_size, Batch_first=true, bidirectional=true) x = Variabl E (torch. Floattensor (BS, Max_time_step, Feat_size). Normal_ ()) Using_packed_res = Rnn_forwarder (RNN, X, seq_lengths) print ( Using_packed_res) # do not use PACK_PADed, used to compare with the above results. Not_packed_res, _ = RNN (x) print (not_packed_res)