Skip to content

Commit

Permalink
Update erupt-gym to work with gymnasium (#37)
Browse files Browse the repository at this point in the history
* Make gymnasium work

* Remove tfevents

* wip

* Remove prints

* Fmt

* Remove experiment name from config

* Add experiment name to args parsing

* Update docs
  • Loading branch information
jaxs-ribs authored May 5, 2023
1 parent 386a678 commit fc7b1a5
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 25 deletions.
Binary file added .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions configurations/gym-lunarlander.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ game_kind: 'LunarLander-v2'
game_config:
env_count: 10
continuous: true
render_mode: "human"
2 changes: 1 addition & 1 deletion docs/src/local-training.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 🖥 Local training

Erupt supports launching and training locally with Gym, and potentially other engines. To do so; you'll have to define a
Erupt supports launching and training locally with Gymnasium, and potentially other engines. To do so; you'll have to define a
configuration that follows our [configuration format](./concepts/engine.md). This is an example config using `erupt-gym`
to train locally.

Expand Down
5 changes: 2 additions & 3 deletions src/py/erupt-gym/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Gym bindings for Erupt

This contains the bindings for training with Gym on Erupt.
This contains the bindings for training with Gymnasium on Erupt.

## Running from inside Pants

`pants run //src/py/erupt-gym:train -- --help`
`pants run //src/py/erupt-gym:train -- -c config/file/relative/to-repo-root`
`pants run cmd:train -- configurations/gym-lunarlander.yaml configurations/temp_artifact_location/`
11 changes: 3 additions & 8 deletions src/py/erupt-gym/erupt_gym/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,32 @@
import threading
from collections import deque

from emote.callbacks import LoggingMixin
from emote.callbacks import Callback, LoggingMixin
from emote.proxies import AgentProxy, MemoryProxy

from erupt_gym.wrappers import DictGymWrapper


class GymCollector(LoggingMixin):
class GymCollector(LoggingMixin, Callback):
MAX_NUMBER_REWARDS = 1000

def __init__(
self,
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render: bool = True,
warmup_steps: int = 0,
):
super().__init__()
self._agent = agent
self._memory = memory
self._env = env
self._render = render
self._last_environment_rewards = deque(maxlen=1000)
self.num_envs = env.num_envs
self._warmup_steps = warmup_steps

def collect_data(self):
"""Collect a single rollout"""
if self._render:
self._env.render()
actions = self._agent(self._obs)
next_obs, ep_info = self._env.dict_step(actions)

Expand Down Expand Up @@ -71,10 +67,9 @@ def __init__(
env: DictGymWrapper,
agent: AgentProxy,
memory: MemoryProxy,
render: bool = True,
warmup_steps: int = 0,
):
super().__init__(env, agent, memory, render, warmup_steps)
super().__init__(env, agent, memory, warmup_steps)
self._warmup_steps = warmup_steps
self._stop = False
self._thread = None
Expand Down
15 changes: 7 additions & 8 deletions src/py/erupt-gym/erupt_gym/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import dataclass, field
from functools import partial
from typing import Any, ClassVar, Type
from typing import Any, Type

import numpy as np
import torch
Expand Down Expand Up @@ -124,7 +124,6 @@ def construct(
agent_proxy,
memory_proxy,
warmup_steps=batch_size,
render=False,
),
]

Expand Down Expand Up @@ -205,7 +204,7 @@ def protocols() -> dict[str, Type[IProtocol]]:

@classmethod
def create_experiment(
cls: ClassVar[GymEngine],
cls: Type[GymEngine],
engine_config: EngineConfig,
protocol_kind: str,
protocol_config: Any,
Expand All @@ -216,13 +215,13 @@ def create_experiment(
protocol_config = protocol_class.deserialize_protocol_configuration(protocol_config)

def _make_env(game, game_config, rank):
import gym
import gymnasium

def _thunk():
env = gym.make(game, **game_config)
env = gym.wrappers.FrameStack(env, 3)
env = gym.wrappers.FlattenObservation(env)
env.seed(rank)
env = gymnasium.make(game, **game_config)
env = gymnasium.wrappers.FrameStack(env, 3)
env = gymnasium.wrappers.FlattenObservation(env)
_ = env.reset(seed=rank)
return env

return _thunk
Expand Down
7 changes: 2 additions & 5 deletions src/py/erupt-gym/erupt_gym/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ def __init__(self, env: VectorEnv):
DictSpace({"obs": BoxSpace(os.dtype, os.shape)}),
)

def render(self):
self.env.envs[0].render()

def dict_step(self, actions: Dict[AgentId, DictResponse]) -> Dict[AgentId, DictObservation]:
batched_actions = np.stack(
[actions[agent].list_data["actions"] for agent in self._agent_ids]
)
self.step_async(batched_actions)
next_obs, rewards, dones, info = super().step_wait()
next_obs, rewards, dones, _truncations, info = super().step_wait()
new_agents = []
results = {}
completed_episode_rewards = []
Expand Down Expand Up @@ -82,7 +79,7 @@ def dict_step(self, actions: Dict[AgentId, DictResponse]) -> Dict[AgentId, DictO
def dict_reset(self) -> Dict[AgentId, DictObservation]:
self._agent_ids = [next(self._next_agent) for i in range(self.num_envs)]
self.reset_async()
obs = self.reset_wait()
obs, _info = self.reset_wait()
return {
agent_id: DictObservation(
episode_state=EpisodeState.INITIAL,
Expand Down
1 change: 1 addition & 0 deletions src/py/erupt/erupt/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, args: TrainerArguments):
engine_kind = params_dict["engine_kind"]

engine_config = params_dict["engine_config"]
engine_config["experiment_name"] = args.experiment_name
engine_config["artifact_location"] = args.artifact_location

if args.infer_only == ServingMode.CONFIG:
Expand Down

0 comments on commit fc7b1a5

Please sign in to comment.