-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update to support Gymnasium #277
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
Related to #271 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @arjun-kg thanks for preparing the PR. Looking forward to using the latest gym. I left some very preliminary comments.
One important thing to note during the refactor is to see if the change could result in a performance difference (not just a simple variable renaming). For example, the current PPO scripts did not handle the time out correctly, so handling time out correctly in this PR is a performance-impacting change.
We need to be careful with the performance-impacting changes because we would need to re-run the benchmarks on those changes to ensure there is no surprise regression in the performance.
cleanrl/ppo.py
Outdated
@@ -213,18 +213,18 @@ def get_action_and_value(self, x, action=None): | |||
writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) | |||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part may need to be changed to
if "episode" in info:
for item in info["episode"]["r"]:
print(f"global_step={global_step}, episodic_return={item}")
writer.add_scalar("charts/episodic_return", item, global_step)
break
for item in info["episode"]["l"]:
writer.add_scalar("charts/episodic_length", item, global_step)
break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To replicate the original behavior, you probably need something like (but hopefully better-looking than!):
if "episode" in info:
first_idx = info["_episode"].nonzero()[0][0]
r = info["episode"]["r"][first_idx]
l = info["episode"]["l"][first_idx]
print(f"global_step={global_step}, episodic_return={r}")
writer.add_scalar("charts/episodic_return", r, global_step)
writer.add_scalar("charts/episodic_length", l, global_step)
There's no guarantee that the first index in "episode" won't just be a zero, need the mask to specify which one.
Alternatively, it might be better to track a running average using the deques built into the RecordEpisodeStatistics wrapper, though that would likely results in different performance graphs.
@@ -159,12 +159,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): | |||
envs.single_observation_space, | |||
envs.single_action_space, | |||
device, | |||
handle_timeout_termination=True, | |||
handle_timeout_termination=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!
cleanrl/ppo.py
Outdated
with torch.no_grad(): | ||
next_value = agent.get_value(next_obs).reshape(1, -1) | ||
if args.gae: | ||
advantages = torch.zeros_like(rewards).to(device) | ||
lastgaelam = 0 | ||
for t in reversed(range(args.num_steps)): | ||
if t == args.num_steps - 1: | ||
nextnonterminal = 1.0 - next_done | ||
nextnonterminal = 1.0 - next_terminated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A note for myself: this is a change that could impact performance. We would need to re-run the benchmark here.
Thanks @arjun-kg for the PR. We look forward to supporting the next generation of gym. It's important to identify the performance-impacting changes and non-performance-impacting changes:
In this PR for initial support for v0.26.1, let's aim to make only non-performance-impacting changes. With that said, I have added a todo list in the PR description. |
@arjun-kg I made the first pass of editing to make Btw the plan is to have an announcement like the following on the main page, since I expect to encounter more issues. |
Hi ! |
@GaetanLepage yeah, we should do @arjun-kg I added some changes to |
@arjun-kg we are thinking of probably supporting both gymnasium and gym simultaneously. See #318 (comment) as an example. This will give us a much smoother transition |
@vwxyzjn sounds good, will check it out. I'm a bit tied up this week. I'll continue work on this from next week if it's okay. |
for idx, d in enumerate(dones): | ||
if d: | ||
real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I guess this line has been forgotten in the migration :-).
We have just released Gymnasium v0.27.0, this should be backward compatible. Would it be possible to update this Pr to v0.27 and check that nothing new breaks |
@vwxyzjn recently SB3 supports gymnasium with a branch, but I'm not sure if some parallel work is going on to update cleanrl to gymnasium? Would you like me to update this PR to gymnasium with SB3 on the gymnasium branch? |
real_next_obs = next_obs.copy() | ||
for idx, d in enumerate(dones): | ||
for idx, d in enumerate(terminateds): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it use truncated
instead of terminated
here ?
With truncated
, the results are identical with same seeding between the old and new implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was a mistake, it should be truncated
@@ -191,12 +190,12 @@ def forward(self, x): | |||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | |||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making the assumption that there will be no parrallel env, this could work:
if "final_info" in infos:
info = infos["final_info"][0]
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
But I have seen that there is a different solution in the DQN file
@@ -71,7 +70,7 @@ def thunk(): | |||
if capture_video: | |||
if idx == 0: | |||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | |||
env.seed(seed) | |||
|
|||
env.action_space.seed(seed) | |||
env.observation_space.seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think env.observation_space.seed(seed)
can be remove
Just tried the PR with: diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py
import time
from distutils.util import strtobool
-import gym
+import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn with stable-baselines3==2.0.0a5 and gymnasium==0.28.1 when i run
always after 'global_step=10009' execution stop with this error:
|
i think it was not intended to remove the following line:
This could be the fix (fixed it for me): diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py
index 14864e7..4e73a6e 100644
--- a/cleanrl/dqn.py
+++ b/cleanrl/dqn.py
@@ -156,7 +156,7 @@ if __name__ == "__main__":
start_time = time.time()
# TRY NOT TO MODIFY: start the game
- obs = envs.reset(seed=args.seed)
+ obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
@@ -185,6 +185,7 @@ if __name__ == "__main__":
for idx, d in enumerate(infos["_final_observation"]):
if d:
real_next_obs[idx] = infos["final_observation"][idx]
+ rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs |
@pseudo-rnd-thoughts absolutely. Closing this PR now. |
Description
A draft PR updating CleanRL to support Gymnasium. Closes #263
This mostly includes updating step and seed API. Tries to use gymnasium branches on the dependent packages (SB3 etc) After these are updated, will verify the changes, check the tests, and get the PR ready for review.
Costa's comment:
Thanks @arjun-kg for the PR. We look forward to supporting the next generation of gym.
It's important to identify the performance-impacting changes and non-performance-impacting changes:
In this PR for initial support fo v0.26.1, let's aim to make only non-performance-impacting changes. With that said, here is a todo list:
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).