From a6b2f48e518d3a1e979a236d63be6971568ad3ee Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 12 Dec 2024 11:07:11 +0100 Subject: [PATCH] PyTorch inline constants in dispatch to avoid graph breaks --- pytensor/link/pytorch/dispatch/basic.py | 41 ++++++++++++++++----- pytensor/link/pytorch/dispatch/shape.py | 19 ++++++++-- pytensor/link/pytorch/dispatch/subtensor.py | 15 ++++++++ 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 6cf4f29aab..ef4bf10637 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -8,6 +8,7 @@ from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse from pytensor.link.utils import fgraph_to_python @@ -121,14 +122,23 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) -def pytorch_funcify_Join(op, **kwargs): - def join(axis, *tensors): - # tensors could also be tuples, and in this case they don't have a ndim - tensors = [torch.tensor(tensor) for tensor in tensors] +def pytorch_funcify_Join(op, node, **kwargs): + axis = node.inputs[0] - return torch.cat(tensors, dim=axis) + if isinstance(axis, Constant): + axis = int(axis.data) - return join + def join_constant_axis(_, *tensors): + return torch.cat(tensors, dim=axis) + + return join_constant_axis + + else: + + def join(axis, *tensors): + return torch.cat(tensors, dim=axis) + + return join @pytorch_funcify.register(Eye) @@ -173,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites PYTORCH.optimizer(op.fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) @@ -190,7 +199,19 @@ def tensorfromscalar(x): @pytorch_funcify.register(Split) def pytorch_funcify_Split(op, node, **kwargs): - def inner_fn(x, dim, split_amounts): - return x.split(split_amounts.tolist(), dim=dim.item()) + x, dim, split_sizes = node.inputs + if isinstance(dim, Constant) and isinstance(split_sizes, Constant): + dim = int(dim.data) + split_sizes = tuple(int(size) for size in split_sizes.data) + + def split_constant_axis_and_sizes(x, *_): + return x.split(split_sizes, dim=dim) + + return split_constant_axis_and_sizes + + else: + + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) - return inner_fn + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index f771ac7211..c15b3a3779 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -1,15 +1,28 @@ import torch +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast @pytorch_funcify.register(Reshape) def pytorch_funcify_Reshape(op, node, **kwargs): - def reshape(x, shape): - return torch.reshape(x, tuple(shape)) + _, shape = node.inputs - return reshape + if isinstance(shape, Constant): + constant_shape = tuple(int(dim) for dim in shape.data) + + def reshape_constant_shape(x, *_): + return torch.reshape(x, constant_shape) + + return reshape_constant_shape + + else: + + def reshape(x, shape): + return torch.reshape(x, tuple(shape)) + + return reshape @pytorch_funcify.register(Shape) diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 75e7ec0776..34358797fb 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -1,3 +1,4 @@ +from pytensor.graph.basic import Constant from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -23,7 +24,21 @@ def check_negative_steps(indices): @pytorch_funcify.register(Subtensor) def pytorch_funcify_Subtensor(op, node, **kwargs): idx_list = op.idx_list + x, *idxs = node.inputs + if all(isinstance(idx, Constant) for idx in idxs): + # Use constant indices to avoid graph break + constant_indices = indices_from_subtensor( + [int(idx.data) for idx in idxs], idx_list + ) + check_negative_steps(constant_indices) + + def constant_index_subtensor(x, *_): + return x[constant_indices] + + return constant_index_subtensor + + # Fallback that will introduce a graph break def subtensor(x, *flattened_indices): indices = indices_from_subtensor(flattened_indices, idx_list) check_negative_steps(indices)