Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to PyTorch 1.9.0 #2887

Merged
merged 12 commits into from
Jul 2, 2021
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install -r docs/requirements.txt
pip freeze
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip freeze
Expand Down Expand Up @@ -112,7 +112,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip freeze
Expand Down Expand Up @@ -143,7 +143,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip freeze
Expand Down Expand Up @@ -172,7 +172,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install coveralls
pip freeze
Expand Down Expand Up @@ -201,7 +201,7 @@ jobs:
python -m pip install --upgrade pip wheel setuptools
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install torch==1.8.0+cpu torchvision==0.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install .[test]
pip install -e .[funsor]
pip install coveralls
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,5 @@ def setup(app):
# See similar line in the install section of .travis.yml
if 'READTHEDOCS' in os.environ:
os.system('pip install numpy')
os.system('pip install torch==1.8.0+cpu torchvision==0.9.0+cpu '
os.system('pip install torch==1.9.0+cpu torchvision==0.10.0+cpu '
'-f https://download.pytorch.org/whl/torch_stable.html')
17 changes: 3 additions & 14 deletions pyro/contrib/examples/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,9 @@


class MNIST(datasets.MNIST):
# For older torchvision.
urls = [
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz",
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz",
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz",
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz",
]
# For newer torchvision.
resources = list(zip(urls, [
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
"d53e105ee54ea40749a09fcbcd1e9432",
"9fb629c4189551a2d022fa330f9573f3",
"ec29112dd5afa0611ce80d1b7f02629c"
]))
mirrors = [
"https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/"
] + datasets.MNIST.mirrors


def get_data_loader(dataset_name,
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def model(self):
N = self.X.size(0)
Kff = self.kernel(self.X)
Kff.view(-1)[::N + 1] += self.jitter + self.noise # add noise to diagonal
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)

zero_loc = self.X.new_zeros(self.X.size(0))
f_loc = zero_loc + self.mean_function(self.X)
Expand Down Expand Up @@ -123,7 +123,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
N = self.X.size(0)
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter + self.noise # add noise to the diagonal
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)

y_residual = self.y - self.mean_function(self.X)
loc, cov = conditional(Xnew, self.X, self.kernel, y_residual, None, Lff,
Expand Down Expand Up @@ -179,7 +179,7 @@ def sample_next(xnew, outside_vars):
X, y, Kff = outside_vars["X"], outside_vars["y"], outside_vars["Kff"]

# Compute Cholesky decomposition of kernel matrix
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)
y_residual = y - self.mean_function(X)

# Compute conditional mean and variance
Expand Down
6 changes: 3 additions & 3 deletions pyro/contrib/gp/models/sgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def model(self):
M = self.Xu.size(0)
Kuu = self.kernel(self.Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.cholesky()
Luu = torch.linalg.cholesky(Kuu)
Kuf = self.kernel(self.Xu, self.X)
W = Kuf.triangular_solve(Luu, upper=False)[0].t()

Expand Down Expand Up @@ -204,7 +204,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):

Kuu = self.kernel(self.Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.cholesky()
Luu = torch.linalg.cholesky(Kuu)

Kuf = self.kernel(self.Xu, self.X)

Expand All @@ -218,7 +218,7 @@ def forward(self, Xnew, full_cov=False, noiseless=True):
W_Dinv = W / D
K = W_Dinv.matmul(W.t()).contiguous()
K.view(-1)[::M + 1] += 1 # add identity matrix to K
L = K.cholesky()
L = torch.linalg.cholesky(K)

# get y_residual and convert it into 2D tensor for packing
y_residual = self.y - self.mean_function(self.X)
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/models/vgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def model(self):
N = self.X.size(0)
Kff = self.kernel(self.X).contiguous()
Kff.view(-1)[::N + 1] += self.jitter # add jitter to the diagonal
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)

zero_loc = self.X.new_zeros(self.f_loc.shape)
if self.whiten:
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/models/vsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def model(self):
M = self.Xu.size(0)
Kuu = self.kernel(self.Xu).contiguous()
Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal
Luu = Kuu.cholesky()
Luu = torch.linalg.cholesky(Kuu)

zero_loc = self.Xu.new_zeros(self.u_loc.shape)
if self.whiten:
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def conditional(Xnew, X, kernel, f_loc, f_scale_tril=None, Lff=None, full_cov=Fa
if Lff is None:
Kff = kernel(X).contiguous()
Kff.view(-1)[::N + 1] += jitter # add jitter to diagonal
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)
Kfs = kernel(X, Xnew)

# convert f_loc_shape from latent_shape x N to N x latent_shape
Expand Down
4 changes: 3 additions & 1 deletion pyro/contrib/oed/glmm/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def finalize(self, loss, target_labels):
continue
hess_l = self._hessian_diag(loss, mu_l, event_shape=(self.w_sizes[l],))
cov_l = rinverse(hess_l)
self.scale_trils[l] = cov_l.cholesky(upper=False)
self.scale_trils[l] = torch.linalg.cholesky(
cov_l.transpose(-2, -1)
).transpose(-2, -1)

