Skip to content

Commit

Permalink
OpFromGraph subclasses shouldn't have __props__
Browse files Browse the repository at this point in the history
When specified, Ops with identical __props__ are considered identical, in that they can be swapped and given the original inputs to obtain the same output.
  • Loading branch information
ricardoV94 committed Aug 24, 2024
1 parent 1509cee commit 9f88e1f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3780,15 +3780,16 @@ class AllocDiag(OpFromGraph):
Wrapper Op for alloc_diag graphs
"""

__props__ = ("axis1", "axis2")

def __init__(self, *args, axis1, axis2, offset, **kwargs):
self.axis1 = axis1
self.axis2 = axis2
self.offset = offset

super().__init__(*args, **kwargs, strict=True)

def __str__(self):
return f"AllocDiag{{{self.axis1=}, {self.axis2=}, {self.offset=}}}"

@staticmethod
def is_offset_zero(node) -> bool:
"""
Expand Down
5 changes: 3 additions & 2 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ class Einsum(OpFromGraph):
desired. We haven't decided whether we want to provide this functionality.
"""

__props__ = ("subscripts", "path", "optimized")

def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs):
self.subscripts = subscripts
self.path = path
self.optimized = optimized
super().__init__(*args, **kwargs, strict=True)

def __str__(self):
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"


def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TensorFromScalar,
Tri,
alloc,
alloc_diag,
arange,
as_tensor_variable,
atleast_Nd,
Expand Down Expand Up @@ -3793,6 +3794,18 @@ def test_alloc_diag_values(self):
)
assert np.all(true_grad_input == grad_input)

def test_multiple_ops_same_graph(self):
"""Regression test when AllocDiag OFG was given insufficient props, causing incompatible Ops to be merged."""
v1 = vector("v1", shape=(2,), dtype="float64")
v2 = vector("v2", shape=(3,), dtype="float64")
a1 = alloc_diag(v1)
a2 = alloc_diag(v2)

fn = function([v1, v2], [a1, a2])
res1, res2 = fn(v1=[np.e, np.e], v2=[np.pi, np.pi, np.pi])
np.testing.assert_allclose(res1, np.eye(2) * np.e)
np.testing.assert_allclose(res2, np.eye(3) * np.pi)


def test_diagonal_negative_axis():
x = np.arange(2 * 3 * 3).reshape((2, 3, 3))
Expand Down

0 comments on commit 9f88e1f

Please sign in to comment.