Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch inline constants in dispatch to avoid graph breaks #1118

Merged
merged 2 commits into from
Feb 10, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 12, 2024

When we have static inputs, inlining helps torch not breaking the graph.

Related to #1110


📚 Documentation preview 📚: https://pytensor--1118.org.readthedocs.build/en/1118/

@ricardoV94 ricardoV94 added performance torch PyTorch backend labels Dec 12, 2024
@ricardoV94
Copy link
Member Author

Still need to do something about the runtime broadcast in elemwise. Can we use torch._check for that instead of Python loops/asserts?

Copy link

codecov bot commented Dec 12, 2024

Codecov Report

Attention: Patch coverage is 68.88889% with 14 lines in your changes missing coverage. Please review.

Project coverage is 82.27%. Comparing base (4ea4259) to head (75eef40).
Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/basic.py 52.17% 9 Missing and 2 partials ⚠️
pytensor/link/pytorch/dispatch/scalar.py 66.66% 1 Missing ⚠️
pytensor/link/pytorch/dispatch/shape.py 90.00% 1 Missing ⚠️
pytensor/link/pytorch/dispatch/subtensor.py 87.50% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (68.88%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1118   +/-   ##
=======================================
  Coverage   82.27%   82.27%           
=======================================
  Files         186      186           
  Lines       48000    48066   +66     
  Branches     8621     8633   +12     
=======================================
+ Hits        39490    39546   +56     
- Misses       6353     6360    +7     
- Partials     2157     2160    +3     
Files with missing lines Coverage Δ
pytensor/link/pytorch/linker.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/scalar.py 73.68% <66.66%> (-0.39%) ⬇️
pytensor/link/pytorch/dispatch/shape.py 85.71% <90.00%> (ø)
pytensor/link/pytorch/dispatch/subtensor.py 89.53% <87.50%> (-0.21%) ⬇️
pytensor/link/pytorch/dispatch/basic.py 87.40% <52.17%> (-7.10%) ⬇️

... and 13 files with indirect coverage changes

@ricardoV94
Copy link
Member Author

Even without the runtime broadcast check, elemwise seems to break the graph

@ricardoV94 ricardoV94 force-pushed the torch_constant_dispatch branch from c08d288 to 566145a Compare December 12, 2024 10:48
@Ch0ronomato
Copy link
Contributor

Did you get a chance to profile this pr?

@Ch0ronomato
Copy link
Contributor

Btw I did profile this. My machine actually failed to even compile dlogp for a model but I suspect that's unrelated. The logp method did show some improvement. The thing that intrigued me is this change reduced the number of guards by a lot (it was 10:1 with the other ones). I thought that maybe that was the cause of the runtime switch, but that didn't have the payoff I was expecting

@ricardoV94
Copy link
Member Author

The cost of the guards may be non-linear so we should try to remove all

@Ch0ronomato
Copy link
Contributor

The cost of the guards may be non-linear so we should try to remove all

Idk about removing all, since guards are the primitive that ensures runtime correctness. Significantly reduce, i agree

@Ch0ronomato
Copy link
Contributor

Btw, for the actual perf benefit, these are the numbers i see.

# ricardo shape: 772 μs ± 12 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
# no ricardo shape: 818 μs ± 9.48 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

So it's like ~5%, probably more on slower cpus. The graph breaks are definitely a problem :(

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Dec 27, 2024

If we add these two flags with the changes in this PR:

torch._dynamo.config.capture_func_transforms=True
torch._dynamo.config.capture_scalar_outputs = True

we come down to almost 500us.

504 μs ± 12.2 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

@ricardoV94
Copy link
Member Author

@Ch0ronomato can you revert the removal of the Elemwise bcast check (for now), and add those flags? Then we can merge this PR and keep playing with stuff

@Ch0ronomato
Copy link
Contributor

The ci doesn't like those flags. I'll investigate

@Ch0ronomato
Copy link
Contributor

I think the path to fix this is not use those flags by default, but when we have a shape operation. The torch compiler might be really restrictive

@Ch0ronomato Ch0ronomato force-pushed the torch_constant_dispatch branch from c5f26fd to dbc95e4 Compare January 26, 2025 17:46
@@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic

@Ch0ronomato Ch0ronomato marked this pull request as ready for review January 26, 2025 20:20
@Ch0ronomato Ch0ronomato force-pushed the torch_constant_dispatch branch from eb3ff29 to 75eef40 Compare February 10, 2025 00:39
@Ch0ronomato Ch0ronomato merged commit 4fa9bb8 into pymc-devs:main Feb 10, 2025
62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants