Pytorch Study notes (21): Using Pack_padded_sequence

Source: Internet
Author: User
Tags pytorch

When using Pytorch's RNN module, it is sometimes unavoidable to use pack_padded_sequence and pad_packed_sequence, when using two-way RNN, you must use Pack_padded_seque NCE ! Otherwise, the Pytorch is unable to obtain the length of the sequence, and it does not correctly calculate the results of the bidirectional rnn/gru/lstm.

However, there is a problem when using pack_padded_sequence, that is, the length of the input mini-batch sequence must be ordered from long to short, and when the order of the samples in Mini-batch is very important, it is a bit tricky. For example, each sample is a letter-level representation of a word, and a mini-batch preserves the words of a sentence.

In this case, we still want to use pack_padded_sequence, so we need to sort the samples in Mini-batch first, and then revert to the previous order after the Rnn/lstm/gru is done.

The following code will be used to implement this method:

Import torch from torch import nn from Torch.autograd import Variable def rnn_forwarder (RNN, inputs, seq_lengths): "" ":p Aram Rnn:rnn instance:p Aram Inputs:floattensor, shape [Batch, time, Dim] If rnn.batch_first Else [time, bat CH, Dim]:p Aram Seq_lengths:longtensor shape [batch]: return:the result of RNN layer, "" "Batch_first = Rnn.batch_first # Assume seq_lengths = [3, 5, 2] # sort the sequence length (descending), sorted_seq_lengths = [5, 3, 2] # indices for [
    1, 0, 2], the value of indices can be expressed in the language # original batch in the 0 position of the value, now on position 1.
    # The value of the original batch in the 1 position, now on position 0.
    # The value of the original batch in the 2 position, now on position 2. Sorted_seq_lengths, indices = Torch.sort (Seq_lengths, descending=true) # If we want to restore the result of the calculation to the order before sorting, # only need to indices
    Sort again (ascending), get [0, 1, 2], # desorted_indices The result is [1, 0, 2] # use Desorted_indices to index the results of the calculation. _, Desorted_indices = Torch.sort (Indices, descending=false) # sort the original sequence if batch_first:inputs = Inputs[i Ndices] ELse:inputs = inputs[:, indices] packed_inputs = nn.utils.rnn.pack_padded_sequence (inputs,
                                                      Sorted_seq_lengths.cpu (). NumPy (), Batch_first=batch_first) res, state = RNN (packed_inputs) padded_res, _ = nn.utils.rnn.pad_packed_sequence (res, Batch_first=batch_first) # Restore the sample order before sorting if batch_first:desorted_res = Padded_res[desorted_indices] Els
    E:desorted_res = padded_res[:, desorted_indices] return desorted_res if __name__ = = "__main__": bs = 3 Max_time_step = 5 Feat_size = Hidden_size = 7 Seq_lengths = [3, 5, 2] rnn = nn. GRU (Input_size=feat_size, Hidden_size=hidden_size, Batch_first=true, bidirectional=true) x = Variabl E (torch. Floattensor (BS, Max_time_step, Feat_size). Normal_ ()) Using_packed_res = Rnn_forwarder (RNN, X, seq_lengths) print ( Using_packed_res) # do not use PACK_PADed, used to compare with the above results. Not_packed_res, _ = RNN (x) print (not_packed_res)

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.