Skip to content

Commit

Permalink
Implement nlinalg Ops in PyTorch (#920)
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw authored Jul 26, 2024
1 parent 367351f commit 58fec45
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort

import pytensor.link.pytorch.dispatch.nlinalg
# isort: on
103 changes: 103 additions & 0 deletions pytensor/link/pytorch/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch

from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)


@pytorch_funcify.register(SVD)
def pytorch_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv

def svd(x):
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)
if compute_uv:
return U, S, V
return S

return svd


@pytorch_funcify.register(Det)
def pytorch_funcify_Det(op, **kwargs):
def det(x):
return torch.linalg.det(x)

return det


@pytorch_funcify.register(SLogDet)
def pytorch_funcify_SLogDet(op, **kwargs):
def slogdet(x):
return torch.linalg.slogdet(x)

return slogdet


@pytorch_funcify.register(Eig)
def pytorch_funcify_Eig(op, **kwargs):
def eig(x):
return torch.linalg.eig(x)

return eig


@pytorch_funcify.register(Eigh)
def pytorch_funcify_Eigh(op, **kwargs):
uplo = op.UPLO

def eigh(x, uplo=uplo):
return torch.linalg.eigh(x, UPLO=uplo)

return eigh


@pytorch_funcify.register(MatrixInverse)
def pytorch_funcify_MatrixInverse(op, **kwargs):
def matrix_inverse(x):
return torch.linalg.inv(x)

return matrix_inverse


@pytorch_funcify.register(QRFull)
def pytorch_funcify_QRFull(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")

def qr_full(x):
Q, R = torch.linalg.qr(x, mode=mode)
if mode == "r":
return R
return Q, R

return qr_full


@pytorch_funcify.register(MatrixPinv)
def pytorch_funcify_Pinv(op, **kwargs):
hermitian = op.hermitian

def pinv(x):
return torch.linalg.pinv(x, hermitian=hermitian)

return pinv


@pytorch_funcify.register(KroneckerProduct)
def pytorch_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return torch.kron(x, y)

return _kron
111 changes: 111 additions & 0 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
import pytest

from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nla
from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py


@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)

M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)

x = matrix("x")
return (x, test_value)


@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test

out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)

compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)


@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
compare_pytorch_and_py(out_fg, [test_value])


@pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test):
x, test_value = matrix_test

out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

compare_pytorch_and_py(out_fg, [test_value])


def test_pinv():
x = matrix("x")
x_inv = pt_nla.pinv(x)

fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np])


@pytest.mark.parametrize("hermitian", [False, True])
def test_pinv_hermitian(hermitian):
A = matrix("A", dtype="complex128")
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
A_not_h_test = A_h_test + 0 + 1j

A_inv = pt_nla.pinv(A, hermitian=hermitian)
torch_fn = function([A], A_inv, mode="PYTORCH")

assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))

assert (
np.allclose(
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
)
is not hermitian
)

assert (
np.allclose(
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
)
is hermitian
)


def test_kron():
x = matrix("x")
y = matrix("y")
z = pt_nla.kron(x, y)

fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)

compare_pytorch_and_py(fgraph, [x_np, y_np])

0 comments on commit 58fec45

Please sign in to comment.