From d2df1ee4a47ee5109ec25f6b4679ee61b21515c1 Mon Sep 17 00:00:00 2001 From: Matteo Bachetti Date: Tue, 26 Sep 2023 08:48:40 +0200 Subject: [PATCH] Make more robust for numba's picky type selection --- stingray/utils.py | 52 ++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/stingray/utils.py b/stingray/utils.py index f27d0a116..b592ccfac 100644 --- a/stingray/utils.py +++ b/stingray/utils.py @@ -46,11 +46,11 @@ HAS_NUMBA = True from numba import njit, prange, vectorize, float32, float64, int32, int64 - from numba.core.errors import NumbaValueError + from numba.core.errors import NumbaValueError, NumbaNotImplementedError except ImportError: warnings.warn("Numba not installed. Faking it") HAS_NUMBA = False - NumbaValueError = Exception + NumbaValueError = NumbaNotImplementedError = Exception def njit(f=None, *args, **kwargs): def decorator(func, *a, **kw): @@ -1979,18 +1979,25 @@ def histogram2d(*args, **kwargs): ... ranges=[[0., 1.], [2., 3.]]) >>> assert np.all(Hn1 == Hn2) """ - if "range" in kwargs: - kwargs["ranges"] = kwargs.pop("range") + kwargs_copy = kwargs.copy() + if "range" in kwargs_copy: + kwargs_copy["ranges"] = kwargs_copy.pop("range") - if "weights" not in kwargs: - return hist2d_numba_seq(*args, **kwargs) + if "weights" not in kwargs_copy: + return hist2d_numba_seq(*args, **kwargs_copy) - weights = kwargs.pop("weights") + weights = kwargs_copy.pop("weights") - if weights is not None: - return hist2d_numba_seq_weight(*args, weights, **kwargs) + try: + if weights is not None: + return hist2d_numba_seq_weight(*args, weights, **kwargs_copy) - return hist2d_numba_seq(*args, **kwargs) + return hist2d_numba_seq(*args, **kwargs_copy) + except NumbaValueError: + warnings.warn("Numba could not calculate the histogram. Trying standard numpy.") + return histogram2d_np(*args, **kwargs)[0] + except: + raise def histogram(*args, **kwargs): """ @@ -2006,18 +2013,25 @@ def histogram(*args, **kwargs): >>> Hn2 = histogram(x, bins=5, ranges=[0., 1.]) >>> assert np.all(Hn1 == Hn2) """ - if "range" in kwargs: - kwargs["ranges"] = kwargs.pop("range") - - if "weights" not in kwargs: - return hist1d_numba_seq(*args, **kwargs) + kwargs_copy = kwargs.copy() + if "range" in kwargs_copy: + kwargs_copy["ranges"] = kwargs_copy.pop("range") - weights = kwargs.pop("weights") + if "weights" not in kwargs_copy: + return hist1d_numba_seq(*args, **kwargs_copy) - if weights is not None: - return hist1d_numba_seq_weight(*args, weights, **kwargs) + weights = kwargs_copy.pop("weights") - return hist1d_numba_seq(*args, **kwargs) + try: + if weights is not None: + return hist1d_numba_seq_weight(*args, weights, **kwargs_copy) + + return hist1d_numba_seq(*args, **kwargs_copy) + except (NumbaValueError, NumbaNotImplementedError): + warnings.warn("Numba could not calculate the histogram. Trying standard numpy.") + return histogram_np(*args, **kwargs)[0] + except: + raise else: