Skip to content

Commit

Permalink
Speed up missing._get_interpolator (#4776)
Browse files Browse the repository at this point in the history
* Speed up _get_interpolator

Importing scipy.interpolate is slow and should only be done when necessary. Test case from 200ms to 6ms.

* typos

* retain info from the except.
  • Loading branch information
Illviljan authored Jan 8, 2021
1 parent 5ddb8d5 commit 8ff8113
Showing 1 changed file with 21 additions and 25 deletions.
46 changes: 21 additions & 25 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,16 @@ def bfill(arr, dim=None, limit=None):
).transpose(*arr.dims)


def _import_interpolant(interpolant, method):
"""Import interpolant from scipy.interpolate."""
try:
from scipy import interpolate

return getattr(interpolate, interpolant)
except ImportError as e:
raise ImportError(f"Interpolation with method {method} requires scipy.") from e


def _get_interpolator(method, vectorizeable_only=False, **kwargs):
"""helper function to select the appropriate interpolator class
Expand All @@ -459,12 +469,6 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
"akima",
]

has_scipy = True
try:
from scipy import interpolate
except ImportError:
has_scipy = False

# prioritize scipy.interpolate
if (
method == "linear"
Expand All @@ -475,32 +479,29 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs):
interp_class = NumpyInterpolator

elif method in valid_methods:
if not has_scipy:
raise ImportError("Interpolation with method `%s` requires scipy" % method)

if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
"{} is not a vectorizeable interpolator. "
"Available methods are {}".format(method, interp1d_methods)
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
interp_class = interpolate.BarycentricInterpolator
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method == "krog":
interp_class = interpolate.KroghInterpolator
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
interp_class = interpolate.PchipInterpolator
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
interp_class = interpolate.Akima1DInterpolator
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError("%s is not a valid scipy interpolator" % method)
raise ValueError(f"{method} is not a valid scipy interpolator")
else:
raise ValueError("%s is not a valid interpolator" % method)
raise ValueError(f"{method} is not a valid interpolator")

return interp_class, kwargs

Expand All @@ -512,18 +513,13 @@ def _get_interpolator_nd(method, **kwargs):
"""
valid_methods = ["linear", "nearest"]

try:
from scipy import interpolate
except ImportError:
raise ImportError("Interpolation with method `%s` requires scipy" % method)

if method in valid_methods:
kwargs.update(method=method)
interp_class = interpolate.interpn
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
"%s is not a valid interpolator for interpolating "
"over multiple dimensions." % method
f"{method} is not a valid interpolator for interpolating "
"over multiple dimensions."
)

return interp_class, kwargs
Expand Down

0 comments on commit 8ff8113

Please sign in to comment.