Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] computation of log prob in composite distribution for batched samples #1054

Conversation

thomasbbrunner
Copy link

@thomasbbrunner thomasbbrunner commented Oct 22, 2024

Description

As far as I can tell, there's a bug in the latest version of the CompositeDistribution when calculating the log prob.

Motivation and Context

The log prob is erroneously flattened according to the sample's ndim. More specifically, the ndim of the batch shape of the root level of the sample tensordict.

However, the sample's ndim does not always match the batch shape of the distribution itself!

This is the case, for instance, in a multi-agent setup.

Let's take an environment with two grouped agents and a policy with two categorical heads as an example. The tensordict in this case looks something like (edited for clarity):

>>> td = policy_module(env.reset())
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: TensorDict(
                    fields={
                        head_0: TensorDict(
                            fields={
                                action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
                                action_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),
                        head_1: TensorDict(
                            fields={
                                action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
                                action_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),)},
                    batch_size=torch.Size([2]),),
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 70]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        head_0: TensorDict(
                            fields={
                                logits: Tensor(shape=torch.Size([2, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),
                        head_1: TensorDict(
                            fields={
                                logits: Tensor(shape=torch.Size([2, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),},
                    batch_size=torch.Size([2]),),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),)

With the current code, the resulting sample_log_prob is:

>>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor(-16.9137, grad_fn=<AddBackward0>)

Note that it outputs a single float. This does not make sense, as our logits field has a batch shape of (2,)!

The entropy is correctly computed:

>>> dist.entropy(td)
tensor([8.2005, 8.2096], grad_fn=<AddBackward0>)

With the changes in this MR, the log prob is computed just like the entropy, resulting in the correct shape:

>>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor([-9.3534, -8.6486], grad_fn=<AddBackward0>)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 22, 2024
@thomasbbrunner
Copy link
Author

@albertbou92 tagging you, as this might be of interest.

@thomasbbrunner
Copy link
Author

Some tests are failing, I'll address this asap!

@vmoens
Copy link
Contributor

vmoens commented Oct 22, 2024

This will go in a minor thanks for flagging

@vmoens vmoens changed the title Fix computation of log prob in composite distribution for batched samples. [BugFix] computation of log prob in composite distribution for batched samples Oct 22, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey
So I reviewed these changes and I think the current behaviour is accurate - but I'm open to change my mind.

  1. The "advised" way of using log_prob is going to be to write the tensors in the TD from now on. If that is done, they should have the proper shape
  2. We can also return an aggregate but we will require it to have the same shape as the root tensordict. I thought about making it less restrictive and allow the lp to have the shape of the parent tensordict (eg ("group0", "agent1")) but then what happens if we sum the lps from ("group0", "agent0"), ("group0", "agent1"), ("group1", "agent0") and ("group1", "agent1") where the first two have, say, shape [2] and the last two have shape [3]? In this case we cannot sum it - the only thing we can do is to reduce the lp to a single float. Because the content of the tensordict should not condition the shape that you get, I think that having an lp that has the shape of the sample tensordict (as it is the case now) is the most robust behaviour.

The main consideration I have is whether this is BC-breaking - did you encounter a problem when switching to 0.6?

Happy to hear your thoughts on this!

@louisfaury
Copy link
Contributor

Hi Vincent,

Some thoughts after your comment. I think that, indeed, the main pain point is "where" the summed lp should be written. In the example above, I do believe it should not be in the root, but in the agents tensordict.

Let's build on the concrete example you laid out. I do not believe it makes sense to sum log-probs across group_0 and group_1. When we split agent in groups, I think its safe to assume we do this because their i/o shapes are incompatible -- hence the joint action between each group should be generated by different distributions. In that case, it feels natural to have an aggregate log_probs under both the group_0 and group_1 tensordicts.

For the same example, I do not see the learning-theoretic usefulness of the scalar log-prob. If we think of training this multi-agent dynamical system with PPO, the scalar sample_log_prob aggregated over all the agents is useless for learning -- only the one that is aggregated at the group level is.

@vmoens vmoens added the bug Something isn't working label Oct 25, 2024
@vmoens
Copy link
Contributor

vmoens commented Oct 25, 2024

Some thoughts after your comment. I think that, indeed, the main pain point is "where" the summed lp should be written. In the example above, I do believe it should not be in the root, but in the agents tensordict.

that's already the case no?

d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))

@louisfaury
Copy link
Contributor

Maybe I misunderstood your comment, it sounded to me like you were arguing to sum this across agents and have it sit in the root -- which I doubt is the right way to go.

To say it otherwise: I personally agree with @thomasbbrunner that in this case, the log-probs should keep the shape of the group tensordict (i.e. [2,]) instead of being reduced down to a scalar. Hence I think I disagree with the statement:

[..] the current behaviour is accurate

for the reasons detailed in my first post.

@vmoens
Copy link
Contributor

vmoens commented Oct 26, 2024

Just to make sure we're on the same page:
There are two distrinct behaviours for log_prob and entropy: aggregate_probabilities=True and aggregate_probabilities=False
For the reason I explained above, when aggregate_probabilities=True we must sum them until they have the shape of the root tensordict because it's the only shape that is guaranteed to be valid (otherwise there can be a casse where two leaves have incompatible shapes and the sum cannot be done).

If you don't want to sum or you want to keep the shape of the node where the log-prob is to be found, you should use aggregate_probabilities=False.

here is an example:

import torch
from torch import distributions as d

from tensordict import TensorDict
from tensordict.nn import CompositeDistribution

params = TensorDict({
    "cont": {"loc": torch.randn(3, 4), "scale": torch.rand(3, 4)},
    "nested": TensorDict(disc={"logits": torch.randn(3, 4, 5)}, batch_size=[3, 4]),
}, batch_size=[3])

dist = CompositeDistribution(params,
                             distribution_map={"cont": lambda loc, scale: d.Independent(d.Normal(loc, scale), 1),
                                               ("nested", "disc"): d.Categorical})

# 10 samples
sample = dist.sample((10,))
print("sample", sample)

# Not aggregated
print("Not aggregated", dist.log_prob(sample, aggregate_probabilities=False))

print("Aggregated", dist.log_prob(sample, aggregate_probabilities=True).shape)

print("default", dist.log_prob(sample))

So we have a sample of batch size (10, 3)

sample TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10, 3]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

When we compute the log probs and don't aggregate, we have what I understand you guys want:

Not aggregated TensorDict(
    fields={
        cont: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        cont_log_prob: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        nested: TensorDict(
            fields={
                disc: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.int64, is_shared=False),
                disc_log_prob: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10, 3]),
            device=None,
            is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

We have the cont_log_prob which has the size of the normal distribution (where we sum the last dim because of the Independent)
We have disc_log_prob which has size (10, 3, 4) where the 10 comes from the number of sampels and (3, 4) is the batch size of our categorical distribution.

Then we have the aggregate sample_log_prob with the same shape as the root tensordict. That's the value you'll get when aggregate_probabilities=True

Aggregated torch.Size([10, 3])

Is that clearer?

@priba
Copy link
Contributor

priba commented Oct 30, 2024

Closely related to this #1065, I agree with @vmoens comment here. I think the cause of the shape you report here comes from the ProbabilisticTensorDictModule.

@thomasbbrunner
Copy link
Author

Closing this PR. It turns out that the issue was somewhere else, as pointed by @priba.

I have just one question remaining: does it make sense that the log_prob and entropy have different shapes?

As illustrated in the example above:

>>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor(-16.9137, grad_fn=<AddBackward0>)
>>> dist.entropy(td)
tensor([8.2005, 8.2096], grad_fn=<AddBackward0>)

Wouldn't it also make sense for the entropy to return a single value?

@thomasbbrunner thomasbbrunner deleted the tbrunner/fix-composite-distribution branch November 4, 2024 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants