Skip to content

Commit

Permalink
PyTorch inline constants in dispatch to avoid graph breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 12, 2024
1 parent 7ce2a5d commit a6b2f48
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
41 changes: 31 additions & 10 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -121,14 +122,23 @@ def arange(start, stop, step):


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]

return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)

return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

return join


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -173,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
Expand All @@ -190,7 +199,19 @@ def tensorfromscalar(x):

@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)

def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)

return split_constant_axis_and_sizes

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

return inner_fn
return inner_fn
19 changes: 16 additions & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import torch

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs

return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)

def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)

return reshape_constant_shape

else:

def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
Expand Down
15 changes: 15 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs

if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)

def constant_index_subtensor(x, *_):
return x[constant_indices]

return constant_index_subtensor

# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
Expand Down

0 comments on commit a6b2f48

Please sign in to comment.