-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] computation of log prob in composite distribution for batched samples #1054
[BugFix] computation of log prob in composite distribution for batched samples #1054
Conversation
@albertbou92 tagging you, as this might be of interest. |
Some tests are failing, I'll address this asap! |
This will go in a minor thanks for flagging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey
So I reviewed these changes and I think the current behaviour is accurate - but I'm open to change my mind.
- The "advised" way of using log_prob is going to be to write the tensors in the TD from now on. If that is done, they should have the proper shape
- We can also return an aggregate but we will require it to have the same shape as the root tensordict. I thought about making it less restrictive and allow the
lp
to have the shape of the parent tensordict (eg("group0", "agent1")
) but then what happens if we sum the lps from("group0", "agent0")
,("group0", "agent1")
,("group1", "agent0")
and("group1", "agent1")
where the first two have, say, shape [2] and the last two have shape [3]? In this case we cannot sum it - the only thing we can do is to reduce thelp
to a single float. Because the content of the tensordict should not condition the shape that you get, I think that having anlp
that has the shape of the sample tensordict (as it is the case now) is the most robust behaviour.
The main consideration I have is whether this is BC-breaking - did you encounter a problem when switching to 0.6?
Happy to hear your thoughts on this!
Hi Vincent, Some thoughts after your comment. I think that, indeed, the main pain point is "where" the summed Let's build on the concrete example you laid out. I do not believe it makes sense to sum log-probs across For the same example, I do not see the learning-theoretic usefulness of the scalar log-prob. If we think of training this multi-agent dynamical system with PPO, the scalar |
that's already the case no?
|
Maybe I misunderstood your comment, it sounded to me like you were arguing to sum this across agents and have it sit in the root -- which I doubt is the right way to go. To say it otherwise: I personally agree with @thomasbbrunner that in this case, the log-probs should keep the shape of the group tensordict (i.e.
for the reasons detailed in my first post. |
Just to make sure we're on the same page: If you don't want to sum or you want to keep the shape of the node where the log-prob is to be found, you should use here is an example: import torch
from torch import distributions as d
from tensordict import TensorDict
from tensordict.nn import CompositeDistribution
params = TensorDict({
"cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
"nested": TensorDict(disc={"logits": torch.randn(3, 4, 5)}, batch_size=[3, 4]),
}, batch_size=[3])
dist = CompositeDistribution(params,
distribution_map={"cont": lambda loc, scale: d.Independent(d.Normal(loc, scale), 1),
("nested", "disc"): d.Categorical})
# 10 samples
sample = dist.sample((10,))
print("sample", sample)
# Not aggregated
print("Not aggregated", dist.log_prob(sample, aggregate_probabilities=False))
print("Aggregated", dist.log_prob(sample, aggregate_probabilities=True).shape)
print("default", dist.log_prob(sample)) So we have a sample of batch size
When we compute the log probs and don't aggregate, we have what I understand you guys want:
We have the Then we have the aggregate sample_log_prob with the same shape as the root tensordict. That's the value you'll get when
Is that clearer? |
Closing this PR. It turns out that the issue was somewhere else, as pointed by @priba. I have just one question remaining: does it make sense that the As illustrated in the example above: >>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor(-16.9137, grad_fn=<AddBackward0>)
>>> dist.entropy(td)
tensor([8.2005, 8.2096], grad_fn=<AddBackward0>) Wouldn't it also make sense for the entropy to return a single value? |
Description
As far as I can tell, there's a bug in the latest version of the
CompositeDistribution
when calculating the log prob.Motivation and Context
The log prob is erroneously flattened according to the sample's ndim. More specifically, the ndim of the batch shape of the root level of the sample tensordict.
However, the sample's ndim does not always match the batch shape of the distribution itself!
This is the case, for instance, in a multi-agent setup.
Let's take an environment with two grouped agents and a policy with two categorical heads as an example. The tensordict in this case looks something like (edited for clarity):
With the current code, the resulting
sample_log_prob
is:Note that it outputs a single float. This does not make sense, as our
logits
field has a batch shape of(2,)
!The entropy is correctly computed:
With the changes in this MR, the log prob is computed just like the entropy, resulting in the correct shape:
Types of changes
Checklist