From 0a23ae8cfb0f695a973e53654841da10eb6eb5e4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 20 Mar 2024 20:42:39 +0100 Subject: [PATCH 01/37] add crossQ examples --- examples/crossQ/config.yaml | 57 ++++++ examples/crossQ/crossQ.py | 226 ++++++++++++++++++++++++ examples/crossQ/utils.py | 308 +++++++++++++++++++++++++++++++++ torchrl/objectives/__init__.py | 1 + 4 files changed, 592 insertions(+) create mode 100644 examples/crossQ/config.yaml create mode 100644 examples/crossQ/crossQ.py create mode 100644 examples/crossQ/utils.py diff --git a/examples/crossQ/config.yaml b/examples/crossQ/config.yaml new file mode 100644 index 00000000000..fc652a06783 --- /dev/null +++ b/examples/crossQ/config.yaml @@ -0,0 +1,57 @@ +# environment and task +env: + name: HalfCheetah-v4 + task: "" + library: gym + max_episode_steps: 1000 + seed: 42 + +# collector +collector: + total_frames: 1_000_000 + init_random_frames: 25000 + frames_per_batch: 1000 + init_env_steps: 1000 + device: cpu + env_per_collector: 1 + reset_at_each_iter: False + +# replay buffer +replay_buffer: + size: 1000000 + prb: 0 # use prioritized experience replay + scratch_dir: null + +# optim +optim: + utd_ratio: 1.0 + policy_update_delay: 3 + gamma: 0.99 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 + batch_size: 256 + alpha_init: 1.0 + # Adam β1 = 0.5 + adam_eps: 1.0e-8 + +# network +network: + batch_norm_momentum: 0.01 + # warmup_steps: 100000 # 10^5 + critic_hidden_sizes: [2048, 2048] + actor_hidden_sizes: [256, 256] + critic_activation: tanh + actor_activation: relu + default_policy_scale: 1.0 + scale_lb: 0.1 + device: "cuda:0" + +# logging +logger: + backend: wandb + project_name: torchrl_example_crossQ + group_name: null + exp_name: ${env.name}_CrossQ + mode: online + eval_iter: 25000 diff --git a/examples/crossQ/crossQ.py b/examples/crossQ/crossQ.py new file mode 100644 index 00000000000..125518abe93 --- /dev/null +++ b/examples/crossQ/crossQ.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""CrossQ Example. + +This is a simple self-contained example of a CrossQ training script. + +It supports state environments like MuJoCo. + +The helper functions are coded in the utils.py associated with this script. +""" +import time + +import hydra + +import numpy as np +import torch +import torch.cuda +import tqdm +from torchrl._utils import logger as torchrl_logger +from torchrl.envs.utils import ExplorationType, set_exploration_type + +from torchrl.record.loggers import generate_exp_name, get_logger +from utils import ( + log_metrics, + make_collector, + make_environment, + make_loss_module, + make_replay_buffer, + make_sac_agent, + make_sac_optimizer, +) + + +@hydra.main(version_base="1.1", config_path=".", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + device = torch.device(cfg.network.device) + + # Create logger + exp_name = generate_exp_name("SAC", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="crossq_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create environments + train_env, eval_env = make_environment(cfg) + + # Create agent + model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) + + # Create SAC loss + loss_module = make_loss_module(cfg, model) + + # Create off-policy collector + collector = make_collector(cfg, train_env, exploration_policy.eval()) + + # Create replay buffer + replay_buffer = make_replay_buffer( + batch_size=cfg.optim.batch_size, + prb=cfg.replay_buffer.prb, + buffer_size=cfg.replay_buffer.size, + scratch_dir=cfg.replay_buffer.scratch_dir, + device="cpu", + ) + + # Create optimizers + ( + optimizer_actor, + optimizer_critic, + optimizer_alpha, + ) = make_sac_optimizer(cfg, loss_module) + + # Main loop + start_time = time.time() + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.collector.total_frames) + + init_random_frames = cfg.collector.init_random_frames + num_updates = int( + cfg.collector.env_per_collector + * cfg.collector.frames_per_batch + * cfg.optim.utd_ratio + ) + prb = cfg.replay_buffer.prb + eval_iter = cfg.logger.eval_iter + frames_per_batch = cfg.collector.frames_per_batch + eval_rollout_steps = cfg.env.max_episode_steps + + sampling_start = time.time() + update_counter = 0 + delayed_updates = cfg.optim.policy_update_delay + for _, tensordict in enumerate(collector): + sampling_time = time.time() - sampling_start + + # Update weights of the inference policy + collector.update_policy_weights_() + + pbar.update(tensordict.numel()) + + tensordict = tensordict.reshape(-1) + current_frames = tensordict.numel() + # Add to replay buffer + replay_buffer.extend(tensordict.cpu()) + collected_frames += current_frames + + # Optimization steps + training_start = time.time() + if collected_frames >= init_random_frames: + ( + actor_losses, + alpha_losses, + q_losses, + ) = ([], [], []) + for _ in range(num_updates): + + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to( + device, non_blocking=True + ) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict) + + # Update critic + optimizer_critic.zero_grad() + q_loss.mean().backward() + optimizer_critic.step() + q_losses.append(q_loss.mean().detach().item()) + + if update_actor: + actor_loss, metadata_actor = loss_module._actor_loss( + sampled_tensordict + ) + alpha_loss = loss_module._alpha_loss( + log_prob=metadata_actor["log_prob"] + ) + + # Update actor + optimizer_actor.zero_grad() + actor_loss.mean().backward() + optimizer_actor.step() + + # Update alpha + optimizer_alpha.zero_grad() + alpha_loss.mean().backward() + optimizer_alpha.step() + + actor_losses.append(actor_loss.mean().detach().item()) + alpha_losses.append(alpha_loss.mean().detach().item()) + + # Update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) + + training_time = time.time() - training_start + episode_end = ( + tensordict["next", "done"] + if tensordict["next", "done"].any() + else tensordict["next", "truncated"] + ) + episode_rewards = tensordict["next", "episode_reward"][episode_end] + + # Logging + metrics_to_log = {} + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][episode_end] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = np.mean(q_losses).item() + metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() + metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() + # metrics_to_log["train/alpha"] = loss_td["alpha"].item() + # metrics_to_log["train/entropy"] = loss_td["entropy"].item() + metrics_to_log["train/sampling_time"] = sampling_time + metrics_to_log["train/training_time"] = training_time + + # Evaluation + if abs(collected_frames % eval_iter) < frames_per_batch: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_start = time.time() + eval_rollout = eval_env.rollout( + eval_rollout_steps, + model[0], + auto_cast_to_device=True, + break_when_any_done=True, + ) + eval_time = time.time() - eval_start + eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() + metrics_to_log["eval/reward"] = eval_reward + metrics_to_log["eval/time"] = eval_time + if logger is not None: + log_metrics(logger, metrics_to_log, collected_frames) + sampling_start = time.time() + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/crossQ/utils.py b/examples/crossQ/utils.py new file mode 100644 index 00000000000..76de788d9ac --- /dev/null +++ b/examples/crossQ/utils.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.nn import InteractionType, TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn, optim +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + ParallelEnv, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import MLP, ProbabilisticActor, ValueOperator +from torchrl.modules.distributions import TanhNormal +from torchrl.objectives import CrossQLoss + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu"): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + ) + elif lib == "dm_control": + env = DMControlEnv(cfg.env.name, cfg.env.task) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps=1000): + transformed_env = TransformedEnv( + env, + Compose( + InitTracker(), + StepCounter(max_episode_steps), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg): + """Make environments for training and evaluation.""" + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(lambda cfg=cfg: env_maker(cfg)), + serial_for_single=True, + ), + train_env.transform.clone(), + ) + return train_env, eval_env + + +# ==================================================================== +# Collector and replay buffer +# --------------------------- + + +def make_collector(cfg, train_env, actor_model_explore): + """Make collector.""" + collector = SyncDataCollector( + train_env, + actor_model_explore, + init_random_frames=cfg.collector.init_random_frames, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=cfg.collector.device, + ) + collector.set_seed(cfg.env.seed) + return collector + + +def make_replay_buffer( + batch_size, + prb=False, + buffer_size=1000000, + scratch_dir=None, + device="cpu", + prefetch=3, +): + if prb: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + else: + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=prefetch, + storage=LazyMemmapStorage( + buffer_size, + scratch_dir=scratch_dir, + device=device, + ), + batch_size=batch_size, + ) + return replay_buffer + + +# ==================================================================== +# Model +# ----- + + +def make_sac_agent(cfg, train_env, eval_env, device): + """Make SAC agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.actor_hidden_sizes, + "out_features": 2 * action_spec.shape[-1], + "activation_class": get_activation(cfg.network.actor_activation), + "norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.actor_hidden_sizes[-1], + }, + } + + actor_net = MLP(**actor_net_kwargs) + + dist_class = TanhNormal + dist_kwargs = { + "min": action_spec.space.low, + "max": action_spec.space.high, + "tanh_loc": False, + } + + actor_extractor = NormalParamExtractor( + scale_mapping=f"biased_softplus_{cfg.network.default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ) + actor_net = nn.Sequential(actor_net, actor_extractor) + + in_keys_actor = in_keys + actor_module = TensorDictModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "loc", + "scale", + ], + ) + actor = ProbabilisticActor( + spec=action_spec, + in_keys=["loc", "scale"], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + default_interaction_type=InteractionType.RANDOM, + return_log_prob=False, + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.critic_hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg.network.critic_activation), + "norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch + "norm_kwargs": { + "momentum": cfg.network.batch_norm_momentum, + "num_features": cfg.network.critic_hidden_sizes[-1], + }, + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net.eval() + net(td) + net.train() + del td + eval_env.close() + + return model, model[0] + + +# ==================================================================== +# SAC Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create SAC loss + loss_module = CrossQLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.optim.loss_function, + delay_actor=False, + alpha_init=cfg.optim.alpha_init, + ) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) + + return loss_module + + +def split_critic_params(critic_params): + critic1_params = [] + critic2_params = [] + + for param in critic_params: + data1, data2 = param.data.chunk(2, dim=0) + critic1_params.append(nn.Parameter(data1)) + critic2_params.append(nn.Parameter(data2)) + return critic1_params, critic2_params + + +def make_sac_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_alpha = optim.Adam( + [loss_module.log_alpha], + lr=3.0e-4, + ) + return optimizer_actor, optimizer_critic, optimizer_alpha + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(activation: str): + if activation == "relu": + return nn.ReLU + elif activation == "tanh": + return nn.Tanh + elif activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8d2bd1d977..f7ba0aaa21e 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,6 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .cql import CQLLoss, DiscreteCQLLoss +from .crossQ import CrossQLoss from .ddpg import DDPGLoss from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss From 9bdee715e2cc94beeb85a615f0ccac1f45c577c8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 20 Mar 2024 20:43:42 +0100 Subject: [PATCH 02/37] add loss --- torchrl/objectives/crossQ.py | 639 +++++++++++++++++++++++++++++++++++ 1 file changed, 639 insertions(+) create mode 100644 torchrl/objectives/crossQ.py diff --git a/torchrl/objectives/crossQ.py b/torchrl/objectives/crossQ.py new file mode 100644 index 00000000000..2cd049e7811 --- /dev/null +++ b/torchrl/objectives/crossQ.py @@ -0,0 +1,639 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import wraps +from typing import Dict, Tuple, Union + +import torch +from tensordict import TensorDict, TensorDictBase + +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torch import Tensor +from torchrl.data.tensor_specs import CompositeSpec +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ProbabilisticActor +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +def _delezify(func): + @wraps(func) + def new_func(self, *args, **kwargs): + self.target_entropy + return func(self, *args, **kwargs) + + return new_func + + +class CrossQLoss(LossModule): + """TorchRL implementation of the CrossQ loss. + + Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING + FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX + + Args: + actor_network (ProbabilisticActor): stochastic actor + qvalue_network (TensorDictModule): Q(s, a) parametric model. + This module typically outputs a ``"state_action_value"`` entry. + + .. note:: + If not provided, the second version of SAC is assumed, where + only the Q-Value network is needed. + + num_qvalue_nets (integer, optional): number of Q-Value networks used. + Defaults to ``2``. + loss_function (str, optional): loss function to be used with + the value function loss. Default is `"smooth_l1"`. + alpha_init (float, optional): initial entropy multiplier. + Default is 1.0. + min_alpha (float, optional): min value of alpha. + Default is None (no minimum value). + max_alpha (float, optional): max value of alpha. + Default is None (no maximum value). + action_spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + fixed_alpha (bool, optional): if ``True``, alpha will be fixed to its + initial value. Otherwise, alpha will be optimized to + match the 'target_entropy' value. + Default is ``False``. + target_entropy (float or str, optional): Target entropy for the + stochastic policy. Default is "auto", where target entropy is + computed as :obj:`-prod(n_actions)`. + delay_actor (bool, optional): Whether to separate the target actor + networks from the actor networks used for data collection. + Default is ``False``. + priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] + Tensordict key where to write the + priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.sac import SACLoss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> module = nn.Linear(n_obs, 1) + >>> value = ValueOperator( + ... module=module, + ... in_keys=["observation"]) + >>> loss = SACLoss(actor, qvalue, value) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + ``"loss_value"`` if version one is used. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.sac import SACLoss + >>> _ = torch.manual_seed(42) + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) + >>> actor = ProbabilisticActor( + ... module=module, + ... in_keys=["loc", "scale"], + ... spec=spec, + ... distribution_class=TanhNormal) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> module = nn.Linear(n_obs, 1) + >>> value = ValueOperator( + ... module=module, + ... in_keys=["observation"]) + >>> loss = SACLoss(actor, qvalue, value) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue, _, _, _, _ = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + + The output keys can also be filtered using the :meth:`SACLoss.select_out_keys` + method. + + Examples: + >>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_observation=torch.zeros(*batch, n_obs), + ... next_reward=torch.randn(*batch, 1)) + >>> loss_actor.backward() + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"advantage"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Defaults to ``"state_action_value"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + log_prob: NestedKey = "_log_prob" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + + def __init__( + self, + actor_network: ProbabilisticActor, + qvalue_network: TensorDictModule, + *, + num_qvalue_nets: int = 2, + loss_function: str = "smooth_l1", + alpha_init: float = 1.0, + min_alpha: float = None, + max_alpha: float = None, + action_spec=None, + fixed_alpha: bool = False, + target_entropy: Union[str, float] = "auto", + delay_actor: bool = False, + gamma: float = None, + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + self._set_deprecated_ctor_keys(priority_key=priority_key) + + # Actor + self.delay_actor = delay_actor + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + q_value_policy_params = None + + # Q value + self.num_qvalue_nets = num_qvalue_nets + + q_value_policy_params = policy_params + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=False, + compare_against=q_value_policy_params, + ) + + self.loss_function = loss_function + try: + device = next(self.parameters()).device + except AttributeError: + device = torch.device("cpu") + self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) + if bool(min_alpha) ^ bool(max_alpha): + min_alpha = min_alpha if min_alpha else 0.0 + if max_alpha == 0: + raise ValueError("max_alpha must be either None or greater than 0.") + max_alpha = max_alpha if max_alpha else 1e9 + if min_alpha: + self.register_buffer( + "min_log_alpha", torch.tensor(min_alpha, device=device).log() + ) + else: + self.min_log_alpha = None + if max_alpha: + self.register_buffer( + "max_log_alpha", torch.tensor(max_alpha, device=device).log() + ) + else: + self.max_log_alpha = None + self.fixed_alpha = fixed_alpha + if fixed_alpha: + self.register_buffer( + "log_alpha", torch.tensor(math.log(alpha_init), device=device) + ) + else: + self.register_parameter( + "log_alpha", + torch.nn.Parameter(torch.tensor(math.log(alpha_init), device=device)), + ) + + self._target_entropy = target_entropy + self._action_spec = action_spec + if gamma is not None: + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._vmap_qnetworkN0 = _vmap_func( + self.qvalue_network, (None, 0), randomness=self.vmap_randomness + ) + self.reduction = reduction + + @property + def target_entropy_buffer(self): + return self.target_entropy + + @property + def target_entropy(self): + target_entropy = self._buffers.get("_target_entropy", None) + if target_entropy is not None: + return target_entropy + target_entropy = self._target_entropy + action_spec = self._action_spec + actor_network = self.actor_network + device = next(self.parameters()).device + if target_entropy == "auto": + action_spec = ( + action_spec + if action_spec is not None + else getattr(actor_network, "spec", None) + ) + if action_spec is None: + raise RuntimeError( + "Cannot infer the dimensionality of the action. Consider providing " + "the target entropy explicitely or provide the spec of the " + "action tensor in the actor network." + ) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape + else: + action_container_shape = action_spec.shape + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[len(action_container_shape) :] + .numel() + ) + delattr(self, "_target_entropy") + self.register_buffer( + "_target_entropy", torch.tensor(target_entropy, device=device) + ) + return self._target_entropy + + state_dict = _delezify(LossModule.state_dict) + load_state_dict = _delezify(LossModule.load_state_dict) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self.tensor_keys.value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + + value_net = None + hp = dict(default_value_kwargs(value_type)) + hp.update(hyperparams) + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator( + **hp, + value_network=value_net, + ) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator( + **hp, + value_network=value_net, + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + # "value_target": "value_target", + # "value": self.tensor_keys.value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) + + @property + def device(self) -> torch.device: + for p in self.parameters(): + return p.device + raise RuntimeError( + "At least one of the networks of SACLoss must have trainable " "parameters." + ) + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + if self._version == 1: + keys.extend(self.value_network.in_keys) + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + shape = None + if tensordict.ndimension() > 1: + shape = tensordict.shape + tensordict_reshape = tensordict.reshape(-1) + else: + tensordict_reshape = tensordict + + loss_qvalue, value_metadata = self._qvalue_loss(tensordict_reshape) + loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) + loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) + if loss_actor.shape != loss_qvalue.shape: + raise RuntimeError( + f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}" + ) + if shape: + tensordict.update(tensordict_reshape.view(shape)) + entropy = -metadata_actor["log_prob"] + out = { + "loss_actor": loss_actor, + "loss_qvalue": loss_qvalue, + "loss_alpha": loss_alpha, + "alpha": self._alpha, + "entropy": entropy.detach().mean(), + } + td_out = TensorDict(out, []) + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], + ) + return td_out + + @property + @_cache_values + def _cached_detached_qvalue_params(self): + return self.qvalue_network_params.detach() + + def _actor_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm) + + td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + td_q.set(self.tensor_keys.action, a_reparm) + td_q = self._vmap_qnetworkN0( + td_q, + self._cached_detached_qvalue_params, # should we clone? + ) + min_q_logprob = ( + td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + ) + + if log_prob.shape != min_q_logprob.shape: + raise RuntimeError( + f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" + ) + + return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} + + def _compute_target(self, tensordict) -> Tensor: + r"""Value network for SAC v2. + + SAC v2 is based on a value estimate of the form: + + .. math:: + + V = Q(s,a) - \alpha * \log p(a | s) + + This class computes this value given the actor and qvalue network + + """ + tensordict = tensordict.clone(False) + # TODO asser that models are in train mode + # get actions and log-probs + with torch.no_grad(): + with set_exploration_type( + ExplorationType.RANDOM + ), self.actor_network_params.to_module(self.actor_network): + next_tensordict = tensordict.get("next").clone(False) + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = next_dist.log_prob(next_action) + + # get q-values + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.qvalue_network_params + ) + state_action_value = next_tensordict_expand.get( + self.tensor_keys.state_action_value + ) + if ( + state_action_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape + ): + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + next_state_value = state_action_value - self._alpha * next_sample_log_prob + next_state_value = next_state_value.min(0)[0] + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_value + ) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + return target_value + + def _qvalue_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: + # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. + target_value = self._compute_target(tensordict) + + tensordict_expand = self._vmap_qnetworkN0( + tensordict.select(*self.qvalue_network.in_keys, strict=False), + self.qvalue_network_params, + ) + pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( + -1 + ) + td_error = abs(pred_val - target_value) + loss_qval = distance_loss( + pred_val, + target_value.expand_as(pred_val), + loss_function=self.loss_function, + ).sum(0) + metadata = {"td_error": td_error.detach().max(0)[0]} + return loss_qval, metadata + + def _alpha_loss(self, log_prob: Tensor) -> Tensor: + if self.target_entropy is not None: + # we can compute this loss even if log_alpha is not a parameter + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) + else: + # placeholder + alpha_loss = torch.zeros_like(log_prob) + return alpha_loss + + @property + def _alpha(self): + if self.min_log_alpha is not None: + self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) + with torch.no_grad(): + alpha = self.log_alpha.exp() + return alpha From 570a20ee3b379af4d91e68ac820dab9afc9a6bec Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 10:27:39 +0100 Subject: [PATCH 03/37] Update naming experiment --- examples/crossQ/crossQ.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/crossQ/crossQ.py b/examples/crossQ/crossQ.py index 125518abe93..c9b8da7006f 100644 --- a/examples/crossQ/crossQ.py +++ b/examples/crossQ/crossQ.py @@ -38,7 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("SAC", cfg.logger.exp_name) + exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( From 50862498a9888d9b087decfd6706b377d6416aae Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 16:30:56 +0100 Subject: [PATCH 04/37] update --- examples/{crossQ => crossq}/config.yaml | 0 .../{crossQ/crossQ.py => crossq/crossq.py} | 31 ++-- examples/{crossQ => crossq}/utils.py | 10 +- torchrl/objectives/__init__.py | 2 +- torchrl/objectives/{crossQ.py => crossq.py} | 156 ++++++++++++------ 5 files changed, 125 insertions(+), 74 deletions(-) rename examples/{crossQ => crossq}/config.yaml (100%) rename examples/{crossQ/crossQ.py => crossq/crossq.py} (90%) rename examples/{crossQ => crossq}/utils.py (98%) rename torchrl/objectives/{crossQ.py => crossq.py} (84%) diff --git a/examples/crossQ/config.yaml b/examples/crossq/config.yaml similarity index 100% rename from examples/crossQ/config.yaml rename to examples/crossq/config.yaml diff --git a/examples/crossQ/crossQ.py b/examples/crossq/crossq.py similarity index 90% rename from examples/crossQ/crossQ.py rename to examples/crossq/crossq.py index 125518abe93..a1ca275663a 100644 --- a/examples/crossQ/crossQ.py +++ b/examples/crossq/crossq.py @@ -25,11 +25,11 @@ from utils import ( log_metrics, make_collector, + make_crossQ_agent, + make_crossQ_optimizer, make_environment, make_loss_module, make_replay_buffer, - make_sac_agent, - make_sac_optimizer, ) @@ -38,7 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) # Create logger - exp_name = generate_exp_name("SAC", cfg.logger.exp_name) + exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( @@ -60,9 +60,9 @@ def main(cfg: "DictConfig"): # noqa: F821 train_env, eval_env = make_environment(cfg) # Create agent - model, exploration_policy = make_sac_agent(cfg, train_env, eval_env, device) + model, exploration_policy = make_crossQ_agent(cfg, train_env, eval_env, device) - # Create SAC loss + # Create CrossQ loss loss_module = make_loss_module(cfg, model) # Create off-policy collector @@ -82,7 +82,7 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_alpha, - ) = make_sac_optimizer(cfg, loss_module) + ) = make_crossQ_optimizer(cfg, loss_module) # Main loop start_time = time.time() @@ -141,33 +141,34 @@ def main(cfg: "DictConfig"): # noqa: F821 # Compute loss q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict) - + q_loss = q_loss.mean() # Update critic optimizer_critic.zero_grad() - q_loss.mean().backward() + q_loss.backward() optimizer_critic.step() - q_losses.append(q_loss.mean().detach().item()) + q_losses.append(q_loss.detach().item()) if update_actor: actor_loss, metadata_actor = loss_module._actor_loss( sampled_tensordict ) + actor_loss = actor_loss.mean() alpha_loss = loss_module._alpha_loss( log_prob=metadata_actor["log_prob"] ) - + alpha_loss = alpha_loss.mean() # Update actor optimizer_actor.zero_grad() - actor_loss.mean().backward() + actor_loss.backward() optimizer_actor.step() # Update alpha optimizer_alpha.zero_grad() - alpha_loss.mean().backward() + alpha_loss.backward() optimizer_alpha.step() - actor_losses.append(actor_loss.mean().detach().item()) - alpha_losses.append(alpha_loss.mean().detach().item()) + actor_losses.append(actor_loss.detach().item()) + alpha_losses.append(alpha_loss.detach().item()) # Update priority if prb: @@ -193,8 +194,6 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/q_loss"] = np.mean(q_losses).item() metrics_to_log["train/actor_loss"] = np.mean(actor_losses).item() metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses).item() - # metrics_to_log["train/alpha"] = loss_td["alpha"].item() - # metrics_to_log["train/entropy"] = loss_td["entropy"].item() metrics_to_log["train/sampling_time"] = sampling_time metrics_to_log["train/training_time"] = training_time diff --git a/examples/crossQ/utils.py b/examples/crossq/utils.py similarity index 98% rename from examples/crossQ/utils.py rename to examples/crossq/utils.py index 76de788d9ac..b4865f79c39 100644 --- a/examples/crossQ/utils.py +++ b/examples/crossq/utils.py @@ -143,8 +143,8 @@ def make_replay_buffer( # ----- -def make_sac_agent(cfg, train_env, eval_env, device): - """Make SAC agent.""" +def make_crossQ_agent(cfg, train_env, eval_env, device): + """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] action_spec = train_env.action_spec @@ -233,13 +233,13 @@ def make_sac_agent(cfg, train_env, eval_env, device): # ==================================================================== -# SAC Loss +# CrossQ Loss # --------- def make_loss_module(cfg, model): """Make loss module and target network updater.""" - # Create SAC loss + # Create CrossQ loss loss_module = CrossQLoss( actor_network=model[0], qvalue_network=model[1], @@ -264,7 +264,7 @@ def split_critic_params(critic_params): return critic1_params, critic2_params -def make_sac_optimizer(cfg, loss_module): +def make_crossQ_optimizer(cfg, loss_module): critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) actor_params = list(loss_module.actor_network_params.flatten_keys().values()) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f7ba0aaa21e..c5155405db4 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -6,7 +6,7 @@ from .a2c import A2CLoss from .common import LossModule from .cql import CQLLoss, DiscreteCQLLoss -from .crossQ import CrossQLoss +from .crossq import CrossQLoss from .ddpg import DDPGLoss from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss diff --git a/torchrl/objectives/crossQ.py b/torchrl/objectives/crossq.py similarity index 84% rename from torchrl/objectives/crossQ.py rename to torchrl/objectives/crossq.py index 2cd049e7811..754d8928733 100644 --- a/torchrl/objectives/crossQ.py +++ b/torchrl/objectives/crossq.py @@ -52,10 +52,6 @@ class CrossQLoss(LossModule): qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. - .. note:: - If not provided, the second version of SAC is assumed, where - only the Q-Value network is needed. - num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -98,7 +94,7 @@ class CrossQLoss(LossModule): >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule - >>> from torchrl.objectives.sac import SACLoss + >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) @@ -119,11 +115,7 @@ class CrossQLoss(LossModule): >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) - >>> module = nn.Linear(n_obs, 1) - >>> value = ValueOperator( - ... module=module, - ... in_keys=["observation"]) - >>> loss = SACLoss(actor, qvalue, value) + >>> loss = CrossQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ @@ -141,8 +133,7 @@ class CrossQLoss(LossModule): entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) @@ -150,9 +141,9 @@ class CrossQLoss(LossModule): This class is compatible with non-tensordict based modules too and can be used without recurring to any tensordict-related primitive. In this case, the expected keyword arguments are: - ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor, value, and qvalue network. + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network. The return value is a tuple of tensors in the following order: - ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` + ``"loss_value"`` if version one is used. + ``["loss_actor", "loss_qvalue", "loss_alpha", "alpha", "entropy"]`` Examples: >>> import torch @@ -161,7 +152,7 @@ class CrossQLoss(LossModule): >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule - >>> from torchrl.objectives.sac import SACLoss + >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) @@ -182,14 +173,10 @@ class CrossQLoss(LossModule): >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) - >>> module = nn.Linear(n_obs, 1) - >>> value = ValueOperator( - ... module=module, - ... in_keys=["observation"]) - >>> loss = SACLoss(actor, qvalue, value) + >>> loss = CrossQLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> loss_actor, loss_qvalue, _, _, _, _ = loss( + >>> loss_actor, loss_qvalue, _, _, _ = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), @@ -198,7 +185,7 @@ class CrossQLoss(LossModule): ... next_reward=torch.randn(*batch, 1)) >>> loss_actor.backward() - The output keys can also be filtered using the :meth:`SACLoss.select_out_keys` + The output keys can also be filtered using the :meth:`CrossQLoss.select_out_keys` method. Examples: @@ -344,6 +331,9 @@ def __init__( self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) + self._vmap_qnetwork00 = _vmap_func( + self.qvalue_network, randomness=self.vmap_randomness + ) self.reduction = reduction @property @@ -435,8 +425,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams raise NotImplementedError(f"Unknown value type {value_type}") tensor_keys = { - # "value_target": "value_target", - # "value": self.tensor_keys.value, "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, @@ -461,8 +449,6 @@ def _set_in_keys(self): *[("next", key) for key in self.actor_network.in_keys], *self.qvalue_network.in_keys, ] - if self._version == 1: - keys.extend(self.value_network.in_keys) self._in_keys = list(set(keys)) @property @@ -515,9 +501,11 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: } td_out = TensorDict(out, []) td_out = td_out.named_apply( - lambda name, value: _reduce(value, reduction=self.reduction) - if name.startswith("loss_") - else value, + lambda name, value: ( + _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value + ), batch_size=[], ) return td_out @@ -527,6 +515,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() + @property + @_cache_values + def _cached_qvalue_params(self): + return torch.cat( + [self.qvalue_network_params, self.qvalue_network_params.detach()], 0 + ) + def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -555,9 +550,9 @@ def _actor_loss( return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} def _compute_target(self, tensordict) -> Tensor: - r"""Value network for SAC v2. + r"""Value network for CrossQ. - SAC v2 is based on a value estimate of the form: + CrossQ is based on a value estimate of the form: .. math:: @@ -567,7 +562,6 @@ def _compute_target(self, tensordict) -> Tensor: """ tensordict = tensordict.clone(False) - # TODO asser that models are in train mode # get actions and log-probs with torch.no_grad(): with set_exploration_type( @@ -579,32 +573,32 @@ def _compute_target(self, tensordict) -> Tensor: next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) - # get q-values - next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict, self.qvalue_network_params - ) - state_action_value = next_tensordict_expand.get( - self.tensor_keys.state_action_value - ) - if ( - state_action_value.shape[-len(next_sample_log_prob.shape) :] - != next_sample_log_prob.shape - ): - next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - next_state_value = state_action_value - self._alpha * next_sample_log_prob - next_state_value = next_state_value.min(0)[0] - tensordict.set( - ("next", self.value_estimator.tensor_keys.value), next_state_value - ) - target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - return target_value + # get q-values + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.qvalue_network_params + ) + state_action_value = next_tensordict_expand.get( + self.tensor_keys.state_action_value + ) + if ( + state_action_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape + ): + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + next_state_value = state_action_value - self._alpha * next_sample_log_prob + next_state_value = next_state_value.min(0)[0] + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_value + ) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + return target_value def _qvalue_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. - target_value = self._compute_target(tensordict) + target_value = self._compute_target(tensordict) tensordict_expand = self._vmap_qnetworkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), self.qvalue_network_params, @@ -612,6 +606,64 @@ def _qvalue_loss( pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( -1 ) + + # ############################ + # # compute next action + # with torch.no_grad(): + # with set_exploration_type( + # ExplorationType.MODE + # ), self.actor_network_params.to_module(self.actor_network): + # next_tensordict = tensordict.get("next").clone(False) + # next_dist = self.actor_network.get_dist(next_tensordict) + # next_action = next_dist.loc #.rsample() + # next_tensordict.set(self.tensor_keys.action, next_action) + # next_sample_log_prob = next_dist.log_prob(next_action) + + # q_values_tensordict = torch.cat( + # [ + # tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( + # self.num_qvalue_nets, *tensordict.batch_size + # ), + # next_tensordict.select( + # *self.qvalue_network.in_keys, strict=False + # ).expand(self.num_qvalue_nets, *tensordict.batch_size), + # ], + # 0, + # ) # shape (4, batch_size, *) + # q_values_tensordict = q_values_tensordict.contiguous() + + # q_values_tensordict = self._vmap_qnetwork00( + # q_values_tensordict, self._cached_qvalue_params + # ) + # # split q values + # (current_state_action_value, next_state_action_value) = q_values_tensordict.get( + # self.tensor_keys.state_action_value + # ).split( + # [ + # self.num_qvalue_nets, + # self.num_qvalue_nets, + # ], + # dim=0, + # ) + # # compute target value + # next_state_action_value = next_state_action_value.detach() + # if ( + # next_state_action_value.shape[-len(next_sample_log_prob.shape) :] + # != next_sample_log_prob.shape + # ): + # next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + # next_state_action_value = ( + # next_state_action_value - self._alpha * next_sample_log_prob + # ) + # next_state_action_value = next_state_action_value.min(0)[0] + # tensordict.set( + # ("next", self.value_estimator.tensor_keys.value), next_state_action_value + # ) + # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + # # get current q-values + # pred_val = current_state_action_value.squeeze(-1) + ############################ + # compute loss td_error = abs(pred_val - target_value) loss_qval = distance_loss( pred_val, From c3a927fa3ea1994f3bd53ad2cfaf44eb145feef3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 18:21:30 +0100 Subject: [PATCH 05/37] update add tests --- .../linux_examples/scripts/run_test.sh | 12 + test/test_cost.py | 750 ++++++++++++++++++ 2 files changed, 762 insertions(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index e75f4b1bc1c..c74d86435d9 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -166,6 +166,18 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/discrete_sac/d logger.backend= # logger.record_video=True \ # logger.record_frames=4 \ +python .github/unittest/helpers/coverage_run_parallel.py examples/crossq/crossq.py \ + collector.total_frames=48 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=16 \ + collector.env_per_collector=2 \ + collector.device=cuda:0 \ + optim.batch_size=10 \ + optim.utd_ratio=1 \ + replay_buffer.size=120 \ + env.name=Pendulum-v1 \ + network.device=cuda:0 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py examples/dreamer/dreamer.py \ total_frames=200 \ init_random_frames=10 \ diff --git a/test/test_cost.py b/test/test_cost.py index 58f56d8ed89..3fb1633884a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -98,6 +98,7 @@ A2CLoss, ClipPPOLoss, CQLLoss, + CrossQLoss, DDPGLoss, DiscreteCQLLoss, DiscreteIQLLoss, @@ -4211,6 +4212,755 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == torch.Size([]) +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestCrossQ(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + out_keys=[action_key], + ) + return actor.to(device) + + def _create_mock_qvalue( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + action_key="action", + out_keys=None, + ): + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + qvalue = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return qvalue.to(device) + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + qvalue = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ), + ) + qvalue_head = Mod( + qvalue, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + qvalue = Seq(common, qvalue_head) + return actor, qvalue, common, td + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_data_crossq( + self, + batch=16, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + observation_key="observation", + action_key="action", + done_key="done", + terminated_key="terminated", + reward_key="reward", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_crossq( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "next": { + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "terminated": terminated, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + "collector": {"mask": mask}, + "action": action.masked_fill_(~mask.unsqueeze(-1), 0.0), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + def test_crossq( + self, + delay_actor, + num_qvalue, + device, + version, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + + with _check_td_steady(td), pytest.warns( + UserWarning, match="No target network updater" + ): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("num_qvalue", [2]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_crossq_state_dict( + self, + delay_actor, + num_qvalue, + device, + version, + ): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + sd = loss_fn.state_dict() + loss_fn2 = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_crossq_separate_losses( + self, + device, + separate_losses, + version, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + action_spec=UnboundedContinuousTensorSpec(shape=(n_act,)), + num_qvalue_nets=1, + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert loss_fn.tensor_keys.priority in td.keys() + + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + common_layers_no = len(list(common.parameters())) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("delay_actor", (True, False)) + @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_crossq_batcher( + self, + n, + delay_actor, + num_qvalue, + device, + version, + ): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_sac(device=device) + + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + + value = None + + kwargs = {} + if delay_actor: + kwargs["delay_actor"] = True + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=num_qvalue, + loss_function="l2", + **kwargs, + ) + + ms = MultiStep(gamma=0.9, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + with pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ): + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + assert not hasattr(loss_fn, "target_qvalue_network_params") + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + + if loss_fn.delay_actor: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all( + (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) + ) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_crossq_tensordict_keys(self, td_est, version): + td = self._create_mock_data_crossq() + + actor = self._create_mock_actor() + qvalue = self._create_mock_qvalue() + value = None + + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + loss_function="l2", + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "log_prob": "_log_prob", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value() + loss_fn = CrossQLoss( + actor, + value, + loss_function="l2", + ) + + key_mapping = { + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_crossq_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key, version + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_crossq( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], + ) + + value = None + + loss = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + # setting the seed for each loss so that drawing the random samples from value network + # leads to same numbers for both runs + torch.manual_seed(self.seed) + with pytest.warns(UserWarning, match="No target network updater"): + loss_val = loss(**kwargs) + + torch.manual_seed(self.seed) + + loss_val_td = loss(td) + assert len(loss_val) == 5 + + torch.testing.assert_close(loss_val_td.get("loss_actor"), loss_val[0]) + torch.testing.assert_close(loss_val_td.get("loss_qvalue"), loss_val[1]) + torch.testing.assert_close(loss_val_td.get("loss_alpha"), loss_val[2]) + torch.testing.assert_close(loss_val_td.get("alpha"), loss_val[3]) + torch.testing.assert_close(loss_val_td.get("entropy"), loss_val[4]) + + # test select + torch.manual_seed(self.seed) + loss.select_out_keys("loss_actor", "loss_alpha") + if torch.__version__ >= "2.0.0": + loss_actor, loss_alpha = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_alpha = loss(**kwargs) + return + assert loss_actor == loss_val_td["loss_actor"] + assert loss_alpha == loss_val_td["loss_alpha"] + + def test_state_dict(self, version): + + model = torch.nn.Linear(3, 4) + actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) + policy = ProbabilisticActor( + module=actor_module, + in_keys=["logits"], + out_keys=["action"], + distribution_class=TanhDelta, + ) + value = ValueOperator(module=model, in_keys=["obs"], out_keys="value") + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + # with an access in between + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.target_entropy + state = loss.state_dict() + + loss = CrossQLoss( + actor_network=policy, + qvalue_network=value, + action_spec=UnboundedContinuousTensorSpec(shape=(2,)), + ) + loss.load_state_dict(state) + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_crossq_reduction(self, reduction, version): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_crossq(device=device) + actor = self._create_mock_actor(device=device) + qvalue = self._create_mock_qvalue(device=device) + value = None + loss_fn = CrossQLoss( + actor_network=actor, + qvalue_network=qvalue, + loss_function="l2", + delay_actor=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) From d1c9c348ffca606af07b5c20e4acec1bbb38d7b1 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 18:22:13 +0100 Subject: [PATCH 06/37] detach --- torchrl/objectives/crossq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 754d8928733..1a1c9db967a 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -591,7 +591,7 @@ def _compute_target(self, tensordict) -> Tensor: ("next", self.value_estimator.tensor_keys.value), next_state_value ) target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - return target_value + return target_value.detach() def _qvalue_loss( self, tensordict: TensorDictBase @@ -659,10 +659,10 @@ def _qvalue_loss( # tensordict.set( # ("next", self.value_estimator.tensor_keys.value), next_state_action_value # ) - # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1).detach() # # get current q-values # pred_val = current_state_action_value.squeeze(-1) - ############################ + # ########################### # compute loss td_error = abs(pred_val - target_value) loss_qval = distance_loss( From e879b7cfe943179f12139233dc9637d678baf1c4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 18:44:34 +0100 Subject: [PATCH 07/37] update tests --- test/test_cost.py | 52 ++++++++++++++++++----------------------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 3fb1633884a..d5b65e63771 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -257,9 +257,9 @@ def __init__(self): self.vmap_model = _vmap_func( self.model, (None, 0), - randomness="error" - if vmap_randomness == "error" - else self.vmap_randomness, + randomness=( + "error" if vmap_randomness == "error" else self.vmap_randomness + ), ) def forward(self, td): @@ -315,9 +315,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], @@ -4412,17 +4412,14 @@ def test_crossq( delay_actor, num_qvalue, device, - version, td_est, ): torch.manual_seed(self.seed) - td = self._create_mock_data_sac(device=device) + td = self._create_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - value = None - kwargs = {} if delay_actor: kwargs["delay_actor"] = True @@ -4524,15 +4521,12 @@ def test_crossq_state_dict( delay_actor, num_qvalue, device, - version, ): torch.manual_seed(self.seed) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - value = None - kwargs = {} if delay_actor: kwargs["delay_actor"] = True @@ -4560,7 +4554,6 @@ def test_crossq_separate_losses( self, device, separate_losses, - version, n_act=4, ): torch.manual_seed(self.seed) @@ -4654,16 +4647,13 @@ def test_crossq_batcher( delay_actor, num_qvalue, device, - version, ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_sac(device=device) + td = self._create_seq_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - value = None - kwargs = {} if delay_actor: kwargs["delay_actor"] = True @@ -4730,7 +4720,6 @@ def test_crossq_batcher( include_nested=True, leaves_only=True ) ] - assert not hasattr(loss_fn, "target_qvalue_network_params") for p in loss_fn.parameters(): if p.requires_grad: p.data += torch.randn_like(p) @@ -4762,8 +4751,7 @@ def test_crossq_batcher( @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] ) - def test_crossq_tensordict_keys(self, td_est, version): - td = self._create_mock_data_crossq() + def test_crossq_tensordict_keys(self, td_est): actor = self._create_mock_actor() qvalue = self._create_mock_qvalue() @@ -4792,10 +4780,10 @@ def test_crossq_tensordict_keys(self, td_est, version): td_est=td_est, ) - value = self._create_mock_value() + qvalue = self._create_mock_qvalue() loss_fn = CrossQLoss( actor, - value, + qvalue, loss_function="l2", ) @@ -4812,7 +4800,7 @@ def test_crossq_tensordict_keys(self, td_est, version): @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_crossq_notensordict( - self, action_key, observation_key, reward_key, done_key, terminated_key, version + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) td = self._create_mock_data_crossq( @@ -4832,8 +4820,6 @@ def test_crossq_notensordict( out_keys=["state_action_value"], ) - value = None - loss = CrossQLoss( actor_network=actor, qvalue_network=qvalue, @@ -4887,7 +4873,9 @@ def test_crossq_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] - def test_state_dict(self, version): + def test_state_dict( + self, + ): model = torch.nn.Linear(3, 4) actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"]) @@ -4930,7 +4918,7 @@ def test_state_dict(self, version): loss.load_state_dict(state) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_crossq_reduction(self, reduction, version): + def test_crossq_reduction(self, reduction): torch.manual_seed(self.seed) device = ( torch.device("cpu") @@ -4940,7 +4928,7 @@ def test_crossq_reduction(self, reduction, version): td = self._create_mock_data_crossq(device=device) actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - value = None + loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, @@ -6431,9 +6419,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], From a7b79c3558dbb0786a15563bf76f26d506c0f51e Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 21 Mar 2024 18:53:12 +0100 Subject: [PATCH 08/37] move crossq to sota-implementations --- {examples => sota-implementations}/crossq/config.yaml | 0 {examples => sota-implementations}/crossq/crossq.py | 0 {examples => sota-implementations}/crossq/utils.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {examples => sota-implementations}/crossq/config.yaml (100%) rename {examples => sota-implementations}/crossq/crossq.py (100%) rename {examples => sota-implementations}/crossq/utils.py (100%) diff --git a/examples/crossq/config.yaml b/sota-implementations/crossq/config.yaml similarity index 100% rename from examples/crossq/config.yaml rename to sota-implementations/crossq/config.yaml diff --git a/examples/crossq/crossq.py b/sota-implementations/crossq/crossq.py similarity index 100% rename from examples/crossq/crossq.py rename to sota-implementations/crossq/crossq.py diff --git a/examples/crossq/utils.py b/sota-implementations/crossq/utils.py similarity index 100% rename from examples/crossq/utils.py rename to sota-implementations/crossq/utils.py From be84f3fcba545541cffa91a5997c71b7e92985c5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Mar 2024 19:39:01 +0100 Subject: [PATCH 09/37] update loss --- torchrl/objectives/crossq.py | 114 +++++++++++------------------------ 1 file changed, 35 insertions(+), 79 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 1a1c9db967a..c467acd1426 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -519,7 +519,7 @@ def _cached_detached_qvalue_params(self): @_cache_values def _cached_qvalue_params(self): return torch.cat( - [self.qvalue_network_params, self.qvalue_network_params.detach()], 0 + [self.qvalue_network_params, self.qvalue_network_params], 0 # .detach() ) def _actor_loss( @@ -549,76 +549,22 @@ def _actor_loss( return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} - def _compute_target(self, tensordict) -> Tensor: - r"""Value network for CrossQ. - - CrossQ is based on a value estimate of the form: - - .. math:: - - V = Q(s,a) - \alpha * \log p(a | s) - - This class computes this value given the actor and qvalue network + def _qvalue_loss( + self, tensordict: TensorDictBase + ) -> Tuple[Tensor, Dict[str, Tensor]]: - """ - tensordict = tensordict.clone(False) - # get actions and log-probs + # # compute next action with torch.no_grad(): with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").clone(False) next_dist = self.actor_network.get_dist(next_tensordict) - next_action = next_dist.rsample() + next_action = next_dist.sample() next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) - # get q-values - next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict, self.qvalue_network_params - ) - state_action_value = next_tensordict_expand.get( - self.tensor_keys.state_action_value - ) - if ( - state_action_value.shape[-len(next_sample_log_prob.shape) :] - != next_sample_log_prob.shape - ): - next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - next_state_value = state_action_value - self._alpha * next_sample_log_prob - next_state_value = next_state_value.min(0)[0] - tensordict.set( - ("next", self.value_estimator.tensor_keys.value), next_state_value - ) - target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - return target_value.detach() - - def _qvalue_loss( - self, tensordict: TensorDictBase - ) -> Tuple[Tensor, Dict[str, Tensor]]: - # we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first. - - target_value = self._compute_target(tensordict) - tensordict_expand = self._vmap_qnetworkN0( - tensordict.select(*self.qvalue_network.in_keys, strict=False), - self.qvalue_network_params, - ) - pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( - -1 - ) - - # ############################ - # # compute next action - # with torch.no_grad(): - # with set_exploration_type( - # ExplorationType.MODE - # ), self.actor_network_params.to_module(self.actor_network): - # next_tensordict = tensordict.get("next").clone(False) - # next_dist = self.actor_network.get_dist(next_tensordict) - # next_action = next_dist.loc #.rsample() - # next_tensordict.set(self.tensor_keys.action, next_action) - # next_sample_log_prob = next_dist.log_prob(next_action) - + # TODO: we should pass them together to the qvalue network # q_values_tensordict = torch.cat( # [ # tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( @@ -645,24 +591,34 @@ def _qvalue_loss( # ], # dim=0, # ) - # # compute target value - # next_state_action_value = next_state_action_value.detach() - # if ( - # next_state_action_value.shape[-len(next_sample_log_prob.shape) :] - # != next_sample_log_prob.shape - # ): - # next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) - # next_state_action_value = ( - # next_state_action_value - self._alpha * next_sample_log_prob - # ) - # next_state_action_value = next_state_action_value.min(0)[0] - # tensordict.set( - # ("next", self.value_estimator.tensor_keys.value), next_state_action_value - # ) - # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1).detach() - # # get current q-values - # pred_val = current_state_action_value.squeeze(-1) - # ########################### + + next_state_action_value = self._vmap_qnetworkN0( + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + self.qvalue_network_params, + ).get(self.tensor_keys.state_action_value) + + current_state_action_value = self._vmap_qnetworkN0( + tensordict.select(*self.qvalue_network.in_keys, strict=False), + self.qvalue_network_params, + ).get(self.tensor_keys.state_action_value) + + # compute target value + if ( + next_state_action_value.shape[-len(next_sample_log_prob.shape) :] + != next_sample_log_prob.shape + ): + next_sample_log_prob = next_sample_log_prob.unsqueeze(-1) + next_state_action_value = next_state_action_value.min(0)[0] + next_state_action_value = ( + next_state_action_value - self._alpha * next_sample_log_prob + ).detach() + tensordict.set( + ("next", self.value_estimator.tensor_keys.value), next_state_action_value + ) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + # get current q-values + pred_val = current_state_action_value.squeeze(-1) + # compute loss td_error = abs(pred_val - target_value) loss_qval = distance_loss( From 2170ad85638d9429d0fb302effa18646c56f7a10 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 26 Mar 2024 20:11:09 +0100 Subject: [PATCH 10/37] update cat prediction --- torchrl/objectives/crossq.py | 69 ++++++++++++------------------------ 1 file changed, 22 insertions(+), 47 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index c467acd1426..134ce254916 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -331,9 +331,6 @@ def __init__( self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self._vmap_qnetwork00 = _vmap_func( - self.qvalue_network, randomness=self.vmap_randomness - ) self.reduction = reduction @property @@ -515,13 +512,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() - @property - @_cache_values - def _cached_qvalue_params(self): - return torch.cat( - [self.qvalue_network_params, self.qvalue_network_params], 0 # .detach() - ) - def _actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -564,43 +554,28 @@ def _qvalue_loss( next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) - # TODO: we should pass them together to the qvalue network - # q_values_tensordict = torch.cat( - # [ - # tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( - # self.num_qvalue_nets, *tensordict.batch_size - # ), - # next_tensordict.select( - # *self.qvalue_network.in_keys, strict=False - # ).expand(self.num_qvalue_nets, *tensordict.batch_size), - # ], - # 0, - # ) # shape (4, batch_size, *) - # q_values_tensordict = q_values_tensordict.contiguous() - - # q_values_tensordict = self._vmap_qnetwork00( - # q_values_tensordict, self._cached_qvalue_params - # ) - # # split q values - # (current_state_action_value, next_state_action_value) = q_values_tensordict.get( - # self.tensor_keys.state_action_value - # ).split( - # [ - # self.num_qvalue_nets, - # self.num_qvalue_nets, - # ], - # dim=0, - # ) - - next_state_action_value = self._vmap_qnetworkN0( - next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - self.qvalue_network_params, - ).get(self.tensor_keys.state_action_value) - - current_state_action_value = self._vmap_qnetworkN0( - tensordict.select(*self.qvalue_network.in_keys, strict=False), - self.qvalue_network_params, - ).get(self.tensor_keys.state_action_value) + # next_state_action_value = self._vmap_qnetworkN0( + # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + # self.qvalue_network_params, + # ).get(self.tensor_keys.state_action_value) + + # current_state_action_value = self._vmap_qnetworkN0( + # tensordict.select(*self.qvalue_network.in_keys, strict=False), + # self.qvalue_network_params, + # ).get(self.tensor_keys.state_action_value) + + combined = torch.cat( + [ + tensordict.select(*self.qvalue_network.in_keys, strict=False), + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + ] + ) + pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( + self.tensor_keys.state_action_value + ) + (current_state_action_value, next_state_action_value) = pred_qs.split( + tensordict.batch_size[0], dim=1 + ) # compute target value if ( From f0ac167b2535b343a9e043e5767df96a2ed46fb4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 20:02:28 +0200 Subject: [PATCH 11/37] add batchrenorm to crossq --- sota-implementations/crossq/batchrenorm.py | 96 ++++++++++++++++++++++ sota-implementations/crossq/config.yaml | 7 +- sota-implementations/crossq/utils.py | 13 +-- 3 files changed, 108 insertions(+), 8 deletions(-) create mode 100644 sota-implementations/crossq/batchrenorm.py diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py new file mode 100644 index 00000000000..7bbb71b355d --- /dev/null +++ b/sota-implementations/crossq/batchrenorm.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn + + +class BatchRenorm(nn.Module): + """ + BatchRenorm Module (https://arxiv.org/abs/1702.03275). + + BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm, + BatchRenorm utilizes running statistics to normalize batches after an initial warmup phase. + This approach reduces the impact of "outlier" batches that may occur during extended training periods, + making BatchRenorm more robust for long training runs. + + During the warmup phase, BatchRenorm functions identically to a BatchNorm layer. + + Args: + num_features (int): Number of features in the input tensor. + eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. + momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01. + r_max (float, optional): Maximum value for the scaling factor r. Default is 3.0. + d_max (float, optional): Maximum value for the bias factor d. Default is 5.0. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 5000. + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.01, + r_max=3.0, + d_max=5.0, + warmup_steps=5000, + ): + + super(BatchRenorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.r_max = r_max + self.d_max = d_max + self.warmup_steps = warmup_steps + self.step_count = 0 + + self.gamma = nn.Parameter(torch.ones(num_features)) + self.beta = nn.Parameter(torch.zeros(num_features)) + + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + + def forward(self, x): + self.step_count += 1 + + # Compute the dimensions for mean and variance calculation + dims = [i for i in range(x.dim()) if i != 1] + expand_dims = [1 if i != 1 else -1 for i in range(x.dim())] + + # Compute batch statistics + batch_mean = x.mean(dims, keepdim=True) + batch_var = x.var(dims, unbiased=False, keepdim=True) + + if self.training: + if self.step_count <= self.warmup_steps: + # Use classical BatchNorm during warmup + x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps) + else: + # Use Batch Renormalization + with torch.no_grad(): + r = torch.clamp( + batch_var / self.running_var.view(*expand_dims), + 1.0 / self.r_max, + self.r_max, + ) + d = torch.clamp( + (batch_mean - self.running_mean.view(*expand_dims)) + / torch.sqrt(self.running_var.view(*expand_dims) + self.eps), + -self.d_max, + self.d_max, + ) + + x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps) + x_hat = x_hat * r + d + + # Update running statistics + self.running_mean.mul_(1 - self.momentum).add_( + batch_mean.squeeze().detach() * self.momentum + ) + self.running_var.mul_(1 - self.momentum).add_( + batch_var.squeeze().detach() * self.momentum + ) + else: + # Use running statistics during inference + x_hat = (x - self.running_mean.view(*expand_dims)) / torch.sqrt( + self.running_var.view(*expand_dims) + self.eps + ) + + return self.gamma.view(*expand_dims) * x_hat + self.beta.view(*expand_dims) diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index fc652a06783..8c72e996136 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -32,16 +32,17 @@ optim: weight_decay: 0.0 batch_size: 256 alpha_init: 1.0 - # Adam β1 = 0.5 adam_eps: 1.0e-8 + beta1: 0.5 + beta2: 0.999 # network network: batch_norm_momentum: 0.01 - # warmup_steps: 100000 # 10^5 + warmup_steps: 100000 # 10^5 critic_hidden_sizes: [2048, 2048] actor_hidden_sizes: [256, 256] - critic_activation: tanh + critic_activation: relu actor_activation: relu default_policy_scale: 1.0 scale_lb: 0.1 diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index b4865f79c39..d195e567688 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch +from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -26,7 +27,6 @@ from torchrl.modules.distributions import TanhNormal from torchrl.objectives import CrossQLoss - # ==================================================================== # Environment utils # ----------------- @@ -120,7 +120,6 @@ def make_replay_buffer( storage=LazyMemmapStorage( buffer_size, scratch_dir=scratch_dir, - device=device, ), batch_size=batch_size, ) @@ -131,10 +130,10 @@ def make_replay_buffer( storage=LazyMemmapStorage( buffer_size, scratch_dir=scratch_dir, - device=device, ), batch_size=batch_size, ) + replay_buffer.append_transform(lambda x: x.to(device, non_blocking=True)) return replay_buffer @@ -154,10 +153,11 @@ def make_crossQ_agent(cfg, train_env, eval_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, }, } @@ -200,10 +200,11 @@ def make_crossQ_agent(cfg, train_env, eval_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": nn.BatchNorm1d, # Should be BRN (https://arxiv.org/abs/1702.03275) not sure if added to torch + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], + "warmup_steps": cfg.network.warmup_steps, }, } @@ -273,12 +274,14 @@ def make_crossQ_optimizer(cfg, loss_module): lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), ) optimizer_critic = optim.Adam( critic_params, lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.adam_eps, + betas=(cfg.optim.beta1, cfg.optim.beta2), ) optimizer_alpha = optim.Adam( [loss_module.log_alpha], From bc7675a75f17158bf0a3da7ff522ba0af17d236e Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 20:43:33 +0200 Subject: [PATCH 12/37] small fixes --- sota-implementations/crossq/crossq.py | 14 ++++++++------ sota-implementations/crossq/utils.py | 5 ++--- torchrl/objectives/crossq.py | 20 ++++++++++---------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index a1ca275663a..63c25e0535a 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -36,6 +36,8 @@ @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(cfg.network.device) + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Create logger exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) @@ -60,7 +62,7 @@ def main(cfg: "DictConfig"): # noqa: F821 train_env, eval_env = make_environment(cfg) # Create agent - model, exploration_policy = make_crossQ_agent(cfg, train_env, eval_env, device) + model, exploration_policy = make_crossQ_agent(cfg, train_env, device) # Create CrossQ loss loss_module = make_loss_module(cfg, model) @@ -140,7 +142,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampled_tensordict = sampled_tensordict.clone() # Compute loss - q_loss, *_ = loss_module._qvalue_loss(sampled_tensordict) + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) q_loss = q_loss.mean() # Update critic optimizer_critic.zero_grad() @@ -149,14 +151,14 @@ def main(cfg: "DictConfig"): # noqa: F821 q_losses.append(q_loss.detach().item()) if update_actor: - actor_loss, metadata_actor = loss_module._actor_loss( + actor_loss, metadata_actor = loss_module.actor_loss( sampled_tensordict ) actor_loss = actor_loss.mean() - alpha_loss = loss_module._alpha_loss( + alpha_loss = loss_module.alpha_loss( log_prob=metadata_actor["log_prob"] - ) - alpha_loss = alpha_loss.mean() + ).mean() + # Update actor optimizer_actor.zero_grad() actor_loss.backward() diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index d195e567688..8b990fb9931 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -142,7 +142,7 @@ def make_replay_buffer( # ----- -def make_crossQ_agent(cfg, train_env, eval_env, device): +def make_crossQ_agent(cfg, train_env, device): """Make CrossQ agent.""" # Define Actor Network in_keys = ["observation"] @@ -221,14 +221,13 @@ def make_crossQ_agent(cfg, train_env, eval_env, device): # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = train_env.fake_tensordict() td = td.to(device) for net in model: net.eval() net(td) net.train() del td - eval_env.close() return model, model[0] diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 134ce254916..a9257e3d98a 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -51,7 +51,7 @@ class CrossQLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. - + Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -212,8 +212,8 @@ class _AcceptedKeys: Defaults to ``"advantage"``. state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. - log_prob (NestedKey): The input tensordict key where the log probability is expected. - Defaults to ``"_log_prob"``. + # log_prob (NestedKey): The input tensordict key where the log probability is expected. + # Defaults to ``"_log_prob"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -228,7 +228,7 @@ class _AcceptedKeys: action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" - log_prob: NestedKey = "_log_prob" + # log_prob: NestedKey = "_log_prob" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" @@ -478,9 +478,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: else: tensordict_reshape = tensordict - loss_qvalue, value_metadata = self._qvalue_loss(tensordict_reshape) - loss_actor, metadata_actor = self._actor_loss(tensordict_reshape) - loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"]) + loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape) + loss_actor, metadata_actor = self.actor_loss(tensordict_reshape) + loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"]) tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"]) if loss_actor.shape != loss_qvalue.shape: raise RuntimeError( @@ -512,7 +512,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def _cached_detached_qvalue_params(self): return self.qvalue_network_params.detach() - def _actor_loss( + def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: with set_exploration_type( @@ -539,7 +539,7 @@ def _actor_loss( return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} - def _qvalue_loss( + def qvalue_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -604,7 +604,7 @@ def _qvalue_loss( metadata = {"td_error": td_error.detach().max(0)[0]} return loss_qval, metadata - def _alpha_loss(self, log_prob: Tensor) -> Tensor: + def alpha_loss(self, log_prob: Tensor) -> Tensor: if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) From 9543f2e5a0e49410bf6706aa6d2b807ed1dc806d Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 20:49:08 +0200 Subject: [PATCH 13/37] update docs and sota checks --- docs/source/reference/objectives.rst | 9 +++++++++ sota-check/run_crossq.sh | 26 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 sota-check/run_crossq.sh diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c2f43d8e9b6..1342183f3d6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -121,6 +121,15 @@ REDQ REDQLoss +CrossQ +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + CrossQ + IQL ---- diff --git a/sota-check/run_crossq.sh b/sota-check/run_crossq.sh new file mode 100644 index 00000000000..2ae4ea51c49 --- /dev/null +++ b/sota-check/run_crossq.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=crossq +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/crossq_%j.txt +#SBATCH --error=slurm_errors/crossq_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="crossq" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/crossq/crossq.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi From 53e35f7c71f2ba685dc2ab9add1b4004ef223d53 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 21:13:47 +0200 Subject: [PATCH 14/37] hyperparam fix --- sota-implementations/crossq/config.yaml | 6 +++--- torchrl/objectives/crossq.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index 8c72e996136..baebcd50b1f 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -28,7 +28,7 @@ optim: policy_update_delay: 3 gamma: 0.99 loss_function: l2 - lr: 3.0e-4 + lr: 1.0e-3 weight_decay: 0.0 batch_size: 256 alpha_init: 1.0 @@ -38,8 +38,8 @@ optim: # network network: - batch_norm_momentum: 0.01 - warmup_steps: 100000 # 10^5 + batch_norm_momentum: 0.99 + warmup_steps: 100000 critic_hidden_sizes: [2048, 2048] actor_hidden_sizes: [256, 256] critic_activation: relu diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index a9257e3d98a..59599cca580 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -554,6 +554,7 @@ def qvalue_loss( next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) + # TODO: separate forward pass seems faster than the combined. # next_state_action_value = self._vmap_qnetworkN0( # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), # self.qvalue_network_params, From 172e1c014be5b5a08d2e161fd9d56ae9016fc15f Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:52:22 +0200 Subject: [PATCH 15/37] test --- sota-implementations/crossq/config.yaml | 2 +- sota-implementations/crossq/utils.py | 11 ++++--- torchrl/objectives/crossq.py | 44 ++++++++++++------------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index baebcd50b1f..1dcbd3db92d 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -38,7 +38,7 @@ optim: # network network: - batch_norm_momentum: 0.99 + batch_norm_momentum: 0.01 warmup_steps: 100000 critic_hidden_sizes: [2048, 2048] actor_hidden_sizes: [256, 256] diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 8b990fb9931..dc413c36219 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import torch -from batchrenorm import BatchRenorm + +# from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -153,11 +154,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": BatchRenorm, + "norm_class": nn.BatchNorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], - "warmup_steps": cfg.network.warmup_steps, + # "warmup_steps": cfg.network.warmup_steps, }, } @@ -200,11 +201,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": BatchRenorm, + "norm_class": nn.BatchNorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], - "warmup_steps": cfg.network.warmup_steps, + # "warmup_steps": cfg.network.warmup_steps, }, } diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 59599cca580..b87e0db2af8 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -555,28 +555,28 @@ def qvalue_loss( next_sample_log_prob = next_dist.log_prob(next_action) # TODO: separate forward pass seems faster than the combined. - # next_state_action_value = self._vmap_qnetworkN0( - # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - - # current_state_action_value = self._vmap_qnetworkN0( - # tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - - combined = torch.cat( - [ - tensordict.select(*self.qvalue_network.in_keys, strict=False), - next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - ] - ) - pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( - self.tensor_keys.state_action_value - ) - (current_state_action_value, next_state_action_value) = pred_qs.split( - tensordict.batch_size[0], dim=1 - ) + next_state_action_value = self._vmap_qnetworkN0( + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + self.qvalue_network_params, + ).get(self.tensor_keys.state_action_value) + + current_state_action_value = self._vmap_qnetworkN0( + tensordict.select(*self.qvalue_network.in_keys, strict=False), + self.qvalue_network_params, + ).get(self.tensor_keys.state_action_value) + + # combined = torch.cat( + # [ + # tensordict.select(*self.qvalue_network.in_keys, strict=False), + # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + # ] + # ) + # pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( + # self.tensor_keys.state_action_value + # ) + # (current_state_action_value, next_state_action_value) = pred_qs.split( + # tensordict.batch_size[0], dim=1 + # ) # compute target value if ( From fdb7e8b3d9e338449e9f39c2578a129b0c67c1fe Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 13:40:01 +0200 Subject: [PATCH 16/37] update batch norm tests --- sota-implementations/crossq/batchrenorm.py | 6 +-- sota-implementations/crossq/config.yaml | 2 +- sota-implementations/crossq/crossq.py | 4 +- sota-implementations/crossq/utils.py | 10 ++--- torchrl/objectives/crossq.py | 44 +++++++++++----------- 5 files changed, 32 insertions(+), 34 deletions(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 7bbb71b355d..fc34b716581 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -25,11 +25,11 @@ class BatchRenorm(nn.Module): def __init__( self, num_features, - eps=1e-5, - momentum=0.01, + eps=0.01, + momentum=0.99, r_max=3.0, d_max=5.0, - warmup_steps=5000, + warmup_steps=100000, ): super(BatchRenorm, self).__init__() diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index 1dcbd3db92d..baebcd50b1f 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -38,7 +38,7 @@ optim: # network network: - batch_norm_momentum: 0.01 + batch_norm_momentum: 0.99 warmup_steps: 100000 critic_hidden_sizes: [2048, 2048] actor_hidden_sizes: [256, 256] diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 63c25e0535a..4c5b6f476db 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -135,9 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Sample from replay buffer sampled_tensordict = replay_buffer.sample() if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) + sampled_tensordict = sampled_tensordict.to(device) else: sampled_tensordict = sampled_tensordict.clone() diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index dc413c36219..f12cf4da51d 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -5,7 +5,7 @@ import torch -# from batchrenorm import BatchRenorm +from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -154,11 +154,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": nn.BatchNorm1d, + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], - # "warmup_steps": cfg.network.warmup_steps, + "warmup_steps": cfg.network.warmup_steps, }, } @@ -201,11 +201,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": nn.BatchNorm1d, + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], - # "warmup_steps": cfg.network.warmup_steps, + "warmup_steps": cfg.network.warmup_steps, }, } diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index b87e0db2af8..59599cca580 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -555,28 +555,28 @@ def qvalue_loss( next_sample_log_prob = next_dist.log_prob(next_action) # TODO: separate forward pass seems faster than the combined. - next_state_action_value = self._vmap_qnetworkN0( - next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - self.qvalue_network_params, - ).get(self.tensor_keys.state_action_value) - - current_state_action_value = self._vmap_qnetworkN0( - tensordict.select(*self.qvalue_network.in_keys, strict=False), - self.qvalue_network_params, - ).get(self.tensor_keys.state_action_value) - - # combined = torch.cat( - # [ - # tensordict.select(*self.qvalue_network.in_keys, strict=False), - # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - # ] - # ) - # pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( - # self.tensor_keys.state_action_value - # ) - # (current_state_action_value, next_state_action_value) = pred_qs.split( - # tensordict.batch_size[0], dim=1 - # ) + # next_state_action_value = self._vmap_qnetworkN0( + # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + # self.qvalue_network_params, + # ).get(self.tensor_keys.state_action_value) + + # current_state_action_value = self._vmap_qnetworkN0( + # tensordict.select(*self.qvalue_network.in_keys, strict=False), + # self.qvalue_network_params, + # ).get(self.tensor_keys.state_action_value) + + combined = torch.cat( + [ + tensordict.select(*self.qvalue_network.in_keys, strict=False), + next_tensordict.select(*self.qvalue_network.in_keys, strict=False), + ] + ) + pred_qs = self._vmap_qnetworkN0(combined, self.qvalue_network_params).get( + self.tensor_keys.state_action_value + ) + (current_state_action_value, next_state_action_value) = pred_qs.split( + tensordict.batch_size[0], dim=1 + ) # compute target value if ( From 5501d43e695a4e05cbab180c9968afc9c6793d51 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 3 Jul 2024 19:54:34 +0200 Subject: [PATCH 17/37] tests --- sota-implementations/crossq/batchrenorm.py | 94 ++++++++++++++++++++++ sota-implementations/crossq/utils.py | 14 ++-- torchrl/objectives/crossq.py | 19 +++-- 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index fc34b716581..98f83c9f95f 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -94,3 +94,97 @@ def forward(self, x): ) return self.gamma.view(*expand_dims) * x_hat + self.beta.view(*expand_dims) + + +import torch.nn as nn + + +class AdaptiveBatchRenorm(nn.Module): + def __init__( + self, + num_features, + epsilon=1e-5, + momentum=0.99, + max_r=3.0, + max_d=5.0, + warmup_steps=10000, + ): + super(AdaptiveBatchRenorm, self).__init__() + self.num_features = num_features + self.epsilon = epsilon + self.momentum = momentum + self.max_r = max_r + self.max_d = max_d + self.warmup_steps = warmup_steps + + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.float32) + ) + + def forward(self, x): + if x.dim() not in [2, 3]: + raise ValueError("AdaptiveBatchRenorm expects 2D or 3D inputs") + + if x.dim() == 3: + batch_size, seq_len, _ = x.size() + x = x.reshape(batch_size * seq_len, self.num_features) + + if self.training: + self.num_batches_tracked += 1 + + batch_mean = x.mean(dim=0) + batch_var = x.var(dim=0, unbiased=False) + + # Compute r and d factors + r = torch.clamp( + (batch_var.sqrt() / (self.running_var.sqrt() + self.epsilon)), + 1 / self.max_r, + self.max_r, + ) + d = torch.clamp( + ( + (batch_mean - self.running_mean) + / (self.running_var.sqrt() + self.epsilon) + ), + -self.max_d, + self.max_d, + ) + + # Compute warmup factor (0 during warmup, 1 after warmup) + warmup_factor = torch.clamp( + self.num_batches_tracked / self.warmup_steps, 0.0, 1.0 + ) + + # Interpolate between batch norm and renorm based on warmup factor + effective_r = 1.0 + (r - 1.0) * warmup_factor + effective_d = d * warmup_factor + + x_hat = (x - batch_mean[None, :]) * effective_r[None, :] + effective_d[ + None, : + ] + x_hat = x_hat / (batch_var[None, :] + self.epsilon).sqrt() + + # Update running statistics using Flax-style momentum + self.running_mean = ( + self.momentum * self.running_mean + (1 - self.momentum) * batch_mean + ) + self.running_var = ( + self.momentum * self.running_var + (1 - self.momentum) * batch_var + ) + + else: + x_hat = (x - self.running_mean[None, :]) / ( + self.running_var[None, :] + self.epsilon + ).sqrt() + + output = self.weight[None, :] * x_hat + self.bias[None, :] + + if x.dim() == 3: + output = output.reshape(batch_size, seq_len, self.num_features) + + return output diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index f12cf4da51d..c141f957a39 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -5,7 +5,7 @@ import torch -from batchrenorm import BatchRenorm +# from batchrenorm import AdaptiveBatchRenorm, BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -154,11 +154,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": BatchRenorm, + "norm_class": nn.BatchNorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], - "warmup_steps": cfg.network.warmup_steps, + # "warmup_steps": cfg.network.warmup_steps, }, } @@ -166,8 +166,8 @@ def make_crossQ_agent(cfg, train_env, device): dist_class = TanhNormal dist_kwargs = { - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, "tanh_loc": False, } @@ -201,11 +201,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": BatchRenorm, + "norm_class": nn.BatchNorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], - "warmup_steps": cfg.network.warmup_steps, + # "warmup_steps": cfg.network.warmup_steps, }, } diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 59599cca580..c1bf287fda4 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -526,18 +526,16 @@ def actor_loss( td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( td_q, - self._cached_detached_qvalue_params, # should we clone? - ) - min_q_logprob = ( - td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + self._cached_detached_qvalue_params, ) + min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) - if log_prob.shape != min_q_logprob.shape: + if log_prob.shape != min_q.shape: raise RuntimeError( - f"Losses shape mismatch: {log_prob.shape} and {min_q_logprob.shape}" + f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" ) - return self._alpha * log_prob - min_q_logprob, {"log_prob": log_prob.detach()} + return self._alpha * log_prob - min_q, {"log_prob": log_prob.detach()} def qvalue_loss( self, tensordict: TensorDictBase @@ -591,7 +589,12 @@ def qvalue_loss( tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_action_value ) - target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + + reward = tensordict.get(("next", self.tensor_keys.reward)) + done = tensordict.get(("next", self.tensor_keys.done)) + target_value = (reward + (~done) * 0.99 * next_state_action_value).squeeze(-1) + # get current q-values pred_val = current_state_action_value.squeeze(-1) From c47ac8431e17ed45eb54bbd1589c282898acde83 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 5 Jul 2024 10:23:37 +0200 Subject: [PATCH 18/37] cleanup --- sota-implementations/crossq/batchrenorm.py | 90 +++------------------- sota-implementations/crossq/utils.py | 10 +-- torchrl/objectives/crossq.py | 4 +- 3 files changed, 17 insertions(+), 87 deletions(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 98f83c9f95f..7d7214425f8 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -1,3 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch import torch.nn as nn @@ -15,8 +19,10 @@ class BatchRenorm(nn.Module): Args: num_features (int): Number of features in the input tensor. - eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. + + Keyword Args: momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01. + eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. r_max (float, optional): Maximum value for the scaling factor r. Default is 3.0. d_max (float, optional): Maximum value for the bias factor d. Default is 5.0. warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 5000. @@ -25,91 +31,13 @@ class BatchRenorm(nn.Module): def __init__( self, num_features, - eps=0.01, momentum=0.99, - r_max=3.0, - d_max=5.0, - warmup_steps=100000, - ): - - super(BatchRenorm, self).__init__() - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.r_max = r_max - self.d_max = d_max - self.warmup_steps = warmup_steps - self.step_count = 0 - - self.gamma = nn.Parameter(torch.ones(num_features)) - self.beta = nn.Parameter(torch.zeros(num_features)) - - self.register_buffer("running_mean", torch.zeros(num_features)) - self.register_buffer("running_var", torch.ones(num_features)) - - def forward(self, x): - self.step_count += 1 - - # Compute the dimensions for mean and variance calculation - dims = [i for i in range(x.dim()) if i != 1] - expand_dims = [1 if i != 1 else -1 for i in range(x.dim())] - - # Compute batch statistics - batch_mean = x.mean(dims, keepdim=True) - batch_var = x.var(dims, unbiased=False, keepdim=True) - - if self.training: - if self.step_count <= self.warmup_steps: - # Use classical BatchNorm during warmup - x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps) - else: - # Use Batch Renormalization - with torch.no_grad(): - r = torch.clamp( - batch_var / self.running_var.view(*expand_dims), - 1.0 / self.r_max, - self.r_max, - ) - d = torch.clamp( - (batch_mean - self.running_mean.view(*expand_dims)) - / torch.sqrt(self.running_var.view(*expand_dims) + self.eps), - -self.d_max, - self.d_max, - ) - - x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps) - x_hat = x_hat * r + d - - # Update running statistics - self.running_mean.mul_(1 - self.momentum).add_( - batch_mean.squeeze().detach() * self.momentum - ) - self.running_var.mul_(1 - self.momentum).add_( - batch_var.squeeze().detach() * self.momentum - ) - else: - # Use running statistics during inference - x_hat = (x - self.running_mean.view(*expand_dims)) / torch.sqrt( - self.running_var.view(*expand_dims) + self.eps - ) - - return self.gamma.view(*expand_dims) * x_hat + self.beta.view(*expand_dims) - - -import torch.nn as nn - - -class AdaptiveBatchRenorm(nn.Module): - def __init__( - self, - num_features, epsilon=1e-5, - momentum=0.99, max_r=3.0, max_d=5.0, warmup_steps=10000, ): - super(AdaptiveBatchRenorm, self).__init__() + super(BatchRenorm, self).__init__() self.num_features = num_features self.epsilon = epsilon self.momentum = momentum @@ -128,7 +56,7 @@ def __init__( def forward(self, x): if x.dim() not in [2, 3]: - raise ValueError("AdaptiveBatchRenorm expects 2D or 3D inputs") + raise ValueError("BatchRenorm expects 2D or 3D inputs") if x.dim() == 3: batch_size, seq_len, _ = x.size() diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index c141f957a39..1fabd994279 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -5,7 +5,7 @@ import torch -# from batchrenorm import AdaptiveBatchRenorm, BatchRenorm +from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -154,11 +154,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": nn.BatchNorm1d, + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], - # "warmup_steps": cfg.network.warmup_steps, + "warmup_steps": cfg.network.warmup_steps, }, } @@ -201,11 +201,11 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": nn.BatchNorm1d, + "norm_class": BatchRenorm, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], - # "warmup_steps": cfg.network.warmup_steps, + "warmup_steps": cfg.network.warmup_steps, }, } diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index c1bf287fda4..ccb82a6a501 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -515,6 +515,7 @@ def _cached_detached_qvalue_params(self): def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Compute the actor loss.""" with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -540,7 +541,7 @@ def actor_loss( def qvalue_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - + """Compute the CrossQ-value loss.""" # # compute next action with torch.no_grad(): with set_exploration_type( @@ -609,6 +610,7 @@ def qvalue_loss( return loss_qval, metadata def alpha_loss(self, log_prob: Tensor) -> Tensor: + """Compute the entropy loss.""" if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) From f94165ec47648f091500d870104cdff89313439d Mon Sep 17 00:00:00 2001 From: BY571 Date: Sun, 7 Jul 2024 17:34:39 +0200 Subject: [PATCH 19/37] update --- sota-implementations/crossq/batchrenorm.py | 92 +++++++++------------- sota-implementations/crossq/config.yaml | 2 +- torchrl/objectives/crossq.py | 10 +-- 3 files changed, 40 insertions(+), 64 deletions(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 7d7214425f8..93ac9195e33 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -23,62 +23,57 @@ class BatchRenorm(nn.Module): Keyword Args: momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01. eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. - r_max (float, optional): Maximum value for the scaling factor r. Default is 3.0. - d_max (float, optional): Maximum value for the bias factor d. Default is 5.0. - warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 5000. + max_r (float, optional): Maximum value for the scaling factor r. Default is 3.0. + max_d (float, optional): Maximum value for the bias factor d. Default is 5.0. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 10000. """ def __init__( self, num_features, momentum=0.99, - epsilon=1e-5, + eps=1e-5, max_r=3.0, max_d=5.0, warmup_steps=10000, ): super(BatchRenorm, self).__init__() self.num_features = num_features - self.epsilon = epsilon + self.eps = eps self.momentum = momentum self.max_r = max_r self.max_d = max_d self.warmup_steps = warmup_steps - self.register_buffer("running_mean", torch.zeros(num_features)) - self.register_buffer("running_var", torch.ones(num_features)) - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - self.register_buffer( - "num_batches_tracked", torch.tensor(0, dtype=torch.float32) + "running_mean", torch.zeros(num_features, dtype=torch.float32) + ) + self.register_buffer( + "running_var", torch.ones(num_features, dtype=torch.float32) ) + self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.int64)) + self.weight = nn.Parameter(torch.ones(num_features, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32)) - def forward(self, x): - if x.dim() not in [2, 3]: - raise ValueError("BatchRenorm expects 2D or 3D inputs") + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() >= 2 + view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) + # _v = lambda v: v.view(view_dims) - if x.dim() == 3: - batch_size, seq_len, _ = x.size() - x = x.reshape(batch_size * seq_len, self.num_features) + def _v(v): + return v.view(view_dims) - if self.training: - self.num_batches_tracked += 1 + running_std = (self.running_var + self.eps).sqrt_() - batch_mean = x.mean(dim=0) - batch_var = x.var(dim=0, unbiased=False) + if self.training: + reduce_dims = [i for i in range(x.dim()) if i != 1] + b_mean = x.mean(reduce_dims) + b_var = x.var(reduce_dims, unbiased=False) + b_std = (b_var + self.eps).sqrt_() - # Compute r and d factors - r = torch.clamp( - (batch_var.sqrt() / (self.running_var.sqrt() + self.epsilon)), - 1 / self.max_r, - self.max_r, - ) + r = torch.clamp((b_std.detach() / running_std), 1 / self.max_r, self.max_r) d = torch.clamp( - ( - (batch_mean - self.running_mean) - / (self.running_var.sqrt() + self.epsilon) - ), + (b_mean.detach() - self.running_mean) / running_std, -self.max_d, self.max_d, ) @@ -87,32 +82,17 @@ def forward(self, x): warmup_factor = torch.clamp( self.num_batches_tracked / self.warmup_steps, 0.0, 1.0 ) + r = 1.0 + (r - 1.0) * warmup_factor + d = d * warmup_factor - # Interpolate between batch norm and renorm based on warmup factor - effective_r = 1.0 + (r - 1.0) * warmup_factor - effective_d = d * warmup_factor - - x_hat = (x - batch_mean[None, :]) * effective_r[None, :] + effective_d[ - None, : - ] - x_hat = x_hat / (batch_var[None, :] + self.epsilon).sqrt() - - # Update running statistics using Flax-style momentum - self.running_mean = ( - self.momentum * self.running_mean + (1 - self.momentum) * batch_mean - ) - self.running_var = ( - self.momentum * self.running_var + (1 - self.momentum) * batch_var - ) + x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d) + unbiased_var = b_var.detach() * x.shape[1] / (x.shape[1] - 1) + self.running_var += self.momentum * (unbiased_var - self.running_var) + self.running_mean += self.momentum * (b_mean.detach() - self.running_mean) + self.num_batches_tracked += 1 else: - x_hat = (x - self.running_mean[None, :]) / ( - self.running_var[None, :] + self.epsilon - ).sqrt() - - output = self.weight[None, :] * x_hat + self.bias[None, :] - - if x.dim() == 3: - output = output.reshape(batch_size, seq_len, self.num_features) + x = (x - _v(self.running_mean)) / _v(running_std) - return output + x = _v(self.weight) * x + _v(self.bias) + return x diff --git a/sota-implementations/crossq/config.yaml b/sota-implementations/crossq/config.yaml index baebcd50b1f..1dcbd3db92d 100644 --- a/sota-implementations/crossq/config.yaml +++ b/sota-implementations/crossq/config.yaml @@ -38,7 +38,7 @@ optim: # network network: - batch_norm_momentum: 0.99 + batch_norm_momentum: 0.01 warmup_steps: 100000 critic_hidden_sizes: [2048, 2048] actor_hidden_sizes: [256, 256] diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index ccb82a6a501..6e7e6db1697 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -587,14 +587,10 @@ def qvalue_loss( next_state_action_value = ( next_state_action_value - self._alpha * next_sample_log_prob ).detach() - tensordict.set( - ("next", self.value_estimator.tensor_keys.value), next_state_action_value - ) - # target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) - reward = tensordict.get(("next", self.tensor_keys.reward)) - done = tensordict.get(("next", self.tensor_keys.done)) - target_value = (reward + (~done) * 0.99 * next_state_action_value).squeeze(-1) + target_value = self.value_estimator.value_estimate( + tensordict, next_value=next_state_action_value + ).squeeze(-1) # get current q-values pred_val = current_state_action_value.squeeze(-1) From 02c94ff4b6139f685aa735eb856c082cb21a29d4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jul 2024 09:32:41 +0200 Subject: [PATCH 20/37] update lr param --- sota-implementations/crossq/batchrenorm.py | 2 +- sota-implementations/crossq/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 93ac9195e33..81000b4fa7b 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -31,7 +31,7 @@ class BatchRenorm(nn.Module): def __init__( self, num_features, - momentum=0.99, + momentum=0.01, eps=1e-5, max_r=3.0, max_d=5.0, diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 1fabd994279..f6615689384 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -285,7 +285,7 @@ def make_crossQ_optimizer(cfg, loss_module): ) optimizer_alpha = optim.Adam( [loss_module.log_alpha], - lr=3.0e-4, + lr=cfg.optim.lr, ) return optimizer_actor, optimizer_critic, optimizer_alpha From 4b914e61b4392ca52f2a0be6659dbb07f8c61c30 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 10:28:05 +0100 Subject: [PATCH 21/37] Apply suggestions from code review --- sota-implementations/crossq/batchrenorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 81000b4fa7b..9d1a78d1135 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -37,7 +37,7 @@ def __init__( max_d=5.0, warmup_steps=10000, ): - super(BatchRenorm, self).__init__() + super().__init__() self.num_features = num_features self.eps = eps self.momentum = momentum From 7b4a69d259d860123d40a5a213fa53e5679b5973 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jul 2024 15:21:11 +0200 Subject: [PATCH 22/37] set qnet eval in actor loss --- sota-implementations/crossq/batchrenorm.py | 1 - torchrl/objectives/crossq.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sota-implementations/crossq/batchrenorm.py b/sota-implementations/crossq/batchrenorm.py index 81000b4fa7b..a355ff3ff9a 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/sota-implementations/crossq/batchrenorm.py @@ -58,7 +58,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.dim() >= 2 view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) - # _v = lambda v: v.view(view_dims) def _v(v): return v.view(view_dims) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 6e7e6db1697..ac35ab9ea98 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -524,12 +524,15 @@ def actor_loss( log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + self.qvalue_network.eval() td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( td_q, self._cached_detached_qvalue_params, ) + min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + self.qvalue_network.train() if log_prob.shape != min_q.shape: raise RuntimeError( From 35c7a98543966b061ab5b189521293da076b4863 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jul 2024 15:39:25 +0200 Subject: [PATCH 23/37] take off comment --- torchrl/objectives/crossq.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index ac35ab9ea98..948b427c94c 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -556,17 +556,6 @@ def qvalue_loss( next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) - # TODO: separate forward pass seems faster than the combined. - # next_state_action_value = self._vmap_qnetworkN0( - # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - - # current_state_action_value = self._vmap_qnetworkN0( - # tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - combined = torch.cat( [ tensordict.select(*self.qvalue_network.in_keys, strict=False), From 68a1a9f9b2503ce1d6253a0deb0c595d104d6fe1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 16:04:59 +0100 Subject: [PATCH 24/37] amend --- docs/source/reference/modules.rst | 1 + sota-implementations/crossq/utils.py | 8 +-- test/test_modules.py | 35 +++++++++++- torchrl/envs/batched_envs.py | 2 +- torchrl/envs/transforms/transforms.py | 2 +- torchrl/modules/models/__init__.py | 2 + .../modules/models/batchrenorm1d.py | 57 +++++++++++-------- torchrl/objectives/crossq.py | 3 - 8 files changed, 77 insertions(+), 33 deletions(-) rename sota-implementations/crossq/batchrenorm.py => torchrl/modules/models/batchrenorm1d.py (69%) diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ccd6cb23ed0..b46d789ed15 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -317,6 +317,7 @@ Regular modules Conv3dNet SqueezeLayer Squeeze2dLayer + BatchRenorm Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index f6615689384..26798b1ee10 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. import torch - -from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -26,6 +24,8 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.models.batchrenorm1d import BatchRenorm1d from torchrl.objectives import CrossQLoss # ==================================================================== @@ -154,7 +154,7 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": BatchRenorm, + "norm_class": BatchRenorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], @@ -201,7 +201,7 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": BatchRenorm, + "norm_class": BatchRenorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], diff --git a/test/test_modules.py b/test/test_modules.py index 59adbea653d..e3b774a0358 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -34,7 +34,14 @@ VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models import ( + BatchRenorm1d, + Conv3dNet, + ConvNet, + MLP, + NoisyLazyLinear, + NoisyLinear, +) from torchrl.modules.models.decision_transformer import ( _has_transformers, DecisionTransformer, @@ -1438,6 +1445,32 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): torch.testing.assert_close(h1, h2) +class TestBatchRenorm: + @pytest.mark.parametrize("num_steps", [0, 5]) + def test_batchrenorm(self, num_steps): + torch.manual_seed(0) + bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) + brn = BatchRenorm1d( + 5, momentum=0.1, eps=1e-5, warmup_steps=num_steps, max_d=10000, max_r=10000 + ) + bn.train() + brn.train() + data_train = torch.randn(100, 5).split(25) + data_test = torch.randn(100, 5) + for d in data_train: + _ = bn(d) + _ = brn(d) + # if num_steps == 0: + # print(a, b) + # torch.testing.assert_close(a, b) + # else: + # assert not torch.isclose(a, b).all() + + bn.eval() + brn.eval() + torch.testing.assert_close(bn(data_test), brn(data_test)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7f462782757..4241f6613a0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -26,8 +26,8 @@ LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) -from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bec76c603e6..eb9cdce923d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,7 +39,7 @@ unravel_key, unravel_key_list, ) -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._C import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn, Tensor diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index fb0cc0135b8..2a2fc6b31d3 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -6,6 +6,8 @@ from torchrl.modules.tensordict_module.common import DistributionalDQNnet +from .batchrenorm1d import BatchRenorm1d + from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import ( diff --git a/sota-implementations/crossq/batchrenorm.py b/torchrl/modules/models/batchrenorm1d.py similarity index 69% rename from sota-implementations/crossq/batchrenorm.py rename to torchrl/modules/models/batchrenorm1d.py index 9d1a78d1135..33b4df31cc7 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm1d.py @@ -6,14 +6,16 @@ import torch.nn as nn -class BatchRenorm(nn.Module): +class BatchRenorm1d(nn.Module): """ BatchRenorm Module (https://arxiv.org/abs/1702.03275). + The code is adapted from https://github.com/google-research/corenet + BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm, - BatchRenorm utilizes running statistics to normalize batches after an initial warmup phase. - This approach reduces the impact of "outlier" batches that may occur during extended training periods, - making BatchRenorm more robust for long training runs. + it utilizes running statistics to normalize batches after an initial warmup phase. + This approach reduces the impact of "outlier" batches that may occur during + extended training periods, making BatchRenorm more robust for long training runs. During the warmup phase, BatchRenorm functions identically to a BatchNorm layer. @@ -21,21 +23,27 @@ class BatchRenorm(nn.Module): num_features (int): Number of features in the input tensor. Keyword Args: - momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01. - eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. - max_r (float, optional): Maximum value for the scaling factor r. Default is 3.0. - max_d (float, optional): Maximum value for the bias factor d. Default is 5.0. - warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 10000. + momentum (float, optional): Momentum factor for computing the running mean and variance. + Defaults to ``0.01``. + eps (float, optional): Small value added to the variance to avoid division by zero. + Defaults to ``1e-5``. + max_r (float, optional): Maximum value for the scaling factor r. + Defaults to ``3.0``. + max_d (float, optional): Maximum value for the bias factor d. + Defaults to ``5.0``. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. + Defaults to ``10000``. """ def __init__( self, - num_features, - momentum=0.01, - eps=1e-5, - max_r=3.0, - max_d=5.0, - warmup_steps=10000, + num_features: int, + *, + momentum: float = 0.01, + eps: float = 1e-5, + max_r: float = 3.0, + max_d: float = 5.0, + warmup_steps: int = 10000, ): super().__init__() self.num_features = num_features @@ -56,9 +64,12 @@ def __init__( self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.dim() >= 2 + if not x.dim() >= 2: + raise ValueError( + f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}." + ) + view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) - # _v = lambda v: v.view(view_dims) def _v(v): return v.view(view_dims) @@ -79,18 +90,18 @@ def _v(v): ) # Compute warmup factor (0 during warmup, 1 after warmup) - warmup_factor = torch.clamp( - self.num_batches_tracked / self.warmup_steps, 0.0, 1.0 - ) - r = 1.0 + (r - 1.0) * warmup_factor - d = d * warmup_factor + if self.warmup_steps > 0: + warmup_factor = self.num_batches_tracked / self.warmup_steps + r = 1.0 + (r - 1.0) * warmup_factor + d = d * warmup_factor x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d) - unbiased_var = b_var.detach() * x.shape[1] / (x.shape[1] - 1) + unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1) self.running_var += self.momentum * (unbiased_var - self.running_var) self.running_mean += self.momentum * (b_mean.detach() - self.running_mean) self.num_batches_tracked += 1 + self.num_batches_tracked.clamp_max(self.warmup_steps) else: x = (x - _v(self.running_mean)) / _v(running_std) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 6e7e6db1697..2372b9e3163 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -251,7 +251,6 @@ def __init__( fixed_alpha: bool = False, target_entropy: Union[str, float] = "auto", delay_actor: bool = False, - gamma: float = None, priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -326,8 +325,6 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec - if gamma is not None: - raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) From 7fbb27d2fd129b7ad511159362e5e4ad3465a07b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 16:16:40 +0100 Subject: [PATCH 25/37] amend --- sota-implementations/crossq/utils.py | 2 +- torchrl/modules/models/__init__.py | 2 +- torchrl/modules/models/{batchrenorm1d.py => batchrenorm.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename torchrl/modules/models/{batchrenorm1d.py => batchrenorm.py} (100%) diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 26798b1ee10..b450451fec4 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -25,7 +25,7 @@ from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal -from torchrl.modules.models.batchrenorm1d import BatchRenorm1d +from torchrl.modules.models.batchrenorm import BatchRenorm1d from torchrl.objectives import CrossQLoss # ==================================================================== diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 2a2fc6b31d3..62ccf53c30a 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -6,7 +6,7 @@ from torchrl.modules.tensordict_module.common import DistributionalDQNnet -from .batchrenorm1d import BatchRenorm1d +from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise diff --git a/torchrl/modules/models/batchrenorm1d.py b/torchrl/modules/models/batchrenorm.py similarity index 100% rename from torchrl/modules/models/batchrenorm1d.py rename to torchrl/modules/models/batchrenorm.py From ff8048122211821e82b265f77e73e52de64f3e7a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 16:43:14 +0100 Subject: [PATCH 26/37] amend --- test/test_modules.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index e3b774a0358..313bc1617bd 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1457,14 +1457,13 @@ def test_batchrenorm(self, num_steps): brn.train() data_train = torch.randn(100, 5).split(25) data_test = torch.randn(100, 5) - for d in data_train: - _ = bn(d) - _ = brn(d) - # if num_steps == 0: - # print(a, b) - # torch.testing.assert_close(a, b) - # else: - # assert not torch.isclose(a, b).all() + for i, d in enumerate(data_train): + b = bn(d) + a = brn(d) + if num_steps > 0 and i == 0: + torch.testing.assert_close(a, b) + else: + assert not torch.isclose(a, b).all() bn.eval() brn.eval() From caf702e69099c775d999422be793742b2002b916 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 16:49:14 +0100 Subject: [PATCH 27/37] amend --- torchrl/modules/models/batchrenorm.py | 3 +-- torchrl/objectives/crossq.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index 33b4df31cc7..8f23aedda14 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -7,8 +7,7 @@ class BatchRenorm1d(nn.Module): - """ - BatchRenorm Module (https://arxiv.org/abs/1702.03275). + """BatchRenorm Module (https://arxiv.org/abs/1702.03275). The code is adapted from https://github.com/google-research/corenet diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index b79c032ac9e..9cffa28f4a4 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -22,7 +22,6 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, From 70e28827a6c412eee701fb7c1a89f23d72e7349d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 17:07:44 +0100 Subject: [PATCH 28/37] amend --- test/test_modules.py | 4 ++-- torchrl/modules/models/batchrenorm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 313bc1617bd..b41d7000155 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1460,10 +1460,10 @@ def test_batchrenorm(self, num_steps): for i, d in enumerate(data_train): b = bn(d) a = brn(d) - if num_steps > 0 and i == 0: + if num_steps > 0 and i < num_steps: torch.testing.assert_close(a, b) else: - assert not torch.isclose(a, b).all() + assert not torch.isclose(a, b).all(), i bn.eval() brn.eval() diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index 8f23aedda14..56aeb6a48dd 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -90,7 +90,7 @@ def _v(v): # Compute warmup factor (0 during warmup, 1 after warmup) if self.warmup_steps > 0: - warmup_factor = self.num_batches_tracked / self.warmup_steps + warmup_factor = self.num_batches_tracked // self.warmup_steps r = 1.0 + (r - 1.0) * warmup_factor d = d * warmup_factor From ccd1b7f9e64f6b623cbe04c956f31bfb6780442a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jul 2024 17:14:40 +0100 Subject: [PATCH 29/37] amend --- test/test_modules.py | 15 ++++++++++++--- torchrl/modules/models/batchrenorm.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index b41d7000155..592464f0a96 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1447,11 +1447,18 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): class TestBatchRenorm: @pytest.mark.parametrize("num_steps", [0, 5]) - def test_batchrenorm(self, num_steps): + @pytest.mark.parametrize("smooth", [False, True]) + def test_batchrenorm(self, num_steps, smooth): torch.manual_seed(0) bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) brn = BatchRenorm1d( - 5, momentum=0.1, eps=1e-5, warmup_steps=num_steps, max_d=10000, max_r=10000 + 5, + momentum=0.1, + eps=1e-5, + warmup_steps=num_steps, + max_d=10000, + max_r=10000, + smooth=smooth, ) bn.train() brn.train() @@ -1460,7 +1467,9 @@ def test_batchrenorm(self, num_steps): for i, d in enumerate(data_train): b = bn(d) a = brn(d) - if num_steps > 0 and i < num_steps: + if num_steps > 0 and ( + (i < num_steps and not smooth) or (i == 0 and smooth) + ): torch.testing.assert_close(a, b) else: assert not torch.isclose(a, b).all(), i diff --git a/torchrl/modules/models/batchrenorm.py b/torchrl/modules/models/batchrenorm.py index 56aeb6a48dd..26a2f9d50d2 100644 --- a/torchrl/modules/models/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm.py @@ -32,6 +32,10 @@ class BatchRenorm1d(nn.Module): Defaults to ``5.0``. warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Defaults to ``10000``. + smooth (bool, optional): if ``True``, the behaviour smoothly transitions from regular + batch-norm (when ``iter=0``) to batch-renorm (when ``iter=warmup_steps``). + Otherwise, the behaviour will transition from batch-norm to batch-renorm when + ``iter=warmup_steps``. Defaults to ``False``. """ def __init__( @@ -43,6 +47,7 @@ def __init__( max_r: float = 3.0, max_d: float = 5.0, warmup_steps: int = 10000, + smooth: bool = False, ): super().__init__() self.num_features = num_features @@ -51,6 +56,7 @@ def __init__( self.max_r = max_r self.max_d = max_d self.warmup_steps = warmup_steps + self.smooth = smooth self.register_buffer( "running_mean", torch.zeros(num_features, dtype=torch.float32) @@ -90,7 +96,10 @@ def _v(v): # Compute warmup factor (0 during warmup, 1 after warmup) if self.warmup_steps > 0: - warmup_factor = self.num_batches_tracked // self.warmup_steps + if self.smooth: + warmup_factor = self.num_batches_tracked / self.warmup_steps + else: + warmup_factor = self.num_batches_tracked // self.warmup_steps r = 1.0 + (r - 1.0) * warmup_factor d = d * warmup_factor From d3e0bb1a2e9a4f1c693225612911bb3911739052 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 09:07:21 +0100 Subject: [PATCH 30/37] Apply suggestions from code review --- test/test_cost.py | 3 --- torchrl/objectives/crossq.py | 49 +++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 065f9dff946..660b0b5b491 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4217,9 +4217,6 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == torch.Size([]) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) class TestCrossQ(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 9cffa28f4a4..8d28dc0e0b1 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -46,10 +46,15 @@ class CrossQLoss(LossModule): Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX + This class has three loss functions that will be called sequentially by the `forward` method: + :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can + be called by the user that order. + Args: actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. @@ -331,6 +336,10 @@ def __init__( @property def target_entropy_buffer(self): + """The target entropy. + + This value can be controlled via the `target_entropy` kwarg in the constructor. + """ return self.target_entropy @property @@ -467,6 +476,13 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns + a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -511,7 +527,17 @@ def _cached_detached_qvalue_params(self): def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Compute the actor loss.""" + """Compute the actor loss. + + + The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which requires the `log_prob` field of the `metadata` returned by this method. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. + """ with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -540,7 +566,16 @@ def actor_loss( def qvalue_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Compute the CrossQ-value loss.""" + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. + """ # # compute next action with torch.no_grad(): with set_exploration_type( @@ -594,7 +629,15 @@ def qvalue_loss( return loss_qval, metadata def alpha_loss(self, log_prob: Tensor) -> Tensor: - """Compute the entropy loss.""" + """Compute the entropy loss. + + The entropy loss should be computed last. + + Args: + log_prob: a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. + + Returns: a differentiable tensor with the entropy loss. + """ if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_prob + self.target_entropy) From 349cb289536cb30d7b7db5d5fc103debbf66a424 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 09:10:39 +0100 Subject: [PATCH 31/37] amend --- torchrl/objectives/crossq.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 8d28dc0e0b1..958e5bb6220 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -46,7 +46,7 @@ class CrossQLoss(LossModule): Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX - This class has three loss functions that will be called sequentially by the `forward` method: + This class has three loss functions that will be called sequentially by the `forward` method: :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can be called by the user that order. @@ -216,8 +216,6 @@ class _AcceptedKeys: Defaults to ``"advantage"``. state_action_value (NestedKey): The input tensordict key where the state action value is expected. Defaults to ``"state_action_value"``. - # log_prob (NestedKey): The input tensordict key where the log probability is expected. - # Defaults to ``"_log_prob"``. priority (NestedKey): The input tensordict key where the target priority is written to. Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. @@ -232,7 +230,6 @@ class _AcceptedKeys: action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" - # log_prob: NestedKey = "_log_prob" priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" @@ -476,13 +473,13 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """The forward method. + """The forward method. - Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns - a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). - To see what keys are expected in the input tensordict and what keys are expected as output, check the - class's `"in_keys"` and `"out_keys"` attributes. - """ + Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns + a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -528,10 +525,10 @@ def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: """Compute the actor loss. - - - The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which requires the `log_prob` field of the `metadata` returned by this method. - + + The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which + requires the `log_prob` field of the `metadata` returned by this method. + Args: tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. @@ -574,7 +571,8 @@ def qvalue_loss( tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields are required for this to be computed. - Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling. """ # # compute next action with torch.no_grad(): @@ -630,11 +628,11 @@ def qvalue_loss( def alpha_loss(self, log_prob: Tensor) -> Tensor: """Compute the entropy loss. - + The entropy loss should be computed last. - + Args: - log_prob: a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. + log_prob (torch.Tensor): a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. Returns: a differentiable tensor with the entropy loss. """ From 75a43e7a01247d7c6446a1411b73f86863561050 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 09:13:34 +0100 Subject: [PATCH 32/37] amend --- torchrl/objectives/crossq.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 958e5bb6220..4d6263fbf05 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -10,7 +10,7 @@ from typing import Dict, Tuple, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey @@ -238,6 +238,13 @@ class _AcceptedKeys: default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 + actor_network: ProbabilisticActor + actor_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + qvalue_network: TensorDictModule + qvalue_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + def __init__( self, actor_network: ProbabilisticActor, From abada6ce85b655ecb22d8c562971f6e7fbb1e0ff Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 10:34:00 +0200 Subject: [PATCH 33/37] fix device error --- sota-implementations/crossq/crossq.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 4c5b6f476db..43f2b7d2eb6 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -35,9 +35,13 @@ @hydra.main(version_base="1.1", config_path=".", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - device = torch.device(cfg.network.device) - if device is None: - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = cfg.network.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + device = torch.device(device) # Create logger exp_name = generate_exp_name("CrossQ", cfg.logger.exp_name) From c878b81d500369fbe754d64467bb677d9d4fae90 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 15:08:02 +0200 Subject: [PATCH 34/37] Update objective delay actor --- torchrl/objectives/crossq.py | 40 +++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 4d6263fbf05..22d35bd5799 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -76,9 +76,6 @@ class CrossQLoss(LossModule): target_entropy (float or str, optional): Target entropy for the stochastic policy. Default is "auto", where target entropy is computed as :obj:`-prod(n_actions)`. - delay_actor (bool, optional): Whether to separate the target actor - networks from the actor networks used for data collection. - Default is ``False``. priority_key (str, optional): [Deprecated, use .set_keys(priority_key=priority_key) instead] Tensordict key where to write the priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. @@ -226,6 +223,8 @@ class _AcceptedKeys: terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Will be used for the underlying value estimator. Defaults to ``"terminated"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. """ action: NestedKey = "action" @@ -234,15 +233,16 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + log_prob: NestedKey = "_log_prob" default_keys = _AcceptedKeys() default_value_estimator = ValueEstimators.TD0 actor_network: ProbabilisticActor actor_network_params: TensorDictParams - target_actor_network_params: TensorDictParams qvalue_network: TensorDictModule qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams target_qvalue_network_params: TensorDictParams def __init__( @@ -258,7 +258,6 @@ def __init__( action_spec=None, fixed_alpha: bool = False, target_entropy: Union[str, float] = "auto", - delay_actor: bool = False, priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -271,11 +270,10 @@ def __init__( self._set_deprecated_ctor_keys(priority_key=priority_key) # Actor - self.delay_actor = delay_actor self.convert_to_functional( actor_network, "actor_network", - create_target_params=self.delay_actor, + create_target_params=False, ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -511,16 +509,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_alpha": loss_alpha, "alpha": self._alpha, "entropy": entropy.detach().mean(), + **metadata_actor, + **value_metadata, } td_out = TensorDict(out, []) - td_out = td_out.named_apply( - lambda name, value: ( - _reduce(value, reduction=self.reduction) - if name.startswith("loss_") - else value - ), - batch_size=[], - ) + # td_out = td_out.named_apply( + # lambda name, value: ( + # _reduce(value, reduction=self.reduction) + # if name.startswith("loss_") + # else value + # ), + # batch_size=[], + # ) return td_out @property @@ -564,8 +564,10 @@ def actor_loss( raise RuntimeError( f"Losses shape mismatch: {log_prob.shape} and {min_q.shape}" ) - - return self._alpha * log_prob - min_q, {"log_prob": log_prob.detach()} + actor_loss = self._alpha * log_prob - min_q + return _reduce(actor_loss, reduction=self.reduction), { + "log_prob": log_prob.detach() + } def qvalue_loss( self, tensordict: TensorDictBase @@ -631,7 +633,7 @@ def qvalue_loss( loss_function=self.loss_function, ).sum(0) metadata = {"td_error": td_error.detach().max(0)[0]} - return loss_qval, metadata + return _reduce(loss_qval, reduction=self.reduction), metadata def alpha_loss(self, log_prob: Tensor) -> Tensor: """Compute the entropy loss. @@ -649,7 +651,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor: else: # placeholder alpha_loss = torch.zeros_like(log_prob) - return alpha_loss + return _reduce(alpha_loss, reduction=self.reduction) @property def _alpha(self): From f222b11b2647a4ec3e6202645911d8a43e335f28 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 15:10:07 +0200 Subject: [PATCH 35/37] Update tests not expecting target update --- test/test_cost.py | 276 ++++++++++++++++++++-------------------------- 1 file changed, 117 insertions(+), 159 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 660b0b5b491..072656be25e 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4405,33 +4405,25 @@ def _create_seq_mock_data_crossq( ) return td - @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_crossq( self, - delay_actor, num_qvalue, device, td_est, ): torch.manual_seed(self.seed) td = self._create_mock_data_crossq(device=device) - actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, num_qvalue_nets=num_qvalue, loss_function="l2", - **kwargs, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -4441,9 +4433,7 @@ def test_crossq( if td_est is not None: loss_fn.make_value_estimator(td_est) - with _check_td_steady(td), pytest.warns( - UserWarning, match="No target network updater" - ): + with _check_td_steady(td): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -4515,12 +4505,10 @@ def test_crossq( p.grad is None or p.grad.norm() == 0.0 ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("num_qvalue", [2]) @pytest.mark.parametrize("device", get_default_devices()) def test_crossq_state_dict( self, - delay_actor, num_qvalue, device, ): @@ -4529,16 +4517,11 @@ def test_crossq_state_dict( actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, num_qvalue_nets=num_qvalue, loss_function="l2", - **kwargs, ) sd = loss_fn.state_dict() loss_fn2 = CrossQLoss( @@ -4546,7 +4529,6 @@ def test_crossq_state_dict( qvalue_network=qvalue, num_qvalue_nets=num_qvalue, loss_function="l2", - **kwargs, ) loss_fn2.load_state_dict(sd) @@ -4554,10 +4536,10 @@ def test_crossq_state_dict( @pytest.mark.parametrize("separate_losses", [False, True]) def test_crossq_separate_losses( self, - device, separate_losses, - n_act=4, + device, ): + n_act = 4 torch.manual_seed(self.seed) actor, qvalue, common, td = self._create_mock_common_layer_setup(n_act=n_act) @@ -4568,85 +4550,80 @@ def test_crossq_separate_losses( num_qvalue_nets=1, separate_losses=separate_losses, ) - with pytest.warns(UserWarning, match="No target network updater has been"): - loss = loss_fn(td) + loss = loss_fn(td) - assert loss_fn.tensor_keys.priority in td.keys() + assert loss_fn.tensor_keys.priority in td.keys() - # check that losses are independent - for k in loss.keys(): - if not k.startswith("loss"): - continue - loss[k].sum().backward(retain_graph=True) - if k == "loss_actor": + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True + ) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + elif k == "loss_qvalue": + common_layers_no = len(list(common.parameters())) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True + ) + ) + if separate_losses: + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) + (p.grad is None) or (p.grad == 0).all() for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, ) assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) + (p.grad is None) or (p.grad == 0).all() for p in qvalue_layers ) - elif k == "loss_qvalue": - common_layers_no = len(list(common.parameters())) - assert all( + else: + assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) + for p in loss_fn.qvalue_network_params.values(True, True) ) - if separate_losses: - common_layers = itertools.islice( - loss_fn.qvalue_network_params.values(True, True), - common_layers_no, - ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in common_layers - ) - qvalue_layers = itertools.islice( - loss_fn.qvalue_network_params.values(True, True), - common_layers_no, - None, - ) - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in qvalue_layers - ) - else: - assert not any( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values(True, True) - ) - elif k == "loss_alpha": - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.actor_network_params.values( - include_nested=True, leaves_only=True - ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values( + include_nested=True, leaves_only=True ) - assert all( - (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.qvalue_network_params.values( - include_nested=True, leaves_only=True - ) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values( + include_nested=True, leaves_only=True ) - else: - raise NotImplementedError(k) - loss_fn.zero_grad() + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() @pytest.mark.parametrize("n", range(1, 4)) - @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8]) @pytest.mark.parametrize("device", get_default_devices()) def test_crossq_batcher( self, n, - delay_actor, num_qvalue, device, ): @@ -4656,16 +4633,11 @@ def test_crossq_batcher( actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) - kwargs = {} - if delay_actor: - kwargs["delay_actor"] = True - loss_fn = CrossQLoss( actor_network=actor, qvalue_network=qvalue, num_qvalue_nets=num_qvalue, loss_function="l2", - **kwargs, ) ms = MultiStep(gamma=0.9, n_steps=n).to(device) @@ -4675,80 +4647,68 @@ def test_crossq_batcher( torch.manual_seed(0) np.random.seed(0) - with pytest.warns( - UserWarning, - match="No target network updater has been associated with this loss module", - ): - with _check_td_steady(ms_td): - loss_ms = loss_fn(ms_td) - assert loss_fn.tensor_keys.priority in ms_td.keys() - with torch.no_grad(): - torch.manual_seed(0) # log-prob is computed with a random action - np.random.seed(0) - loss = loss_fn(td) - if n == 1: - assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) - _loss = sum( - [item for name, item in loss.items() if name.startswith("loss_")] - ) - _loss_ms = sum( - [item for name, item in loss_ms.items() if name.startswith("loss_")] - ) + with _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + for name, p in named_parameters: + if not name.startswith("target_"): assert ( - abs(_loss - _loss_ms) < 1e-3 - ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" else: - with pytest.raises(AssertionError): - assert_allclose_td(loss, loss_ms) - sum( - [item for name, item in loss_ms.items() if name.startswith("loss_")] - ).backward() - named_parameters = loss_fn.named_parameters() - for name, p in named_parameters: - if not name.startswith("target_"): - assert ( - p.grad is not None and p.grad.norm() > 0.0 - ), f"parameter {name} (shape: {p.shape}) has a null gradient" - else: - assert ( - p.grad is None or p.grad.norm() == 0.0 - ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" - # Check param update effect on targets - target_actor = [ - p.clone() - for p in loss_fn.target_actor_network_params.values( - include_nested=True, leaves_only=True - ) - ] - for p in loss_fn.parameters(): - if p.requires_grad: - p.data += torch.randn_like(p) - target_actor2 = [ - p.clone() - for p in loss_fn.target_actor_network_params.values( - include_nested=True, leaves_only=True - ) - ] + # Check param update effect on targets + target_actor = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = [ + p.clone() + for p in loss_fn.target_actor_network_params.values( + include_nested=True, leaves_only=True + ) + ] - if loss_fn.delay_actor: - assert all( - (p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2) - ) - else: - assert not any( - (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) - ) + assert not any((p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)) - # check that policy is updated after parameter update - parameters = [p.clone() for p in actor.parameters()] - for p in loss_fn.parameters(): - if p.requires_grad: - p.data += torch.randn_like(p) - assert all( - (p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()) - ) + # check that policy is updated after parameter update + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) @pytest.mark.parametrize( "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] @@ -4846,8 +4806,7 @@ def test_crossq_notensordict( # setting the seed for each loss so that drawing the random samples from value network # leads to same numbers for both runs torch.manual_seed(self.seed) - with pytest.warns(UserWarning, match="No target network updater"): - loss_val = loss(**kwargs) + loss_val = loss(**kwargs) torch.manual_seed(self.seed) @@ -4935,7 +4894,6 @@ def test_crossq_reduction(self, reduction): actor_network=actor, qvalue_network=qvalue, loss_function="l2", - delay_actor=False, reduction=reduction, ) loss_fn.make_value_estimator() From 067b5605c715b45471abe684ab80273eefb4c537 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 19:07:51 +0200 Subject: [PATCH 36/37] update example utils --- sota-implementations/crossq/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index b450451fec4..56256afdf32 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -246,7 +246,6 @@ def make_loss_module(cfg, model): qvalue_network=model[1], num_qvalue_nets=2, loss_function=cfg.optim.loss_function, - delay_actor=False, alpha_init=cfg.optim.alpha_init, ) loss_module.make_value_estimator(gamma=cfg.optim.gamma) From c010e392b28a3796798f0a8cbc1866b6e204db9d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 20:55:50 +0100 Subject: [PATCH 37/37] amend --- sota-implementations/crossq/crossq.py | 2 +- sota-implementations/crossq/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 43f2b7d2eb6..df34d4ae68d 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -72,7 +72,7 @@ def main(cfg: "DictConfig"): # noqa: F821 loss_module = make_loss_module(cfg, model) # Create off-policy collector - collector = make_collector(cfg, train_env, exploration_policy.eval()) + collector = make_collector(cfg, train_env, exploration_policy.eval(), device=device) # Create replay buffer replay_buffer = make_replay_buffer( diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index 56256afdf32..9883bc50b17 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -90,7 +90,7 @@ def make_environment(cfg): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, device): """Make collector.""" collector = SyncDataCollector( train_env, @@ -98,7 +98,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector