Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo: how to deal with multiple inheritance (nn.Module/MutableMapping)? #141118

Closed
vmoens opened this issue Nov 20, 2024 · 1 comment
Closed
Assignees
Labels
dynamo-dicts dynamo-nn-modules module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vmoens
Copy link
Contributor

vmoens commented Nov 20, 2024

🐛 Describe the bug

TensorDict is a MutableMapping object, and is treated as such by torch.compile:

import torch
from tensordict import TensorDict

td = TensorDict(a=1, b=2, c=True)

@torch.compile(fullgraph=True)
def add1(td):
    return TensorDict(**td)+1

add1(td)

We also have a TensorDictParams primitive that acts a bit like ParameterList: it is a TensorDict but also an nn.Module. That's useful when you want to set a TensorDict in an nn.Module have have the leaf tensors included in the state_dict, or dispatch ops like module.to(...) to the tensors it contains. However, _dynamo looks at it like an nn.Module and not a MutableMapping

import torch
from tensordict import TensorDictParams, TensorDict

td = TensorDictParams(TensorDict(a=1, b=2, c=True))

@torch.compile(fullgraph=True)
def add1(td):
    return TensorDict(**td)+1

add1(td)

breaks with

  File "/Users/vmoens/venv/rl/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 357, in call_method
    dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
  File "/Users/vmoens/venv/rl/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1432, in call_custom_dict
    unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
  File "/Users/vmoens/venv/rl/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 313, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: dict(): (UnspecializedNNModuleVariable(TensorDictParams),) {}

My understanding is that call_custom_dict looks at the arg an in one case it's a variables.MutableMappingVariable which is fine but in the other it's a UnspecializedNNModuleVariable which isn't a mutable mapping.

So I guess my question is (other than how can we fix this) how does dynamo look at multiple inheritance? Shouldn't there be a way to tell "look, this isn't a bird or a fish but a fish that can fly"?

(note that in this specific case, smth(**obj) will call obj.keys() followed by obj.__getitem__ which are ops that compile is happy about - maybe that's what call_custom_dict should be doing?)

Here is a MRE:

import torch
from torch import nn
import collections

# class MyWeirdDict(collections.abc.MutableMapping):  # Works
class MyWeirdDict(collections.abc.MutableMapping, nn.Module):  # breaks
    def __init__(self, **kwargs):
        super().__init__()
        self._items = kwargs
    def keys(self):
        return self._items.keys()
    def __getitem__(self, item):
        return self._items[item]
    def __setitem__(self, key, value):
        self._items[key] = value
    def __delitem__(self, item):
        del self._items[item]
    def __len__(self):
        return len(self._items)
    def __iter__(self):
        yield from self._items
    def __hash__(self):
        return hash(id(self))
    def items(self):
        for k, v in self._items.items():
            yield (k, v)

@torch.compile(fullgraph=True)
def to_weird_dict(td):
    return MyWeirdDict(**td)

d = MyWeirdDict(a=1, b=2, c=3)
to_weird_dict(d)

Error logs

See above

Versions

nightlies

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@vmoens vmoens changed the title how to deal with multiple inheritance (nn.Module/MutableMapping)? Dynamo: how to deal with multiple inheritance (nn.Module/MutableMapping)? Nov 20, 2024
@yf225 yf225 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 22, 2024
@StrongerXi StrongerXi self-assigned this Dec 4, 2024
StrongerXi added a commit that referenced this issue Dec 9, 2024
This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

ghstack-source-id: 796fed17826b59a5bef3db849385f5bda86f8c0b
Pull Request resolved: #142416
@StrongerXi
Copy link
Contributor

StrongerXi commented Dec 9, 2024

So I guess my question is (other than how can we fix this) how does dynamo look at multiple inheritance? Shouldn't there be a way to tell "look, this isn't a bird or a fish but a fish that can fly"?

@vmoens It doesn't seem to be well supported in general, but I think we need more datapoints to evaluate, so I created #142414 to track that.

mori360 pushed a commit to mori360/pytorch that referenced this issue Dec 11, 2024
…ytorch#142416)

This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created pytorch#142414 to track that.

Fixes pytorch#141118.

Pull Request resolved: pytorch#142416
Approved by: https://github.com/jansel
StrongerXi added a commit that referenced this issue Dec 11, 2024
This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

ghstack-source-id: c1c788bcb2cd91156ebab37b7f0ade0315e4c70f
Pull Request resolved: #142416
StrongerXi added a commit that referenced this issue Dec 11, 2024
This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

ghstack-source-id: 938c7732b433c9869efc070479f9cd0ba6e91637
Pull Request resolved: #142416
StrongerXi added a commit that referenced this issue Dec 12, 2024
This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

ghstack-source-id: da6e5eafbb97c9a641909bb2253ac74224d17f1f
Pull Request resolved: #142416
pytorchmergebot pushed a commit that referenced this issue Dec 13, 2024
…142416)

This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created #142414 to track that.

Fixes #141118.

Pull Request resolved: #142416
Approved by: https://github.com/jansel
bluenote10 pushed a commit to bluenote10/pytorch that referenced this issue Dec 14, 2024
…ytorch#142416)

This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created pytorch#142414 to track that.

Fixes pytorch#141118.

Pull Request resolved: pytorch#142416
Approved by: https://github.com/jansel
bluenote10 pushed a commit to bluenote10/pytorch that referenced this issue Dec 14, 2024
…ytorch#142416)

This patch applies a local and practical workaround for custom dict
construction when multiple inheritance is involved.

Handling multiple inheritance in general could be a lot more involved,
so I created pytorch#142414 to track that.

Fixes pytorch#141118.

Pull Request resolved: pytorch#142416
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-dicts dynamo-nn-modules module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants