-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Stacked univariate distributions incompatible with PPOLoss and/or in composites with multivariate distributions #2753
Comments
Here is the current status, maybe you could help me comment what should the expected behaviour be: import functools
import torch
from tensordict import TensorDict
from torch.distributions import *
from tensordict.nn import CompositeDistribution, set_composite_lp_aggregate
set_composite_lp_aggregate(False).set()
d0 = Independent(Dirichlet(torch.ones(100, 10)), 1)
print(d0.log_prob(d0.sample()).shape) # torch.Size([])
d1 = Categorical(probs=torch.ones(10)/10)
print(d1.log_prob(d1.sample()).shape) # torch.Size([])
td = TensorDict(dirich=TensorDict(concentration=torch.ones(100, 10)), categ=TensorDict(probs=torch.ones(10)/10))
c = CompositeDistribution(params=td, distribution_map={
"dirich": functools.partial(Dirichlet),
"categ": functools.partial(Categorical),
})
print(c.log_prob(c.sample()))
# TensorDict(
# fields={
# categ_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
# dirich_log_prob: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.float32, is_shared=False)},
# batch_size=torch.Size([]),
# device=None,
# is_shared=False)
c = CompositeDistribution(params=td, distribution_map={
"dirich": lambda concentration: Independent(Dirichlet(concentration), 1),
"categ": functools.partial(Categorical),
})
print(c.log_prob(c.sample()))
# TensorDict(
# fields={
# categ_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
# dirich_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
# batch_size=torch.Size([]),
# device=None,
# is_shared=False)
# With a TD of shape (10,)
c = CompositeDistribution(params=td.expand(10), distribution_map={
"dirich": lambda concentration: Independent(Dirichlet(concentration), 1),
"categ": functools.partial(Categorical),
})
# TensorDict(
# fields={
# categ_log_prob: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
# dirich_log_prob: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
# batch_size=torch.Size([10]),
# device=None,
# is_shared=False)
print(c.log_prob(c.sample())) So far I don't see any shape mismatch. If we call print(_sum_td_features(c.log_prob(c.sample())))
# prints a tensor of shape (10,) which also seems ok to me. Based on these, do you feel you could write a minimal reprod of the problem you're facing? I didn't really grasp what the issue was. |
I think I understood what the problem was and what solution to bring. ContextIn general, PPO and other similar losses need to multiply an advantage by (some transform of) a log-prob. In MARL, we often have obs, actions, done or reward states of shape ProblemComing back to PPO: the product we need to do involves a log-prob and an advantage. Per se, the log-prob doesn't have that trailing "feature dim" (because that's how a log-prob behaves), so we usually unsqueeze. In the MARL case (or anywhere we have more than one action), we want to return a tensordict containing the various log-probs and not a single aggregated tensor. The reason for this is complex but it boild down to (1) it's safer, (2) it's clearer and (3) it allows us to do fancy stuff like clipping each PPO objective independently. If we bind things together we realize that we end up with a tensordict of shape SolutionWhat we need is to reshape the tensordict to the shape of the advantage. We can simply do |
Describe the bug
Some pytorch distributions like dirichlet are multivariate by nature. Some others like beta or gamma distributions are univariate by default.
However they can also be initialized with multivariate parameters which will cause torch to stack a number of independent univariate distributions.
This does not seem to work properly with the composite distribution and ppo loss in torchrl. Using such stacked univariate distributions leads to a shape mismatch when calculating the ppo loss as their
log_prob
andlog_weight
shapes are 4-dimensional while the natively multivariate distributions' log shapes are 3-dimensional.Perhaps this issue only occurs when mixing stacked and native multivariate distributions but i have not tested this.
To Reproduce
Create a probabilistic actor with a composite distribution using stacked univariate distributions.
A runnable example can be found here: https://github.com/rerz/rltest
I also tried splitting the stacked distributions into multiple ones with just one set of parameters but this still produces 4-dimensional log shapes.
It also contains a very unclean workaround in the forward method of
PPOLoss
which basically just removes this extra dimension and changes the reduction of the log weights to end up with a shape that is compatible with the advantage tensor.Expected behavior
I can use any distribution no matter if it is univariate or multivariate by default.
System info
torchrl and tensordict from the current main branches.
Reason and Possible fixes
A workaround can be found in the linked repo above. While the fix is unclean i believe to have found the root cause of the issue.
Checklist
The text was updated successfully, but these errors were encountered: