Skip to content

VC-FB and MC-FB algorithms from "Zero-Shot Reinforcement Learning from Low Quality Data" (NeurIPS 2024)

License

Notifications You must be signed in to change notification settings

enjeeneer/zero-shot-rl

Repository files navigation

Zero Shot Reinforcement Learning from Low Quality Data

License: MIT Code style: black PyTorch Paper

Figure 1: Conservative zero-shot RL methods suppress the values or measures on actions not in the dataset for all tasks. Black dots represent state-action samples present in the dataset.

The is the official codebase for Zero-Shot Reinforcement Learning from Low Quality Data by Scott Jeen, Tom Bewley and Jonathan Cullen.

Summary

Imagine you've collected a dataset from a system you'd like to control more efficiently. Examples include: household robots, chemical manufacturing processes, autonomous vehicles, or steel-making furnaces. An ideal solution would be to train an autonomous agent on your dataset, then for it to use what it learns to solve any task inside the system. For our household robot, such tasks may include sweeping the floor, making a cup of tea, or cleaning the windows. Formally, we call this problem setting zero-shot reinforcement learning (RL), and taking steps toward realising it in the real-world is the focus of this work.

If our dataset is pseudo-optimal, that is to say, it tells our domestic robot the full extent of the floorspace, where the tea bags are stored, and how many windows exist, then the existing state-of-the-art method, Forward Backward (FB) representations, performs excellently. On average it will solve any task you want inside the system with 85% accuracy. However, if the data we've collected from the system is suboptimal--it doesn't provide all the information required to solve all tasks--then FB representations fail. They fail because they overestimate the value of the data not present in the dataset, or in RL parlance, they overestimate out-of-distribution state-action values--Figure 1 (Middle).

In this work, we resolve this by artificially suppressing these out-of-distribution values, leveraging ideas from conservatism in the Offline RL literature--Figure 1 (Right). In experiments across a variety of systems and tasks, we show these methods consistently outperform their non-conservative counterparts when the datasets are suboptimal--Figure 2.

Figure 2: Aggregate performance. (Left) Normalised average performance w.r.t. single-task baseline algorithm CQL. (Right) Performance profiles showing distribution of scores across all tasks and domains. Both conservative FB variants stochastically dominate vanilla FB.

We also find that our proposals don't sacrifice performance when the dataset is pseudo-optimal, and so present little downside over their predecessors.

For further detail we recommend reading the paper. Direct any correspondance to Scott Jeen or raise an issue!

Setup

Dependencies

Assuming you have MuJoCo installed, setup a conda env with Python 3.9.16 using requirements.txt as usual:

conda create --name zsrl python=3.9.16

then install the dependencies from requirements.txt:

pip install -r requirements.txt

Domains and Datasets

In the paper we report results with agents trained on datasets collected from different exploratory algorithms on different domains. The domains are:

Domain Eval Tasks Dimensionality Type Reward Command Line Argument
Walker stand walk run flip Low Locomotion Dense walker
Quadruped stand roll roll_fast jump escape High Locomotion Dense quadruped
Point-mass Maze reach_top_left reach_top_right reach_bottom_left reach_bottom_right Low Goal-reaching Sparse point_mass_maze
Jaco reach_top_left reach_top_right reach_bottom_left reach_bottom_right High Goal-reaching Sparse jaco

and the dataset collecting algorithms are:

Dataset Collecting Algorithm State Coverage Command Line Argument
Random Network Distillation (RND) High rnd
Diversity is All You Need (DIAYN) Medium diayn
Random Low random

State coverage illustrations on point_mass_maze are provided in Figure 3. For each domain, datasets need to be downloaded manually from the ExORL benchmark then reformatted. To download the rnd dataset on the walker domain, seperate their command line args with an _ and run:

python exorl_reformatter.py walker_rnd

this will create a single dataset.npz file in the dataset/walker/rnd/buffer directory.

Figure 3: State coverage by dataset on point_mass_maze.

WandB

To use Weights & Biases for logging, create a free account and run wandb login from the command line. Subsequent runs will automatically log to a new project named conservative-world-models.

Algorithms

We provide implementations of the following algorithms:

Algorithm Authors Command Line Argument
Conservative $Q$-learning Kumar et. al (2020) cql
Offline TD3 Fujimoto et. al (2021) td3
Universal Successor Features learned with Laplacian Eigenfunctions (SF-LAP) Borsa et. al (2018) sf-lap
FB Representations Touati et. al (2023) fb
Value-Conservative FB Representations Jeen et. al (2024) vcfb
Measure-Conservative FB Representations Jeen et. al (2024) mcfb

Training

To train a standard Value-Conservative Forward Backward Representation with the rnd (100k) dataset to solve all tasks in the walker domain, run:

python main_offline.py vcfb walker rnd --eval_task stand run walk flip

Citation

If you find this work informative please consider citing the paper!

@article{jeen2023,
  url = {https://arxiv.org/abs/2309.15178},
  author = {Jeen, Scott and Bewley, Tom and Cullen, Jonathan M.},  
  title = {Zero-Shot Reinforcement Learning from Low Quality Data},
  publisher = {arXiv},
  year = {2023},
}

License

This work licensed under a standard MIT License, see LICENSE.md for further details.