Skip to content

Code to train quickly and efficiently diffusion-based neural environments with just 500 lines of code

Notifications You must be signed in to change notification settings

Francesco215/neural_env

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gymnasium Env to Neural Env

This repository is inspired by GameNGen and DIAMOND and is designed for researchers and practitioners interested in leveraging diffusion models to simulate dynamic environments.

The main objectives are

  • Having as few lines of code as possible (~400 lines in the src/ folder and ~100 for the train.py code)
  • Making it run on consumer hardware and performing experiments in ~1 hour

It creates a Neural Gymnasium Environment where the dynamics are determined by a diffusion-based world model.

Simplified code

# take a starting gym environment
original_env = gym.make("LunarLander-v3", render_mode="rgb_array")

# start with some pretrained model
model_id="stabilityai/stable-diffusion-2-1"
autoencoder = AutoencoderKL.from_pretrained(model_id, subfolder="vae").requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
diffusion_scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

# define the diffusion model
diffusion = DiffusionModel(autoencoder, unet, diffusion_scheduler, state_size, original_env.action_space.n).to(device)

# and you can define the NeuralEnv like so
neural_env = NeuralEnv(diffusion,original_env)

The NeuralEnv class is a subclass of gymnasium.Env and inherits all its methods. The key difference is that .step(), .render(), and other functions are evaluated using the diffusion module. This assures that the external API of the NeuralEnv class is the same at the one of any other gymnasium.Env class.

The diffusion model can be trained in ~1h to simulate the dynamics of the gym.Env it was trained on.

Examples

Here is an example with 26M parameters LoRa trained in ~30min on a RTX3090. The first 8 frames are given as starting frames, and all of the others are generated

frame_history_2800

There are still some artifacts and it's not perfect, but it demonstrates how leveraging pretrained models like Stable Diffusion allows for efficient adaptation to train world models that simulate dynamic environments.

Pretrained models already encode rich representations from extensive training, reducing the computational and data requirements for fine-tuning, and enabling faster convergence to high-quality results.

What's under the hood

The world model is a LoRa of Stable Diffusion 2.1. This allows for much faster and efficient training. The graph below shows the result of a ~30min run on a RTX3090

training_loss

The first convolutional layer of the diffusion model is expanded to be able to take in multiple frames by frame stacking (image below taken from the DIAMOND paper) Screenshot 2024-12-13 at 00 24 05

Installation

Installation

To set up the repository, follow these steps:

  1. Clone the repository:
    git clone https://github.com/Francesco215/neural_env.git
    cd neural_env
  2. Install the required Python dependencies:
    pip install -r requirements.txt
  3. Optionally, set up a virtual environment:
    pip install uv swig  # you still need to have swig installed globally :(
    uv venv
    . .venv/bin/activate
    uv pip install -e .

Running the Code

  • To train the model, run:
python train.py
  • To generate videos, run:
python video.py

Future plans

  • Add noise-depentend weighting to the loss function
  • At the moment the world model is trained on random actions, it would be nice to train it on an expert agent
  • Leverage the world model to train a reward model
  • Use the world and reward model to train a bootstrapped RL agent

About

Code to train quickly and efficiently diffusion-based neural environments with just 500 lines of code

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages