From cacea5aacedecf7537040e9cc2d9a75eb538c25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Sat, 15 Jul 2023 16:45:37 +0200 Subject: [PATCH 1/5] Revert to support torch>=1.11.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cc12a20cf4..af8f385809 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,7 @@ "numpy>=1.7", "opt_einsum>=2.3.2", "pyro-api>=0.1.1", - "torch>=2.0.1", + "torch>=1.11.0", "tqdm>=4.36", ], extras_require={ From 48493f37cc4030ca0a8de6e52cfdbf24d509c886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Sat, 15 Jul 2023 19:02:18 +0200 Subject: [PATCH 2/5] Switch to latest torch for whole CI --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b86d5e4ca6..56c61eb5f5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install .[test] pip install -r docs/requirements.txt pip freeze @@ -82,7 +82,7 @@ jobs: python -m pip install --upgrade pip wheel 'setuptools!=58.5.*,<60' # Keep track of pyro-api master branch pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install .[test] pip install -r docs/requirements.txt # requirements for tutorials (from .[dev]) From 654bb380dd89feaaf23e76a6b57df09548557250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Sat, 15 Jul 2023 20:13:28 +0200 Subject: [PATCH 3/5] Add pytest-xdist to test requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index af8f385809..19b3bb9d03 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,7 @@ "black>=21.4b0", "nbval", "pytest-cov", + "pytest-xdist", "pytest>=5.0", "ruff", "scipy>=1.1", From 681610f21ab8d3d86e134318f41d8707ff4122d7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 19 Jul 2023 18:57:15 +0200 Subject: [PATCH 4/5] Fix infinite recursion --- pyro/ops/provenance.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py index a6902a60cd..d1fca068c8 100644 --- a/pyro/ops/provenance.py +++ b/pyro/ops/provenance.py @@ -93,6 +93,11 @@ def _track_provenance_set(x, provenance: frozenset): @track_provenance.register(tuple) @track_provenance.register(dict) def _track_provenance_pytree(x, provenance: frozenset): + # avoid max-recursion depth error for torch<=2.0 + flat_args, _ = tree_flatten(x) + if flat_args[0] is x: + return x + return tree_map(partial(track_provenance, provenance=provenance), x) @@ -138,6 +143,11 @@ def _extract_provenance_set(x): @extract_provenance.register(tuple) @extract_provenance.register(dict) def _extract_provenance_pytree(x): + # avoid max-recursion depth error for torch<=2.0 + flat_args, _ = tree_flatten(x) + if flat_args[0] is x: + return x, frozenset() + flat_args, spec = tree_flatten(x) xs = [] provenance = frozenset() From bb7bdf223496d21f46aadd55317c82582d39d29a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Thu, 20 Jul 2023 11:06:44 +0200 Subject: [PATCH 5/5] Fix empty list indexing --- pyro/ops/provenance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/ops/provenance.py b/pyro/ops/provenance.py index d1fca068c8..cde77aa45d 100644 --- a/pyro/ops/provenance.py +++ b/pyro/ops/provenance.py @@ -95,7 +95,7 @@ def _track_provenance_set(x, provenance: frozenset): def _track_provenance_pytree(x, provenance: frozenset): # avoid max-recursion depth error for torch<=2.0 flat_args, _ = tree_flatten(x) - if flat_args[0] is x: + if not flat_args or flat_args[0] is x: return x return tree_map(partial(track_provenance, provenance=provenance), x) @@ -145,7 +145,7 @@ def _extract_provenance_set(x): def _extract_provenance_pytree(x): # avoid max-recursion depth error for torch<=2.0 flat_args, _ = tree_flatten(x) - if flat_args[0] is x: + if not flat_args or flat_args[0] is x: return x, frozenset() flat_args, spec = tree_flatten(x)