def forward(self, design, target_labels=None):
"""
Expand Down
8 changes: 4 additions & 4 deletions pyro/contrib/tracking/extended_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ def update(self, measurement):
S = H.mm(P).mm(H.transpose(-1, -2)) + R # innovation cov

K_prefix = self._cov.mm(H.transpose(-1, -2))
dx = K_prefix.mm(torch.solve(dz.unsqueeze(1), S)[0]).squeeze(1) # K*dz
dx = K_prefix.mm(torch.linalg.solve(S, dz.unsqueeze(1))).squeeze(1) # K*dz
x = self._dynamic_model.geodesic_difference(x, -dx)

I = eye_like(x, self._dynamic_model.dimension) # noqa: E741
ImKH = I - K_prefix.mm(torch.solve(H, S)[0])
ImKH = I - K_prefix.mm(torch.linalg.solve(S, H))
# *Joseph form* of covariance update for numerical stability.
S_inv_R = torch.linalg.solve(S, R)
P = ImKH.mm(self.cov).mm(ImKH.transpose(-1, -2)) \
+ K_prefix.mm(torch.solve((K_prefix.mm(torch.solve(R, S)[0])).transpose(-1, -2),
S)[0])
+ K_prefix.mm(torch.linalg.solve(S, K_prefix.mm(S_inv_R).transpose(-1, -2)))

pred_mean = x
pred_cov = P
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/spanning_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def log_partition_function(self):
import gpytorch
log_det = gpytorch.lazy.NonLazyTensor(truncated).logdet()
except ImportError:
log_det = torch.cholesky(truncated).diag().log().sum() * 2
log_det = torch.linalg.cholesky(truncated).diag().log().sum() * 2
return log_det + log_diag[:-1].sum()

def log_prob(self, edges):
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _PositiveDefinite_check(self, value):
matrix_shape = value.shape[-2:]
batch_shape = value.shape[:-2]
flattened_value = value.reshape((-1,) + matrix_shape)
return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
return torch.stack([torch.linalg.eigvalsh(v)[:1] > 0.0
for v in flattened_value]).view(batch_shape)


Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __eq__(self, other):
return isinstance(other, CholeskyTransform)

def _call(self, x):
return torch.cholesky(x)
return torch.linalg.cholesky(x)

def _inverse(self, y):
return torch.matmul(y, torch.transpose(y, -2, -1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(self, channels=3, permutation=None):
self.__delattr__('permutation')

# Sample a random orthogonal matrix
W, _ = torch.qr(torch.randn(channels, channels))
W, _ = torch.linalg.qr(torch.randn(channels, channels))

# Construct the partially pivoted LU-form and the pivots
LU, pivots = W.lu()
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def laplace_approximation(self, *args, **kwargs):
H = hessian(loss, self.loc)
cov = H.inverse()
loc = self.loc
scale_tril = cov.cholesky()
scale_tril = torch.linalg.cholesky(cov)

gaussian_guide = AutoMultivariateNormal(self.model)
gaussian_guide._setup_prototype(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/mcmc/adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _matvecmul(x, y):


def _cholesky(x):
return x.sqrt() if x.dim() == 1 else x.cholesky()
return x.sqrt() if x.dim() == 1 else torch.linalg.cholesky(x)


def _transpose(x):
Expand Down
9 changes: 8 additions & 1 deletion pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,19 @@ def compute_expectation(self, costs):
cost = cost.masked_select(mask)
else:
cost, prob = packed.broadcast_all(cost, prob)
expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
expected_cost = expected_cost + scale * _fulldot(prob, cost)

LAST_CACHE_SIZE[0] = count_cached_ops(cache)
return expected_cost


def _fulldot(x, y):
assert x.dim() == y.dim()
if x.dim() == 0:
return x * y
return torch.tensordot(x, y, dims=x.dim())


def check_fully_reparametrized(guide_site):
log_prob, score_function_term, entropy_term = guide_site["score_parts"]
fully_rep = (guide_site["fn"].has_rsample and not is_identically_zero(entropy_term) and
Expand Down
4 changes: 3 additions & 1 deletion pyro/ops/arrowhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def sqrt(x):
# is upper triangular) using some `flip` operators:
# flip(cholesky(flip(schur_complement)))
try:
top_left = torch.flip(torch.cholesky(torch.flip(schur_complement, (-2, -1))), (-2, -1))
top_left = torch.flip(
torch.linalg.cholesky(torch.flip(schur_complement, (-2, -1))), (-2, -1)
)
break
except RuntimeError:
B = B / 2
Expand Down
4 changes: 2 additions & 2 deletions pyro/ops/gamma_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def marginalize(self, left=0, right=0):
P_aa = self.precision[..., a, a]
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = P_bb.cholesky()
P_b = torch.linalg.cholesky(P_bb)
P_a = P_ba.triangular_solve(P_b, upper=False).solution
P_at = P_a.transpose(-1, -2)
precision = P_aa - P_at.matmul(P_a)
Expand Down Expand Up @@ -290,7 +290,7 @@ def event_logsumexp(self):
Integrates out all latent state (i.e. operating on event dimensions) of Gaussian component.
"""
n = self.dim()
chol_P = self.precision.cholesky()
chol_P = torch.linalg.cholesky(self.precision)
chol_P_u = self.info_vec.unsqueeze(-1).triangular_solve(chol_P, upper=False).solution.squeeze(-1)
u_P_u = chol_P_u.pow(2).sum(-1)
# considering GammaGaussian as a Gaussian with precision = s * precision, info_vec = s * info_vec,
Expand Down
4 changes: 2 additions & 2 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def inverse_haar_transform(x):
def cholesky(x):
if x.size(-1) == 1:
return x.sqrt()
return x.cholesky()
return torch.linalg.cholesky(x)


