The previous section describes the use of Pytorch to construct a CNN network, which introduces points to advanced things lstm.
Please refer to the two famous blogs about Lstm's introduction to the theory:
http://karpathy.github.io/2015/05/21/rnn-effectiveness/
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
And one of my previous Chinese translation blogs:
http://blog.csdn.net/q295684174/article/details/78973445 LSTM
Class Torch.nn.LSTM (*args, **kwargs)
Parameters input_size input feature dimension hidden_size the number of dimensions of the hidden layer num_layers rnn layer bias The hidden layer state is bias, the default is True Batch_first Whether the first dimension of the input output is batchsize dropout whether to add dropout layer after the RNN layer except the last RNN layer bidirectional is bidirectional RNN, the default is False
Inputs:input, (H_0, c_0) input (Seq_len, batch, input_size) contains the input sequence of the feature, and if Batch_first is set, batch is the first dimension (H_0, c_0) hidden layer State
The outputs:output, (H_n, c_n) output (Seq_len, batch, Hidden_size * num_directions) contains the output characteristics for each moment, if the Batch_first is set , batch is the first dimension (H_n, c_n) hidden layer State Model
Class RNN (NN. Module):
def __init__ (self, input_size, hidden_size, Num_layers, num_classes):
super (RNN, self). __init__ ()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn. LSTM (Input_size, Hidden_size, Num_layers,
batch_first=true)
self.fc = nn. Linear (Hidden_size, Num_classes) # 2 for Bidirection
def forward (self, x):
# forward Propagate RNN
Out, _ = self.lstm (x) # Decode Hidden state of the last time
step
out = SELF.FC (out[:,-1,:])
return
out RNN = RNN (Input_size, Hidden_size, Num_layers, num_classes)
Rnn.cuda ()
Pytorch implementation of LSTM is very convenient, only need to define the input dimension, the hidden layer dimension, the number of RNN, as well as the number of categories can be. If the input state of the lstm is empty, it is initialized to 0 by default. On Mnist, it takes only 2 epochs to achieve a 97% accuracy.