From f3e83e6c75da348ff88fcaed4ec898839b8a4104 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Jan 2021 11:11:43 +0100 Subject: [PATCH 1/5] scipy.interpolate.interp1d always forces to float. --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f608468ed9f..73d38111781 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -742,7 +742,7 @@ def interp_func(var, x, new_x, method, kwargs): interp_kwargs=kwargs, localize=localize, concatenate=True, - dtype=var.dtype, + dtype=float, # scipy.interpolate.interp1d always forces to float. new_axes=new_axes, ) From a2062cad31853b94875cd5ff1565d6617689e09c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Jan 2021 11:55:06 +0100 Subject: [PATCH 2/5] Copy type-check from scipy.interpolate.interp1d --- xarray/core/missing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 73d38111781..f2c3c3f325c 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -733,6 +733,13 @@ def interp_func(var, x, new_x, method, kwargs): # if usefull, re-use localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + + # scipy.interpolate.interp1d always forces to float. + # Use the same check for blockwise as well: + if not issubclass(var.dtype.type, np.inexact): + dtype = np.float_ + else: + dtype = var.dtype return da.blockwise( _dask_aware_interpnd, @@ -742,7 +749,7 @@ def interp_func(var, x, new_x, method, kwargs): interp_kwargs=kwargs, localize=localize, concatenate=True, - dtype=float, # scipy.interpolate.interp1d always forces to float. + dtype=dtype, new_axes=new_axes, ) From e935f8249707424440383f81c398ec25518e7893 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Jan 2021 11:56:30 +0100 Subject: [PATCH 3/5] Update missing.py --- xarray/core/missing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f2c3c3f325c..5116a6d2651 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -733,7 +733,7 @@ def interp_func(var, x, new_x, method, kwargs): # if usefull, re-use localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) - + # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: if not issubclass(var.dtype.type, np.inexact): From 2633469f23b6ddfb176031befba20c26e450c554 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Jan 2021 12:58:08 +0100 Subject: [PATCH 4/5] Test that pre- and post-compute dtypes matches --- xarray/tests/test_missing.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 21d82b1948b..f9cbe6ecbfe 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -370,6 +370,20 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim(): da.interpolate_na("time") +@requires_dask +@requires_scipy +@pytest.mark.parametrize("dtype", [float, int]) +def test_interpolate_dask_expected_dtype(dtype): + da = xr.DataArray( + data=np.array([0, 1], dtype=dtype), + dims=["time"], + coords=dict(time=np.array([0, 1])), + ).chunk(dict(time=2)) + da = da.interp(time=np.array([0, 0.5, 1, 2]), method="linear") + + assert da.dtype == da.compute().dtype + + @requires_bottleneck def test_ffill(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") From 51e0d2f20176a4ff1fad3d4d3def6c4ea8a9dbcb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 6 Jan 2021 13:03:54 +0100 Subject: [PATCH 5/5] Update test_missing.py --- xarray/tests/test_missing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index f9cbe6ecbfe..2ab3508b667 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -372,14 +372,14 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim(): @requires_dask @requires_scipy -@pytest.mark.parametrize("dtype", [float, int]) -def test_interpolate_dask_expected_dtype(dtype): +@pytest.mark.parametrize("dtype, method", [(int, "linear"), (int, "nearest")]) +def test_interpolate_dask_expected_dtype(dtype, method): da = xr.DataArray( data=np.array([0, 1], dtype=dtype), dims=["time"], coords=dict(time=np.array([0, 1])), ).chunk(dict(time=2)) - da = da.interp(time=np.array([0, 0.5, 1, 2]), method="linear") + da = da.interp(time=np.array([0, 0.5, 1, 2]), method=method) assert da.dtype == da.compute().dtype