Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch inline constants in dispatch to avoid graph breaks #1118

Merged
merged 2 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
Expand All @@ -19,6 +20,7 @@
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)

Expand Down Expand Up @@ -120,14 +122,23 @@ def arange(start, stop, step):


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]

return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)

return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

return join


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
Expand All @@ -185,3 +195,23 @@ def tensorfromscalar(x):
return torch.as_tensor(x)

return tensorfromscalar


@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)

def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)

return split_constant_axis_and_sizes

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

return inner_fn
6 changes: 6 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import torch

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs

return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)

def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)

return reshape_constant_shape

else:

def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
Expand Down
15 changes: 15 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs

if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)

def constant_index_subtensor(x, *_):
return x[constant_indices]

return constant_index_subtensor

# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic


from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
Expand Down
50 changes: 50 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
rng.normal(size=20).astype(config.floatX),
[],
),
(
5,
0,
rng.normal(size=5).astype(config.floatX),
rng.multinomial(5, np.ones(5) / 5),
),
(
5,
0,
rng.normal(size=10).astype(config.floatX),
rng.multinomial(10, np.ones(5) / 5),
),
(
5,
-1,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(7, np.ones(5) / 5),
),
(
5,
-2,
rng.normal(size=(11, 7)).astype(config.floatX),
rng.multinomial(11, np.ones(5) / 5),
),
],
)
def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)

compare_pytorch_and_py(g_fg, [values, sizes])