Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyTorch 1.11.0 #3045

Merged
merged 16 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -39,7 +39,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -54,7 +54,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install -r docs/requirements.txt
pip freeze
Expand All @@ -67,7 +67,7 @@ jobs:
needs: docs
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -82,7 +82,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install --upgrade coveralls
pip freeze
Expand All @@ -99,7 +99,7 @@ jobs:
needs: docs
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -116,7 +116,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install --upgrade coveralls
pip freeze
Expand All @@ -135,7 +135,7 @@ jobs:
needs: docs
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -150,7 +150,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install --upgrade coveralls
pip freeze
Expand All @@ -167,7 +167,7 @@ jobs:
needs: docs
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -182,7 +182,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install --upgrade coveralls
pip freeze
Expand All @@ -199,7 +199,7 @@ jobs:
needs: docs
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -214,7 +214,7 @@ jobs:
python -m pip install --upgrade pip wheel 'setuptools!=58.5.*'
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install -e .[funsor]
pip install --upgrade coveralls
Expand All @@ -233,7 +233,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.6]
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.6
version: 3.7
install:
- requirements: docs/requirements.txt
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ lint: FORCE
black --check *.py pyro examples tests scripts profiler
isort --check .
python scripts/update_headers.py --check
mypy pyro
# mypy examples # FIXME
mypy scripts
mypy --install-types --non-interactive pyro scripts

license: FORCE
python scripts/update_headers.py
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,6 @@ def setup(app):
if "READTHEDOCS" in os.environ:
os.system("pip install numpy")
os.system(
"pip install torch==1.9.0+cpu torchvision==0.10.0+cpu "
"pip install torch==1.11.0+cpu torchvision==0.12.0+cpu "
"-f https://download.pytorch.org/whl/torch_stable.html"
)
1 change: 1 addition & 0 deletions examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def get_summary_table(
if site_summary["mean"].shape:
site_df = pd.DataFrame(site_summary, index=player_names)
else:
site_summary = {k: float(v) for k, v in site_summary.items()}
site_df = pd.DataFrame(site_summary, index=[0])
if not diagnostics:
site_df = site_df.drop(["n_eff", "r_hat"], axis=1)
Expand Down
8 changes: 4 additions & 4 deletions pyro/contrib/gp/models/sgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def model(self):
Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal
Luu = torch.linalg.cholesky(Kuu)
Kuf = self.kernel(self.Xu, self.X)
W = Kuf.triangular_solve(Luu, upper=False)[0].t()
W = torch.linalg.solve_triangular(Luu, Kuf, upper=False).t()

D = self.noise.expand(N)
if self.approx == "FITC" or self.approx == "VFE":
Expand Down Expand Up @@ -227,7 +227,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):

Kuf = self.kernel(self.Xu, self.X)

W = Kuf.triangular_solve(Luu, upper=False)[0]
W = torch.linalg.solve_triangular(Luu, Kuf, upper=False)
D = self.noise.expand(N)
if self.approx == "FITC":
Kffdiag = self.kernel(self.X, diag=True)
Expand All @@ -247,9 +247,9 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
# End caching ----------

