diff --git a/LICENSE b/LICENSE index cb2434995..c4e562311 100644 --- a/LICENSE +++ b/LICENSE @@ -291,3 +291,26 @@ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +-------------------------------------------------------------------------------- + +Code in `cleanrl/qdagger_dqn_atari_impalacnn.py` and `cleanrl/qdagger_dqn_atari_jax_impalacnn.py` are adapted from https://github.com/google-research/reincarnating_rl + +**NOTE: the original repo did not fill out the copyright section in their license +so the following copyright notice is copied as is per the license requirement. +See https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/LICENSE#L189 + + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index 829602bb9..4adb562d0 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,8 @@ You may also use a prebuilt development environment hosted in Gitpod: | | [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy) | | ✅ [Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppg/#ppg_procgenpy) | | ✅ [Random Network Distillation (RND)](https://arxiv.org/abs/1810.12894) | [`ppo_rnd_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_rnd_envpool.py), [docs](/rl-algorithms/ppo-rnd/#ppo_rnd_envpoolpy) | +| ✅ [Qdagger](https://arxiv.org/abs/2206.01626) | [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) | +| | [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) | ## Open RL Benchmark diff --git a/benchmark/qdagger.sh b/benchmark/qdagger.sh new file mode 100644 index 000000000..2491716a0 --- /dev/null +++ b/benchmark/qdagger.sh @@ -0,0 +1,15 @@ +poetry install -E atari +OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ + --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ + --command "poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --track --capture-video" \ + --num-seeds 3 \ + --workers 1 + + +poetry install -E "atari jax" +poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ + --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ + --command "poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --track --capture-video" \ + --num-seeds 3 \ + --workers 1 diff --git a/cleanrl/qdagger_dqn_atari_impalacnn.py b/cleanrl/qdagger_dqn_atari_impalacnn.py new file mode 100644 index 000000000..15b6e273b --- /dev/null +++ b/cleanrl/qdagger_dqn_atari_impalacnn.py @@ -0,0 +1,477 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from huggingface_hub import hf_hub_download +from rich.progress import track +from stable_baselines3.common.atari_wrappers import ( + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, +) +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + +from cleanrl.dqn_atari import QNetwork as TeacherModel +from cleanrl_utils.evals.dqn_eval import evaluate + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to save model into the `runs/{run_name}` folder") + parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to upload the saved model to huggingface") + parser.add_argument("--hf-entity", type=str, default="", + help="the user or org name of the model repository from the Hugging Face Hub") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=10000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=1e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=1, + help="the number of parallel game environments") + parser.add_argument("--buffer-size", type=int, default=1000000, + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=1., + help="the target network update rate") + parser.add_argument("--target-network-frequency", type=int, default=1000, + help="the timesteps it takes to update the target network") + parser.add_argument("--batch-size", type=int, default=32, + help="the batch size of sample from the reply memory") + parser.add_argument("--start-e", type=float, default=1, + help="the starting epsilon for exploration") + parser.add_argument("--end-e", type=float, default=0.01, + help="the ending epsilon for exploration") + parser.add_argument("--exploration-fraction", type=float, default=0.10, + help="the fraction of `total-timesteps` it takes from start-e to go end-e") + parser.add_argument("--learning-starts", type=int, default=80000, + help="timestep to start learning") + parser.add_argument("--train-frequency", type=int, default=4, + help="the frequency of training") + + # QDagger specific arguments + parser.add_argument("--teacher-policy-hf-repo", type=str, default=None, + help="the huggingface repo of the teacher policy") + parser.add_argument("--teacher-eval-episodes", type=int, default=10, + help="the number of episodes to run the teacher policy evaluate") + parser.add_argument("--teacher-steps", type=int, default=500000, + help="the number of steps to run the teacher policy to generate the replay buffer") + parser.add_argument("--offline-steps", type=int, default=500000, + help="the number of steps to run the student policy with the teacher's replay buffer") + parser.add_argument("--temperature", type=float, default=1.0, + help="the temperature parameter for qdagger") + args = parser.parse_args() + # fmt: on + assert args.num_envs == 1, "vectorized envs are not supported at the moment" + + if args.teacher_policy_hf_repo is None: + args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari-seed1" + + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = ClipRewardEnv(env) + env = gym.wrappers.ResizeObservation(env, (84, 84)) + env = gym.wrappers.GrayScaleObservation(env) + env = gym.wrappers.FrameStack(env, 4) + env.action_space.seed(seed) + + return env + + return thunk + + +# taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + + def forward(self, x): + inputs = x + x = nn.functional.relu(x) + x = self.conv0(x) + x = nn.functional.relu(x) + x = self.conv1(x) + return x + inputs + + +class ConvSequence(nn.Module): + def __init__(self, input_shape, out_channels): + super().__init__() + self._input_shape = input_shape + self._out_channels = out_channels + self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1) + self.res_block0 = ResidualBlock(self._out_channels) + self.res_block1 = ResidualBlock(self._out_channels) + + def forward(self, x): + x = self.conv(x) + x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.res_block0(x) + x = self.res_block1(x) + assert x.shape[1:] == self.get_output_shape() + return x + + def get_output_shape(self): + _c, h, w = self._input_shape + return (self._out_channels, (h + 1) // 2, (w + 1) // 2) + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + def __init__(self, env): + super().__init__() + c, h, w = envs.single_observation_space.shape + shape = (c, h, w) + conv_seqs = [] + for out_channels in [16, 32, 32]: + conv_seq = ConvSequence(shape, out_channels) + shape = conv_seq.get_output_shape() + conv_seqs.append(conv_seq) + conv_seqs += [ + nn.Flatten(), + nn.ReLU(), + nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256), + nn.ReLU(), + nn.Linear(in_features=256, out_features=env.single_action_space.n), + ] + self.network = nn.Sequential(*conv_seqs) + + def forward(self, x): + return self.network(x / 255.0) + + +def linear_schedule(start_e: float, end_e: float, duration: int, t: int): + slope = (end_e - start_e) / duration + return max(slope * t + start_e, end_e) + + +def kl_divergence_with_logits(target_logits, prediction_logits): + """Implementation of on-policy distillation loss.""" + out = -F.softmax(target_logits, dim=-1) * (F.log_softmax(prediction_logits, dim=-1) - F.log_softmax(target_logits, dim=-1)) + return torch.sum(out) + + +if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1" +""" + ) + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + q_network = QNetwork(envs).to(device) + optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) + target_network = QNetwork(envs).to(device) + target_network.load_state_dict(q_network.state_dict()) + + # QDAGGER LOGIC: + teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari.cleanrl_model") + teacher_model = TeacherModel(envs).to(device) + teacher_model.load_state_dict(torch.load(teacher_model_path, map_location=device)) + teacher_model.eval() + + # evaluate the teacher model + teacher_episodic_returns = evaluate( + teacher_model_path, + make_env, + args.env_id, + eval_episodes=args.teacher_eval_episodes, + run_name=f"{run_name}-teacher-eval", + Model=TeacherModel, + epsilon=0.05, + capture_video=False, + ) + writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) + + # collect teacher data for args.teacher_steps + # we assume we don't have access to the teacher's replay buffer + # see Fig. A.19 in Agarwal et al. 2022 for more detail + teacher_rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + device, + optimize_memory_usage=True, + handle_timeout_termination=False, + ) + + obs, _ = envs.reset(seed=args.seed) + for global_step in track(range(args.teacher_steps), description="filling teacher's replay buffer"): + epsilon = linear_schedule(args.start_e, args.end_e, args.teacher_steps, global_step) + if random.random() < epsilon: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + q_values = teacher_model(torch.Tensor(obs).to(device)) + actions = torch.argmax(q_values, dim=1).cpu().numpy() + next_obs, rewards, terminated, truncated, infos = envs.step(actions) + real_next_obs = next_obs.copy() + for idx, d in enumerate(truncated): + if d: + real_next_obs[idx] = infos["final_observation"][idx] + teacher_rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + obs = next_obs + + # offline training phase: train the student model using the qdagger loss + for global_step in track(range(args.offline_steps), description="offline student training"): + data = teacher_rb.sample(args.batch_size) + # perform a gradient-descent step + with torch.no_grad(): + target_max, _ = target_network(data.next_observations).max(dim=1) + td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) + teacher_q_values = teacher_model(data.observations) / args.temperature + + student_q_values = q_network(data.observations) + old_val = student_q_values.gather(1, data.actions).squeeze() + q_loss = F.mse_loss(td_target, old_val) + + student_q_values = student_q_values / args.temperature + distill_loss = torch.mean(kl_divergence_with_logits(teacher_q_values, student_q_values)) + + loss = q_loss + 1.0 * distill_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update the target network + if global_step % args.target_network_frequency == 0: + for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): + target_network_param.data.copy_(args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data) + + if global_step % 100 == 0: + writer.add_scalar("charts/offline/loss", loss, global_step) + writer.add_scalar("charts/offline/q_loss", q_loss, global_step) + writer.add_scalar("charts/offline/distill_loss", distill_loss, global_step) + + if global_step % 100000 == 0: + # evaluate the student model + model_path = f"runs/{run_name}/{args.exp_name}-offline-{global_step}.cleanrl_model" + torch.save(q_network.state_dict(), model_path) + print(f"model saved to {model_path}") + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=QNetwork, + device=device, + epsilon=0.05, + ) + print(episodic_returns) + writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) + + rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + device, + optimize_memory_usage=True, + handle_timeout_termination=False, + ) + start_time = time.time() + + # TRY NOT TO MODIFY: start the game + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) + obs, _ = envs.reset(seed=args.seed) + episodic_returns = deque(maxlen=10) + # online training phase + for global_step in track(range(args.total_timesteps), description="online student training"): + global_step += args.offline_steps + # ALGO LOGIC: put action logic here + epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) + if random.random() < epsilon: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + q_values = q_network(torch.Tensor(obs).to(device)) + actions = torch.argmax(q_values, dim=1).cpu().numpy() + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, terminated, truncated, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + if "final_info" in infos: + for info in infos["final_info"]: + # Skip the envs that are not done + if "episode" not in info: + continue + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + writer.add_scalar("charts/epsilon", epsilon, global_step) + episodic_returns.append(info["episode"]["r"]) + break + + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` + real_next_obs = next_obs.copy() + for idx, d in enumerate(truncated): + if d: + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + # perform a gradient-descent step + if len(episodic_returns) < 10: + distill_coeff = 1.0 + else: + distill_coeff = max(1 - np.mean(episodic_returns) / np.mean(teacher_episodic_returns), 0) + with torch.no_grad(): + target_max, _ = target_network(data.next_observations).max(dim=1) + td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) + teacher_q_values = teacher_model(data.observations) / args.temperature + + student_q_values = q_network(data.observations) + old_val = student_q_values.gather(1, data.actions).squeeze() + q_loss = F.mse_loss(td_target, old_val) + + student_q_values = student_q_values / args.temperature + distill_loss = torch.mean(kl_divergence_with_logits(teacher_q_values, student_q_values)) + + loss = q_loss + distill_coeff * distill_loss + + if global_step % 100 == 0: + writer.add_scalar("losses/loss", loss, global_step) + writer.add_scalar("losses/td_loss", q_loss, global_step) + writer.add_scalar("losses/distill_loss", distill_loss, global_step) + writer.add_scalar("losses/q_values", old_val.mean().item(), global_step) + writer.add_scalar("charts/distill_coeff", distill_coeff, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + print(distill_coeff) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # optimize the model + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update the target network + if global_step % args.target_network_frequency == 0: + for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): + target_network_param.data.copy_( + args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data + ) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + torch.save(q_network.state_dict(), model_path) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.dqn_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=QNetwork, + device=device, + epsilon=0.05, + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "Qdagger", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() diff --git a/cleanrl/qdagger_dqn_atari_jax_impalacnn.py b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py new file mode 100644 index 000000000..ce55baf4c --- /dev/null +++ b/cleanrl/qdagger_dqn_atari_jax_impalacnn.py @@ -0,0 +1,487 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool +from typing import Sequence + +os.environ[ + "XLA_PYTHON_CLIENT_MEM_FRACTION" +] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 + +import flax +import flax.linen as nn +import gymnasium as gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +from huggingface_hub import hf_hub_download +from rich.progress import track +from stable_baselines3.common.atari_wrappers import ( + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, +) +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + +from cleanrl.dqn_atari_jax import QNetwork as TeacherModel +from cleanrl_utils.evals.dqn_jax_eval import evaluate + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to save model into the `runs/{run_name}` folder") + parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to upload the saved model to huggingface") + parser.add_argument("--hf-entity", type=str, default="", + help="the user or org name of the model repository from the Hugging Face Hub") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=10000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=1e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=1, + help="the number of parallel game environments") + parser.add_argument("--buffer-size", type=int, default=1000000, + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=1., + help="the target network update rate") + parser.add_argument("--target-network-frequency", type=int, default=1000, + help="the timesteps it takes to update the target network") + parser.add_argument("--batch-size", type=int, default=32, + help="the batch size of sample from the reply memory") + parser.add_argument("--start-e", type=float, default=1, + help="the starting epsilon for exploration") + parser.add_argument("--end-e", type=float, default=0.01, + help="the ending epsilon for exploration") + parser.add_argument("--exploration-fraction", type=float, default=0.10, + help="the fraction of `total-timesteps` it takes from start-e to go end-e") + parser.add_argument("--learning-starts", type=int, default=80000, + help="timestep to start learning") + parser.add_argument("--train-frequency", type=int, default=4, + help="the frequency of training") + + # QDagger specific arguments + parser.add_argument("--teacher-policy-hf-repo", type=str, default=None, + help="the huggingface repo of the teacher policy") + parser.add_argument("--teacher-eval-episodes", type=int, default=10, + help="the number of episodes to run the teacher policy evaluate") + parser.add_argument("--teacher-steps", type=int, default=500000, + help="the number of steps to run the teacher policy to generate the replay buffer") + parser.add_argument("--offline-steps", type=int, default=500000, + help="the number of steps to run the student policy with the teacher's replay buffer") + parser.add_argument("--temperature", type=float, default=1.0, + help="the temperature parameter for qdagger") + args = parser.parse_args() + # fmt: on + assert args.num_envs == 1, "vectorized envs are not supported at the moment" + + if args.teacher_policy_hf_repo is None: + args.teacher_policy_hf_repo = f"cleanrl/{args.env_id}-dqn_atari_jax-seed1" + + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if "FIRE" in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = ClipRewardEnv(env) + env = gym.wrappers.ResizeObservation(env, (84, 84)) + env = gym.wrappers.GrayScaleObservation(env) + env = gym.wrappers.FrameStack(env, 4) + env.action_space.seed(seed) + + return env + + return thunk + + +# taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py +class ResidualBlock(nn.Module): + channels: int + + @nn.compact + def __call__(self, x): + inputs = x + x = nn.relu(x) + x = nn.Conv( + self.channels, + kernel_size=(3, 3), + )(x) + x = nn.relu(x) + x = nn.Conv( + self.channels, + kernel_size=(3, 3), + )(x) + return x + inputs + + +class ConvSequence(nn.Module): + channels: int + + @nn.compact + def __call__(self, x): + x = nn.Conv( + self.channels, + kernel_size=(3, 3), + )(x) + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") + x = ResidualBlock(self.channels)(x) + x = ResidualBlock(self.channels)(x) + return x + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + action_dim: int + channelss: Sequence[int] = (16, 32, 32) + + @nn.compact + def __call__(self, x): + x = jnp.transpose(x, (0, 2, 3, 1)) + x = x / (255.0) + for channels in self.channelss: + x = ConvSequence(channels)(x) + x = nn.relu(x) + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + return x + + +class TrainState(TrainState): + target_params: flax.core.FrozenDict + + +def linear_schedule(start_e: float, end_e: float, duration: int, t: int): + slope = (end_e - start_e) / duration + return max(slope * t + start_e, end_e) + + +if __name__ == "__main__": + import stable_baselines3 as sb3 + + if sb3.__version__ < "2.0": + raise ValueError( + """Ongoing migration: run the following command to install the new dependencies: + +poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1" +""" + ) + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, q_key = jax.random.split(key, 2) + + # env setup + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + q_network = QNetwork(channelss=(16, 32, 32), action_dim=envs.single_action_space.n) + + q_state = TrainState.create( + apply_fn=q_network.apply, + params=q_network.init(q_key, envs.observation_space.sample()), + target_params=q_network.init(q_key, envs.observation_space.sample()), + tx=optax.adam(learning_rate=args.learning_rate), + ) + q_network.apply = jax.jit(q_network.apply) + + # QDAGGER LOGIC: + teacher_model_path = hf_hub_download(repo_id=args.teacher_policy_hf_repo, filename="dqn_atari_jax.cleanrl_model") + teacher_model = TeacherModel(action_dim=envs.single_action_space.n) + teacher_model_key = jax.random.PRNGKey(args.seed) + teacher_params = teacher_model.init(teacher_model_key, envs.observation_space.sample()) + with open(teacher_model_path, "rb") as f: + teacher_params = flax.serialization.from_bytes(teacher_params, f.read()) + teacher_model.apply = jax.jit(teacher_model.apply) + + # evaluate the teacher model + teacher_episodic_returns = evaluate( + teacher_model_path, + make_env, + args.env_id, + eval_episodes=args.teacher_eval_episodes, + run_name=f"{run_name}-teacher-eval", + Model=TeacherModel, + epsilon=0.05, + capture_video=False, + ) + writer.add_scalar("charts/teacher/avg_episodic_return", np.mean(teacher_episodic_returns), 0) + + # collect teacher data for args.teacher_steps + # we assume we don't have access to the teacher's replay buffer + # see Fig. A.19 in Agarwal et al. 2022 for more detail + teacher_rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + "cpu", + optimize_memory_usage=True, + handle_timeout_termination=False, + ) + + obs, _ = envs.reset(seed=args.seed) + for global_step in track(range(args.teacher_steps), description="filling teacher's replay buffer"): + epsilon = linear_schedule(args.start_e, args.end_e, args.teacher_steps, global_step) + if random.random() < epsilon: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + q_values = teacher_model.apply(teacher_params, obs) + actions = q_values.argmax(axis=-1) + actions = jax.device_get(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions) + real_next_obs = next_obs.copy() + for idx, d in enumerate(truncated): + if d: + real_next_obs[idx] = infos["final_observation"][idx] + teacher_rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + obs = next_obs + + def kl_divergence_with_logits(target_logits, prediction_logits): + """Implementation of on-policy distillation loss.""" + out = -nn.softmax(target_logits) * (nn.log_softmax(prediction_logits) - nn.log_softmax(target_logits)) + return jnp.sum(out) + + @jax.jit + def update(q_state, observations, actions, next_observations, rewards, dones, distill_coeff): + q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions) + q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) + td_target = rewards + (1 - dones) * args.gamma * q_next_target + teacher_q_values = teacher_model.apply(teacher_params, observations) + + def loss(params, td_target, teacher_q_values, distill_coeff): + student_q_values = q_network.apply(params, observations) # (batch_size, num_actions) + q_pred = student_q_values[np.arange(student_q_values.shape[0]), actions.squeeze()] # (batch_size,) + q_loss = ((q_pred - td_target) ** 2).mean() + teacher_q_values = teacher_q_values / args.temperature + student_q_values = student_q_values / args.temperature + distill_loss = jnp.mean(jax.vmap(kl_divergence_with_logits)(teacher_q_values, student_q_values)) + overall_loss = q_loss + distill_coeff * distill_loss + return overall_loss, (q_loss, q_pred, distill_loss) + + (loss_value, (q_loss, q_pred, distill_loss)), grads = jax.value_and_grad(loss, has_aux=True)( + q_state.params, td_target, teacher_q_values, distill_coeff + ) + q_state = q_state.apply_gradients(grads=grads) + return loss_value, q_loss, q_pred, distill_loss, q_state + + # offline training phase: train the student model using the qdagger loss + for global_step in track(range(args.offline_steps), description="offline student training"): + data = teacher_rb.sample(args.batch_size) + # perform a gradient-descent step + loss, q_loss, old_val, distill_loss, q_state = update( + q_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + 1.0, + ) + + # update the target network + if global_step % args.target_network_frequency == 0: + q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)) + + if global_step % 100 == 0: + writer.add_scalar("charts/offline/loss", jax.device_get(loss), global_step) + writer.add_scalar("charts/offline/q_loss", jax.device_get(q_loss), global_step) + writer.add_scalar("charts/offline/distill_loss", jax.device_get(distill_loss), global_step) + + if global_step % 100000 == 0: + # evaluate the student model + model_path = f"runs/{run_name}/{args.exp_name}-offline-{global_step}.cleanrl_model" + with open(model_path, "wb") as f: + f.write(flax.serialization.to_bytes(q_state.params)) + print(f"model saved to {model_path}") + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=QNetwork, + epsilon=0.05, + ) + print(episodic_returns) + writer.add_scalar("charts/offline/avg_episodic_return", np.mean(episodic_returns), global_step) + + rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + "cpu", + optimize_memory_usage=True, + handle_timeout_termination=False, + ) + start_time = time.time() + + # TRY NOT TO MODIFY: start the game + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) + obs, _ = envs.reset(seed=args.seed) + episodic_returns = deque(maxlen=10) + # online training phase + for global_step in track(range(args.total_timesteps), description="online student training"): + global_step += args.offline_steps + # ALGO LOGIC: put action logic here + epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) + if random.random() < epsilon: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + q_values = q_network.apply(q_state.params, obs) + actions = q_values.argmax(axis=-1) + actions = jax.device_get(actions) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, terminated, truncated, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + if "final_info" in infos: + for info in infos["final_info"]: + # Skip the envs that are not done + if "episode" not in info: + continue + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + writer.add_scalar("charts/epsilon", epsilon, global_step) + episodic_returns.append(info["episode"]["r"]) + break + + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` + real_next_obs = next_obs.copy() + for idx, d in enumerate(truncated): + if d: + real_next_obs[idx] = infos["final_observation"][idx] + rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + if global_step % args.train_frequency == 0: + data = rb.sample(args.batch_size) + # perform a gradient-descent step + if len(episodic_returns) < 10: + distill_coeff = 1.0 + else: + distill_coeff = max(1 - np.mean(episodic_returns) / np.mean(teacher_episodic_returns), 0) + loss, q_loss, old_val, distill_loss, q_state = update( + q_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + distill_coeff, + ) + + if global_step % 100 == 0: + writer.add_scalar("losses/loss", jax.device_get(loss), global_step) + writer.add_scalar("losses/td_loss", jax.device_get(q_loss), global_step) + writer.add_scalar("losses/distill_loss", jax.device_get(distill_loss), global_step) + writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) + writer.add_scalar("charts/distill_coeff", distill_coeff, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + print(distill_coeff) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + # update the target network + if global_step % args.target_network_frequency == 0: + q_state = q_state.replace( + target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau) + ) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + with open(model_path, "wb") as f: + f.write(flax.serialization.to_bytes(q_state.params)) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.dqn_jax_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=QNetwork, + epsilon=0.05, + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "Qdagger", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() diff --git a/docs/rl-algorithms/overview.md b/docs/rl-algorithms/overview.md index 22f7288e7..6c2d94ea8 100644 --- a/docs/rl-algorithms/overview.md +++ b/docs/rl-algorithms/overview.md @@ -29,3 +29,5 @@ | | :material-github: [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py), :material-file-document: [docs](/rl-algorithms/td3/#td3_continuous_action_jaxpy) | | ✅ [Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | :material-github: [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), :material-file-document: [docs](/rl-algorithms/ppg/#ppg_procgenpy) | | ✅ [Random Network Distillation (RND)](https://arxiv.org/abs/1810.12894) | :material-github: [`ppo_rnd_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_rnd_envpool.py), :material-file-document: [docs](/rl-algorithms/ppo-rnd/#ppo_rnd_envpoolpy) | +| ✅ [Qdagger](https://arxiv.org/abs/2206.01626) | :material-github: [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) | +| | :material-github: [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) | \ No newline at end of file diff --git a/docs/rl-algorithms/qdagger.md b/docs/rl-algorithms/qdagger.md new file mode 100644 index 000000000..678c73818 --- /dev/null +++ b/docs/rl-algorithms/qdagger.md @@ -0,0 +1,201 @@ +# QDagger + +## Overview + +QDagger is an extension of the DQN algorithm that uses previously computed results, like teacher policy and teacher replay buffer, to help train student policy. This method eliminates the need for learning from scratch, improving sample efficiency and reducing computational effort in training new policy. + +Original paper: + +* [Reincarnating Reinforcement Learning: Reusing Prior Computation to Accelerate Progress](https://arxiv.org/abs/2206.01626) + +Reference resources: + +* :material-github: [google-research/reincarnating_rl](https://github.com/google-research/reincarnating_rl) +* [Original Paper's Website](https://agarwl.github.io/reincarnating_rl/) + +## Implemented Variants + +| Variants Implemented | Description | +| ----------- | ----------- | +| :material-github: [`qdagger_dqn_atari_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_impalacnnpy) | For playing Atari games. It uses Impala-CNN from RainbowDQN and common atari-based pre-processing techniques. | +| :material-github: [`qdagger_dqn_atari_jax_impalacnn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py), :material-file-document: [docs](/rl-algorithms/qdagger/#qdagger_dqn_atari_jax_impalacnnpy) | For playing Atari games. It uses Impala-CNN from RainbowDQN and common atari-based pre-processing techniques. | + + +Below are our single-file implementations of QDagger: + + +## `qdagger_dqn_atari_impalacnn.py` + +The [qdagger_dqn_atari_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py) has the following features: + +* For playing Atari games. It uses Impala-CNN from RainbowDQN and common atari-based pre-processing techniques. +* Its teacher policy uses CleanRL's `dqn_atari` policy from the [huggingface/cleanrl](https://huggingface.co/cleanrl) repository. +* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)` +* Works with the `Discrete` action space + +### Usage + +```bash +poetry install -E atari +python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 +python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 +``` + +=== "poetry" + + ```bash + poetry install -E atari + poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 + poetry run python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 + ``` + +=== "pip" + + ```bash + pip install -r requirements/requirements-atari.txt + python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id BreakoutNoFrameskip-v4 + python cleanrl/qdagger_dqn_atari_impalacnn.py --env-id PongNoFrameskip-v4 + ``` + + +### Explanation of the logged metrics + +Running `python cleanrl/qdagger_dqn_atari_impalacnn.py` will automatically record various metrics such as value or distillation losses in Tensorboard. Below is the documentation for these metrics: + +* `charts/episodic_return`: episodic return of the game +* `charts/SPS`: number of steps per second +* `losses/td_loss`: the mean squared error (MSE) between the Q values at timestep $t$ and the Bellman update target estimated using the reward $r_t$ and the Q values at timestep $t+1$, thus minimizing the *one-step* temporal difference. Formally, it can be expressed by the equation below. +$$ + J(\theta^{Q}) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \big[ (Q(s, a) - y)^2 \big], +$$ +with the Bellman update target is $y = r + \gamma \, Q^{'}(s', a')$ and the replay buffer is $\mathcal{D}$. +* `losses/q_values`: implemented as `qf1(data.observations, data.actions).view(-1)`, it is the average Q values of the sampled data in the replay buffer; useful when gauging if under or over estimation happens. +* `losses/distill_loss`: the distillation loss, which is the KL divergence between the teacher policy $\pi_T$ and the student policy $\pi$. Formally, it can be expressed by the equation below. +$$ + L_{\text{distill}} = \lambda_t \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \sum_a \pi_T(a|s)\log\pi(a|s)\right] +$$ +* `Charts/distill_coeff`: the coefficient $\lambda_t$ for the distillation loss, which is a function of the ratio between the teacher policy $\pi_T$ and the student policy $\pi$. Formally, it can be expressed by the equation below. +$$ +\lambda_t = 1_{t + + + + + + + +Learning curve comparison with `dqn_atari`: + + + +Tracked experiments and game play videos: + + + + +## `qdagger_dqn_atari_jax_impalacnn.py` + + +The [qdagger_dqn_atari_jax_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py) has the following features: + +* Uses [Jax](https://github.com/google/jax), [Flax](https://github.com/google/flax), and [Optax](https://github.com/deepmind/optax) instead of `torch`. [qdagger_dqn_atari_jax_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_jax_impalacnn.py) is roughly 25%-50% faster than [qdagger_dqn_atari_impalacnn.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/qdagger_dqn_atari_impalacnn.py) +* For playing Atari games. It uses Impala-CNN from RainbowDQN and common atari-based pre-processing techniques. +* Its teacher policy uses CleanRL's `dqn_atari_jax` policy from the [huggingface/cleanrl](https://huggingface.co/cleanrl) repository. +* Works with the Atari's pixel `Box` observation space of shape `(210, 160, 3)` +* Works with the `Discrete` action space + +### Usage + + +=== "poetry" + + ```bash + poetry install -E "atari jax" + poetry run pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id BreakoutNoFrameskip-v4 + poetry run python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id PongNoFrameskip-v4 + ``` + +=== "pip" + + ```bash + pip install -r requirements/requirements-atari.txt + pip install -r requirements/requirements-jax.txt + pip install --upgrade "jax[cuda]==0.3.17" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id BreakoutNoFrameskip-v4 + python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --env-id PongNoFrameskip-v4 + ``` + + +???+ warning + + Note that JAX does not work in Windows :fontawesome-brands-windows:. The official [docs](https://github.com/google/jax#installation) recommends using Windows Subsystem for Linux (WSL) to install JAX. + +### Explanation of the logged metrics + +See [related docs](/rl-algorithms/qdagger/#explanation-of-the-logged-metrics) for `qdagger_dqn_atari_impalacnn.py`. + +### Implementation details + +See [related docs](/rl-algorithms/qdagger/#implementation-details) for `qdagger_dqn_atari_impalacnn.py`. + +### Experiment results + +Below are the average episodic returns for `qdagger_dqn_atari_jax_impalacnn.py`. + + +| Environment | `qdagger_dqn_atari_jax_impalacnn.py` 10M steps(40M frames) | (Agarwal et al., 2022)[^1] 10M frames | +| ----------- | ----------- | ----------- | +| BreakoutNoFrameskip-v4 | 335.08 ± 19.12 | 275.15 ± 20.65 | +| PongNoFrameskip-v4 | 18.75 ± 0.19 | - | +| BeamRiderNoFrameskip-v4 | 8024.75 ± 579.02 | 6514.25 ± 411.10 | + + +Learning curves: + +
+ + + + + +
+ +Learning curve comparison with `dqn_atari_jax`: + + + + +[^1]:Agarwal, Rishabh, Max Schwarzer, Pablo Samuel Castro, Aaron Courville, and Marc G. Bellemare. “Reincarnating Reinforcement Learning: Reusing Prior Computation to Accelerate Progress.” arXiv, October 4, 2022. http://arxiv.org/abs/2206.01626. diff --git a/docs/rl-algorithms/qdagger/BeamRiderNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/BeamRiderNoFrameskip-v4.png new file mode 100644 index 000000000..c1af79311 Binary files /dev/null and b/docs/rl-algorithms/qdagger/BeamRiderNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/BreakoutNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/BreakoutNoFrameskip-v4.png new file mode 100644 index 000000000..ed2e1e649 Binary files /dev/null and b/docs/rl-algorithms/qdagger/BreakoutNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/PongNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/PongNoFrameskip-v4.png new file mode 100644 index 000000000..a2b14778a Binary files /dev/null and b/docs/rl-algorithms/qdagger/PongNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/compare.png b/docs/rl-algorithms/qdagger/compare.png new file mode 100644 index 000000000..c136b878f Binary files /dev/null and b/docs/rl-algorithms/qdagger/compare.png differ diff --git a/docs/rl-algorithms/qdagger/jax/BeamRiderNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/jax/BeamRiderNoFrameskip-v4.png new file mode 100644 index 000000000..ede257ba5 Binary files /dev/null and b/docs/rl-algorithms/qdagger/jax/BeamRiderNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/jax/BreakoutNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/jax/BreakoutNoFrameskip-v4.png new file mode 100644 index 000000000..5c385e26c Binary files /dev/null and b/docs/rl-algorithms/qdagger/jax/BreakoutNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/jax/PongNoFrameskip-v4.png b/docs/rl-algorithms/qdagger/jax/PongNoFrameskip-v4.png new file mode 100644 index 000000000..aa8811f80 Binary files /dev/null and b/docs/rl-algorithms/qdagger/jax/PongNoFrameskip-v4.png differ diff --git a/docs/rl-algorithms/qdagger/jax/compare.png b/docs/rl-algorithms/qdagger/jax/compare.png new file mode 100644 index 000000000..b37e4bcec Binary files /dev/null and b/docs/rl-algorithms/qdagger/jax/compare.png differ diff --git a/mkdocs.yml b/mkdocs.yml index 39aa19bb4..7af22abe2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -48,6 +48,7 @@ nav: - rl-algorithms/ppg.md - rl-algorithms/ppo-rnd.md - rl-algorithms/rpo.md + - rl-algorithms/qdagger.md - Advanced: - advanced/hyperparameter-tuning.md - advanced/resume-training.md diff --git a/poetry.lock b/poetry.lock index 34191031f..d85b6b168 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -574,7 +574,7 @@ name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" category = "main" -optional = true +optional = false python-versions = "*" files = [ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, @@ -3463,7 +3463,7 @@ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." category = "main" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "Pygments-2.15.1-py3-none-any.whl", hash = "sha256:db2db3deb4b4179f399a09054b023b6a586b76499d36965813c71aa8ed7b5fd1"}, @@ -3920,7 +3920,7 @@ name = "rich" version = "11.2.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" category = "main" -optional = true +optional = false python-versions = ">=3.6.2,<4.0.0" files = [ {file = "rich-11.2.0-py3-none-any.whl", hash = "sha256:d5f49ad91fb343efcae45a2b2df04a9755e863e50413623ab8c9e74f05aee52b"}, @@ -4578,6 +4578,7 @@ files = [ {file = "tinyscaler-1.2.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4af0a9502e9ef118c84de80b09544407c8dbbe815af215b1abb8eb170271ab71"}, {file = "tinyscaler-1.2.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0bde14fb15027d73f4cc5ac837e849feb1cbedbfc0a0c0928f11756f08f6626"}, {file = "tinyscaler-1.2.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c75799068330ff7c28fd01f10409d4f12c22f1adbe732f1699228449a4d712"}, + {file = "tinyscaler-1.2.5.tar.gz", hash = "sha256:deb47df1a53a55b53f0ae15b89b4814af184d149a8149385e54e11afc57364a5"}, ] [package.dependencies] @@ -4984,30 +4985,32 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -atari = ["AutoROM", "ale-py", "opencv-python"] +atari = ["ale-py", "AutoROM", "opencv-python"] c51 = [] -c51-atari = ["AutoROM", "ale-py", "opencv-python"] -c51-atari-jax = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", "opencv-python"] -c51-jax = ["flax", "jax", "jaxlib"] -cloud = ["awscli", "boto3"] -dm-control = ["mujoco", "shimmy"] -docs = ["markdown-include", "mkdocs-material", "openrlbenchmark"] +c51-atari = ["ale-py", "AutoROM", "opencv-python"] +c51-atari-jax = ["ale-py", "AutoROM", "opencv-python", "jax", "jaxlib", "flax"] +c51-jax = ["jax", "jaxlib", "flax"] +cloud = ["boto3", "awscli"] +dm-control = ["shimmy", "mujoco"] +docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"] dqn = [] -dqn-atari = ["AutoROM", "ale-py", "opencv-python"] -dqn-atari-jax = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", "opencv-python"] -dqn-jax = ["flax", "jax", "jaxlib"] +dqn-atari = ["ale-py", "AutoROM", "opencv-python"] +dqn-atari-jax = ["ale-py", "AutoROM", "opencv-python", "jax", "jaxlib", "flax"] +dqn-jax = ["jax", "jaxlib", "flax"] envpool = ["envpool"] -jax = ["flax", "jax", "jaxlib"] -mujoco = ["imageio", "mujoco"] +jax = ["jax", "jaxlib", "flax"] +mujoco = ["mujoco", "imageio"] mujoco-py = ["free-mujoco-py"] -optuna = ["optuna", "optuna-dashboard", "rich"] +optuna = ["optuna", "optuna-dashboard"] pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"] plot = [] -ppo-atari-envpool-xla-jax-scan = ["AutoROM", "ale-py", "envpool", "flax", "jax", "jaxlib", "opencv-python"] +ppo-atari-envpool-xla-jax-scan = ["ale-py", "AutoROM", "opencv-python", "jax", "jaxlib", "flax", "envpool"] procgen = ["procgen"] pytest = ["pytest"] +qdagger-dqn-atari-impalacnn = ["ale-py", "AutoROM", "opencv-python"] +qdagger-dqn-atari-jax-impalacnn = ["ale-py", "AutoROM", "opencv-python", "jax", "jaxlib", "flax"] [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.11" -content-hash = "35fa1060eb934e0873608e938cfa5370d24c003c01efe7afe2ee569c492396d1" +content-hash = "2b738b0e6e88ac605120cd147cfbc026499171d38afd45d1ff15c8c0152ec5cb" diff --git a/pyproject.toml b/pyproject.toml index 3bec254e2..bfb1427ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ gymnasium = ">=0.28.1" moviepy = "^1.0.3" pygame = "2.1.0" huggingface-hub = "^0.11.1" +rich = "<12.0" ale-py = {version = "0.7.4", optional = true} AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2", optional = true} @@ -39,7 +40,6 @@ jaxlib = {version = "^0.3.15", optional = true} flax = {version = "^0.6.0", optional = true} optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} -rich = {version = "<12.0", optional = true} envpool = {version = "^0.6.4", optional = true} PettingZoo = {version = "1.18.1", optional = true} SuperSuit = {version = "3.4.0", optional = true} @@ -73,7 +73,7 @@ mujoco_py = ["free-mujoco-py"] jax = ["jax", "jaxlib", "flax"] docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"] envpool = ["envpool"] -optuna = ["optuna", "optuna-dashboard", "rich"] +optuna = ["optuna", "optuna-dashboard"] pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"] cloud = ["boto3", "awscli"] dm_control = ["shimmy", "mujoco"] @@ -97,4 +97,11 @@ ppo_atari_envpool_xla_jax_scan = [ "ale-py", "AutoROM", "opencv-python", # atari "jax", "jaxlib", "flax", # jax "envpool", # envpool +] +qdagger_dqn_atari_impalacnn = [ + "ale-py", "AutoROM", "opencv-python" +] +qdagger_dqn_atari_jax_impalacnn = [ + "ale-py", "AutoROM", "opencv-python", # atari + "jax", "jaxlib", "flax", # jax ] \ No newline at end of file diff --git a/requirements/requirements-atari.txt b/requirements/requirements-atari.txt index 7d17166c2..e2c7fb94b 100644 --- a/requirements/requirements-atari.txt +++ b/requirements/requirements-atari.txt @@ -8,7 +8,8 @@ certifi==2023.5.7 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -48,12 +49,14 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-cloud.txt b/requirements/requirements-cloud.txt index 99ca1e86f..cb33f79e7 100644 --- a/requirements/requirements-cloud.txt +++ b/requirements/requirements-cloud.txt @@ -9,6 +9,7 @@ charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -48,12 +49,14 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" s3transfer==0.6.1 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-dm_control.txt b/requirements/requirements-dm_control.txt index 456fd56bc..de47adb1c 100644 --- a/requirements/requirements-dm_control.txt +++ b/requirements/requirements-dm_control.txt @@ -5,7 +5,8 @@ certifi==2023.5.7 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" dm-control==1.0.11 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -51,6 +52,7 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyopengl==3.1.6 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -58,6 +60,7 @@ pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" scipy==1.7.3 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-envpool.txt b/requirements/requirements-envpool.txt index 83f24488b..9ec85812e 100644 --- a/requirements/requirements-envpool.txt +++ b/requirements/requirements-envpool.txt @@ -7,7 +7,8 @@ chardet==4.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" dill==0.3.6 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -60,6 +61,7 @@ pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-mujoco.txt b/requirements/requirements-mujoco.txt index d038c98a6..44a4a6654 100644 --- a/requirements/requirements-mujoco.txt +++ b/requirements/requirements-mujoco.txt @@ -5,7 +5,8 @@ certifi==2023.5.7 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -45,6 +46,7 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyopengl==3.1.6 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -52,6 +54,7 @@ pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-mujoco_py.txt b/requirements/requirements-mujoco_py.txt index 2a82f088e..18155a287 100644 --- a/requirements/requirements-mujoco_py.txt +++ b/requirements/requirements-mujoco_py.txt @@ -6,7 +6,8 @@ cffi==1.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" cython==0.29.34 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -50,12 +51,14 @@ pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.1 pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pycparser==2.21 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index bd03422e2..fd860c418 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -5,7 +5,8 @@ certifi==2023.5.7 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -45,12 +46,14 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-procgen.txt b/requirements/requirements-procgen.txt index f28d31ed3..017f49d47 100644 --- a/requirements/requirements-procgen.txt +++ b/requirements/requirements-procgen.txt @@ -6,7 +6,8 @@ cffi==1.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -50,12 +51,14 @@ pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.1 pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pycparser==2.21 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 023f0c544..15e082f86 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,7 +5,8 @@ certifi==2023.5.7 ; python_full_version >= "3.7.1" and python_version < "3.11" charset-normalizer==3.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" click==8.1.3 ; python_full_version >= "3.7.1" and python_version < "3.11" cloudpickle==2.2.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_system == "Windows" +colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -43,12 +44,14 @@ psutil==5.9.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1-modules==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pyasn1==0.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" pygame==2.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +pygments==2.15.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pyparsing==3.0.9 ; python_full_version >= "3.7.1" and python_version < "3.11" python-dateutil==2.8.2 ; python_full_version >= "3.7.1" and python_version < "3.11" pytz==2023.3 ; python_full_version >= "3.7.1" and python_version < "3.11" pyyaml==5.4.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < "3.11" requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/tests/test_atari_gymnasium.py b/tests/test_atari_gymnasium.py index 3d629a4c8..3153577db 100644 --- a/tests/test_atari_gymnasium.py +++ b/tests/test_atari_gymnasium.py @@ -15,3 +15,19 @@ def test_dqn_eval(): shell=True, check=True, ) + + +def test_qdagger_dqn_atari_impalacnn(): + subprocess.run( + "python cleanrl/qdagger_dqn_atari_impalacnn.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", + shell=True, + check=True, + ) + + +def test_qdagger_dqn_atari_impalacnn_eval(): + subprocess.run( + "python cleanrl/qdagger_dqn_atari_impalacnn.py --save-model True --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", + shell=True, + check=True, + ) diff --git a/tests/test_atari_jax_gymnasium.py b/tests/test_atari_jax_gymnasium.py index 86fc1ec09..b73e692d9 100644 --- a/tests/test_atari_jax_gymnasium.py +++ b/tests/test_atari_jax_gymnasium.py @@ -15,3 +15,19 @@ def test_dqn_jax_eval(): shell=True, check=True, ) + + +def test_qdagger_dqn_atari_jax_impalacnn(): + subprocess.run( + "python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", + shell=True, + check=True, + ) + + +def test_qdagger_dqn_atari_jax_impalacnn_eval(): + subprocess.run( + "python cleanrl/qdagger_dqn_atari_jax_impalacnn.py --save-model True --learning-starts 10 --total-timesteps 16 --buffer-size 10 --batch-size 4 --teacher-steps 16 --offline-steps 16 --teacher-eval-episodes 1", + shell=True, + check=True, + )