-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ScalarLoop in torch backend #958
base: main
Are you sure you want to change the base?
Conversation
|
||
@pytorch_funcify.register(ScalarLoop) | ||
def pytorch_funicify_ScalarLoop(op, node, **kwargs): | ||
update = pytorch_funcify(op.fgraph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't torch.compile
this - not sure if we need to. I think the torch.compile in the linker will do fine (and hopefully get a bit of a loop reduction for us)
@Ch0ronomato thanks for taking a stab, I left some comments above |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #958 +/- ##
==========================================
+ Coverage 81.60% 81.75% +0.14%
==========================================
Files 179 183 +4
Lines 47271 47745 +474
Branches 11481 11620 +139
==========================================
+ Hits 38574 39032 +458
- Misses 6511 6518 +7
- Partials 2186 2195 +9
|
const = float64("const") | ||
x = x0 + const | ||
|
||
op = ScalarLoop(init=[x0], constant=[const], update=[x]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should test one ScalarLoop with is_while as well. And I would like to actually test an Elemwise form of the ScalarLoop, since that's the most common code. After defining the op
you pass it to an Elemwise
, and then you can call the elemwise version with tensors, instead of scalars.
carry = update(*carry, *constants) | ||
return torch.stack(carry) | ||
|
||
return torch.compiler.disable(scalar_loop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do recursive=False?
@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓 |
Description
Adds
ScalarLoop
for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.Related Issue
Checklist
Type of change