Kus = self.kernel(self.Xu, Xnew)
Ws = Kus.triangular_solve(Luu, upper=False)[0]
Ws = torch.linalg.solve_triangular(Luu, Kus, upper=False)
pack = torch.cat((W_Dinv_y, Ws), dim=1)
Linv_pack = pack.triangular_solve(L, upper=False)[0]
Linv_pack = torch.linalg.solve_triangular(L, pack, upper=False)
# unpack
Linv_W_Dinv_y = Linv_pack[:, : W_Dinv_y.shape[1]]
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1] :]
Expand Down
4 changes: 2 additions & 2 deletions pyro/contrib/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ def conditional(

if whiten:
v_2D = f_loc_2D
W = Kfs.triangular_solve(Lff, upper=False)[0].t()
W = torch.linalg.solve_triangular(Lff, Kfs, upper=False).t()
if f_scale_tril is not None:
S_2D = f_scale_tril_2D
else:
pack = torch.cat((f_loc_2D, Kfs), dim=1)
if f_scale_tril is not None:
pack = torch.cat((pack, f_scale_tril_2D), dim=1)

Lffinv_pack = pack.triangular_solve(Lff, upper=False)[0]
Lffinv_pack = torch.linalg.solve_triangular(Lff, pack, upper=False)
# unpack
v_2D = Lffinv_pack[:, : f_loc_2D.size(1)]
W = Lffinv_pack[:, f_loc_2D.size(1) : f_loc_2D.size(1) + M].t()
Expand Down
9 changes: 3 additions & 6 deletions pyro/distributions/multivariate_studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,9 @@ def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
n = self.loc.size(-1)
y = (
(value - self.loc)
.unsqueeze(-1)
.triangular_solve(self.scale_tril, upper=False)
.solution.squeeze(-1)
)
y = torch.linalg.solve_triangular(
self.scale_tril, (value - self.loc).unsqueeze(-1), upper=False
).squeeze(-1)
Z = (
self.scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ 0.5 * n * self.df.log()
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/omt_mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def backward(ctx, grad_output):
loc_grad = sum_leftmost(grad_output, -1)

identity = eye_like(g, dim)
R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0]
R_inv = torch.linalg.solve_triangular(L.t(), identity, upper=True)

z_ja = z.unsqueeze(-1)
g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2)
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/transforms/generalized_channel_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def _inverse(self, y):
LUx = (y_flat.unsqueeze(-3) * self.permutation.T.unsqueeze(-1)).sum(-2)

# Solve L(Ux) = P^1y
Ux, _ = torch.triangular_solve(LUx, self.L, upper=False)
Ux = torch.linalg.solve_triangular(self.L, LUx, upper=False)

# Solve Ux = (PL)^-1y
x, _ = torch.triangular_solve(Ux, self.U)
x = torch.linalg.solve_triangular(self.U, Ux, upper=True)

# Unflatten x (works when context variable has batch dim)
return x.reshape(x.shape[:-1] + y.shape[-2:])
Expand Down
6 changes: 3 additions & 3 deletions pyro/distributions/transforms/lower_cholesky_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _inverse(self, y):

Inverts y => x.
"""
return torch.triangular_solve(
(y - self.loc).unsqueeze(-1), self.scale_tril, upper=False, transpose=False
)[0].squeeze(-1)
return torch.linalg.solve_triangular(
self.scale_tril, (y - self.loc).unsqueeze(-1), upper=False
).squeeze(-1)

def log_abs_det_jacobian(self, x, y):
"""
Expand Down
8 changes: 5 additions & 3 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyro.distributions import constraints
from pyro.infer.inspect import get_dependencies, is_sample_site
from pyro.nn.module import PyroModule, PyroParam
from pyro.ops.linalg import ignore_torch_deprecation_warnings
from pyro.poutine.runtime import am_i_wrapped, get_plates
from pyro.poutine.util import site_is_subsample

Expand Down Expand Up @@ -553,12 +554,13 @@ def _precision_to_scale_tril(P):
# Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
L = torch.triangular_solve(
torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False
)[0]
L = torch.linalg.solve_triangular(
L_inv, torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), upper=False
)
return L


@ignore_torch_deprecation_warnings()
def _try_possibly_intractable(fn, *args, **kwargs):
# Convert ValueError into NotImplementedError.
try:
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/mcmc/adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _triu_inverse(x):
return x.reciprocal()
else:
identity = torch.eye(x.size(-1), dtype=x.dtype, device=x.device)
return torch.triangular_solve(identity, x, upper=True)[0]
return torch.linalg.solve_triangular(x, identity, upper=True)


class BlockMassMatrix:
Expand Down
2 changes: 1 addition & 1 deletion pyro/ops/arrowhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def triu_inverse(x):
B_Dinv = B / x.bottom_diag.unsqueeze(-2)

