Skip to content

Commit

Permalink
Bump PyTensor to 2.9.1 (#6431)
Browse files Browse the repository at this point in the history
* Bump PyTensor to 2.9.1

* Allow installation from `.conda` artifacts

* Workaround type issues in `shape_utils`

Caused by pymc-devs/pytensor#193
  • Loading branch information
michaelosthege authored Jan 12, 2023
1 parent ecb3666 commit c3b8ff4
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
environment-file: conda-envs/environment-test.yml
python-version: 3.9
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install-pymc and mypy dependencies
run: |
conda activate pymc-test
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
environment-file: conda-envs/environment-test.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install-pymc
run: |
conda activate pymc-test
Expand Down Expand Up @@ -211,7 +211,7 @@ jobs:
environment-file: conda-envs/windows-environment-test.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install-pymc
run: |
conda activate pymc-test
Expand Down Expand Up @@ -290,7 +290,7 @@ jobs:
environment-file: conda-envs/environment-test.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install pymc
run: |
conda activate pymc-test
Expand Down Expand Up @@ -355,7 +355,7 @@ jobs:
environment-file: conda-envs/environment-test.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install pymc
run: |
conda activate pymc-test
Expand Down Expand Up @@ -425,7 +425,7 @@ jobs:
environment-file: conda-envs/windows-environment-test.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
- name: Install-pymc
run: |
conda activate pymc-test
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.11
- pytensor=2.9.1
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.11
- pytensor=2.9.1
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.11
- pytensor=2.9.1
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor=2.8.11
- pytensor=2.9.1
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
17 changes: 10 additions & 7 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,14 +740,15 @@ def get_support_shape(
observed.shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)
]

# We did not learn anything
if inferred_support_shape is None and support_shape is None:
return None
# Only source of information was the originally provided support_shape
elif inferred_support_shape is None:
inferred_support_shape = support_shape
# There were two sources of support_shape, make sure they are consistent
if inferred_support_shape is None:
if support_shape is not None:
# Only source of information was the originally provided support_shape
inferred_support_shape = support_shape
else:
# We did not learn anything
return None
elif support_shape is not None:
# There were two sources of support_shape, make sure they are consistent
inferred_support_shape = [
cast(
Variable,
Expand All @@ -758,6 +759,8 @@ def get_support_shape(
for inferred, explicit in zip(inferred_support_shape, support_shape)
]

# Workaround https://github.com/pymc-devs/pytensor/issues/193 typing bug in stack signature
inferred_support_shape = cast(Sequence[TensorVariable], inferred_support_shape)
return at.stack(inferred_support_shape)


Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor==2.8.11
pytensor==2.9.1
pytest-cov>=2.5
pytest>=3.0
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cloudpickle
fastprogress>=0.2.0
numpy>=1.15.0
pandas>=0.24.0
pytensor==2.8.11
pytensor==2.9.1
scipy>=1.4.1
typing-extensions>=3.7.4

0 comments on commit c3b8ff4

Please sign in to comment.