-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revert to support torch>=1.11.0
#3242
Conversation
Thanks, LGTM as long as you can confirm tests still pass locally using torch==1.11.0 |
I will now run (
|
There is an issue with the Makefile' s
@fritzo do you know what the |
@francois-rozet you can enable the pytest -n flag via |
There were still a few dependencies missing (
With torch 2.0.1, it leads to
With torch < 2.0.1 (1.11.0, 1.12.1 and 1.13.0), it leads to
To summarize, there seems to be an infinite recursion error within If necessary, an |
@ordabayevy do you recall the motivation behind #3223? Was it a fix to support PyTorch 2, or merely a cleanup made possible by PyTorch 2? If it's merely backwards-incompatible refactoring, we might be better off reverting it so as to continue supporting PyTorch 1.11. 🤔 |
Hi @fritzo . It was a cleanup made possible by PyTorch 2.0. It can be reverted if necessary. |
Is there a way to make it compatible with 1.11 instead of fully reverting it or maybe add a condition |
It seems the issue is that in # avoid max-recursion depth error
flat_args, _ = tree_flatten(x)
if flat_args[0] is x:
return x |
@francois-rozet can I push my proposed changes to this branch? I'm running |
Sure, i'll run them on my computer when I get back home! |
I'm not sure how to push to your branch :) The idea behind this code is that in an older version of PyTorch I can confirm that all tests in diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py
index a6902a60..6ca6c0c6 100644
--- a/pyro/ops/provenance.py
+++ b/pyro/ops/provenance.py
@@ -93,6 +93,10 @@ def _track_provenance_set(x, provenance: frozenset):
@track_provenance.register(tuple)
@track_provenance.register(dict)
def _track_provenance_pytree(x, provenance: frozenset):
+ # avoid max-recursion depth error for torch<=2.0
+ flat_args, _ = tree_flatten(x)
+ if flat_args[0] is x:
+ return x
return tree_map(partial(track_provenance, provenance=provenance), x)
@@ -138,6 +142,10 @@ def _extract_provenance_set(x):
@extract_provenance.register(tuple)
@extract_provenance.register(dict)
def _extract_provenance_pytree(x):
+ # avoid max-recursion depth error for torch<=2.0
+ flat_args, _ = tree_flatten(x)
+ if flat_args[0] is x:
+ return x, frozenset()
flat_args, spec = tree_flatten(x)
xs = []
provenance = frozenset() |
With the fix of @ordabayevy, all the tests (
|
Unfortunately @ordabayevy, this fix breaks the tests for torch 2.0 😢 I think this arises when |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think it can be merged pending any further comments.
flat_args, _ = tree_flatten(x) | ||
if not flat_args or flat_args[0] is x: | ||
return x | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a heuristic solution but we can always patch bug fixes if any bugs come up in the future. So I am happy with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad to have this working, thanks @francois-rozet and @ordabayevy!
If nobody objects, I'd like to release a Pyro 1.8.6 as soon as this merges, treating the semver breakage issue of 1.8.5 as a bug. (I may try to get #3243 in to the release as well.)
See #3239