Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: PyTorch backend #457

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PyTorchLinker
from pytensor.link.vm import VMLinker


Expand All @@ -48,6 +49,7 @@
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"numba": NumbaLinker(),
"pytorch": PyTorchLinker(),
}


Expand Down Expand Up @@ -469,13 +471,26 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)

PYTORCH = Mode(
PyTorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)

predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
}

instantiated_default_mode = None
Expand Down Expand Up @@ -548,7 +563,7 @@ def register_mode(name, mode):
predefined_modes[name] = mode


def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]:
def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax", "pytorch"], ...]:
"""Get the compilation target language."""

if mode is None:
Expand All @@ -560,6 +575,8 @@ def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"],
return ("numba",)
if isinstance(linker, JAXLinker):
return ("jax",)
if isinstance(linker, PyTorchLinker):
return ("pytorch",)
if isinstance(linker, PerformLinker):
return ("py",)
if isinstance(linker, CLinker):
Expand Down
1 change: 1 addition & 0 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _filter_mode(val):
"DEBUG_MODE",
"JAX",
"NUMBA",
"PYTORCH"
]
if val in str_options:
return val
Expand Down
17 changes: 17 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify

# Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.tensor_basic
import pytensor.link.pytorch.dispatch.subtensor
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.slinalg
import pytensor.link.pytorch.dispatch.random
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.scan
import pytensor.link.pytorch.dispatch.sparse

# isort: on
109 changes: 109 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import warnings
from functools import singledispatch

import torch
import numpy as np

from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
if dtype is None:
return data
else:
return torch.tensor(data, dtype=dtype)


@pytorch_typify.register(np.ndarray)
def pytorch_typify_ndarray(data, dtype=None, **kwargs):
if len(data.shape) == 0:
return data.item()
return torch.tensor(data, dtype=dtype)


@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}")


@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="torch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)


@pytorch_funcify.register(IfElse)
def pytorch_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs

def ifelse(cond, *args, n_outs=n_outs):
res = torch.where(
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
)
return res if n_outs > 1 else res[0]

return ifelse


@pytorch_funcify.register(Assert)
@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as PyTorch tracing would remove it.""",
stacklevel=2,
)

def assert_fn(x, *inputs):
return x

return assert_fn


def pytorch_safe_copy(x):
try:
res = torch.clone(x)
except NotImplementedError:
warnings.warn(
"`torch.clone` is not implemented yet. Using the object's `copy` method."
)
if hasattr(x, "copy"):
res = torch.tensor(x.copy())
else:
warnings.warn(f"Object has no `copy` method: {x}")
res = x

return res


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return pytorch_safe_copy(x)

return deepcopyop


@pytorch_funcify.register(ViewOp)
def pytorch_funcify_ViewOp(op, **kwargs):
def viewop(x):
return x

return viewop
113 changes: 113 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad


@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
# ScalarVariables in PyTorch are passed as int/float.
# We wrap them in tensors just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(torch.tensor, inputs)))
return base_fn(*inputs)

return elemwise_fn


@pytorch_funcify.register(CAReduce)
def pytorch_funcify_CAReduce(op, **kwargs):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
scalar_op_name = getattr(op.scalar_op, "name", None)
scalar_op_identity = getattr(op.scalar_op, "identity", None)
acc_dtype = getattr(op, "acc_dtype", None)

def careduce(x):
nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype

if axis is None:
axis = list(range(x.ndim))

if acc_dtype is None:
acc_dtype = x.dtype.type

if op_nfunc_spec:
torch_op = getattr(torch, op_nfunc_spec[0])
return torch_op(x, axis=axis).type(acc_dtype)

# The PyTensor `Op` didn't tell us which PyTorch equivalent to use (or
# there isn't one), so we use this fallback approach
if scalar_nfunc_spec:
scalar_fn_name = scalar_nfunc_spec[0]
elif scalar_op_name:
scalar_fn_name = scalar_op_name

to_reduce = sorted(axis, reverse=True)

if to_reduce:
# In this case, we need to use the `torch` function (if there
# is one), and not the `torch` version.
torch_op = getattr(torch, scalar_fn_name)
init_value = torch.tensor(scalar_op_identity, dtype=acc_dtype)
return torch.reduce(x, init_value, torch_op, to_reduce).type(acc_dtype)
else:
return x

return careduce


@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.transpose(x, op.transposition)

shape = list(res.shape[: len(op.shuffle)])

for augm in op.augment:
shape.insert(augm, 1)

res = torch.reshape(res, shape)

if not op.inplace:
res = torch.clone(res)

return res

return dimshuffle


@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis

def softmax(x):
return torch.nn.functional.softmax(x, dim=axis)

return softmax


@pytorch_funcify.register(SoftmaxGrad)
def pytorch_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis

def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm

return softmax_grad


@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis

def log_softmax(x):
return torch.nn.functional.log_softmax(x, dim=axis)

return log_softmax
Loading
Loading