diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 1abf951c44b..5b57815c444 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -52,6 +52,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + collector.num_workers=1 \ + logger.backend= \ + logger.test_interval=10 python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index a5265f442b7..360c6daac28 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -117,7 +117,7 @@ "object_store_memory": 1024**3, } collector = RayCollector( - env_makers=[env] * num_collectors, + create_env_fn=[env] * num_collectors, policy=policy_module, collector_class=SyncDataCollector, collector_kwargs={ diff --git a/examples/impala/README.md b/examples/impala/README.md new file mode 100644 index 00000000000..00e0d010b82 --- /dev/null +++ b/examples/impala/README.md @@ -0,0 +1,33 @@ +## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results + +This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018. + +## Examples Structure + +Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml). + + +## Running the Examples + +You can execute the single node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node.py +``` + +You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node_ray.py +``` +or + +```bash +python impala_single_node_submitit.py +``` diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml new file mode 100644 index 00000000000..e312b336651 --- /dev/null +++ b/examples/impala/config_multi_node_ray.yaml @@ -0,0 +1,65 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html +ray_init_config: + address: null + num_cpus: null + num_gpus: null + resources: null + object_store_memory: null + local_mode: False + ignore_reinit_error: False + include_dashboard: null + dashboard_host: 127.0.0.1 + dashboard_port: null + job_config: null + configure_logging: True + logging_level: info + logging_format: null + log_to_driver: True + namespace: null + runtime_env: null + storage: null + +# Device for the forward and backward passes +local_device: "cuda:0" + +# Resources assigned to each IMPALA rollout collection worker +remote_worker_resources: + num_cpus: 1 + num_gpus: 0.25 + memory: 1073741824 # 1*1024**3 - 1GB + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml new file mode 100644 index 00000000000..f632ba15dc2 --- /dev/null +++ b/examples/impala/config_multi_node_submitit.yaml @@ -0,0 +1,46 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Device for the forward and backward passes +local_device: "cuda:0" + +# SLURM config +slurm_config: + timeout_min: 10 + slurm_partition: train + slurm_cpus_per_task: 1 + slurm_gpus_per_node: 1 + +# collector +collector: + backend: gloo + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 1 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml new file mode 100644 index 00000000000..d39407c1a69 --- /dev/null +++ b/examples/impala/config_single_node.yaml @@ -0,0 +1,38 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Device for the forward and backward passes +device: "cuda:0" + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_IMPALA + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py new file mode 100644 index 00000000000..a0d2d88c5a2 --- /dev/null +++ b/examples/impala/impala_multi_node_ray.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import RayCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.local_device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + ray_init_config = { + "address": cfg.ray_init_config.address, + "num_cpus": cfg.ray_init_config.num_cpus, + "num_gpus": cfg.ray_init_config.num_gpus, + "resources": cfg.ray_init_config.resources, + "object_store_memory": cfg.ray_init_config.object_store_memory, + "local_mode": cfg.ray_init_config.local_mode, + "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, + "include_dashboard": cfg.ray_init_config.include_dashboard, + "dashboard_host": cfg.ray_init_config.dashboard_host, + "dashboard_port": cfg.ray_init_config.dashboard_port, + "job_config": cfg.ray_init_config.job_config, + "configure_logging": cfg.ray_init_config.configure_logging, + "logging_level": cfg.ray_init_config.logging_level, + "logging_format": cfg.ray_init_config.logging_format, + "log_to_driver": cfg.ray_init_config.log_to_driver, + "namespace": cfg.ray_init_config.namespace, + "runtime_env": cfg.ray_init_config.runtime_env, + "storage": cfg.ray_init_config.storage, + } + remote_config = { + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus + if torch.cuda.device_count() + else 0, + "memory": cfg.remote_worker_resources.memory, + } + collector = RayCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + collector_class=SyncDataCollector, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + max_frames_per_traj=-1, + ray_init_config=ray_init_config, + remote_configs=remote_config, + sync=False, + update_after_each_batch=True, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "terminated"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py new file mode 100644 index 00000000000..3355febbfaf --- /dev/null +++ b/examples/impala/impala_multi_node_submitit.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main( + config_path=".", config_name="config_multi_node_submitit", version_base="1.1" +) +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import DistributedDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.local_device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + slurm_kwargs = { + "timeout_min": cfg.slurm_config.timeout_min, + "slurm_partition": cfg.slurm_config.slurm_partition, + "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task, + "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node, + } + # Create collector + device_str = "device" if num_workers <= 1 else "devices" + if cfg.collector.backend == "nccl": + collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"} + elif cfg.collector.backend == "gloo": + collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"} + else: + raise NotImplementedError( + f"device assignment not implemented for backend {cfg.collector.backend}" + ) + collector = DistributedDataCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + num_workers_per_collector=1, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + collector_class=SyncDataCollector, + collector_kwargs=collector_kwargs, + slurm_kwargs=slurm_kwargs, + storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu", + launcher="submitit", + # update_after_each_batch=True, + backend=cfg.collector.backend, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py new file mode 100644 index 00000000000..cd270f4c9e9 --- /dev/null +++ b/examples/impala/impala_single_node.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_single_node", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import MultiaSyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = torch.device(cfg.device) + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = MultiaSyncDataCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + update_at_each_batch=True, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + start_time = sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "terminated"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) + + # Compute advantage + with torch.no_grad(): + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device, non_blocking=True) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/utils.py b/examples/impala/utils.py new file mode 100644 index 00000000000..2983f8a0193 --- /dev/null +++ b/examples/impala/utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + CatFrames, + DoubleToFloat, + EndOfLifeTransform, + ExplorationType, + GrayScale, + GymEnv, + NoopResetEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + ValueOperator, +) + + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name, device, is_test=False): + env = GymEnv( + env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + env.append_transform(EndOfLifeTransform()) + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(ToTensorImage(from_int=False)) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_env(env_name, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + + del proof_environment + + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = torch.zeros(num_episodes, dtype=torch.float32) + for i in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards[i] = reward.sum() + del td_test + return test_rewards.mean() diff --git a/test/test_cost.py b/test/test_cost.py index eddf1dfc3bf..35297c3a1e6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -130,6 +130,7 @@ GAE, TD1Estimator, TDLambdaEstimator, + VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -140,6 +141,7 @@ vec_generalized_advantage_estimate, vec_td1_advantage_estimate, vec_td_lambda_advantage_estimate, + vtrace_advantage_estimate, ) from torchrl.objectives.value.utils import ( _custom_conv1d, @@ -437,7 +439,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -915,7 +917,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1400,7 +1402,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2009,7 +2011,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2696,7 +2698,7 @@ def test_sac( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3481,7 +3483,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4091,7 +4093,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4458,7 +4460,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4475,7 +4477,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4895,7 +4897,7 @@ def test_cql( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -5305,7 +5307,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -5536,6 +5538,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5550,6 +5553,8 @@ def _create_mock_actor( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -5587,6 +5592,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Sequential(base_layer, nn.Linear(5, 1)) value = ValueOperator( @@ -5613,6 +5619,7 @@ def _create_mock_actor_value_shared( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Linear(5, 1) value_head = ValueOperator( @@ -5720,7 +5727,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -5733,6 +5740,13 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5799,7 +5813,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -5812,6 +5826,12 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5873,6 +5893,7 @@ def test_ppo_shared(self, loss_class, device, advantage): "advantage", ( "gae", + "vtrace", "td", "td_lambda", ), @@ -5892,6 +5913,12 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5943,7 +5970,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5957,6 +5984,13 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6019,6 +6053,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6060,7 +6095,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" @@ -6078,7 +6113,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], ) - actor = self._create_mock_actor() + actor = self._create_mock_actor( + sample_log_prob_key=tensor_keys["sample_log_prob"] + ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) if advantage == "gae": @@ -6088,6 +6125,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -6181,7 +6225,9 @@ def test_ppo_notensordict( terminated_key=terminated_key, ) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor( + observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + ) value = self._create_mock_value(observation_key=observation_key) loss = loss_class(actor=actor, critic=value) @@ -6240,6 +6286,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -6254,6 +6301,8 @@ def _create_mock_actor( in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, + return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -6344,6 +6393,7 @@ def _create_seq_mock_data_a2c( reward_key="reward", done_key="done", terminated_key="terminated", + sample_log_prob_key="sample_log_prob", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -6373,7 +6423,7 @@ def _create_seq_mock_data_a2c( }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, @@ -6386,7 +6436,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -6399,6 +6449,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6523,7 +6580,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -6541,6 +6598,13 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td_lambda": advantage = TDLambdaEstimator( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode @@ -6590,6 +6654,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6607,6 +6672,7 @@ def test_a2c_tensordict_keys(self, td_est): "reward": "reward", "done": "done", "terminated": "terminated", + "sample_log_prob": "sample_log_prob", } self.tensordict_keys_test( @@ -6629,8 +6695,16 @@ def test_a2c_tensordict_keys(self, td_est): } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.GAE, + ValueEstimators.VTrace, + ], + ) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device): + def test_a2c_tensordict_keys_run(self, device, advantage, td_est): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -6639,6 +6713,7 @@ def test_a2c_tensordict_keys_run(self, device): value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" + sample_log_prob_key = "sample_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -6648,24 +6723,29 @@ def test_a2c_tensordict_keys_run(self, device): reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + sample_log_prob_key=sample_log_prob_key, ) - actor = self._create_mock_actor(device=device) - value = self._create_mock_value(device=device, out_keys=[value_key]) - advantage = GAE( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) - advantage.set_keys( - advantage=advantage_key, - value_target=value_target_key, - value=value_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, + actor = self._create_mock_actor( + device=device, sample_log_prob_key=sample_log_prob_key ) + value = self._create_mock_value(device=device, out_keys=[value_key]) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) + elif advantage is None: + pass + else: + raise NotImplementedError + loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( advantage=advantage_key, @@ -6675,9 +6755,23 @@ def test_a2c_tensordict_keys_run(self, device): reward=reward_key, done=done_key, terminated=done_key, + sample_log_prob=sample_log_prob_key, ) - advantage(td) + if advantage is not None: + advantage.set_keys( + advantage=advantage_key, + value_target=value_target_key, + value=value_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=sample_log_prob_key, + ) + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -6775,7 +6869,16 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) - @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.TD1, + ValueEstimators.TD0, + ValueEstimators.GAE, + ValueEstimators.TDLambda, + None, + ], + ) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 n_act = 5 @@ -7493,7 +7596,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -8235,7 +8338,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -9596,6 +9699,113 @@ def test_gae_multidim( torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) + @pytest.mark.parametrize("T", [200, 5, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("has_done", [False, True]) + def test_vtrace(self, device, gamma, N, T, dtype, has_done): + torch.manual_seed(0) + + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, 1, device=device, dtype=dtype) + state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + + _, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + + assert not torch.isnan(value_target).any() + assert not torch.isinf(value_target).any() + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(3,), (7, 3)]) + @pytest.mark.parametrize("T", [100, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("feature_dim", [[5], [2, 5]]) + @pytest.mark.parametrize("has_done", [True, False]) + def test_vtrace_multidim(self, device, gamma, N, T, dtype, has_done, feature_dim): + D = feature_dim + time_dim = -1 - len(D) + + torch.manual_seed(0) + + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, *D, device=device, dtype=dtype) + state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + + r1 = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + time_dim=time_dim, + ) + if len(D) == 2: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1, j], + log_mu[..., i : i + 1, j], + state_value[..., i : i + 1, j], + next_state_value[..., i : i + 1, j], + reward[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], + time_dim=-2, + ) + for i in range(D[0]) + for j in range(D[1]) + ] + else: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1], + log_mu[..., i : i + 1], + state_value[..., i : i + 1], + next_state_value[..., i : i + 1], + reward[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + time_dim=-2, + ) + for i in range(D[0]) + ] + + list2 = list(zip(*r2)) + r2 = [torch.cat(list2[0], -1), torch.cat(list2[1], -1)] + if len(D) == 2: + r2 = [r2[0].unflatten(-1, D), r2[1].unflatten(-1, D)] + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @@ -10530,6 +10740,7 @@ class TestAdv: [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_dispatch( @@ -10540,18 +10751,46 @@ def test_dispatch( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - **kwargs, - ) - kwargs = { - "obs": torch.randn(1, 10, 3), - "next_reward": torch.randn(1, 10, 1, requires_grad=True), - "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), - "next_obs": torch.randn(1, 10, 3), - } + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } advantage, value_target = module(**kwargs) assert advantage.shape == torch.Size([1, 10, 1]) assert value_target.shape == torch.Size([1, 10, 1]) @@ -10562,6 +10801,7 @@ def test_dispatch( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_diff_reward( @@ -10572,23 +10812,55 @@ def test_diff_reward( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - ) + [1, 10], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + ) td = module(td.clone(False)) # check that the advantage can't backprop to the value params td["advantage"].sum().backward() @@ -10603,6 +10875,7 @@ def test_diff_reward( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("shifted", [True, False]) @@ -10610,25 +10883,60 @@ def test_non_differentiable(self, adv, shifted, kwargs): value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - shifted=shifted, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) assert td["advantage"].is_leaf @@ -10638,6 +10946,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("has_value_net", [True, False]) @@ -10660,28 +10969,65 @@ def test_skip_existing( else: value_net = None - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - shifted=shifted, - skip_existing=skip_existing, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "state_value": torch.ones(1, 10, 1), - "next": { + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=True, + shifted=shifted, + skip_existing=skip_existing, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + shifted=shifted, + skip_existing=skip_existing, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) if has_value_net and not skip_existing: exp_val = 0 @@ -10699,15 +11045,34 @@ def test_skip_existing( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_keys(self, value, adv, kwargs): value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=[value]) - module = adv( - gamma=0.98, - value_network=value_net, - **kwargs, - ) + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + **kwargs, + ) module.set_keys(value=value) assert module.tensor_keys.value == value @@ -10721,6 +11086,7 @@ def test_set_keys(self, value, adv, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_deprecated_keys(self, adv, kwargs): @@ -10729,14 +11095,36 @@ def test_set_deprecated_keys(self, adv, kwargs): ) with pytest.warns(DeprecationWarning): - module = adv( - gamma=0.98, - value_network=value_net, - value_key="test_value", - advantage_key="advantage_test", - value_target_key="value_target_test", - **kwargs, - ) + + if adv is VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) assert module.tensor_keys.value == "test_value" assert module.tensor_keys.advantage == "advantage_test" assert module.tensor_keys.value_target == "value_target_test" diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index bb7b9014f0d..92955d4cab3 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,11 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -20,7 +26,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class A2CLoss(LossModule): @@ -202,6 +214,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE @@ -389,6 +402,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -399,5 +420,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index bdccbda3808..37c5e820d23 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -138,7 +138,7 @@ def set_keys(self, **kwargs) -> None: """ for key, value in kwargs.items(): if key not in self._AcceptedKeys.__dict__: - raise ValueError(f"{key} it not an accepted tensordict key") + raise ValueError(f"{key} is not an accepted tensordict key") if value is not None: setattr(self.tensor_keys, key, value) else: @@ -447,6 +447,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) + elif value_type == ValueEstimators.VTrace: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) elif value_type == ValueEstimators.TDLambda: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e576ca33c1c..2a2cc2fdb6e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. import math import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -22,7 +28,7 @@ ) from .common import LossModule -from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace class PPOLoss(LossModule): @@ -469,6 +475,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -479,6 +493,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 93910f1eebf..1ae9c1e8252 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -3,12 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Optional import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -18,7 +24,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class ReinforceLoss(LossModule): @@ -340,6 +352,14 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=actor_with_params, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -350,5 +370,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index bc678ed0154..b8ec5ec7c32 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -39,6 +39,7 @@ class ValueEstimators(Enum): TD1 = "TD(1) (infinity-step return)" TDLambda = "TD(lambda)" GAE = "Generalized advantage estimate" + VTrace = "V-trace" def default_value_kwargs(value_type: ValueEstimators): @@ -61,6 +62,8 @@ def default_value_kwargs(value_type: ValueEstimators): return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} elif value_type == ValueEstimators.TDLambda: return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} + elif value_type == ValueEstimators.VTrace: + return {"gamma": 0.99, "differentiable": True} else: raise NotImplementedError(f"Unknown value type {value_type}.") diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 11ae2e6d9e2..51496986153 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -12,4 +12,5 @@ TDLambdaEstimate, TDLambdaEstimator, ValueEstimatorBase, + VTrace, ) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 4d3a25279a1..42ba404c05d 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + + import abc import functools import warnings @@ -32,8 +34,10 @@ vec_generalized_advantage_estimate, vec_td1_return_estimate, vec_td_lambda_return_estimate, + vtrace_advantage_estimate, ) + try: from torch import vmap except ImportError as err: @@ -147,6 +151,17 @@ def _call_value_nets( return value, value_ +def _call_actor_net( + actor_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + log_prob_key: NestedKey, +): + # TODO: extend to handle time dimension (and vmap?) + log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key) + return log_pi + + class ValueEstimatorBase(TensorDictModuleBase): """An abstract parent class for value function modules. @@ -179,9 +194,11 @@ class _AcceptedKeys: whether a trajectory is done. Defaults to ``"done"``. terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Defaults to ``"terminated"``. - steps_to_next_obs_key (NestedKey): The key in the input tensordict + steps_to_next_obs (NestedKey): The key in the input tensordict that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. + sample_log_prob (NestedKey): The key in the input tensordict that + indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``. """ advantage: NestedKey = "advantage" @@ -191,6 +208,7 @@ class _AcceptedKeys: done: NestedKey = "done" terminated: NestedKey = "terminated" steps_to_next_obs: NestedKey = "steps_to_next_obs" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] @@ -223,6 +241,10 @@ def terminated_key(self): def steps_to_next_obs_key(self): return self.tensor_keys.steps_to_next_obs + @property + def sample_log_prob_key(self): + return self.tensor_keys.sample_log_prob + @abc.abstractmethod def forward( self, @@ -341,7 +363,7 @@ def set_keys(self, **kwargs) -> None: raise ValueError("tensordict keys cannot be None") if key not in self._AcceptedKeys.__dict__: raise KeyError( - f"{key} it not an accepted tensordict key for advantages" + f"{key} is not an accepted tensordict key for advantages" ) if ( key == "value" @@ -597,7 +619,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code @@ -799,7 +821,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code @@ -1137,7 +1159,7 @@ def __init__( def forward( self, tensordict: TensorDictBase, - *unused_args, + *, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: @@ -1328,6 +1350,287 @@ def value_estimate( return value_target +class VTrace(ValueEstimatorBase): + """A class wrapper around V-Trace estimate functional. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + :ref:`here `_ for more context. + + Args: + gamma (scalar): exponential mean discount. + value_network (TensorDictModule): value operator used to retrieve the value estimates. + actor_network (TensorDictModule): actor operator used to retrieve the log prob. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + Defaults to ``1.0``. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. + Defaults to ``1.0``. + average_adv (bool): if ``True``, the resulting advantage values will be standardized. + Default is ``False``. + differentiable (bool, optional): if ``True``, gradients are propagated through + the computation of the value function. Default is ``False``. + + .. note:: + The proper way to make the function call non-differentiable is to + decorate it in a `torch.no_grad()` context manager/decorator or + pass detached parameters for functional modules. + skip_existing (bool, optional): if ``True``, the value network will skip + modules which outputs are already present in the tensordict. + Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + is not affected. + Defaults to "state_value". + advantage_key (str or tuple of str, optional): [Deprecated] the key of + the advantage entry. Defaults to ``"advantage"``. + value_target_key (str or tuple of str, optional): [Deprecated] the key + of the advantage entry. Defaults to ``"value_target"``. + value_key (str or tuple of str, optional): [Deprecated] the value key to + read from the input tensordict. Defaults to ``"state_value"``. + shifted (bool, optional): if ``True``, the value and next value are + estimated with a single call to the value network. This is faster + but is only valid whenever (1) the ``"next"`` value is shifted by + only one time step (which is not the case with multi-step value + estimation, for instance) and (2) when the parameters used at time + ``t`` and ``t+1`` are identical (which is not the case when target + parameters are to be used). Defaults to ``False``. + + VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also + return a :obj:`"value_target"` entry with the V-Trace target value. + + .. note:: + As other advantage functions do, if the ``value_key`` is already present + in the input tensordict, the VTrace module will ignore the calls to the value + network (if any) and use the provided value instead. + + """ + + def __init__( + self, + *, + gamma: Union[float, torch.Tensor], + actor_network: TensorDictModule, + value_network: TensorDictModule, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + average_adv: bool = False, + differentiable: bool = False, + skip_existing: Optional[bool] = None, + advantage_key: Optional[NestedKey] = None, + value_target_key: Optional[NestedKey] = None, + value_key: Optional[NestedKey] = None, + shifted: bool = False, + ): + super().__init__( + shifted=shifted, + value_network=value_network, + differentiable=differentiable, + advantage_key=advantage_key, + value_target_key=value_target_key, + value_key=value_key, + skip_existing=skip_existing, + ) + try: + device = next(value_network.parameters()).device + except (AttributeError, StopIteration): + device = torch.device("cpu") + + if not isinstance(gamma, torch.Tensor): + gamma = torch.tensor(gamma, device=device) + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + self.register_buffer("gamma", gamma) + self.register_buffer("rho_thresh", rho_thresh) + self.register_buffer("c_thresh", c_thresh) + self.average_adv = average_adv + self.actor_network = actor_network + + if isinstance(gamma, torch.Tensor) and gamma.shape != (): + raise NotImplementedError( + "Per-value gamma is not supported yet. Gamma must be a scalar." + ) + + @property + def in_keys(self): + parent_in_keys = super().in_keys + extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob] + return extended_in_keys + + @_self_set_skip_existing + @_self_set_grad_enabled + @dispatch + def forward( + self, + tensordict: TensorDictBase, + *, + params: Optional[List[Tensor]] = None, + target_params: Optional[List[Tensor]] = None, + ) -> TensorDictBase: + """Computes the V-Trace correction given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + + Args: + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. + + Returns: + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... actor_network=actor_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... actor_network=actor_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) + >>> advantage, value_target = module( + ... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob + ... ) + + """ + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got " + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get(("next", self.tensor_keys.reward)) + device = reward.device + gamma = self.gamma.to(device) + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) + if steps_to_next_obs is not None: + gamma = gamma ** steps_to_next_obs.view_as(reward) + + # Make sure we have the value and next value + if self.value_network is not None: + if params is not None: + params = params.detach() + if target_params is None: + target_params = params.clone(False) + with hold_out_net(self.value_network): + # we may still need to pass gradient, but we don't want to assign grads to + # value net params + value, next_value = _call_value_nets( + value_net=self.value_network, + data=tensordict, + params=params, + next_params=target_params, + single_call=self.shifted, + value_key=self.tensor_keys.value, + detach_next=True, + ) + else: + value = tensordict.get(self.tensor_keys.value) + next_value = tensordict.get(("next", self.tensor_keys.value)) + + # Make sure we have the log prob computed at collection time + if self.tensor_keys.sample_log_prob not in tensordict.keys(): + raise ValueError( + f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" + ) + log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) + + # Compute log prob with current policy + with hold_out_net(self.actor_network): + log_pi = _call_actor_net( + actor_net=self.actor_network, + data=tensordict, + params=None, + log_prob_key=self.tensor_keys.sample_log_prob, + ).view_as(value) + + # Compute the V-Trace correction + done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated)) + + adv, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + value, + next_value, + reward, + done, + terminated, + rho_thresh=self.rho_thresh, + c_thresh=self.c_thresh, + time_dim=tensordict.ndim - 1, + ) + + if self.average_adv: + loc = adv.mean() + scale = adv.std().clamp_min(1e-5) + adv = adv - loc + adv = adv / scale + + tensordict.set(self.tensor_keys.advantage, adv) + tensordict.set(self.tensor_keys.value_target, value_target) + + return tensordict + + def _deprecate_class(cls, new_cls): @wraps(cls.__init__) def new_init(self, *args, **kwargs): diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 7c33895e965..6c43af02aeb 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -27,6 +27,7 @@ "vec_td_lambda_return_estimate", "td_lambda_advantage_estimate", "vec_td_lambda_advantage_estimate", + "vtrace_advantage_estimate", ] from torchrl.objectives.value.utils import ( @@ -1212,6 +1213,93 @@ def vec_td_lambda_advantage_estimate( ) +######################################################################## +# V-Trace +# ----- + + +@_transpose_time +def vtrace_advantage_estimate( + gamma: float, + log_pi: torch.Tensor, + log_mu: torch.Tensor, + state_value: torch.Tensor, + next_state_value: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + terminated: torch.Tensor | None = None, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + time_dim: int = -2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes V-Trace off-policy actor critic targets. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + log_pi (Tensor): collection actor log probability of taking actions in the environment. + log_mu (Tensor): current actor log probability of taking actions in the environment. + state_value (Tensor): value function result with state input. + next_state_value (Tensor): value function result with next_state input. + reward (Tensor): reward of taking actions in the environment. + done (Tensor): boolean flag for end of episode. + terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. + time_dim (int): dimension where the time is unrolled. Defaults to -2. + + All tensors (values, reward and done) must have shape + ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. + """ + if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + raise RuntimeError(SHAPE_ERR) + + device = state_value.device + + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + c_thresh = c_thresh.to(device) + rho_thresh = rho_thresh.to(device) + + not_done = (~done).int() + not_terminated = not_done if terminated is None else (~terminated).int() + *batch_size, time_steps, lastdim = not_done.shape + done_discounts = gamma * not_done + terminated_discounts = gamma * not_terminated + + rho = (log_pi - log_mu).exp() + clipped_rho = rho.clamp_max(rho_thresh) + deltas = clipped_rho * ( + reward + terminated_discounts * next_state_value - state_value + ) + clipped_c = rho.clamp_max(c_thresh) + + vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] + for i in reversed(range(time_steps)): + discount_t, c_t, delta_t = ( + done_discounts[..., i, :], + clipped_c[..., i, :], + deltas[..., i, :], + ) + vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) + vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim) + vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) + vs = vs_minus_v_xs + state_value + vs_t_plus_1 = torch.cat( + [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim + ) + advantages = clipped_rho * ( + reward + terminated_discounts * vs_t_plus_1 - state_value + ) + + return advantages, vs + + ######################################################################## # Reward to go # ------------ diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py deleted file mode 100644 index 43f5246502f..00000000000 --- a/torchrl/objectives/value/vtrace.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import Tuple, Union - -import torch - - -def _c_val( - log_pi: torch.Tensor, - log_mu: torch.Tensor, - c: Union[float, torch.Tensor] = 1, -) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp() - - -def _dv_val( - rewards: torch.Tensor, - vals: torch.Tensor, - gamma: Union[float, torch.Tensor], - rho_bar: Union[float, torch.Tensor], - log_pi: torch.Tensor, - log_mu: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - rho = _c_val(log_pi, log_mu, rho_bar) - next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) - dv = rho * (rewards + gamma * next_vals - vals) - return dv, rho - - -def _vtrace( - rewards: torch.Tensor, - vals: torch.Tensor, - log_pi: torch.Tensor, - log_mu: torch.Tensor, - gamma: Union[torch.Tensor, float], - rho_bar: Union[float, torch.Tensor] = 1.0, - c_bar: Union[float, torch.Tensor] = 1.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - T = vals.shape[1] - if not isinstance(gamma, torch.Tensor): - gamma = torch.full_like(vals, gamma) - - dv, rho = _dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) - c = _c_val(log_pi, log_mu, c_bar) - - v_out = [] - v_out.append(vals[:, -1] + dv[:, -1]) - for t in range(T - 2, -1, -1): - _v_out = ( - vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1]) - ) - v_out.append(_v_out) - v_out = torch.stack(list(reversed(v_out)), 1) - return v_out, rho