From 58fec45d00717ecebbcd4a70e08e1ed0b99e362c Mon Sep 17 00:00:00 2001 From: Diego Sandoval <46681084+twaclaw@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:06:27 +0200 Subject: [PATCH] Implement nlinalg Ops in PyTorch (#920) --- pytensor/link/pytorch/dispatch/__init__.py | 2 +- pytensor/link/pytorch/dispatch/nlinalg.py | 103 +++++++++++++++++++ tests/link/pytorch/test_nlinalg.py | 111 +++++++++++++++++++++ 3 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/pytorch/dispatch/nlinalg.py create mode 100644 tests/link/pytorch/test_nlinalg.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fa47908d74..0295a12e8e 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -9,5 +9,5 @@ import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort - +import pytensor.link.pytorch.dispatch.nlinalg # isort: on diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py new file mode 100644 index 0000000000..91690489e9 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -0,0 +1,103 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.nlinalg import ( + SVD, + Det, + Eig, + Eigh, + KroneckerProduct, + MatrixInverse, + MatrixPinv, + QRFull, + SLogDet, +) + + +@pytorch_funcify.register(SVD) +def pytorch_funcify_SVD(op, **kwargs): + full_matrices = op.full_matrices + compute_uv = op.compute_uv + + def svd(x): + U, S, V = torch.linalg.svd(x, full_matrices=full_matrices) + if compute_uv: + return U, S, V + return S + + return svd + + +@pytorch_funcify.register(Det) +def pytorch_funcify_Det(op, **kwargs): + def det(x): + return torch.linalg.det(x) + + return det + + +@pytorch_funcify.register(SLogDet) +def pytorch_funcify_SLogDet(op, **kwargs): + def slogdet(x): + return torch.linalg.slogdet(x) + + return slogdet + + +@pytorch_funcify.register(Eig) +def pytorch_funcify_Eig(op, **kwargs): + def eig(x): + return torch.linalg.eig(x) + + return eig + + +@pytorch_funcify.register(Eigh) +def pytorch_funcify_Eigh(op, **kwargs): + uplo = op.UPLO + + def eigh(x, uplo=uplo): + return torch.linalg.eigh(x, UPLO=uplo) + + return eigh + + +@pytorch_funcify.register(MatrixInverse) +def pytorch_funcify_MatrixInverse(op, **kwargs): + def matrix_inverse(x): + return torch.linalg.inv(x) + + return matrix_inverse + + +@pytorch_funcify.register(QRFull) +def pytorch_funcify_QRFull(op, **kwargs): + mode = op.mode + if mode == "raw": + raise NotImplementedError("raw mode not implemented in PyTorch") + + def qr_full(x): + Q, R = torch.linalg.qr(x, mode=mode) + if mode == "r": + return R + return Q, R + + return qr_full + + +@pytorch_funcify.register(MatrixPinv) +def pytorch_funcify_Pinv(op, **kwargs): + hermitian = op.hermitian + + def pinv(x): + return torch.linalg.pinv(x, hermitian=hermitian) + + return pinv + + +@pytorch_funcify.register(KroneckerProduct) +def pytorch_funcify_KroneckerProduct(op, **kwargs): + def _kron(x, y): + return torch.kron(x, y) + + return _kron diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py new file mode 100644 index 0000000000..7d69ac0500 --- /dev/null +++ b/tests/link/pytorch/test_nlinalg.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from pytensor.compile.function import function +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import nlinalg as pt_nla +from pytensor.tensor.type import matrix +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.fixture +def matrix_test(): + rng = np.random.default_rng(213234) + + M = rng.normal(size=(3, 3)) + test_value = M.dot(M.T).astype(config.floatX) + + x = matrix("x") + return (x, test_value) + + +@pytest.mark.parametrize( + "func", + (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), +) +def test_lin_alg_no_params(func, matrix_test): + x, test_value = matrix_test + + out = func(x) + out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) + + def assert_fn(x, y): + np.testing.assert_allclose(x, y, rtol=1e-3) + + compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) + + +@pytest.mark.parametrize( + "mode", + ( + "complete", + "reduced", + "r", + pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), + ), +) +def test_qr(mode, matrix_test): + x, test_value = matrix_test + outs = pt_nla.qr(x, mode=mode) + out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs]) + compare_pytorch_and_py(out_fg, [test_value]) + + +@pytest.mark.parametrize("compute_uv", [True, False]) +@pytest.mark.parametrize("full_matrices", [True, False]) +def test_svd(compute_uv, full_matrices, matrix_test): + x, test_value = matrix_test + + out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) + + compare_pytorch_and_py(out_fg, [test_value]) + + +def test_pinv(): + x = matrix("x") + x_inv = pt_nla.pinv(x) + + fgraph = FunctionGraph([x], [x_inv]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + compare_pytorch_and_py(fgraph, [x_np]) + + +@pytest.mark.parametrize("hermitian", [False, True]) +def test_pinv_hermitian(hermitian): + A = matrix("A", dtype="complex128") + A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]] + A_not_h_test = A_h_test + 0 + 1j + + A_inv = pt_nla.pinv(A, hermitian=hermitian) + torch_fn = function([A], A_inv, mode="PYTORCH") + + assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False)) + assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True)) + + assert ( + np.allclose( + torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False) + ) + is not hermitian + ) + + assert ( + np.allclose( + torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True) + ) + is hermitian + ) + + +def test_kron(): + x = matrix("x") + y = matrix("y") + z = pt_nla.kron(x, y) + + fgraph = FunctionGraph([x, y], [z]) + x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + + compare_pytorch_and_py(fgraph, [x_np, y_np])