Figure 1: Conservative zero-shot RL methods suppress the values or measures on actions not in the dataset for all tasks. Black dots represent state-action samples present in the dataset.
The is the official codebase for Zero-Shot Reinforcement Learning from Low Quality Data by Scott Jeen, Tom Bewley and Jonathan Cullen.
Imagine you've collected a dataset from a system you'd like to control more efficiently. Examples include: household robots, chemical manufacturing processes, autonomous vehicles, or steel-making furnaces. An ideal solution would be to train an autonomous agent on your dataset, then for it to use what it learns to solve any task inside the system. For our household robot, such tasks may include sweeping the floor, making a cup of tea, or cleaning the windows. Formally, we call this problem setting zero-shot reinforcement learning (RL), and taking steps toward realising it in the real-world is the focus of this work.
If our dataset is pseudo-optimal, that is to say, it tells our domestic robot the full extent of the floorspace, where the tea bags are stored, and how many windows exist, then the existing state-of-the-art method, Forward Backward (FB) representations, performs excellently. On average it will solve any task you want inside the system with 85% accuracy. However, if the data we've collected from the system is suboptimal--it doesn't provide all the information required to solve all tasks--then FB representations fail. They fail because they overestimate the value of the data not present in the dataset, or in RL parlance, they overestimate out-of-distribution state-action values--Figure 1 (Middle).
In this work, we resolve this by artificially suppressing these out-of-distribution values, leveraging ideas from conservatism in the Offline RL literature--Figure 1 (Right). In experiments across a variety of systems and tasks, we show these methods consistently outperform their non-conservative counterparts when the datasets are suboptimal--Figure 2.
Figure 2: Aggregate performance. (Left) Normalised average performance w.r.t. single-task baseline algorithm CQL. (Right) Performance profiles showing distribution of scores across all tasks and domains. Both conservative FB variants stochastically dominate vanilla FB.
We also find that our proposals don't sacrifice performance when the dataset is pseudo-optimal, and so present little downside over their predecessors.
For further detail we recommend reading the paper. Direct any correspondance to Scott Jeen or raise an issue!
Assuming you have MuJoCo installed, setup a conda env with Python 3.9.16 using requirements.txt
as usual:
conda create --name zsrl python=3.9.16
then install the dependencies from requirements.txt
:
pip install -r requirements.txt
In the paper we report results with agents trained on datasets collected from different exploratory algorithms on different domains. The domains are:
Domain | Eval Tasks | Dimensionality | Type | Reward | Command Line Argument |
---|---|---|---|---|---|
Walker | stand walk run flip |
Low | Locomotion | Dense | walker |
Quadruped | stand roll roll_fast jump escape |
High | Locomotion | Dense | quadruped |
Point-mass Maze | reach_top_left reach_top_right reach_bottom_left reach_bottom_right |
Low | Goal-reaching | Sparse | point_mass_maze |
Jaco | reach_top_left reach_top_right reach_bottom_left reach_bottom_right |
High | Goal-reaching | Sparse | jaco |
and the dataset collecting algorithms are:
Dataset Collecting Algorithm | State Coverage | Command Line Argument |
---|---|---|
Random Network Distillation (RND) | High | rnd |
Diversity is All You Need (DIAYN) | Medium | diayn |
Random | Low | random |
State coverage illustrations on point_mass_maze
are provided in Figure 3. For each domain, datasets need to be downloaded manually from the ExORL benchmark then reformatted.
To download the rnd
dataset on the walker
domain, seperate their command line args with an _
and run:
python exorl_reformatter.py walker_rnd
this will create a single dataset.npz
file in the dataset/walker/rnd/buffer
directory.
Figure 3: State coverage by dataset on point_mass_maze
.
To use Weights & Biases for logging, create a free account and run wandb login
from the command line.
Subsequent runs will automatically log to a new project named conservative-world-models
.
We provide implementations of the following algorithms:
Algorithm | Authors | Command Line Argument |
---|---|---|
Conservative |
Kumar et. al (2020) | cql |
Offline TD3 | Fujimoto et. al (2021) | td3 |
Universal Successor Features learned with Laplacian Eigenfunctions (SF-LAP) | Borsa et. al (2018) | sf-lap |
FB Representations | Touati et. al (2023) | fb |
Value-Conservative FB Representations | Jeen et. al (2024) | vcfb |
Measure-Conservative FB Representations | Jeen et. al (2024) | mcfb |
To train a standard Value-Conservative Forward Backward Representation with the rnd
(100k) dataset to solve all tasks in the walker
domain, run:
python main_offline.py vcfb walker rnd --eval_task stand run walk flip
If you find this work informative please consider citing the paper!
@article{jeen2023,
url = {https://arxiv.org/abs/2309.15178},
author = {Jeen, Scott and Bewley, Tom and Cullen, Jonathan M.},
title = {Zero-Shot Reinforcement Learning from Low Quality Data},
publisher = {arXiv},
year = {2023},
}
This work licensed under a standard MIT License, see LICENSE.md
for further details.