Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Sort/Argsort Ops in PyTorch #897

Merged
merged 3 commits into from
Jul 10, 2024

Conversation

twaclaw
Copy link
Contributor

@twaclaw twaclaw commented Jul 7, 2024

Description

  • Implements SortOp in PyTorch
  • Implements ArgSortOp in PyTorch
  • ⚠️ passing axis=None throws an exception because Reshape is not yet implemented

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented Jul 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.25%. Comparing base (8e0958a) to head (cff4cf4).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #897   +/-   ##
=======================================
  Coverage   81.24%   81.25%           
=======================================
  Files         170      171    +1     
  Lines       46920    46945   +25     
  Branches    11482    11484    +2     
=======================================
+ Hits        38120    38144   +24     
  Misses       6600     6600           
- Partials     2200     2201    +1     
Files Coverage Δ
pytensor/link/pytorch/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/sort.py 100.00% <100.00%> (ø)

... and 2 files with indirect coverage changes

@ricardoV94
Copy link
Member

What happens with axis=None that introduces a reshape?

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jul 7, 2024
@twaclaw
Copy link
Contributor Author

twaclaw commented Jul 7, 2024

What happens with axis=None that introduces a reshape?

I think it is this line. But I should rather remove that test and don't test for axis=None for the time being.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 7, 2024

I see, you can test and mark it as xfail with pytest.mark.xfail

@@ -7,17 +7,12 @@
from tests.link.pytorch.test_basic import compare_pytorch_and_py


@pytest.mark.xfail(reason="Reshape not implemented")
@pytest.mark.parametrize("axis", [0, 1, None])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When only some of the parametrized conditions fail, you should use pytest.param to mark that specifically as xfailing. The others are working and shouldn't be marked. There are some examples in the tests somewhere

@twaclaw twaclaw force-pushed the implement_sort_argsort_ops_torch branch from a9e2343 to cff4cf4 Compare July 8, 2024 16:24
@ricardoV94 ricardoV94 merged commit ee4d4f7 into pymc-devs:main Jul 10, 2024
59 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants