-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamo: how to deal with multiple inheritance (nn.Module/MutableMapping)? #141118
Labels
dynamo-dicts
dynamo-nn-modules
module: dynamo
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
vmoens
changed the title
how to deal with multiple inheritance (nn.Module/MutableMapping)?
Dynamo: how to deal with multiple inheritance (nn.Module/MutableMapping)?
Nov 20, 2024
yf225
added
the
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
label
Nov 22, 2024
This was referenced Dec 9, 2024
StrongerXi
added a commit
that referenced
this issue
Dec 9, 2024
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. ghstack-source-id: 796fed17826b59a5bef3db849385f5bda86f8c0b Pull Request resolved: #142416
@vmoens It doesn't seem to be well supported in general, but I think we need more datapoints to evaluate, so I created #142414 to track that. |
mori360
pushed a commit
to mori360/pytorch
that referenced
this issue
Dec 11, 2024
…ytorch#142416) This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created pytorch#142414 to track that. Fixes pytorch#141118. Pull Request resolved: pytorch#142416 Approved by: https://github.com/jansel
StrongerXi
added a commit
that referenced
this issue
Dec 11, 2024
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. ghstack-source-id: c1c788bcb2cd91156ebab37b7f0ade0315e4c70f Pull Request resolved: #142416
StrongerXi
added a commit
that referenced
this issue
Dec 11, 2024
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. ghstack-source-id: 938c7732b433c9869efc070479f9cd0ba6e91637 Pull Request resolved: #142416
StrongerXi
added a commit
that referenced
this issue
Dec 12, 2024
This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. ghstack-source-id: da6e5eafbb97c9a641909bb2253ac74224d17f1f Pull Request resolved: #142416
pytorchmergebot
pushed a commit
that referenced
this issue
Dec 13, 2024
…142416) This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created #142414 to track that. Fixes #141118. Pull Request resolved: #142416 Approved by: https://github.com/jansel
bluenote10
pushed a commit
to bluenote10/pytorch
that referenced
this issue
Dec 14, 2024
…ytorch#142416) This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created pytorch#142414 to track that. Fixes pytorch#141118. Pull Request resolved: pytorch#142416 Approved by: https://github.com/jansel
bluenote10
pushed a commit
to bluenote10/pytorch
that referenced
this issue
Dec 14, 2024
…ytorch#142416) This patch applies a local and practical workaround for custom dict construction when multiple inheritance is involved. Handling multiple inheritance in general could be a lot more involved, so I created pytorch#142414 to track that. Fixes pytorch#141118. Pull Request resolved: pytorch#142416 Approved by: https://github.com/jansel
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
dynamo-dicts
dynamo-nn-modules
module: dynamo
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
TensorDict is a MutableMapping object, and is treated as such by torch.compile:
We also have a
TensorDictParams
primitive that acts a bit like ParameterList: it is a TensorDict but also an nn.Module. That's useful when you want to set a TensorDict in an nn.Module have have the leaf tensors included in the state_dict, or dispatch ops likemodule.to(...)
to the tensors it contains. However,_dynamo
looks at it like an nn.Module and not a MutableMappingbreaks with
My understanding is that
call_custom_dict
looks at the arg an in one case it's avariables.MutableMappingVariable
which is fine but in the other it's aUnspecializedNNModuleVariable
which isn't a mutable mapping.So I guess my question is (other than how can we fix this) how does dynamo look at multiple inheritance? Shouldn't there be a way to tell "look, this isn't a bird or a fish but a fish that can fly"?
(note that in this specific case,
smth(**obj)
will callobj.keys()
followed byobj.__getitem__
which are ops that compile is happy about - maybe that's whatcall_custom_dict
should be doing?)Here is a MRE:
Error logs
See above
Versions
nightlies
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
The text was updated successfully, but these errors were encountered: