Skip to content

Commit

Permalink
Implement indexing operations in pytorch
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
  • Loading branch information
HarshvirSandhu and ricardoV94 committed Sep 1, 2024
1 parent 1a1c62b commit b38a01c
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 9 deletions.
1 change: 1 addition & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.subtensor
# isort: on
34 changes: 29 additions & 5 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from functools import singledispatch
from types import NoneType

import numpy as np
import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Eye,
Join,
MakeVector,
TensorFromScalar,
)


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")

Check warning on line 24 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L24

Added line #L24 was not covered by tests


@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)


@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
Expand Down Expand Up @@ -132,3 +148,11 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
return torch.as_tensor(x)

Check warning on line 156 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L156

Added line #L156 was not covered by tests

return tensorfromscalar
124 changes: 124 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType


def check_negative_steps(indices):
for index in indices:
if isinstance(index, slice):
if index.step is not None and index.step < 0:
raise NotImplementedError(
"Negative step sizes are not supported in Pytorch"
)


@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list

def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

return subtensor


@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)

Check warning on line 38 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L37-L38

Added lines #L37 - L38 were not covered by tests

return makeslice

Check warning on line 40 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L40

Added line #L40 was not covered by tests


@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
check_negative_steps(indices)
return x[indices]

return advsubtensor


@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
if op.set_instead_of_inc:

def set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y
return x

return set_subtensor

else:

def inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y
return x

return inc_subtensor


@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
return x

return adv_set_subtensor

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()

Check warning on line 104 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L104

Added line #L104 was not covered by tests
x[indices] += y.type_as(x)
return x

return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if not inplace:
x = x.clone()
x.index_put_(indices, y.type_as(x), accumulate=True)
return x

Check warning on line 122 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L120-L122

Added lines #L120 - L122 were not covered by tests

return adv_inc_subtensor
6 changes: 3 additions & 3 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def compare_pytorch_and_py(
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res):
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down
Loading

0 comments on commit b38a01c

Please sign in to comment.