-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
23427a0
commit 18f26a3
Showing
5 changed files
with
112 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
|
||
from pytensor.graph import FunctionGraph | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
from pytensor.tensor.blockwise import Blockwise | ||
|
||
|
||
@pytorch_funcify.register(Blockwise) | ||
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): | ||
batched_dims = op.batch_ndim(node) | ||
core_node = op._create_dummy_core_node(node.inputs) | ||
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) | ||
core_func = pytorch_funcify(core_fgraph) | ||
if len(node.outputs) == 1: | ||
|
||
def inner_func(*inputs): | ||
return core_func(*inputs)[0] | ||
else: | ||
inner_func = core_func | ||
|
||
for _ in range(batched_dims): | ||
inner_func = torch.vmap(inner_func) | ||
|
||
def batcher(*inputs): | ||
op._check_runtime_broadcast(node, inputs) | ||
return inner_func(*inputs) | ||
|
||
return batcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch.linalg | ||
|
||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
from pytensor.tensor.slinalg import Cholesky, SolveTriangular | ||
|
||
|
||
@pytorch_funcify.register(Cholesky) | ||
def pytorch_funcify_Cholesky(op, **kwargs): | ||
lower = op.lower | ||
|
||
def cholesky(a, lower=lower): | ||
return torch.linalg.cholesky(a, upper=not lower) | ||
|
||
return cholesky | ||
|
||
|
||
@pytorch_funcify.register(SolveTriangular) | ||
def pytorch_funcify_SolveTriangular(op, **kwargs): | ||
lower = op.lower | ||
trans = op.trans | ||
unit_diagonal = op.unit_diagonal | ||
|
||
def solve_triangular(A, b): | ||
return torch.linalg.solve_triangular( | ||
A, b, upper=not lower, unit_triangle=unit_diagonal, left=trans == "T" | ||
) | ||
|
||
return solve_triangular |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
|
||
from pytensor.graph.replace import vectorize_node | ||
from pytensor.tensor import tensor | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.nlinalg import MatrixInverse | ||
|
||
|
||
torch = pytest.importorskip("torch") | ||
|
||
|
||
def test_vectorize_blockwise(): | ||
mat = tensor(shape=(None, None)) | ||
tns = tensor(shape=(None, None, None)) | ||
|
||
# Something that falls back to Blockwise | ||
node = MatrixInverse()(mat).owner | ||
vect_node = vectorize_node(node, tns) | ||
assert isinstance(vect_node.op, Blockwise) and isinstance( | ||
vect_node.op.core_op, MatrixInverse | ||
) | ||
assert vect_node.op.signature == ("(m,m)->(m,m)") | ||
assert vect_node.inputs[0] is tns | ||
|
||
# Useless blockwise | ||
tns4 = tensor(shape=(5, None, None, None)) | ||
new_vect_node = vectorize_node(vect_node, tns4) | ||
assert new_vect_node.op is vect_node.op | ||
assert isinstance(new_vect_node.op, Blockwise) and isinstance( | ||
new_vect_node.op.core_op, MatrixInverse | ||
) | ||
assert new_vect_node.inputs[0] is tns4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import pytensor | ||
from pytensor.tensor import tensor | ||
from pytensor.tensor.slinalg import cholesky | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}" | ||
) | ||
def test_batched_mvnormal_logp_and_dlogp(cov_batch_shape): | ||
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal"))) | ||
|
||
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10)) | ||
|
||
test_values = np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)) | ||
|
||
chol_cov = cholesky(cov, lower=True, on_error="raise") | ||
|
||
fn = pytensor.function([cov], [chol_cov]) | ||
assert np.all(np.isclose(fn(test_values), np.linalg.cholesky(test_values))) |