def cholesky_solve(x, y):
Expand Down Expand Up @@ -410,7 +410,7 @@ def triangular_solve(x, y, upper=False, transpose=False):


def precision_to_scale_tril(P):
Lf = torch.cholesky(torch.flip(P, (-2, -1)))
Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
L = torch.triangular_solve(torch.eye(P.shape[-1], dtype=P.dtype, device=P.device),
L_inv, upper=False)[0]
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ filterwarnings = error
ignore::DeprecationWarning
ignore:CUDA initialization:UserWarning
ignore:floor_divide is deprecated:UserWarning
ignore:torch.tensor results are registered as constants in the trace
once::DeprecationWarning

doctest_optionflags = ELLIPSIS NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@
'jupyter>=1.0.0',
'graphviz>=0.8',
'matplotlib>=1.3',
'torchvision>=0.9.0',
'torchvision>=0.10.0',
'visdom>=0.1.4',
'pandas',
'pillow==8.2.0', # https://github.com/pytorch/pytorch/issues/61125
'scikit-learn',
'seaborn',
'wget',
Expand All @@ -89,7 +90,7 @@
'numpy>=1.7',
'opt_einsum>=2.3.2',
'pyro-api>=0.1.1',
'torch>=1.8.0',
'torch>=1.9.0',
'tqdm>=4.36',
],
extras_require={
Expand Down Expand Up @@ -121,7 +122,7 @@
'horovod': ['horovod[pytorch]>=0.19'],
'funsor': [
# This must be a released version when Pyro is released.
'funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7be0ef9af6a100e52ac98ab13b203a4dec0ae42e',
'funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@383e7a6d05c9d5de9646d23698891e10c4cba927',
],
},
python_requires='>=3.6',
Expand Down
2 changes: 2 additions & 0 deletions tests/contrib/funsor/test_vectorized_markov.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def guide_empty(data, history, vectorized):
pass


@pytest.mark.xfail(reason="funsor version drift")
@pytest.mark.parametrize("model,guide,data,history", [
(model_0, guide_empty, torch.rand(3, 5, 4), 1),
(model_1, guide_empty, torch.rand(5, 4), 1),
Expand Down Expand Up @@ -526,6 +527,7 @@ def guide_empty_multi(weeks_data, days_data, history, vectorized):
pass


@pytest.mark.xfail(reason="funsor version drift")
@pytest.mark.parametrize("model,guide,weeks_data,days_data,history", [
(model_8, guide_empty_multi, torch.ones(3), torch.zeros(9), 1),
(model_8, guide_empty_multi, torch.ones(30), torch.zeros(50), 1),
Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/gp/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
X = torch.tensor([[1., 5.], [2., 1.], [3., 2.]])
kernel = Matern52(input_dim=2)
Kff = kernel(X) + torch.eye(3) * 1e-6
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)
pyro.set_rng_seed(123)
f_loc = torch.rand(3)
f_scale_tril = torch.rand(3, 3).tril(-1) + torch.rand(3).exp().diag()
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_conditional_whiten(Xnew, X, kernel, f_loc, f_scale_tril, loc, cov):
loc0, cov0 = conditional(Xnew, X, kernel, f_loc, f_scale_tril, full_cov=True,
whiten=False)
Kff = kernel(X) + torch.eye(3) * 1e-6
Lff = Kff.cholesky()
Lff = torch.linalg.cholesky(Kff)
whiten_f_loc = Lff.inverse().matmul(f_loc)
whiten_f_scale_tril = Lff.inverse().matmul(f_scale_tril)
loc1, cov1 = conditional(Xnew, X, kernel, whiten_f_loc, whiten_f_scale_tril,
Expand Down
Loading