identity = torch.eye(head_size, dtype=A.dtype, device=A.device)
top_left = torch.triangular_solve(identity, A, upper=True)[0]
top_left = torch.linalg.solve_triangular(A, identity, upper=True)
top_right = -top_left.matmul(B_Dinv) # complexity: head_size^2 x N
top = torch.cat([top_left, top_right], -1)
bottom_diag = x.bottom_diag.reciprocal()
Expand Down
12 changes: 5 additions & 7 deletions pyro/ops/gamma_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def marginalize(self, left=0, right=0):
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = torch.linalg.cholesky(P_bb)
P_a = P_ba.triangular_solve(P_b, upper=False).solution
P_a = torch.linalg.solve_triangular(P_b, P_ba, upper=False)
P_at = P_a.transpose(-1, -2)
precision = P_aa - P_at.matmul(P_a)

info_a = self.info_vec[..., a]
info_b = self.info_vec[..., b]
b_tmp = info_b.unsqueeze(-1).triangular_solve(P_b, upper=False).solution
b_tmp = torch.linalg.solve_triangular(P_b, info_b.unsqueeze(-1), upper=False)
info_vec = info_a
if n_b < n:
info_vec = info_vec - P_at.matmul(b_tmp).squeeze(-1)
Expand Down Expand Up @@ -320,11 +320,9 @@ def event_logsumexp(self):
"""
n = self.dim()
chol_P = torch.linalg.cholesky(self.precision)
chol_P_u = (
self.info_vec.unsqueeze(-1)
.triangular_solve(chol_P, upper=False)
.solution.squeeze(-1)
)
chol_P_u = torch.linalg.solve_triangular(
chol_P, self.info_vec.unsqueeze(-1), upper=False
).squeeze(-1)
u_P_u = chol_P_u.pow(2).sum(-1)
# considering GammaGaussian as a Gaussian with precision = s * precision, info_vec = s * info_vec,
# marginalize x variable, we get
Expand Down
2 changes: 1 addition & 1 deletion pyro/ops/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def potential_grad(potential_fn, z):
potential_energy = potential_fn(z)
# deal with singular matrices
except RuntimeError as e:
if "singular U" in str(e) or "input is not positive-definite" in str(e):
if "singular" in str(e) or "input is not positive-definite" in str(e):
grads = {k: v.new_zeros(v.shape) for k, v in z.items()}
return grads, z_nodes[0].new_tensor(float("nan"))
else:
Expand Down
10 changes: 10 additions & 0 deletions pyro/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

import math
import warnings
from contextlib import contextmanager

import torch


@contextmanager
def ignore_torch_deprecation_warnings():
with warnings.catch_warnings():
# Ignore deprecation warning until funsor updates to torch>=1.10.
warnings.filterwarnings("ignore", "torch.triangular_solve is deprecated")
yield


def rinverse(M, sym=False):
"""Matrix inversion of rightmost dimensions (batched).

Expand Down
4 changes: 1 addition & 3 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def __new__(cls, data: torch.Tensor, provenance=frozenset(), **kwargs):
assert not isinstance(data, ProvenanceTensor)
if not provenance:
return data
instance = torch.Tensor.__new__(cls)
instance.__init__(data, provenance)
return instance
return super().__new__(cls)

def __init__(self, data, provenance=frozenset()):
assert isinstance(provenance, frozenset)
Expand Down
11 changes: 7 additions & 4 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,18 @@ def matvecmul(x, y):
def triangular_solve(x, y, upper=False, transpose=False):
if y.size(-1) == 1:
return x / y
return x.triangular_solve(y, upper=upper, transpose=transpose).solution
if transpose:
y = y.transpose(-1, -2)
upper = not upper
return torch.linalg.solve_triangular(y, x, upper=upper)


def precision_to_scale_tril(P):
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
L = torch.triangular_solve(
torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), L_inv, upper=False
)[0]
L = torch.linalg.solve_triangular(
L_inv, torch.eye(P.shape[-1], dtype=P.dtype, device=P.device), upper=False
)
return L


Expand Down
Loading