diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 607be49211a..89b3e5ffbf6 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -585,6 +585,51 @@ should have a considerably lower memory footprint than observations, for instanc This format eliminates any ambiguity regarding the matching of an observation with its action, info, or done state. +A note on singleton dimensions in TED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. _reward_done_singleton: + +In TorchRL, the standard practice is that `done` states (including terminated and truncated) and rewards should have a +dimension that can be expanded to match the shape of observations, states, and actions without recurring to anything +else than repetition (i.e., the reward must have as many dimensions as the observation and/or action, or their +embeddings). + +Essentially, this format is acceptable (though not strictly enforced): + + >>> print(rollout[t]) + ... TensorDict( + ... fields={ + ... action: Tensor(n_action), + ... done: Tensor(1), # The done state has a rightmost singleton dimension + ... next: TensorDict( + ... fields={ + ... done: Tensor(1), + ... observation: Tensor(n_obs), + ... reward: Tensor(1), # The reward has a rightmost singleton dimension + ... terminated: Tensor(1), + ... truncated: Tensor(1), + ... batch_size=torch.Size([]), + ... device=cpu, + ... is_shared=False), + ... observation: Tensor(n_obs), # the observation at reset + ... terminated: Tensor(1), # the terminated at reset + ... truncated: Tensor(1), # the truncated at reset + ... batch_size=torch.Size([]), + ... device=cpu, + ... is_shared=False) + +The rationale behind this is to ensure that the results of operations (such as value estimation) on observations and/or +actions have the same number of dimensions as the reward and `done` state. This consistency allows subsequent operations +to proceed without issues: + + >>> state_value = f(observation) + >>> next_state_value = state_value + reward + +Without this singleton dimension at the end of the reward, broadcasting rules (which only work when tensors can be +expanded from the left) would try to expand the reward on the left. This could lead to failures (at best) or introduce +bugs (at worst). + Flattening TED to reduce memory consumption ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 9e7df1bff8f..209380d6838 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -151,7 +151,7 @@ REDQ REDQLoss CrossQ ----- +------ .. autosummary:: :toctree: generated/ @@ -160,7 +160,7 @@ CrossQ CrossQLoss IQL ----- +--- .. autosummary:: :toctree: generated/ @@ -170,7 +170,7 @@ IQL DiscreteIQLLoss CQL ----- +--- .. autosummary:: :toctree: generated/ @@ -189,7 +189,7 @@ GAIL GAILLoss DT ----- +-- .. autosummary:: :toctree: generated/ @@ -199,7 +199,7 @@ DT OnlineDTLoss TD3 ----- +--- .. autosummary:: :toctree: generated/ @@ -208,7 +208,7 @@ TD3 TD3Loss TD3+BC ----- +------ .. autosummary:: :toctree: generated/ @@ -227,6 +227,85 @@ PPO ClipPPOLoss KLPENPPOLoss +Using PPO with multi-head action policies +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, we have a single advantage value but more than one action undertaken. Each action has its own +log-probability, and shape. For instance, it can be that the action space is structured as follows: + + >>> action_td = TensorDict( + ... action0=Tensor(batch, n_agents, f0), + ... action1=Tensor(batch, n_agents, f1, f2), + ... batch_size=torch.Size((batch,)) + ... ) + +where `f0`, `f1` and `f2` are some arbitrary integers. + +Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it +has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will +also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not +in general appear in the tensordict's batch-size. + +There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the +environment. For example, some environments have a shared done state among all agents. A more complete tensordict +would in this case look like + + >>> action_td = TensorDict( + ... action0=Tensor(batch, n_agents, f0), + ... action1=Tensor(batch, n_agents, f1, f2), + ... done=Tensor(batch, 1), + ... observation=Tensor(batch, n_agents, f3), + ... [...] # etc + ... batch_size=torch.Size((batch,)) + ... ) + +Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc ` +to learn more about this restriction. + +The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`, +:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. +When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the +beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not +be aggregated but rather written as leaves in the tensordict. + +The log-probability of our actions given their respective distributions may look like anything like + + >>> action_td = TensorDict( + ... action0_log_prob=Tensor(batch, n_agents), + ... action1_log_prob=Tensor(batch, n_agents, f1), + ... batch_size=torch.Size((batch,)) + ... ) + +or + + >>> action_td = TensorDict( + ... action0_log_prob=Tensor(batch, n_agents), + ... action1_log_prob=Tensor(batch, n_agents), + ... batch_size=torch.Size((batch,)) + ... ) + +ie, the number of dimensions of distributions log-probabilities generally varies from the sample's dimensionality to +anything inferior to that, e.g. if the distribution is multivariate -- :class:`~torch.distributions.Dirichlet` for +instance -- or an :class:`~torch.distributions.Independent` instance. +The dimension of the tensordict, on the contrary, still matches the env's / replay-buffer's batch-size. + +During a call to the PPO loss, the loss module will schematically execute the following set of operations: + + >>> def ppo(tensordict): + ... prev_log_prob = tensordict.select(*log_prob_keys) + ... action = tensordict.select(*action_keys) + ... new_log_prob = dist.log_prob(action) + ... log_weight = new_log_prob - prev_log_prob + ... advantage = tensordict.get("advantage") # computed by GAE earlier + ... # attempt to map shape + ... log_weight.batch_size = advantage.batch_size[:-1] + ... log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights + ... return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage) + +To appreciate what a PPO pipeline looks like with multihead policies, an example can be found in the library's +`example directory `__. + + A2C --- @@ -258,6 +337,7 @@ Dreamer Multi-agent objectives ----------------------- + .. currentmodule:: torchrl.objectives.multiagent These objectives are specific to multi-agent algorithms. @@ -305,6 +385,7 @@ Returns Utils ----- + .. currentmodule:: torchrl.objectives .. autosummary:: diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py index 501dceb651d..a260c457ae8 100644 --- a/examples/agents/composite_ppo.py +++ b/examples/agents/composite_ppo.py @@ -6,6 +6,7 @@ """ Multi-head Agent and PPO Loss ============================= + This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions (Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. diff --git a/test/test_cost.py b/test/test_cost.py index d73edcc85be..daa026b0ba9 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -12,6 +12,7 @@ import warnings from copy import deepcopy from dataclasses import asdict, dataclass +from typing import Optional import numpy as np import pytest @@ -43,6 +44,7 @@ from torchrl._utils import _standardize from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded from torchrl.data.postprocs.postprocs import MultiStep +from torchrl.envs import EnvBase from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type @@ -199,6 +201,70 @@ def get_devices(): return devices +class MARLEnv(EnvBase): + def __init__(self): + batch = self.batch = (3,) + super().__init__(batch_size=batch) + self.n_agents = n_agents = (4,) + self.obs_feat = obs_feat = (5,) + + self.full_observation_spec = Composite( + observation=Unbounded(batch + n_agents + obs_feat), + batch_size=batch, + ) + self.full_done_spec = Composite( + done=Unbounded(batch + (1,), dtype=torch.bool), + terminated=Unbounded(batch + (1,), dtype=torch.bool), + truncated=Unbounded(batch + (1,), dtype=torch.bool), + batch_size=batch, + ) + + self.act_feat_dirich = act_feat_dirich = ( + 10, + 2, + ) + self.act_feat_categ = act_feat_categ = (7,) + self.full_action_spec = Composite( + dirich=Unbounded(batch + n_agents + act_feat_dirich), + categ=Unbounded(batch + n_agents + act_feat_categ), + batch_size=batch, + ) + + self.full_reward_spec = Composite( + reward=Unbounded(batch + n_agents + (1,)), batch_size=batch + ) + + @classmethod + def make_composite_dist(cls): + dist_cstr = functools.partial( + CompositeDistribution, + distribution_map={ + "dirich": lambda concentration: torch.distributions.Independent( + torch.distributions.Dirichlet(concentration), 1 + ), + "categ": torch.distributions.Categorical, + }, + ) + return ProbabilisticTensorDictModule( + in_keys=["params"], + out_keys=["dirich", "categ"], + distribution_class=dist_cstr, + return_log_prob=True, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + ... + + def _reset(self, tensordic): + ... + + def _set_seed(self, seed: Optional[int]): + ... + + class LossModuleTestBase: @pytest.fixture(scope="class", autouse=True) def _composite_log_prob(self): @@ -9238,6 +9304,40 @@ def mixture_constructor(logits, loc, scale): loss = ppo(data) loss.sum(reduce=True) + def test_ppo_marl_aggregate(self): + env = MARLEnv() + + def primer(td): + params = TensorDict( + dirich=TensorDict(concentration=env.action_spec["dirich"].one()), + categ=TensorDict(logits=env.action_spec["categ"].one()), + batch_size=td.batch_size, + ) + td.set("params", params) + return td + + policy = ProbabilisticTensorDictSequential( + primer, + env.make_composite_dist(), + # return_composite=True, + ) + output = policy(env.fake_tensordict()) + assert output.shape == env.batch_size + assert output["dirich_log_prob"].shape == env.batch_size + env.n_agents + assert output["categ_log_prob"].shape == env.batch_size + env.n_agents + + output["advantage"] = output["next", "reward"].clone() + output["value_target"] = output["next", "reward"].clone() + critic = TensorDictModule( + lambda obs: obs.new_zeros((*obs.shape[:-1], 1)), + in_keys=list(env.full_observation_spec.keys(True, True)), + out_keys=["state_value"], + ) + ppo = ClipPPOLoss(actor_network=policy, critic_network=critic) + ppo.set_keys(action=list(env.full_action_spec.keys(True, True))) + assert isinstance(ppo.tensor_keys.action, list) + ppo(output) + class TestA2C(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index e4f6fce1129..63b77581ff2 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -494,7 +494,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def reset(self) -> None: pass - def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict: + def _get_entropy( + self, dist: d.Distribution, adv_shape: torch.Size + ) -> torch.Tensor | TensorDict: try: entropy = dist.entropy() except NotImplementedError: @@ -513,10 +515,12 @@ def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict: log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) entropy = -log_prob.mean(0) + if is_tensor_collection(entropy) and entropy.batch_size != adv_shape: + entropy.batch_size = adv_shape return entropy.unsqueeze(-1) def _log_weight( - self, tensordict: TensorDictBase + self, tensordict: TensorDictBase, adv_shape: torch.Size ) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]: with self.actor_network_params.to_module( @@ -541,7 +545,9 @@ def _log_weight( action = _maybe_get_or_select(tensordict, self.tensor_keys.action) prev_log_prob = _maybe_get_or_select( - tensordict, self.tensor_keys.sample_log_prob + tensordict, + self.tensor_keys.sample_log_prob, + adv_shape, ) if prev_log_prob.requires_grad: @@ -564,6 +570,8 @@ def _log_weight( "the beginning of your script to get a proper composite log-prob.", category=UserWarning, ) + if log_prob.batch_size != adv_shape: + log_prob.batch_size = adv_shape if ( is_composite and not is_tensor_collection(prev_log_prob) @@ -680,7 +688,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) - log_weight, dist, kl_approx = self._log_weight(tensordict) + log_weight, dist, kl_approx = self._log_weight( + tensordict, adv_shape=advantage.shape[:-1] + ) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) log_weight = log_weight.view(advantage.shape) @@ -688,7 +698,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_objective": -neg_loss}) td_out.set("kl_approx", kl_approx.detach().mean()) # for logging if self.entropy_bonus: - entropy = self._get_entropy(dist) + entropy = self._get_entropy(dist, adv_shape=advantage.shape[:-1]) if is_tensor_collection(entropy): # Reports the entropy of each action head. td_out.set("composite_entropy", entropy.detach()) @@ -968,7 +978,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) - log_weight, dist, kl_approx = self._log_weight(tensordict) + log_weight, dist, kl_approx = self._log_weight( + tensordict, adv_shape=advantage.shape[:-1] + ) # ESS for logging with torch.no_grad(): # In theory, ESS should be computed on particles sampled from the same source. Here we sample according @@ -995,7 +1007,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("kl_approx", kl_approx.detach().mean()) # for logging if self.entropy_bonus: - entropy = self._get_entropy(dist) + entropy = self._get_entropy(dist, adv_shape=advantage.shape[:-1]) if is_tensor_collection(entropy): # Reports the entropy of each action head. td_out.set("composite_entropy", entropy.detach()) @@ -1275,7 +1287,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) advantage = _standardize(advantage, self.normalize_advantage_exclude_dims) - log_weight, dist, kl_approx = self._log_weight(tensordict_copy) + log_weight, dist, kl_approx = self._log_weight( + tensordict_copy, adv_shape=advantage.shape[:-1] + ) neg_loss = log_weight.exp() * advantage if is_tensor_collection(neg_loss): neg_loss = _sum_td_features(neg_loss) @@ -1295,6 +1309,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: previous_log_prob = previous_dist.log_prob(x) current_log_prob = current_dist.log_prob(x) if is_tensor_collection(previous_log_prob): + if previous_log_prob.batch_size != advantage.shape[:-1]: + previous_log_prob.batch_size = ( + self.samples_mc_kl, + ) + advantage.shape[:-1] + current_log_prob.batch_size = ( + self.samples_mc_kl, + ) + advantage.shape[:-1] previous_log_prob = _sum_td_features(previous_log_prob) # Both dists have presumably the same params current_log_prob = _sum_td_features(current_log_prob) @@ -1314,7 +1335,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: ) if self.entropy_bonus: - entropy = self._get_entropy(dist) + entropy = self._get_entropy(dist, adv_shape=advantage.shape[:-1]) if is_tensor_collection(entropy): # Reports the entropy of each action head. td_out.set("composite_entropy", entropy.detach()) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 3e0b97de710..87944ada5eb 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -622,10 +622,13 @@ def _sum_td_features(data: TensorDictBase) -> torch.Tensor: return data.sum(dim="feature", reduce=True) -def _maybe_get_or_select(td, key_or_keys): +def _maybe_get_or_select(td, key_or_keys, target_shape=None): if isinstance(key_or_keys, (str, tuple)): return td.get(key_or_keys) - return td.select(*key_or_keys) + result = td.select(*key_or_keys) + if target_shape is not None and result.shape != target_shape: + result.batch_size = target_shape + return result def _maybe_add_or_extend_key( diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index 0e6cc51adf6..f57e328f582 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -115,7 +115,7 @@ import torch # Tensordict modules -from tensordict.nn import TensorDictModule +from tensordict.nn import set_composite_lp_aggregate, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import multiprocessing @@ -179,6 +179,9 @@ lmbda = 0.9 # lambda for generalised advantage estimation entropy_eps = 1e-4 # coefficient of the entropy term in the PPO loss +# disable log-prob aggregation +set_composite_lp_aggregate(False).set() + ###################################################################### # Environment # ----------- @@ -454,7 +457,6 @@ "high": env.full_action_spec_unbatched[env.action_key].space.high, }, return_log_prob=True, - log_prob_key=("agents", "sample_log_prob"), ) # we'll need the log-prob for the PPO loss ###################################################################### @@ -602,7 +604,6 @@ loss_module.set_keys( # We have to tell the loss where to find the keys reward=env.reward_key, action=env.action_key, - sample_log_prob=("agents", "sample_log_prob"), value=("agents", "state_value"), # These last 2 keys will be expanded to match the reward shape done=("agents", "done"),