-
Notifications
You must be signed in to change notification settings - Fork 4
Home
#Plain old LSTM code, explained. (polstm.q)
LSTM - Long, short-term memory. A flavour of recurrent neural networks, modified in order to beat the vanishing gradient problem. Simply put, as the derivative of error/loss gets passed back through the neural network(back-propagation), multiplication occurs at each step. And with each multiplication, the gradient/derivative shrinks, rendering it useless beyond a point. Here is a step-by-step breakup of the code implementing a LSTM. The test cae in question is the following - given a sequence of text, predict the next character in the sequence. Roughly, the LSTM does this as follows :
- For each (one-hot encoded) character, calculate the loss of prediction against a target (the next character in the sequence, also one-hot encoded)
- Back-propagate the error derivative into the neural net, to adjust weights, and continue.
The above was a bird's-eye view of what happens, from start to finish. Now to delve into the gory internals of what happens inside the model, specifically the LSTM.
Viewing the LSTM as a machine, here are the nuts and bolts:
-
Input layer. This layer accepts inputs into the LSTM, typically a one-hot encoded character array, or a word embedding vector for NLP applications. The dimension of this layer is usually (1Xn), n being the vocabulary size, or length of the embedding vector. Inside the code, this is defined as follows:
/ Read text input - shakespearean text k:read0 `:cmplshake.txt; / Remove multiple spaces, only one space. cleanText:{x where(or)':[not " "=x]}; TXT:" " sv cleanText over 'k; P:0; SEQLEN:25; / calculate VOCAB_SIZE VOCAB_SIZE:count CHARS: distinct TXT; / the following code will run in a while loop, from 0 to size(INPUT) / P increases as loop proceeds / Send batches of SEQLEN characters each time INPUT:TXT[P+til SEQLEN]; / Send in INPUT, appended with position of each (unique) character in CHARS INPUT:INPUT,'(CHARS?INPUT); / T is a counter, from 0 to count(INPUT) / One-hot encoded input here. XS::raze (1,VOCAB_SIZE)#0; XS[INPUT[T][1]]:1;