Skip to content

Commit

Permalink
Merge branch 'main' into ifelse_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato authored Sep 13, 2024
2 parents 881bca1 + b66d859 commit 738393f
Show file tree
Hide file tree
Showing 17 changed files with 1,095 additions and 319 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.6
rev: v0.6.3
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
Expand Down
1 change: 1 addition & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.subtensor
# isort: on
33 changes: 28 additions & 5 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
from functools import singledispatch
from types import NoneType

import numpy as np
import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Eye,
Join,
MakeVector,
TensorFromScalar,
)


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")


@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)


@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
Expand Down Expand Up @@ -146,3 +162,10 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
return torch.stack(true_and_false[n_outs:])

return ifelse

@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
return torch.as_tensor(x)

return tensorfromscalar
124 changes: 124 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType


def check_negative_steps(indices):
for index in indices:
if isinstance(index, slice):
if index.step is not None and index.step < 0:
raise NotImplementedError(
"Negative step sizes are not supported in Pytorch"
)


@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list

def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

return subtensor


@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)

return makeslice


@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
check_negative_steps(indices)
return x[indices]

return advsubtensor


@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
if op.set_instead_of_inc:

def set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y
return x

return set_subtensor

else:

def inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y
return x

return inc_subtensor


@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
return x

return adv_set_subtensor

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y.type_as(x)
return x

return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if not inplace:
x = x.clone()
x.index_put_(indices, y.type_as(x), accumulate=True)
return x

return adv_inc_subtensor
5 changes: 3 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3780,15 +3780,16 @@ class AllocDiag(OpFromGraph):
Wrapper Op for alloc_diag graphs
"""

__props__ = ("axis1", "axis2")

def __init__(self, *args, axis1, axis2, offset, **kwargs):
self.axis1 = axis1
self.axis2 = axis2
self.offset = offset

super().__init__(*args, **kwargs, strict=True)

def __str__(self):
return f"AllocDiag{{{self.axis1=}, {self.axis2=}, {self.offset=}}}"

@staticmethod
def is_offset_zero(node) -> bool:
"""
Expand Down
5 changes: 3 additions & 2 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ class Einsum(OpFromGraph):
desired. We haven't decided whether we want to provide this functionality.
"""

__props__ = ("subscripts", "path", "optimized")

def __init__(self, *args, subscripts: str, path: PATH, optimized: bool, **kwargs):
self.subscripts = subscripts
self.path = path
self.optimized = optimized
super().__init__(*args, **kwargs, strict=True)

def __str__(self):
return f"Einsum{{{self.subscripts=}, {self.path=}, {self.optimized=}}}"


def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
"""
Expand Down
Loading

0 comments on commit 738393f

Please sign in to comment.