Skip to content

LSTM implemented from scratch and with PyTorch's nn.LSTM, trained using PyTorch Lightning on a toy stock prediction task. Educational and beginner-friendly.

License

Notifications You must be signed in to change notification settings

BlazeWild/LSTM_FROM_SCRATCH

Repository files navigation

LSTM_FROM_SCRATCH

This project demonstrates two approaches to implementing LSTM neural networks: manual LSTM from scratch and PyTorch's built-in LSTM module, both using PyTorch Lightning for training.

📈 Project Overview

The project predicts company stock values based on previous days' data:

  • Company A: [0, 0.5, 0.25, 1] → Predicted: 0
  • Company B: [1, 0.5, 0.25, 1] → Predicted: 1

1750081306073

📺 Based on StatQuest Tutorial

StatQuest LSTM Tutorial

Long Short-Term Memory with PyTorch + Lightning - StatQuest with Josh Starmer

Two LSTM Implementations

1. LSTM from Scratch (lstm_scratch.py)

  • Manual LSTM cell implementation
  • Shows LSTM gates (forget, input, output)
  • Educational: understand LSTM internals

2. PyTorch LSTM Module (lstm_nn.py)

  • Uses PyTorch's nn.LSTM module
  • Optimized and faster
  • Production ready

Implementation Comparison

Epochs From Scratch Loss PyTorch LSTM Loss
500 0.45 0.23
1000 0.25 0.12
3000 0.10 0.05

Lightning Trainer Features

  • Checkpointing: Auto-save best models, resume training
  • Logging: TensorBoard integration, progress bars
  • Training: Early stopping, gradient clipping

TensorBoard Visualization

tensorboard --logdir=lstm_nn/  # replace with your checkpoint folder

TensorBoard Log Structure

lstm_nn/lstm_nn/version_X/
├── events.out.tfevents.*  # TensorBoard events
├── hparams.yaml          # Hyperparameters
└── checkpoints/          # Model checkpoints

Model Architecture

  • Input: 4 values (sequence length)
  • LSTM: 1 input size, 1 hidden size
  • Output: 1 predicted value

Usage

Running LSTM Implementations

cd predict
python lstm_nn.py        # PyTorch LSTM implementation
python lstm_scratch.py   # Manual LSTM implementation

Continue Training from Checkpoints

trainer = L.Trainer(max_epochs=4000, logger=logger)
trainer.fit(model, train_dataloaders=dataLoader,
           ckpt_path="lstm_nn/lstm_nn/version_X/checkpoints/epoch=2999-step=4000.ckpt")

About

LSTM implemented from scratch and with PyTorch's nn.LSTM, trained using PyTorch Lightning on a toy stock prediction task. Educational and beginner-friendly.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published