diff --git a/src/recurrent/lstm-time-step.end-to-end.test.ts b/src/recurrent/lstm-time-step.end-to-end.test.ts index ecd2f160..5a8b1cbb 100644 --- a/src/recurrent/lstm-time-step.end-to-end.test.ts +++ b/src/recurrent/lstm-time-step.end-to-end.test.ts @@ -24,4 +24,28 @@ describe('LSTMTimeStep', () => { expect(net.run([[1], [0.001]])[0]).toBeGreaterThan(0.9); expect(net.run([[1], [1]])[0]).toBeLessThan(0.1); }); + + it('can learn a simple pattern', () => { + const net = new LSTMTimeStep({ + inputSize: 2, + hiddenLayers: [4, 2], + outputSize: 1, + }); + const trainingData = [ + { input: [0, 0], output: [0] }, + { input: [0, 1], output: [1] }, + { input: [1, 0], output: [1] }, + { input: [1, 1], output: [0] }, + ]; + const errorThresh = 0.005; + const iterations = 100; + const status = net.train(trainingData, { iterations, errorThresh }); + expect( + status.error <= errorThresh || status.iterations <= iterations + ).toBeTruthy(); + expect(net.run([0, 0])[0]).toBeCloseTo(0, 1); + expect(net.run([0, 1])[0]).toBeCloseTo(1, 1); + expect(net.run([1, 0])[0]).toBeCloseTo(1, 1); + expect(net.run([1, 1])[0]).toBeCloseTo(0, 1); + }); }); diff --git a/src/recurrent/lstm-time-step.test.ts b/src/recurrent/lstm-time-step.test.ts index 5a59f77e..8a247551 100644 --- a/src/recurrent/lstm-time-step.test.ts +++ b/src/recurrent/lstm-time-step.test.ts @@ -143,4 +143,26 @@ describe('LSTMTimeStep', () => { expect(equation.states[25].forwardFn.name).toBe('multiplyElement'); }); }); + + describe('equations property', () => { + it('should initialize equations property in the constructor', () => { + const lstmTimeStep = new LSTMTimeStep({}); + expect(lstmTimeStep.equations).toBeInstanceOf(Array); + }); + + it('should populate equations property before accessing in getEquation method', () => { + const lstmTimeStep = new LSTMTimeStep({}); + const equation = new Equation(); + const inputMatrix = new Matrix(3, 1); + const previousResult = new Matrix(3, 1); + const hiddenLayer = getHiddenLSTMLayer(3, 3); + const result = lstmTimeStep.getEquation( + equation, + inputMatrix, + previousResult, + hiddenLayer + ); + expect(result).toBeInstanceOf(Matrix); + }); + }); }); diff --git a/src/recurrent/lstm-time-step.ts b/src/recurrent/lstm-time-step.ts index c7af5f4b..f720d482 100644 --- a/src/recurrent/lstm-time-step.ts +++ b/src/recurrent/lstm-time-step.ts @@ -5,6 +5,13 @@ import { RNNTimeStep } from './rnn-time-step'; import { IRNNHiddenLayer } from './rnn'; export class LSTMTimeStep extends RNNTimeStep { + equations: Equation[]; + + constructor(options: any) { + super(options); + this.equations = []; + } + getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer { return getHiddenLSTMLayer(hiddenSize, prevSize); } @@ -15,6 +22,9 @@ export class LSTMTimeStep extends RNNTimeStep { previousResult: Matrix, hiddenLayer: IRNNHiddenLayer ): Matrix { + if (!this.equations) { + this.equations = []; + } return getLSTMEquation( equation, inputMatrix, diff --git a/src/recurrent/lstm.ts b/src/recurrent/lstm.ts index 3d2cb39c..d178aae9 100644 --- a/src/recurrent/lstm.ts +++ b/src/recurrent/lstm.ts @@ -19,6 +19,13 @@ export interface ILSTMHiddenLayer extends IRNNHiddenLayer { } export class LSTM extends RNN { + equations: Equation[]; + + constructor(options: any) { + super(options); + this.equations = []; + } + getHiddenLayer(hiddenSize: number, prevSize: number): IRNNHiddenLayer { return getHiddenLSTMLayer(hiddenSize, prevSize); } @@ -29,6 +36,9 @@ export class LSTM extends RNN { previousResult: Matrix, hiddenLayer: IRNNHiddenLayer ): Matrix { + if (!this.equations) { + this.equations = []; + } return getLSTMEquation( equation, inputMatrix,