Skip to content

Commit

Permalink
Make more robust for numba's picky type selection
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobachetti committed Sep 26, 2023
1 parent 0d14af1 commit d2df1ee
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions stingray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@

HAS_NUMBA = True
from numba import njit, prange, vectorize, float32, float64, int32, int64
from numba.core.errors import NumbaValueError
from numba.core.errors import NumbaValueError, NumbaNotImplementedError
except ImportError:
warnings.warn("Numba not installed. Faking it")
HAS_NUMBA = False
NumbaValueError = Exception
NumbaValueError = NumbaNotImplementedError = Exception

def njit(f=None, *args, **kwargs):
def decorator(func, *a, **kw):
Expand Down Expand Up @@ -1979,18 +1979,25 @@ def histogram2d(*args, **kwargs):
... ranges=[[0., 1.], [2., 3.]])
>>> assert np.all(Hn1 == Hn2)
"""
if "range" in kwargs:
kwargs["ranges"] = kwargs.pop("range")
kwargs_copy = kwargs.copy()
if "range" in kwargs_copy:
kwargs_copy["ranges"] = kwargs_copy.pop("range")

if "weights" not in kwargs:
return hist2d_numba_seq(*args, **kwargs)
if "weights" not in kwargs_copy:
return hist2d_numba_seq(*args, **kwargs_copy)

weights = kwargs.pop("weights")
weights = kwargs_copy.pop("weights")

if weights is not None:
return hist2d_numba_seq_weight(*args, weights, **kwargs)
try:
if weights is not None:
return hist2d_numba_seq_weight(*args, weights, **kwargs_copy)

return hist2d_numba_seq(*args, **kwargs)
return hist2d_numba_seq(*args, **kwargs_copy)
except NumbaValueError:
warnings.warn("Numba could not calculate the histogram. Trying standard numpy.")
return histogram2d_np(*args, **kwargs)[0]
except:
raise

def histogram(*args, **kwargs):
"""
Expand All @@ -2006,18 +2013,25 @@ def histogram(*args, **kwargs):
>>> Hn2 = histogram(x, bins=5, ranges=[0., 1.])
>>> assert np.all(Hn1 == Hn2)
"""
if "range" in kwargs:
kwargs["ranges"] = kwargs.pop("range")

if "weights" not in kwargs:
return hist1d_numba_seq(*args, **kwargs)
kwargs_copy = kwargs.copy()
if "range" in kwargs_copy:
kwargs_copy["ranges"] = kwargs_copy.pop("range")

weights = kwargs.pop("weights")
if "weights" not in kwargs_copy:
return hist1d_numba_seq(*args, **kwargs_copy)

if weights is not None:
return hist1d_numba_seq_weight(*args, weights, **kwargs)
weights = kwargs_copy.pop("weights")

return hist1d_numba_seq(*args, **kwargs)
try:
if weights is not None:
return hist1d_numba_seq_weight(*args, weights, **kwargs_copy)

return hist1d_numba_seq(*args, **kwargs_copy)
except (NumbaValueError, NumbaNotImplementedError):
warnings.warn("Numba could not calculate the histogram. Trying standard numpy.")
return histogram_np(*args, **kwargs)[0]
except:
raise

else:

Expand Down

0 comments on commit d2df1ee

Please sign in to comment.