Skip to content

Commit

Permalink
Implemented Sort/Argsort Ops in PyTorch (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
twaclaw authored Jul 10, 2024
1 parent a99d067 commit ee4d4f7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
# isort: on
25 changes: 25 additions & 0 deletions pytensor/link/pytorch/dispatch/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp


@pytorch_funcify.register(SortOp)
def pytorch_funcify_Sort(op, **kwargs):
stable = op.kind == "stable"

def sort(arr, axis):
sorted, _ = torch.sort(arr, dim=axis, stable=stable)
return sorted

return sort


@pytorch_funcify.register(ArgSortOp)
def pytorch_funcify_ArgSort(op, **kwargs):
stable = op.kind == "stable"

def argsort(arr, axis):
return torch.argsort(arr, dim=axis, stable=stable)

return argsort
26 changes: 26 additions & 0 deletions tests/link/pytorch/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import pytest

from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort
from tests.link.pytorch.test_basic import compare_pytorch_and_py


@pytest.mark.parametrize("func", (sort, argsort))
@pytest.mark.parametrize(
"axis",
[
pytest.param(0),
pytest.param(1),
pytest.param(
None, marks=pytest.mark.xfail(reason="Reshape Op not implemented")
),
],
)
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_pytorch_and_py(fgraph, [arr])

0 comments on commit ee4d4f7

Please sign in to comment.