Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite scalar dot as multiplication #1205

Open
ricardoV94 opened this issue Feb 12, 2025 · 1 comment
Open

Rewrite scalar dot as multiplication #1205

ricardoV94 opened this issue Feb 12, 2025 · 1 comment

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 12, 2025

Description

In #1178 we rewrite batched dots that are just multiplication away, but left core dots the same due to use of BLAS operations for those (whether they are worth it or not is a question on its own). But there is one case that is definitely not worth it: scalar multiplication.

The following graph should definitely be simplified:

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(1, 1))
y = pt.tensor("y", shape=(1, 1))
out = x @ y
pytensor.function([x, y], out).dprint()
CGer{non-destructive} [id A] 2
 ├─ [[0.]] [id B]
 ├─ 1.0 [id C]
 ├─ DropDims{axis=1} [id D] 1
 │  └─ x [id E]
 └─ DropDims{axis=0} [id F] 0
    └─ y [id G]

Or without BLAS stuff

pytensor.function([x, y], out, mode="FAST_COMPILE").dprint()
Dot22 [id A] 0
 ├─ x [id B]
 └─ y [id C]

Those should just be mul because that can be fused with other Elemwise operations (and calling BLAS for it is the silliest thing ever)

@ricardoV94
Copy link
Member Author

We should also consider the remaining cases that are just multiplication, specially in non-default backends where the BLAS question is completely irrelevant. Even in the C-backend I saw many cases where it was faster without BLAS (but some where it was slower :( )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant