Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Stacked univariate distributions incompatible with PPOLoss and/or in composites with multivariate distributions #2753

Closed
3 tasks done
rerz opened this issue Feb 3, 2025 · 2 comments · Fixed by #2756
Assignees
Labels
bug Something isn't working

Comments

@rerz
Copy link

rerz commented Feb 3, 2025

Describe the bug

Some pytorch distributions like dirichlet are multivariate by nature. Some others like beta or gamma distributions are univariate by default.
However they can also be initialized with multivariate parameters which will cause torch to stack a number of independent univariate distributions.

This does not seem to work properly with the composite distribution and ppo loss in torchrl. Using such stacked univariate distributions leads to a shape mismatch when calculating the ppo loss as their log_prob and log_weight shapes are 4-dimensional while the natively multivariate distributions' log shapes are 3-dimensional.

Perhaps this issue only occurs when mixing stacked and native multivariate distributions but i have not tested this.

To Reproduce

Create a probabilistic actor with a composite distribution using stacked univariate distributions.

A runnable example can be found here: https://github.com/rerz/rltest

I also tried splitting the stacked distributions into multiple ones with just one set of parameters but this still produces 4-dimensional log shapes.

It also contains a very unclean workaround in the forward method of PPOLoss which basically just removes this extra dimension and changes the reduction of the log weights to end up with a shape that is compatible with the advantage tensor.

Expected behavior

I can use any distribution no matter if it is univariate or multivariate by default.

System info

torchrl and tensordict from the current main branches.

Reason and Possible fixes

A workaround can be found in the linked repo above. While the fix is unclean i believe to have found the root cause of the issue.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@rerz rerz added the bug Something isn't working label Feb 3, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 4, 2025

Here is the current status, maybe you could help me comment what should the expected behaviour be:

import functools

import torch
from tensordict import TensorDict
from torch.distributions import *
from tensordict.nn import CompositeDistribution, set_composite_lp_aggregate

set_composite_lp_aggregate(False).set()

d0 = Independent(Dirichlet(torch.ones(100, 10)), 1)
print(d0.log_prob(d0.sample()).shape) # torch.Size([])
d1 = Categorical(probs=torch.ones(10)/10)
print(d1.log_prob(d1.sample()).shape) # torch.Size([])

td = TensorDict(dirich=TensorDict(concentration=torch.ones(100, 10)), categ=TensorDict(probs=torch.ones(10)/10))
c = CompositeDistribution(params=td, distribution_map={
    "dirich": functools.partial(Dirichlet),
    "categ": functools.partial(Categorical),
})

print(c.log_prob(c.sample()))
# TensorDict(
#     fields={
#         categ_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
#         dirich_log_prob: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.float32, is_shared=False)},
#     batch_size=torch.Size([]),
#     device=None,
#     is_shared=False)


c = CompositeDistribution(params=td, distribution_map={
    "dirich": lambda concentration: Independent(Dirichlet(concentration), 1),
    "categ": functools.partial(Categorical),
})

print(c.log_prob(c.sample()))
# TensorDict(
#     fields={
#         categ_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
#         dirich_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
#     batch_size=torch.Size([]),
#     device=None,
#     is_shared=False)

# With a TD of shape (10,)
c = CompositeDistribution(params=td.expand(10), distribution_map={
    "dirich": lambda concentration: Independent(Dirichlet(concentration), 1),
    "categ": functools.partial(Categorical),
})

# TensorDict(
#     fields={
#         categ_log_prob: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
#         dirich_log_prob: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
#     batch_size=torch.Size([10]),
#     device=None,
#     is_shared=False)

print(c.log_prob(c.sample()))

So far I don't see any shape mismatch.

If we call _sum_td_features like it's done in PPO we get

print(_sum_td_features(c.log_prob(c.sample())))
# prints a tensor of shape (10,)

which also seems ok to me.

Based on these, do you feel you could write a minimal reprod of the problem you're facing? I didn't really grasp what the issue was.

@vmoens
Copy link
Contributor

vmoens commented Feb 4, 2025

I think I understood what the problem was and what solution to bring.

Context

In general, PPO and other similar losses need to multiply an advantage by (some transform of) a log-prob.
The advantage can have any shape but in general we make it so that it has shape (*batch, feature_dim) where more often than not feature_dim=1.
The reason we have a feature_dim in the first place is a torchrl convention: rewards and done states must have a non-empty feature shape to match the observation, such that we can do f(obs, action) + reward (or similar) safely. If we hadn't that extra dim, we would occasionally broadcast: in the best cases, an exception is raised because reward doesn't have the proper shape, in the worst reward can broadcast and it ends up being expanded to the f(obs, action) shape!

In MARL, we often have obs, actions, done or reward states of shape (batch, n_agent, *feature_dim). The root tensordict is most of the time shaped as (batch,) because some but not all of these folks have an agent dim (ie a tensordict with shape (batch, n_agent) is not valid in general).

Problem

Coming back to PPO: the product we need to do involves a log-prob and an advantage. Per se, the log-prob doesn't have that trailing "feature dim" (because that's how a log-prob behaves), so we usually unsqueeze.

In the MARL case (or anywhere we have more than one action), we want to return a tensordict containing the various log-probs and not a single aggregated tensor. The reason for this is complex but it boild down to (1) it's safer, (2) it's clearer and (3) it allows us to do fancy stuff like clipping each PPO objective independently.

If we bind things together we realize that we end up with a tensordict of shape (batch,) containing the log-prob (that we usually unsqueeze to (batch, 1)) and an advantage of shape (batch, n_agents, 1). These things don't broadcast!
Worse, within our log-prob tensordict, we have tensors of shape (batch, 1, n_agent) after we do the unsqueeze.

Solution

What we need is to reshape the tensordict to the shape of the advantage. We can simply do log_prob_td.batch_size=advantage.shape[:-1] before we run _sum_td_features. That way the unsqueeze op will occur at the right place.

@vmoens vmoens closed this as completed Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants