Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarLoop in torch backend #958

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Ch0ronomato
Copy link
Contributor

Description

Adds ScalarLoop for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):


@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't torch.compile this - not sure if we need to. I think the torch.compile in the linker will do fine (and hopefully get a bit of a loop reduction for us)

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Aug 3, 2024
@ricardoV94
Copy link
Member

@Ch0ronomato thanks for taking a stab, I left some comments above

Copy link

codecov bot commented Aug 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.75%. Comparing base (58fec45) to head (e4c2b9d).
Report is 26 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #958      +/-   ##
==========================================
+ Coverage   81.60%   81.75%   +0.14%     
==========================================
  Files         179      183       +4     
  Lines       47271    47745     +474     
  Branches    11481    11620     +139     
==========================================
+ Hits        38574    39032     +458     
- Misses       6511     6518       +7     
- Partials     2186     2195       +9     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/scalar.py 86.84% <100.00%> (+16.25%) ⬆️

... and 28 files with indirect coverage changes

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Add torch scalar loop Implement ScalarLoop in torch backend Sep 1, 2024
const = float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
Copy link
Member

@ricardoV94 ricardoV94 Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should test one ScalarLoop with is_while as well. And I would like to actually test an Elemwise form of the ScalarLoop, since that's the most common code. After defining the op you pass it to an Elemwise, and then you can call the elemwise version with tensors, instead of scalars.

carry = update(*carry, *constants)
return torch.stack(carry)

return torch.compiler.disable(scalar_loop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do recursive=False?

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants