diff --git a/.github/unittest/linux_libs/scripts_d4rl/environment.yml b/.github/unittest/linux_libs/scripts_d4rl/environment.yml index 567ab175d7a..862a148ec87 100644 --- a/.github/unittest/linux_libs/scripts_d4rl/environment.yml +++ b/.github/unittest/linux_libs/scripts_d4rl/environment.yml @@ -17,3 +17,4 @@ dependencies: - pyyaml - scipy - hydra-core + - cython<3 diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index cd2b71a0922..98d2d40cd5c 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -43,6 +43,7 @@ We also give users the ability to compose a replay buffer using the following co Writer RoundRobinWriter TensorDictRoundRobinWriter + TensorDictMaxValueWriter Storage choice is very influential on replay buffer sampling latency, especially in distributed reinforcement learning settings with larger data volumes. :class:`LazyMemmapStorage` is highly advised in distributed settings with shared storage due to the lower serialisation cost of MemmapTensors as well as the ability to specify file storage locations for improved node failure recovery. diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 26979e2ae96..6ac6e001cb4 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -138,7 +138,7 @@ CQL CQLLoss DT ----- +-- .. autosummary:: :toctree: generated/ @@ -148,7 +148,7 @@ DT OnlineDTLoss TD3 ----- +--- .. autosummary:: :toctree: generated/ @@ -156,6 +156,15 @@ TD3 TD3Loss +TQC +--- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TQCLoss + PPO --- diff --git a/examples/tqc/tqc.py b/examples/tqc/tqc.py index f78aad721f6..fb3199fd10c 100644 --- a/examples/tqc/tqc.py +++ b/examples/tqc/tqc.py @@ -29,6 +29,7 @@ """ import time + import hydra import numpy as np import torch @@ -57,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821 exp_name = generate_exp_name("SAC", cfg.env.exp_name) logger = None # TO-DO: Add logging back in before pushing to git repo - #if cfg.logger.backend: + # if cfg.logger.backend: # logger = get_logger( # logger_type=cfg.logger.backend, # logger_name="sac_logging/wandb", @@ -190,7 +191,9 @@ def main(cfg: "DictConfig"): # noqa: F821 episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/critic_loss"] = losses.get("loss_critic").mean().item() + metrics_to_log["train/critic_loss"] = ( + losses.get("loss_critic").mean().item() + ) metrics_to_log["train/actor_loss"] = losses.get("loss_actor").mean().item() metrics_to_log["train/alpha_loss"] = losses.get("loss_alpha").mean().item() metrics_to_log["train/alpha"] = loss_td["alpha"].item() diff --git a/examples/tqc/utils.py b/examples/tqc/utils.py index 760f79408d5..9a89b9cd314 100644 --- a/examples/tqc/utils.py +++ b/examples/tqc/utils.py @@ -5,23 +5,27 @@ import tempfile from contextlib import nullcontext +from typing import Tuple + import torch -import numpy as np from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn, optim from torchrl.collectors import SyncDataCollector -from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data import ( + CompositeSpec, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) from torchrl.data.replay_buffers.storages import LazyMemmapStorage from torchrl.envs import Compose, DoubleToFloat, EnvCreator, ParallelEnv, TransformedEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.transforms import InitTracker, RewardSum, StepCounter from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import MLP, ProbabilisticActor, ValueOperator, ActorCriticWrapper +from torchrl.modules import ActorCriticWrapper, MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal -from torchrl.objectives import SoftUpdate -from torchrl.data import CompositeSpec -from torchrl.objectives.common import LossModule +from torchrl.objectives import SoftUpdate, TQCLoss from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_WARNING, @@ -30,8 +34,6 @@ ValueEstimators, ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator -from tensordict.tensordict import TensorDict, TensorDictBase -from typing import Tuple # ==================================================================== @@ -100,17 +102,17 @@ def make_collector(cfg, train_env, actor_model_explore): def make_replay_buffer( - batch_size, - prb=False, - buffer_size=1_000_000, - buffer_scratch_dir=None, - device="cpu", - prefetch=3, + batch_size, + prb=False, + buffer_size=1_000_000, + buffer_scratch_dir=None, + device="cpu", + prefetch=3, ): with ( - tempfile.TemporaryDirectory() - if buffer_scratch_dir is None - else nullcontext(buffer_scratch_dir) + tempfile.TemporaryDirectory() + if buffer_scratch_dir is None + else nullcontext(buffer_scratch_dir) ) as scratch_dir: if prb: replay_buffer = TensorDictPrioritizedReplayBuffer( @@ -155,13 +157,15 @@ def __init__(self, cfg): } for i in range(cfg.network.n_nets): net = MLP(**qvalue_net_kwargs) - self.add_module(f'critic_net_{i}', net) + self.add_module(f"critic_net_{i}", net) self.nets.append(net) def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if len(inputs) > 1: inputs = (torch.cat([*inputs], -1),) - quantiles = torch.stack(tuple(net(*inputs) for net in self.nets), dim=-2) # batch x n_nets x n_quantiles + quantiles = torch.stack( + tuple(net(*inputs) for net in self.nets), dim=-2 + ) # batch x n_nets x n_quantiles return quantiles @@ -239,172 +243,26 @@ def make_tqc_agent(cfg, train_env, eval_env, device): return model, model[0] -# ==================================================================== -# Quantile Huber Loss -# ------------------- - - -def quantile_huber_loss_f(quantiles, samples): - """ - Quantile Huber loss from the original PyTorch TQC implementation. - See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py - - quantiles is assumed to be of shape [batch size, n_nets, n_quantiles] - samples is assumed to be of shape [batch size, n_samples] - Arbitrary batch sizes are allowed. - """ - pairwise_delta = samples[..., None, None, :] - quantiles[..., None] # batch x n_nets x n_quantiles x n_samples - abs_pairwise_delta = torch.abs(pairwise_delta) - huber_loss = torch.where(abs_pairwise_delta > 1, - abs_pairwise_delta - 0.5, - pairwise_delta ** 2 * 0.5) - n_quantiles = quantiles.shape[-1] - tau = torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + 1 / 2 / n_quantiles - loss = (torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean() - return loss - - # ==================================================================== # TQC Loss # -------- -class TQCLoss(LossModule): - def __init__( - self, - actor_network, - qvalue_network, - gamma, - top_quantiles_to_drop, - alpha_init, - device - ): - super().__init__() - - self.convert_to_functional( - actor_network, - "actor", - create_target_params=False, - funs_to_decorate=["forward", "get_dist"], - ) - - self.convert_to_functional( - qvalue_network, - "critic", - create_target_params=True # Create a target critic network - ) - - self.device = device - self.log_alpha = torch.tensor([np.log(alpha_init)], requires_grad=True, device=self.device) - self.gamma = gamma - self.top_quantiles_to_drop = top_quantiles_to_drop - - # Compute target entropy - action_spec = getattr(self.actor, "spec", None) - if action_spec is None: - print("Could not deduce action spec from actor network.") - if not isinstance(action_spec, CompositeSpec): - action_spec = CompositeSpec({"action": action_spec}) - action_container_len = len(action_spec.shape) - self.target_entropy = -float(action_spec["action"].shape[action_container_len:].numel()) - - def value_loss(self, tensordict): - td_next = tensordict.get("next") - reward = td_next.get("reward") - not_done = tensordict.get("done").logical_not() - alpha = torch.exp(self.log_alpha) - - # Q-loss - with torch.no_grad(): - # get policy action - self.actor(td_next, params=self.actor_params) - self.critic(td_next, params=self.target_critic_params) - - next_log_pi = td_next.get("sample_log_prob") - next_log_pi = torch.unsqueeze(next_log_pi, dim=-1) - - # compute and cut quantiles at the next state - next_z = td_next.get("state_action_value") - sorted_z, _ = torch.sort(next_z.reshape(*tensordict.batch_size, -1)) - sorted_z_part = sorted_z[..., :-self.top_quantiles_to_drop] - - # compute target - # --- Note --- - # This is computed manually here, since the built-in value estimators in the library - # currently do not support a critic of a shape different from the reward. - # ------------ - target = reward + not_done * self.gamma * (sorted_z_part - alpha * next_log_pi) - - self.critic(tensordict, params=self.critic_params) - cur_z = tensordict.get("state_action_value") - critic_loss = quantile_huber_loss_f(cur_z, target) - return critic_loss - - def actor_loss(self, tensordict): - alpha = torch.exp(self.log_alpha) - self.actor(tensordict, params=self.actor_params) - self.critic(tensordict, params=self.critic_params) - new_log_pi = tensordict.get("sample_log_prob") - actor_loss = (alpha * new_log_pi - tensordict.get("state_action_value").mean(-1).mean(-1, keepdim=True)).mean() - return actor_loss, new_log_pi - - def alpha_loss(self, log_prob): - alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean() - return alpha_loss - - def entropy(self, tensordict): - with set_exploration_type(ExplorationType.RANDOM): - dist = self.actor.get_dist( - tensordict, - params=self.actor_params, - ) - a_reparm = dist.rsample() - log_prob = dist.log_prob(a_reparm).detach() - entropy = -log_prob.mean() - return entropy - - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - alpha = torch.exp(self.log_alpha) - critic_loss = self.value_loss(tensordict) - actor_loss, log_prob = self.actor_loss(tensordict) # Compute actor loss AFTER critic loss - alpha_loss = self.alpha_loss(log_prob) - entropy = self.entropy(tensordict) - - return TensorDict( - { - "loss_critic": critic_loss, - "loss_actor": actor_loss, - "loss_alpha": alpha_loss, - "alpha": alpha, - "entropy": entropy, - }, - batch_size=[] - ) - - def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): - """ - This is a dummy function, which simply checks if the value type is TD0 and raises - an error if the value type is different. As of writing of this, the value estimators - in the library do not support a critic shape different from the reward state, which - is however necessary by construction for TQC. Therefore, this function does not - actually construct a value estimator, and the value is estimated "by hand" in the - value_loss function above. - """ - if value_type is not ValueEstimators.TD0: - raise NotImplementedError(f"Value type {value_type} is not currently implemented.") - def make_loss_module(cfg, model): """Make loss module and target network updater.""" # Create TQC loss + top_quantiles_to_drop = ( + cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets + ) loss_module = TQCLoss( actor_network=model[0], qvalue_network=model[1], - device=cfg.network.device, - gamma=cfg.optim.gamma, - top_quantiles_to_drop=cfg.network.top_quantiles_to_drop_per_net * cfg.network.n_nets, - alpha_init=cfg.optim.alpha_init + top_quantiles_to_drop=top_quantiles_to_drop, + alpha_init=cfg.optim.alpha_init, + ) + loss_module.make_value_estimator( + value_type=ValueEstimators.TD0, gamma=cfg.optim.gamma ) - loss_module.make_value_estimator(value_type=ValueEstimators.TD0) # Define Target Network Updater target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) diff --git a/test/test_cost.py b/test/test_cost.py index 5c1a7dbc41c..ef0d1acacde 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -113,6 +113,7 @@ QMixerLoss, SACLoss, TD3Loss, + TQCLoss, ) from torchrl.objectives.common import LossModule from torchrl.objectives.deprecated import DoubleREDQLoss_deprecated, REDQLoss_deprecated @@ -2440,6 +2441,594 @@ def test_td3_notensordict( assert loss_qvalue == loss_val_td["loss_qvalue"] +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestTQC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + module = TensorDictModule( + net, in_keys=[observation_key], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + in_keys=["loc", "scale"], + spec=action_spec, + return_log_prob=True, + ) + return actor.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + action_key="action", + observation_key="observation", + ): + # Actor + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + value = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + return_log_prob=True, + ), + ) + value_head = Mod( + value, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + value = Seq(common, value_head) + return actor, value, common, td + + def _create_mock_data_td3( + self, + batch=8, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", + terminated_key="terminated", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_td3( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next": { + "observation": next_obs * mask.to(obs.dtype), + "reward": reward * mask.to(obs.dtype), + "done": done, + "terminated": terminated, + }, + "collector": {"mask": mask}, + "action": action * mask.to(obs.dtype), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("td_est", [ValueEstimators.TD0, None]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_tqc( + self, + device, + td_est, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3(device=device) + if use_action_spec: + action_spec = actor.spec + else: + action_spec = None + loss_fn = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with _check_td_steady(td): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + elif k == "loss_critic": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + elif k == "loss_alpha": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert loss_fn.log_alpha.grad.norm() > 0 + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum([item for _, item in loss.items()]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_tqc_state_dict( + self, + device, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + else: + bounds = (-1, 1) + loss_fn = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + sd = loss_fn.state_dict() + loss_fn2 = TQCLoss( + actor, + value, + action_spec=action_spec, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_tqc_separate_losses( + self, + device, + separate_losses, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act) + loss_fn = TQCLoss( + actor, + value, + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + loss_function="l2", + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.crtic_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_params.values(True, True) + ) + if separate_losses: + common_layers_no = len(list(common.parameters())) + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + critic_layers = itertools.islice( + loss_fn.critic_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in critic_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.critic_params.values(True, True) + ) + + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("n", list(range(4))) + @pytest.mark.parametrize("device", get_default_devices()) + def test_tqc_batcher( + self, n, delay_actor, delay_qvalue, device, policy_noise, noise_clip, gamma=0.9 + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_td3(device=device) + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if (delay_qvalue or delay_actor) + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + if n == 0: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum([item for _, item in loss.items()]) + _loss_ms = sum([item for _, item in loss_ms.items()]) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + + sum([item for _, item in loss_ms.items()]).backward() + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = loss_fn.target_actor_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_critic_params.clone().values( + include_nested=True, leaves_only=True + ) + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = loss_fn.target_actor_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_critic_params.clone().values( + include_nested=True, leaves_only=True + ) + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_tqc_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["state_action_value_test"]) + loss_fn = TQCLoss( + actor, + value, + action_spec=actor.spec, + ) + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_tqc_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3( + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) + loss_val_td = loss(td) + torch.manual_seed(0) + loss_val = loss(**kwargs) + for i in loss_val: + assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" + + for i, key in enumerate(loss.out_keys): + torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) + if torch.__version__ >= "2.0.0": + loss_actor, loss_qvalue = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_qvalue = loss(**kwargs) + return + + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) diff --git a/test/test_rb.py b/test/test_rb.py index 8e894f45c3e..0b465c0b424 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -38,7 +38,10 @@ ListStorage, TensorStorage, ) -from torchrl.data.replay_buffers.writers import RoundRobinWriter +from torchrl.data.replay_buffers.writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, +) from torchrl.envs.transforms.transforms import ( BinarizeReward, CatFrames, @@ -1209,6 +1212,65 @@ def test_load_state_dict(self, storage_in, storage_out, init_out): assert (s.exclude("index") == 1).all() +@pytest.mark.parametrize("size", [20, 25, 30]) +@pytest.mark.parametrize("batch_size", [1, 10, 15]) +@pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) +def test_max_value_writer(size, batch_size, reward_ranges): + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward1).all() + assert (0 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward2).all() + assert (max_reward1 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), + "obs": torch.tensor(torch.rand(size)), + }, + batch_size=size, + device="cpu", + ) + + for sample in td: + rb.add(sample) + + sample = rb.sample() + assert (sample.get("key") <= max_reward3).all() + assert (max_reward2 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 4c90146ac7f..9a12749b482 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -14,6 +14,7 @@ ReplayBuffer, RoundRobinWriter, Storage, + TensorDictMaxValueWriter, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer, TensorDictRoundRobinWriter, diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index e27dd8572d8..6be80e26c1f 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -23,4 +23,9 @@ Storage, TensorStorage, ) -from .writers import RoundRobinWriter, TensorDictRoundRobinWriter, Writer +from .writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, + Writer, +) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 5d21d202eae..cfc6c90bb2c 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -718,12 +718,13 @@ def add(self, data: TensorDictBase) -> int: data_add = data index = super()._add(data_add) - if is_tensor_collection(data_add): - data_add.set("index", index) + if index is not None: + if is_tensor_collection(data_add): + data_add.set("index", index) - # priority = self._get_priority(data) - # if priority: - self.update_tensordict_priority(data_add) + # priority = self._get_priority(data) + # if priority: + self.update_tensordict_priority(data_add) return index def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 49244262f4e..8a71c5927a1 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import heapq from abc import ABC, abstractmethod from typing import Any, Dict, Sequence @@ -92,3 +93,128 @@ def extend(self, data: Sequence) -> torch.Tensor: data["index"] = index self._storage[index] = data return index + + +class TensorDictMaxValueWriter(Writer): + """A Writer class for composable replay buffers that keeps the top elements based on some ranking key. + + If rank_key is not provided, the key will be ``("next", "reward")``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer, TensorDictMaxValueWriter + >>> from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + >>> rb = TensorDictReplayBuffer( + ... storage=LazyTensorStorage(1), + ... sampler=SamplerWithoutReplacement(), + ... batch_size=1, + ... writer=TensorDictMaxValueWriter(rank_key="key"), + ... ) + >>> td = TensorDict({ + ... "key": torch.tensor(range(10)), + ... "obs": torch.tensor(range(10)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 9 + >>> td = TensorDict({ + ... "key": torch.tensor(range(10, 20)), + ... "obs": torch.tensor(range(10, 20)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 19 + >>> td = TensorDict({ + ... "key": torch.tensor(range(10)), + ... "obs": torch.tensor(range(10)) + ... }, batch_size=10) + >>> rb.extend(td) + >>> print(rb.sample().get("obs").item()) + 19 + """ + + def __init__(self, rank_key=None, **kwargs) -> None: + super().__init__(**kwargs) + self._cursor = 0 + self._current_top_values = [] + self._rank_key = rank_key + if self._rank_key is None: + self._rank_key = ("next", "reward") + + def get_insert_index(self, data: Any) -> int: + """Returns the index where the data should be inserted, or ``None`` if it should not be inserted.""" + if data.batch_dims > 1: + raise RuntimeError( + "Expected input tensordict to have no more than 1 dimension, got" + f"tensordict.batch_size = {data.batch_size}" + ) + + ret = None + rank_data = data.get(("_data", self._rank_key)) + + # If time dimension, sum along it. + rank_data = rank_data.sum(-1).item() + + if rank_data is None: + raise KeyError(f"Rank key {self._rank_key} not found in data.") + + # If the buffer is not full, add the data + if len(self._current_top_values) < self._storage.max_size: + + ret = self._cursor + self._cursor = (self._cursor + 1) % self._storage.max_size + + # Add new reward to the heap + heapq.heappush(self._current_top_values, (rank_data, ret)) + + # If the buffer is full, check if the new data is better than the worst data in the buffer + elif rank_data > self._current_top_values[0][0]: + + # retrieve position of the smallest value + min_sample = heapq.heappop(self._current_top_values) + ret = min_sample[1] + + # Add new reward to the heap + heapq.heappush(self._current_top_values, (rank_data, ret)) + + return ret + + def add(self, data: Any) -> int: + """Inserts a single element of data at an appropriate index, and returns that index. + + The data passed to this module should be structured as :obj:`[]` or :obj:`[T]` where + :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed + over the time dimension. + """ + index = self.get_insert_index(data) + if index is not None: + data.set("index", index) + self._storage[index] = data + return index + + def extend(self, data: Sequence) -> None: + """Inserts a series of data points at appropriate indices. + + The data passed to this module should be structured as :obj:`[B]` or :obj:`[B, T]` where :obj:`B` is + the batch size, :obj:`T` the time dimension. If the data is a trajectory, the rank key will be summed over the + time dimension. + """ + data_to_replace = {} + for i, sample in enumerate(data): + index = self.get_insert_index(sample) + if index is not None: + data_to_replace[index] = i + + # Replace the data in the storage all at once + keys, values = zip(*data_to_replace.items()) + if len(keys) > 0: + index = data.get("index") + values = list(values) + keys = index[values] = torch.tensor(keys, dtype=index.dtype) + data.set("index", index) + self._storage[keys] = data[values] + + def _empty(self) -> None: + self._cursor = 0 + self._current_top_values = [] diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 023b22ba3c4..88cdf397df1 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -17,6 +17,7 @@ from .reinforce import ReinforceLoss from .sac import DiscreteSACLoss, SACLoss from .td3 import TD3Loss +from .tqc import TQCLoss from .utils import ( default_value_kwargs, distance_loss, diff --git a/torchrl/objectives/tqc.py b/torchrl/objectives/tqc.py new file mode 100644 index 00000000000..3bbebd73d82 --- /dev/null +++ b/torchrl/objectives/tqc.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from dataclasses import dataclass + +import numpy as np +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import InteractionType, set_interaction_type, TensorDictModule +from tensordict.utils import NestedKey +from torchrl.data import CompositeSpec +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ValueEstimators + + +class TQCLoss(LossModule): + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"advantage"``. + value (NestedKey): The input tensordict key where the state value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_value"``. + state_action_value (NestedKey): The input tensordict key where the + state action value is expected. Defaults to ``"state_action_value"``. + log_prob (NestedKey): The input tensordict key where the log probability is expected. + Defaults to ``"_log_prob"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + value: NestedKey = "state_value" + state_action_value: NestedKey = "state_action_value" + log_prob: NestedKey = "sample_log_prob" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + + def __init__( + self, + actor_network: TensorDictModule, + qvalue_network: TensorDictModule, + top_quantiles_to_drop: float = 10, + alpha_init: float = 1.0, + # no need to pass device, should be handled by actor/qvalue nets + # device: torch.device, + # gamma should be passed to the value estimator construction + # for consistency with other losses + # gamma: float=None, + target_entropy=None, + action_spec=None, + ): + super().__init__() + + self.convert_to_functional( + actor_network, + "actor", + create_target_params=False, + funs_to_decorate=["forward", "get_dist"], + ) + + self.convert_to_functional( + qvalue_network, + "critic", + create_target_params=True, # Create a target critic network + ) + + # self.device = device + for p in self.parameters(): + device = p.device + break + else: + # this should never be reached unless both network have 0 parameter + raise RuntimeError + self.log_alpha = torch.nn.Parameter( + torch.tensor([np.log(alpha_init)], requires_grad=True, device=device) + ) + self.top_quantiles_to_drop = top_quantiles_to_drop + self.target_entropy = target_entropy + self._action_spec = action_spec + self.make_value_estimator() + + @property + def target_entropy(self): + target_entropy = self.__dict__.get("_target_entropy", None) + if target_entropy is None: + # Compute target entropy + action_spec = self._action_spec + if action_spec is None: + action_spec = getattr(self.actor, "spec", None) + if action_spec is None: + raise RuntimeError( + "Could not deduce action spec neither from " + "the actor network nor from the constructor kwargs. " + "Please provide the target entropy during construction." + ) + if not isinstance(action_spec, CompositeSpec): + action_spec = CompositeSpec({self.tensor_keys.action: action_spec}) + action_container_len = len(action_spec.shape) + + target_entropy = -float( + action_spec[self.tensor_keys.action] + .shape[action_container_len:] + .numel() + ) + self.target_entropy = target_entropy + return target_entropy + + @target_entropy.setter + def target_entropy(self, value): + if value is not None: + value = float(value) + self._target_entropy = value + + @property + def alpha(self): + return self.log_alpha.exp().detach() + + def value_loss(self, tensordict): + tensordict_copy = tensordict.clone(False) + td_next = tensordict_copy.get("next") + reward = td_next.get(self.tensor_keys.reward) + not_done = td_next.get(self.tensor_keys.done).logical_not() + alpha = self.alpha + + # Q-loss + with torch.no_grad(): + # get policy action + self.actor(td_next, params=self.actor_params) + self.critic(td_next, params=self.target_critic_params) + next_log_pi = td_next.get(self.tensor_keys.log_prob) + next_log_pi = torch.unsqueeze(next_log_pi, dim=-1) + + # compute and cut quantiles at the next state + next_z = td_next.get(self.tensor_keys.state_action_value) + sorted_z, _ = torch.sort(next_z.reshape(*tensordict_copy.batch_size, -1)) + sorted_z_part = sorted_z[..., : -self.top_quantiles_to_drop] + + # compute target + # --- Note --- + # This is computed manually here, since the built-in value estimators in the library + # currently do not support a critic of a shape different from the reward. + # ------------ + target = reward + not_done * self.gamma * ( + sorted_z_part - alpha * next_log_pi + ) + + self.critic(tensordict_copy, params=self.critic_params) + cur_z = tensordict_copy.get(self.tensor_keys.state_action_value) + critic_loss = quantile_huber_loss_f(cur_z, target) + metadata = {} + return critic_loss, metadata + + def actor_loss(self, tensordict): + tensordict_copy = tensordict.clone(False) + alpha = self.alpha + self.actor(tensordict_copy, params=self.actor_params) + self.critic(tensordict_copy, params=self.critic_params) + new_log_pi = tensordict_copy.get(self.tensor_keys.log_prob) + tensordict.set(self.tensor_keys.log_prob, new_log_pi) + actor_loss = ( + alpha * new_log_pi + - tensordict_copy.get(self.tensor_keys.state_action_value) + .mean(-1) + .mean(-1, keepdim=True) + ).mean() + metadata = {} + return actor_loss, metadata + + def alpha_loss(self, tensordict): + log_prob = tensordict.get(self.tensor_keys.log_prob) + alpha_loss = -self.log_alpha * (log_prob + self.target_entropy).detach().mean() + return alpha_loss, {} + + def entropy(self, tensordict): + with set_exploration_type(ExplorationType.RANDOM): + dist = self.actor.get_dist( + tensordict, + params=self.actor_params, + ) + a_reparm = dist.rsample() + log_prob = dist.log_prob(a_reparm).detach() + entropy = -log_prob.mean() + return entropy + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + critic_loss, metadata_value = self.value_loss(tensordict) + actor_loss, metadata_actor = self.actor_loss( + tensordict + ) # Compute actor loss AFTER critic loss + alpha_loss, metadata_alpha = self.alpha_loss(tensordict) + metadata = { + "alpha": self.alpha, + "entropy": self.entropy(tensordict), + } + metadata.update(metadata_alpha) + metadata.update(metadata_value) + metadata.update(metadata_actor) + losses = { + "loss_critic": critic_loss, + "loss_actor": actor_loss, + "loss_alpha": alpha_loss, + } + losses.update(metadata) + return TensorDict(losses, batch_size=[]) + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + """Value estimator settor for TQC. + + The only value estimator supported is ``ValueEstimators.TD0``. + + This method can also be used to set the ``gamma`` factor. + + Args: + value_type (ValueEstimators, optional): the value estimator to be used. + Will raise an exception if it differs from ``ValueEstimators.TD0``. + gamma (float, optional): the gamma factor for the target computation. + Defaults to 0.99. + """ + if value_type not in (ValueEstimators.TD0, None): + raise NotImplementedError( + f"Value type {value_type} is not currently implemented." + ) + self.gamma = hyperparams.pop("gamma", 0.99) + + +# ==================================================================== +# Quantile Huber Loss +# ------------------- + + +def quantile_huber_loss_f(quantiles, samples): + """ + Quantile Huber loss from the original PyTorch TQC implementation. + See: https://github.com/SamsungLabs/tqc_pytorch/blob/master/tqc/functions.py + + quantiles is assumed to be of shape [batch size, n_nets, n_quantiles] + samples is assumed to be of shape [batch size, n_samples] + Arbitrary batch sizes are allowed. + """ + pairwise_delta = ( + samples[..., None, None, :] - quantiles[..., None] + ) # batch x n_nets x n_quantiles x n_samples + abs_pairwise_delta = torch.abs(pairwise_delta) + huber_loss = torch.where( + abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5 + ) + n_quantiles = quantiles.shape[-1] + tau = ( + torch.arange(n_quantiles, device=quantiles.device).float() / n_quantiles + + 1 / 2 / n_quantiles + ) + loss = ( + torch.abs(tau[..., None, :, None] - (pairwise_delta < 0).float()) * huber_loss + ).mean() + return loss