Skip to content

Commit

Permalink
Add blockwise and Cholesky
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato committed Sep 8, 2024
1 parent 23427a0 commit 18f26a3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.slinalg
import pytensor.link.pytorch.dispatch.blockwise
# isort: on
28 changes: 28 additions & 0 deletions pytensor/link/pytorch/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from pytensor.graph import FunctionGraph
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.blockwise import Blockwise


@pytorch_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
batched_dims = op.batch_ndim(node)
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
core_func = pytorch_funcify(core_fgraph)
if len(node.outputs) == 1:

def inner_func(*inputs):
return core_func(*inputs)[0]
else:
inner_func = core_func

for _ in range(batched_dims):
inner_func = torch.vmap(inner_func)

def batcher(*inputs):
op._check_runtime_broadcast(node, inputs)
return inner_func(*inputs)

return batcher
28 changes: 28 additions & 0 deletions pytensor/link/pytorch/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch.linalg

from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.slinalg import Cholesky, SolveTriangular


@pytorch_funcify.register(Cholesky)
def pytorch_funcify_Cholesky(op, **kwargs):
lower = op.lower

def cholesky(a, lower=lower):
return torch.linalg.cholesky(a, upper=not lower)

return cholesky


@pytorch_funcify.register(SolveTriangular)
def pytorch_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal

def solve_triangular(A, b):
return torch.linalg.solve_triangular(
A, b, upper=not lower, unit_triangle=unit_diagonal, left=trans == "T"
)

return solve_triangular
32 changes: 32 additions & 0 deletions tests/link/pytorch/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from pytensor.graph.replace import vectorize_node
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse


torch = pytest.importorskip("torch")


def test_vectorize_blockwise():
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))

# Something that falls back to Blockwise
node = MatrixInverse()(mat).owner
vect_node = vectorize_node(node, tns)
assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, MatrixInverse
)
assert vect_node.op.signature == ("(m,m)->(m,m)")
assert vect_node.inputs[0] is tns

# Useless blockwise
tns4 = tensor(shape=(5, None, None, None))
new_vect_node = vectorize_node(vect_node, tns4)
assert new_vect_node.op is vect_node.op
assert isinstance(new_vect_node.op, Blockwise) and isinstance(
new_vect_node.op.core_op, MatrixInverse
)
assert new_vect_node.inputs[0] is tns4
22 changes: 22 additions & 0 deletions tests/link/pytorch/test_slinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pytest

import pytensor
from pytensor.tensor import tensor
from pytensor.tensor.slinalg import cholesky


@pytest.mark.parametrize(
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
)
def test_batched_mvnormal_logp_and_dlogp(cov_batch_shape):
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))

cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))

test_values = np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape))

chol_cov = cholesky(cov, lower=True, on_error="raise")

fn = pytensor.function([cov], [chol_cov])
assert np.all(np.isclose(fn(test_values), np.linalg.cholesky(test_values)))

0 comments on commit 18f26a3

Please sign in to comment.