Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several scipy special functions now upcast integers to float64 #859

Closed
wants to merge 1 commit into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 26, 2024

Description

Tests started failing due to Scipy 1.14.0

Unsure whether this will be a very prohibitive lower pin dependency. Clearly 0.14 was outdated (it's also poetic that the breaking change happens 1 version later :D).

The issue is that our math Ops default to scipy for integer type inputs, and we have to guarantee the right output dtype. Scipy changed a couple that were returning float32 to float64 in the last release.

Alternatively we could read the scipy version and promise the right dtype, allowing older versions of scipy. Or we could coerce the dtypes from within Elemwise which calls scipy under the hood.

Opinions

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 27, 2024

Should check what JAX/Numba implementations do. It's a bit annoying to cast to float64 if we don't have to, and all backends have to respect the same dtypes

@ricardoV94
Copy link
Member Author

There are like 40 custom rules on type conversion...

@jessegrabowski
Copy link
Member

Would it be easier to just write our own implementations of these functions? Could you link to some issues where the JAX/Numba teams discuss this issue?

@ricardoV94
Copy link
Member Author

Would it be easier to just write our own implementations of these functions? Could you link to some issues where the JAX/Numba teams discuss this issue?

I think we should keep the old casting rule, and manually downcast scipy output from float64 back to float32

@jessegrabowski
Copy link
Member

float32 tests failing, must be a day that ends in "y".

It might be better to rename the helper function cast_up_to_floatX and reference config.floatX inside it, instead of assuming you always want to go up to 64

@ricardoV94
Copy link
Member Author

I'm gonna clean this up but probably keep the old behavior of int>32 -> float64, less than that to float32?

@jessegrabowski
Copy link
Member

I'm gonna clean this up but probably keep the old behavior of int>32 -> float64, less than that to float32?

I'm gonna be 100% honest with you -- I have no idea what this sentence means.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 6, 2024

Fair enough. Until now the behavior of Scipy Ops was that any integers below int32/uint32 would lead to a float32 output when going through these Ops, but u/int32/64 would lead to float64 outputs

@jessegrabowski
Copy link
Member

I guess it's a bit surprising that we send int32 to float64 (seems like it should be float32? Just because the numbers match). Otherwise all that seems perfectly reasonable, until we get onboard with the half-precision revolution anyway.

@ricardoV94
Copy link
Member Author

I guess the principled way would be to investigate the domain-codomain of the functions and keep as much precision as strictly needed. Say for a monotonic single argument function if the largest integer input leads to a real output that doesn't overflow in float32 we could use that. But such analysis doesn't seem trivial.

My suggestion for keeping the old behavior is that it was there for nearly a decade and people seemed fine with it?

@jessegrabowski
Copy link
Member

Without question we should keep the old way.

@ferrine
Copy link
Member

ferrine commented Jul 31, 2024

hi, do you need any help with that?

@ricardoV94
Copy link
Member Author

hi, do you need any help with that?

Hi @ferrine yeah help is welcome. We probably want to force the old casting policy manually in the perform method of these Ops.

@ferrine
Copy link
Member

ferrine commented Aug 2, 2024

hi, do you need any help with that?

Hi @ferrine yeah help is welcome. We probably want to force the old casting policy manually in the perform method of these Ops.

As I get it so far, the proposed change

https://github.com/pymc-devs/pytensor/pull/859/files#diff-6f6c0b80f6b733d89b8aef461f6d655fadbc8382f4c246f513465cdcea963cd6R463-R465

changes the output dtype to be float64, and in the failing test there is an attempt to calculate gradients which have different dtype.

Is it correct to assume that it should not be the upcasting rule for change but something here

https://github.com/pymc-devs/pytensor/blob/main/pytensor/scalar/basic.py#L1143-L1150

@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 3, 2024

@ferrine the current PR changed the expected output dtype to float64, we do not want to go that route so ignore /revert it. It should still promise to be float32. There's an odd case with the psi function that already had special case to upcast to float64 that we can also get rid of.

Regarding the enforcing of the downcast. Yes for the scalar case that's where we want to do it, but we have to check what Elemwise does. I think when there is a nfunc_spec or whatever it is called, we use that directly and bypass the scalar implementation. We need to make sure that is also properly dowcasted in the new scipy.

@ferrine ferrine mentioned this pull request Aug 13, 2024
12 tasks
@ricardoV94
Copy link
Member Author

Closing in favor of #972

@ricardoV94 ricardoV94 closed this Aug 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Scipy 1.14.0 changed the upcasting rule of several special functions when passing integers
3 participants