Skip to content
Kumar Ramanathan edited this page Feb 20, 2017 · 19 revisions

#Plain old LSTM code, explained. (polstm.q)

LSTM - Long, short-term memory. A flavour of recurrent neural networks, modified in order to beat the vanishing gradient problem. Simply put, as the derivative of error/loss gets passed back through the neural network(back-propagation), multiplication occurs at each step. And with each multiplication, the gradient/derivative shrinks, rendering it useless beyond a point. Here is a step-by-step breakup of the code implementing a LSTM. The test cae in question is the following - given a sequence of text, predict the next character in the sequence. Roughly, the LSTM does this as follows :

  • For each (one-hot encoded) character, calculate the loss of prediction against a target (the next character in the sequence, also one-hot encoded)
  • Back-propagate the error derivative into the neural net, to adjust weights, and continue.

The above was a bird's-eye view of what happens, from start to finish. Now to delve into the gory internals of what happens inside the model, specifically the LSTM.

Viewing the LSTM as a machine, here are the nuts and bolts:

  • Input layer. This layer accepts inputs into the LSTM, typically a one-hot encoded character array, or a word embedding vector for NLP applications. The dimension of this layer is usually (1Xn), n being the vocabulary size, or length of the embedding vector. Inside the code, this is defined as follows:

      / Read text input - shakespearean text
      k:read0 `:cmplshake.txt;
      / Remove multiple spaces, only one space.
      cleanText:{x where(or)':[not " "=x]};
      TXT:" " sv cleanText over 'k;
      P:0;
      SEQLEN:25;
      / calculate VOCAB_SIZE
      VOCAB_SIZE:count CHARS: distinct TXT;
      / the following code will run in a while loop, from 0 to size(INPUT)
              / P increases as loop proceeds
              / Send batches of SEQLEN characters each time
              INPUT:TXT[P+til SEQLEN];
              / Send in INPUT, appended with position of each (unique) character in CHARS
              INPUT:INPUT,'(CHARS?INPUT);
              / T is a counter, from 0 to count(INPUT)
              / One-hot encoded input here. 
              XS::raze (1,VOCAB_SIZE)#0;
              XS[INPUT[T][1]]:1; 
    
Clone this wiki locally