Skip to content

Commit

Permalink
[BugFix] adapt log-prob TD batch-size to advantage shape in PPO
Browse files Browse the repository at this point in the history
ghstack-source-id: 8ccd12f65f4a74a42356a630e0e5a1f015337d4a
Pull Request resolved: #2756
  • Loading branch information
vmoens committed Feb 4, 2025
1 parent 2f8c118 commit d3beaa6
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 20 deletions.
45 changes: 45 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,51 @@ should have a considerably lower memory footprint than observations, for instanc
This format eliminates any ambiguity regarding the matching of an observation with
its action, info, or done state.

A note on singleton dimensions in TED
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. _reward_done_singleton:

In TorchRL, the standard practice is that `done` states (including terminated and truncated) and rewards should have a
dimension that can be expanded to match the shape of observations, states, and actions without recurring to anything
else than repetition (i.e., the reward must have as many dimensions as the observation and/or action, or their
embeddings).

Essentially, this format is acceptable (though not strictly enforced):

>>> print(rollout[t])
... TensorDict(
... fields={
... action: Tensor(n_action),
... done: Tensor(1), # The done state has a rightmost singleton dimension
... next: TensorDict(
... fields={
... done: Tensor(1),
... observation: Tensor(n_obs),
... reward: Tensor(1), # The reward has a rightmost singleton dimension
... terminated: Tensor(1),
... truncated: Tensor(1),
... batch_size=torch.Size([]),
... device=cpu,
... is_shared=False),
... observation: Tensor(n_obs), # the observation at reset
... terminated: Tensor(1), # the terminated at reset
... truncated: Tensor(1), # the truncated at reset
... batch_size=torch.Size([]),
... device=cpu,
... is_shared=False)

The rationale behind this is to ensure that the results of operations (such as value estimation) on observations and/or
actions have the same number of dimensions as the reward and `done` state. This consistency allows subsequent operations
to proceed without issues:

>>> state_value = f(observation)
>>> next_state_value = state_value + reward

Without this singleton dimension at the end of the reward, broadcasting rules (which only work when tensors can be
expanded from the left) would try to expand the reward on the left. This could lead to failures (at best) or introduce
bugs (at worst).

Flattening TED to reduce memory consumption
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
93 changes: 87 additions & 6 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ REDQ
REDQLoss

CrossQ
----
------

.. autosummary::
:toctree: generated/
Expand All @@ -160,7 +160,7 @@ CrossQ
CrossQLoss

IQL
----
---

.. autosummary::
:toctree: generated/
Expand All @@ -170,7 +170,7 @@ IQL
DiscreteIQLLoss

CQL
----
---

.. autosummary::
:toctree: generated/
Expand All @@ -189,7 +189,7 @@ GAIL
GAILLoss

DT
----
--

.. autosummary::
:toctree: generated/
Expand All @@ -199,7 +199,7 @@ DT
OnlineDTLoss

TD3
----
---

.. autosummary::
:toctree: generated/
Expand All @@ -208,7 +208,7 @@ TD3
TD3Loss

TD3+BC
----
------

.. autosummary::
:toctree: generated/
Expand All @@ -227,6 +227,85 @@ PPO
ClipPPOLoss
KLPENPPOLoss

Using PPO with multi-head action policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, we have a single advantage value but more than one action undertaken. Each action has its own
log-probability, and shape. For instance, it can be that the action space is structured as follows:

>>> action_td = TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... batch_size=torch.Size((batch,))
... )

where `f0`, `f1` and `f2` are some arbitrary integers.

Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will
also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not
in general appear in the tensordict's batch-size.

There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the
environment. For example, some environments have a shared done state among all agents. A more complete tensordict
would in this case look like

>>> action_td = TensorDict(
... action0=Tensor(batch, n_agents, f0),
... action1=Tensor(batch, n_agents, f1, f2),
... done=Tensor(batch, 1),
... observation=Tensor(batch, n_agents, f3),
... [...] # etc
... batch_size=torch.Size((batch,))
... )

Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc <reward_done_singleton>`
to learn more about this restriction.

The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
be aggregated but rather written as leaves in the tensordict.

The log-probability of our actions given their respective distributions may look like anything like

>>> action_td = TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents, f1),
... batch_size=torch.Size((batch,))
... )

or

>>> action_td = TensorDict(
... action0_log_prob=Tensor(batch, n_agents),
... action1_log_prob=Tensor(batch, n_agents),
... batch_size=torch.Size((batch,))
... )

ie, the number of dimensions of distributions log-probabilities generally varies from the sample's dimensionality to
anything inferior to that, e.g. if the distribution is multivariate -- :class:`~torch.distributions.Dirichlet` for
instance -- or an :class:`~torch.distributions.Independent` instance.
The dimension of the tensordict, on the contrary, still matches the env's / replay-buffer's batch-size.

During a call to the PPO loss, the loss module will schematically execute the following set of operations:

>>> def ppo(tensordict):
... prev_log_prob = tensordict.select(*log_prob_keys)
... action = tensordict.select(*action_keys)
... new_log_prob = dist.log_prob(action)
... log_weight = new_log_prob - prev_log_prob
... advantage = tensordict.get("advantage") # computed by GAE earlier
... # attempt to map shape
... log_weight.batch_size = advantage.batch_size[:-1]
... log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights
... return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage)

To appreciate what a PPO pipeline looks like with multihead policies, an example can be found in the library's
`example directory <https://github.com/pytorch/rl/blob/main/examples/agents/composite_ppo.py>`__.


A2C
---

Expand Down Expand Up @@ -258,6 +337,7 @@ Dreamer

Multi-agent objectives
-----------------------

.. currentmodule:: torchrl.objectives.multiagent

These objectives are specific to multi-agent algorithms.
Expand Down Expand Up @@ -305,6 +385,7 @@ Returns

Utils
-----

.. currentmodule:: torchrl.objectives

.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions examples/agents/composite_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
Multi-head Agent and PPO Loss
=============================
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
Expand Down
100 changes: 100 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Optional

import numpy as np
import pytest
Expand Down Expand Up @@ -43,6 +44,7 @@
from torchrl._utils import _standardize
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs import EnvBase
from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
Expand Down Expand Up @@ -199,6 +201,70 @@ def get_devices():
return devices


class MARLEnv(EnvBase):
def __init__(self):
batch = self.batch = (3,)
super().__init__(batch_size=batch)
self.n_agents = n_agents = (4,)
self.obs_feat = obs_feat = (5,)

self.full_observation_spec = Composite(
observation=Unbounded(batch + n_agents + obs_feat),
batch_size=batch,
)
self.full_done_spec = Composite(
done=Unbounded(batch + (1,), dtype=torch.bool),
terminated=Unbounded(batch + (1,), dtype=torch.bool),
truncated=Unbounded(batch + (1,), dtype=torch.bool),
batch_size=batch,
)

self.act_feat_dirich = act_feat_dirich = (
10,
2,
)
self.act_feat_categ = act_feat_categ = (7,)
self.full_action_spec = Composite(
dirich=Unbounded(batch + n_agents + act_feat_dirich),
categ=Unbounded(batch + n_agents + act_feat_categ),
batch_size=batch,
)

self.full_reward_spec = Composite(
reward=Unbounded(batch + n_agents + (1,)), batch_size=batch
)

@classmethod
def make_composite_dist(cls):
dist_cstr = functools.partial(
CompositeDistribution,
distribution_map={
"dirich": lambda concentration: torch.distributions.Independent(
torch.distributions.Dirichlet(concentration), 1
),
"categ": torch.distributions.Categorical,
},
)
return ProbabilisticTensorDictModule(
in_keys=["params"],
out_keys=["dirich", "categ"],
distribution_class=dist_cstr,
return_log_prob=True,
)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
...

def _reset(self, tensordic):
...

def _set_seed(self, seed: Optional[int]):
...


class LossModuleTestBase:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
Expand Down Expand Up @@ -9238,6 +9304,40 @@ def mixture_constructor(logits, loc, scale):
loss = ppo(data)
loss.sum(reduce=True)

def test_ppo_marl_aggregate(self):
env = MARLEnv()

def primer(td):
params = TensorDict(
dirich=TensorDict(concentration=env.action_spec["dirich"].one()),
categ=TensorDict(logits=env.action_spec["categ"].one()),
batch_size=td.batch_size,
)
td.set("params", params)
return td

policy = ProbabilisticTensorDictSequential(
primer,
env.make_composite_dist(),
# return_composite=True,
)
output = policy(env.fake_tensordict())
assert output.shape == env.batch_size
assert output["dirich_log_prob"].shape == env.batch_size + env.n_agents
assert output["categ_log_prob"].shape == env.batch_size + env.n_agents

output["advantage"] = output["next", "reward"].clone()
output["value_target"] = output["next", "reward"].clone()
critic = TensorDictModule(
lambda obs: obs.new_zeros((*obs.shape[:-1], 1)),
in_keys=list(env.full_observation_spec.keys(True, True)),
out_keys=["state_value"],
)
ppo = ClipPPOLoss(actor_network=policy, critic_network=critic)
ppo.set_keys(action=list(env.full_action_spec.keys(True, True)))
assert isinstance(ppo.tensor_keys.action, list)
ppo(output)


class TestA2C(LossModuleTestBase):
seed = 0
Expand Down
Loading

0 comments on commit d3beaa6

Please sign in to comment.