Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Training Tutorial and CITATION.cff #379

Merged
merged 2 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
cff-version: 1.2.0
title: Minigrid
message: >-
If you use this software, please cite it using the
metadata from this file.
authors:
- family-names: Chevalier-Boisvert
given-names: Maxime
- family-names: Dai
given-names: Bolun
- family-names: Towers
given-names: Mark
- family-names: de Lazcano
given-names: Rodrigo
- family-names: Willems
given-names: Lucas
- family-names: Lahlou
given-names: Salem
- family-names: Pal
given-names: Suman
- family-names: Castro
given-names: Pablo Samuel
- family-names: Terry
given-names: Jordan
url: "https://github.com/Farama-Foundation/Minigrid"

preferred-citation:
type: article
authors:
- family-names: Chevalier-Boisvert
given-names: Maxime
- family-names: Dai
given-names: Bolun
- family-names: Towers
given-names: Mark
- family-names: de Lazcano
given-names: Rodrigo
- family-names: Willems
given-names: Lucas
- family-names: Lahlou
given-names: Salem
- family-names: Pal
given-names: Suman
- family-names: Castro
given-names: Pablo Samuel
- family-names: Terry
given-names: Jordan
journal: CoRR
title: Minigrid
volume: abs/2306.13831
year: 2023
67 changes: 67 additions & 0 deletions docs/content/training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
---
layout: "contents"
title: Training Minigrid Environments
firstpage:
---

## Training Minigrid Environments

The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment.

## Create Custom Feature Extractor

Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments, including Minigrid, the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we therefore need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor`

```python
class MinigridFeaturesExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
super().__init__(observation_space, features_dim)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 16, (2, 2)),
nn.ReLU(),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU(),
nn.Flatten(),
)

# Compute shape by doing one forward pass
with torch.no_grad():
n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.linear(self.cnn(observations))
```

This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18).

## Train a PPO Agent

The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done.

```python
import minigrid
from minigrid.wrappers import ImgObsWrapper
from stable_baselines3 import PPO

policy_kwargs = dict(
features_extractor_class=MinigridFeaturesExtractor,
features_extractor_kwargs=dict(features_dim=128),
)

env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array")
env = ImgObsWrapper(env)

model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(2e5)
```

By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation.

## Further Reading

One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py).
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ To cite this project please use:
content/basic_usage
content/publications
content/create_env_tutorial
content/training
```

```{toctree}
Expand Down