Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert to support torch>=1.11.0 #3242

Merged
merged 5 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install -r docs/requirements.txt
pip freeze
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[test]
pip install -r docs/requirements.txt
# requirements for tutorials (from .[dev])
Expand Down
10 changes: 10 additions & 0 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def _track_provenance_set(x, provenance: frozenset):
@track_provenance.register(tuple)
@track_provenance.register(dict)
def _track_provenance_pytree(x, provenance: frozenset):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a heuristic solution but we can always patch bug fixes if any bugs come up in the future. So I am happy with it.

return tree_map(partial(track_provenance, provenance=provenance), x)


Expand Down Expand Up @@ -138,6 +143,11 @@ def _extract_provenance_set(x):
@extract_provenance.register(tuple)
@extract_provenance.register(dict)
def _extract_provenance_pytree(x):
# avoid max-recursion depth error for torch<=2.0
flat_args, _ = tree_flatten(x)
if not flat_args or flat_args[0] is x:
return x, frozenset()

flat_args, spec = tree_flatten(x)
xs = []
provenance = frozenset()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"numpy>=1.7",
"opt_einsum>=2.3.2",
"pyro-api>=0.1.1",
"torch>=2.0.1",
"torch>=1.11.0",
"tqdm>=4.36",
],
extras_require={
Expand All @@ -112,6 +112,7 @@
"black>=21.4b0",
"nbval",
"pytest-cov",
"pytest-xdist",
"pytest>=5.0",
"ruff",
"scipy>=1.1",
Expand Down