Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add einsum #722

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/jax/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Load dispatch specializations
import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.einsum
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.pad
Expand Down
20 changes: 20 additions & 0 deletions pytensor/link/jax/dispatch/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.einsum import Einsum


@jax_funcify.register(Einsum)
def jax_funcify_Einsum(op, **kwargs):
"""Dispatch einsum to JAX.

This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
as it always compiles a function per runtime input shape.
"""
subscripts = op.subscripts

def einsum(*operands):
return jnp.einsum(subscripts, *operands, optimize="optimal")

return einsum
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:


# isort: off
from pytensor.tensor.einsum import einsum
from pytensor.tensor.functional import vectorize
# isort: on

Expand Down
34 changes: 21 additions & 13 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node):
return False

for client, idx in clients:
if isinstance(client.op, Output):
client_op = client.op
if isinstance(client_op, Output):
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Op's through which Alloc can be lifted
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
return False
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
return False
elif (
# The following ops work inplace of their input id 0.
idx == 0
and isinstance(
client.op,
client_op,
pytensor.tensor.subtensor.IncSubtensor
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
| pytensor.tensor.subtensor.AdvancedIncSubtensor
Expand Down Expand Up @@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
_x = as_tensor_variable(x)

if axes is None:
axes = list(range((_x.type.ndim - 1), -1, -1))
axes = tuple(range((_x.type.ndim - 1), -1, -1))

if tuple(axes) == tuple(range(len(axes))):
# No-op
return _x

ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)

if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
ret.name = _x.name + ".T"

return ret
Expand Down Expand Up @@ -3950,6 +3956,10 @@ def moveaxis(
source = normalize_axis_tuple(source, a.ndim, "source")
destination = normalize_axis_tuple(destination, a.ndim, "destination")

if source == destination:
# It's a no-op
return a

if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have the same number of elements"
Expand Down Expand Up @@ -4260,9 +4270,7 @@ def atleast_Nd(
atleast_3d = partial(atleast_Nd, n=3)


def expand_dims(
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
) -> TensorVariable:
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
"""Expand the shape of an array.

Insert a new axis that will appear at the `axis` position in the expanded
Expand All @@ -4281,7 +4289,7 @@ def expand_dims(
"""
a = as_tensor(a)

if not isinstance(axis, tuple | list):
if not isinstance(axis, Sequence):
axis = (axis,)

out_ndim = len(axis) + a.ndim
Expand Down
Loading