A Gentle Introduction to LSTM
In the previous article, we learned about RNN’s and learnt that they have two problems that is vanishing gradients and exploding gradients. To overcome these problems, we have a state of the art method called LSTM- Long Short Term Memory.
What are LSTM’s?
Long Short-Term Memory (LSTM) networks are special RNNs, with different gates and memory cells. It can be described as a special kind of recurrent neural network capable of handling long-term dependencies.
LSTM Architecture
The LSTM consists of three gates, as shown in the image below and each part performs an individual function.
- Gates:
Input gate
Forget gate
Output gate
The first part chooses whether the information coming from the previous timestamp is to be remembered or is irrelevant and can be forgotten. In the second part, the cell tries to learn new information from the input to this cell. At last, in the third part, the cell passes the updated information from the current timestamp to the next timestamp.
These three parts of an LSTM cell are known as gates. The first part is called Forget gate, the second part is known as the Input gate and the last one is the Output gate.All gates use sigmoid activation function.Let’s examine in details how these three gates work together.
Forget gate
First, we have the forget gate. This gate decides what information should be thrown away or kept. Information from the previous hidden state and information from the current input is passed through the sigmoid function. Values come out between 0 and 1. The closer to 0 means to forget, and the closer to 1 means to keep.
Input Gate
To update the cell state, we have the input gate. First, we pass the previous hidden state and current input into a sigmoid function. That decides which values will be updated by transforming the values to be between 0 and 1. 0 means not important, and 1 means important.
Output Gate
The output gate decides what the next hidden state should be. First, we pass the previous hidden state and the current input into a sigmoid function. Then we pass the newly modified cell state to the tanh function. We multiply the tanh output with the sigmoid output to decide what information the hidden state should carry. The output is the hidden state. The new cell state and the new hidden is then carried over to the next time step.
Code Implementation using Pytorch
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)