-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Several scipy special functions now upcast integers to float64 #859
Conversation
Changed in scipy==1.14.0
Should check what JAX/Numba implementations do. It's a bit annoying to cast to float64 if we don't have to, and all backends have to respect the same dtypes |
There are like 40 custom rules on type conversion... |
Would it be easier to just write our own implementations of these functions? Could you link to some issues where the JAX/Numba teams discuss this issue? |
I think we should keep the old casting rule, and manually downcast scipy output from float64 back to float32 |
float32 tests failing, must be a day that ends in "y". It might be better to rename the helper function |
I'm gonna clean this up but probably keep the old behavior of int>32 -> float64, less than that to float32? |
I'm gonna be 100% honest with you -- I have no idea what this sentence means. |
Fair enough. Until now the behavior of Scipy Ops was that any integers below int32/uint32 would lead to a float32 output when going through these Ops, but u/int32/64 would lead to float64 outputs |
I guess it's a bit surprising that we send int32 to float64 (seems like it should be float32? Just because the numbers match). Otherwise all that seems perfectly reasonable, until we get onboard with the half-precision revolution anyway. |
I guess the principled way would be to investigate the domain-codomain of the functions and keep as much precision as strictly needed. Say for a monotonic single argument function if the largest integer input leads to a real output that doesn't overflow in float32 we could use that. But such analysis doesn't seem trivial. My suggestion for keeping the old behavior is that it was there for nearly a decade and people seemed fine with it? |
Without question we should keep the old way. |
hi, do you need any help with that? |
Hi @ferrine yeah help is welcome. We probably want to force the old casting policy manually in the perform method of these Ops. |
As I get it so far, the proposed change changes the output dtype to be float64, and in the failing test there is an attempt to calculate gradients which have different dtype. Is it correct to assume that it should not be the upcasting rule for change but something here https://github.com/pymc-devs/pytensor/blob/main/pytensor/scalar/basic.py#L1143-L1150 |
@ferrine the current PR changed the expected output dtype to float64, we do not want to go that route so ignore /revert it. It should still promise to be float32. There's an odd case with the psi function that already had special case to upcast to float64 that we can also get rid of. Regarding the enforcing of the downcast. Yes for the scalar case that's where we want to do it, but we have to check what Elemwise does. I think when there is a |
Closing in favor of #972 |
Description
Tests started failing due to Scipy 1.14.0
Unsure whether this will be a very prohibitive lower pin dependency. Clearly 0.14 was outdated (it's also poetic that the breaking change happens 1 version later :D).
The issue is that our math Ops default to scipy for integer type inputs, and we have to guarantee the right output dtype. Scipy changed a couple that were returning float32 to float64 in the last release.
Alternatively we could read the scipy version and promise the right dtype, allowing older versions of scipy. Or we could coerce the dtypes from within Elemwise which calls scipy under the hood.
Opinions
Related Issue
Checklist
Type of change