From 4358762d7ccf0d81dfbbc37d9c0665d53fe9c426 Mon Sep 17 00:00:00 2001 From: keewis Date: Thu, 14 Nov 2019 02:24:07 +0100 Subject: [PATCH] Tests for module-level functions with units (#3493) * add tests for replication functions * add tests for `xarray.dot` * add tests for apply_ufunc * explicitly set the test ids to repr * add tests for align * cover a bit more of align * add tests for broadcast * black changed how tuple unpacking should look like * correct the xfail message for full_like tests * add tests for where * add tests for concat * add tests for combine_by_coords * fix a bug in convert_units * convert the align results to the same units * rename the combine_by_coords test * convert the units for expected in combine_by_coords * add tests for combine_nested * add tests for merge with datasets * only use three datasets for merging * add tests for merge with dataarrays * update whats-new.rst --- doc/whats-new.rst | 3 +- xarray/tests/test_units.py | 871 ++++++++++++++++++++++++++++++++++++- 2 files changed, 865 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f840557ab5d..a7687368884 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -118,7 +118,8 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Added integration tests against `pint `_. - (:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin `_. + (:pull:`3238`, :pull:`3447`, :pull:`3493`, :pull:`3508`) + by `Justus Magin `_. .. note:: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index fd9e9b039ac..509a50d23ff 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -222,7 +222,9 @@ def convert_units(obj, to): if name != obj.name } - new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs) + new_obj = xr.DataArray( + name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims + ) elif isinstance(obj, unit_registry.Quantity): units = to.get(None) new_obj = obj.to(units) if units is not None else obj @@ -307,19 +309,689 @@ def __repr__(self): class function: - def __init__(self, name): - self.name = name - self.func = getattr(np, name) + def __init__(self, name_or_function, *args, **kwargs): + if callable(name_or_function): + self.name = name_or_function.__name__ + self.func = name_or_function + else: + self.name = name_or_function + self.func = getattr(np, name_or_function) + if self.func is None: + raise AttributeError( + f"module 'numpy' has no attribute named '{self.name}'" + ) + + self.args = args + self.kwargs = kwargs def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) + all_args = list(self.args) + list(args) + all_kwargs = {**self.kwargs, **kwargs} + + return self.func(*all_args, **all_kwargs) def __repr__(self): return f"function_{self.name}" +def test_apply_ufunc_dataarray(dtype): + func = function( + xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} + ) + + array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(20) * unit_registry.s + data_array = xr.DataArray(data=array, dims="x", coords={"x": x}) + + expected = attach_units(func(strip_units(data_array)), extract_units(data_array)) + result = func(data_array) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail( + reason="pint does not implement `np.result_type` and align strips units" +) +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) +def test_align_dataarray(fill_value, variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit + array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit + x = np.arange(2) * original_unit + x_a1 = np.array([10, 5]) * original_unit + x_a2 = np.array([10, 5]) * coord_unit + + y1 = np.arange(5) * original_unit + y2 = np.arange(2, 7) * dim_unit + + data_array1 = xr.DataArray( + data=array1, coords={"x": x, "x_a": ("x", x_a1), "y": y1}, dims=("x", "y") + ) + data_array2 = xr.DataArray( + data=array2, coords={"x": x, "x_a": ("x", x_a2), "y": y2}, dims=("x", "y") + ) + + fill_value = fill_value * data_unit + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None: + with pytest.raises(error): + func(data_array1, data_array2) + + return + + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: original_unit}) + if isinstance(value, unit_registry.Quantity) + else value + ) + for key, value in func.kwargs.items() + } + units = extract_units(data_array1) + # FIXME: should the expected_b have the same units as data_array1 + # or data_array2? + expected_a, expected_b = tuple( + attach_units(elem, units) + for elem in func( + strip_units(data_array1), + strip_units(convert_units(data_array2, units)), + **stripped_kwargs, + ) + ) + result_a, result_b = func(data_array1, data_array2) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + +@pytest.mark.xfail( + reason="pint does not implement `np.result_type` and align strips units" +) +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) +def test_align_dataset(fill_value, unit, variant, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit + array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit + + x = np.arange(2) * original_unit + x_a1 = np.array([10, 5]) * original_unit + x_a2 = np.array([10, 5]) * coord_unit + + y1 = np.arange(5) * original_unit + y2 = np.arange(2, 7) * dim_unit + + ds1 = xr.Dataset( + data_vars={"a": (("x", "y"), array1)}, + coords={"x": x, "x_a": ("x", x_a1), "y": y1}, + ) + ds2 = xr.Dataset( + data_vars={"a": (("x", "y"), array2)}, + coords={"x": x, "x_a": ("x", x_a2), "y": y2}, + ) + + fill_value = fill_value * data_unit + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None: + with pytest.raises(error): + func(ds1, ds2) + + return + + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: original_unit}) + if isinstance(value, unit_registry.Quantity) + else value + ) + for key, value in func.kwargs.items() + } + units = extract_units(ds1) + # FIXME: should the expected_b have the same units as ds1 or ds2? + expected_a, expected_b = tuple( + attach_units(elem, units) + for elem in func( + strip_units(ds1), strip_units(convert_units(ds2, units)), **stripped_kwargs + ) + ) + result_a, result_b = func(ds1, ds2) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + +def test_broadcast_dataarray(dtype): + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + a = xr.DataArray(data=array1, dims="x") + b = xr.DataArray(data=array2, dims="y") + + expected_a, expected_b = tuple( + attach_units(elem, extract_units(a)) + for elem in xr.broadcast(strip_units(a), strip_units(b)) + ) + result_a, result_b = xr.broadcast(a, b) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + +def test_broadcast_dataset(dtype): + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("y", array2)}) + + (expected,) = tuple( + attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds)) + ) + (result,) = xr.broadcast(ds) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`combine_by_coords` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_combine_by_coords(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + x = np.arange(1, 4) * 10 * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + other_array1 = np.ones_like(array1) * data_unit + other_array2 = np.ones_like(array2) * data_unit + other_x = np.arange(1, 4) * 10 * dim_unit + other_y = np.arange(2, 4) * dim_unit + other_z = np.arange(3, 6) * coord_unit + + ds = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + other = xr.Dataset( + data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)}, + coords={"x": other_x, "y": other_y, "z": ("x", other_z)}, + ) + + if error is not None: + with pytest.raises(error): + xr.combine_by_coords([ds, other]) + + return + + units = extract_units(ds) + expected = attach_units( + xr.combine_by_coords( + [strip_units(ds), strip_units(convert_units(other, units))] + ), + units, + ) + result = xr.combine_by_coords([ds, other]) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_combine_nested(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + + x = np.arange(1, 4) * 10 * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit), + "b": (("y", "x"), np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(3) * dim_unit, + "y": np.arange(2, 4) * dim_unit, + "z": ("x", np.arange(-3, 0) * coord_unit), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit), + "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit), + }, + coords={ + "x": np.arange(3, 6) * dim_unit, + "y": np.arange(4, 6) * dim_unit, + "z": ("x", np.arange(3, 6) * coord_unit), + }, + ) + ds4 = xr.Dataset( + data_vars={ + "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit), + "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(6, 9) * dim_unit, + "y": np.arange(6, 8) * dim_unit, + "z": ("x", np.arange(6, 9) * coord_unit), + }, + ) + + func = function(xr.combine_nested, concat_dim=["x", "y"]) + if error is not None: + with pytest.raises(error): + func([[ds1, ds2], [ds3, ds4]]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func( + [ + [strip_units(ds1), convert_and_strip(ds2)], + [convert_and_strip(ds3), convert_and_strip(ds4)], + ] + ), + units, + ) + result = func([[ds1, ds2], [ds3, ds4]]) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`concat` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + ), +) +def test_concat_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = {"data": (unit, original_unit), "dims": (original_unit, unit)} + data_unit, dims_unit = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit + x1 = np.arange(5, 15) * original_unit + x2 = np.arange(5) * dims_unit + + arr1 = xr.DataArray(data=array1, coords={"x": x1}, dims="x") + arr2 = xr.DataArray(data=array2, coords={"x": x2}, dims="x") + + if error is not None: + with pytest.raises(error): + xr.concat([arr1, arr2], dim="x") + + return + + expected = attach_units( + xr.concat([strip_units(arr1), strip_units(arr2)], dim="x"), extract_units(arr1) + ) + result = xr.concat([arr1, arr2], dim="x") + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`concat` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + ), +) +def test_concat_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = {"data": (unit, original_unit), "dims": (original_unit, unit)} + data_unit, dims_unit = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit + x1 = np.arange(5, 15) * original_unit + x2 = np.arange(5) * dims_unit + + ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1}) + ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2}) + + if error is not None: + with pytest.raises(error): + xr.concat([ds1, ds2], dim="x") + + return + + expected = attach_units( + xr.concat([strip_units(ds1), strip_units(ds2)], dim="x"), extract_units(ds1) + ) + result = xr.concat([ds1, ds2], dim="x") + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_merge_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * original_unit + array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit + array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit + + x = np.arange(2) * original_unit + y = np.arange(3) * original_unit + z = np.arange(4) * original_unit + u = np.linspace(10, 20, 2) * original_unit + v = np.linspace(10, 20, 3) * original_unit + w = np.linspace(10, 20, 4) * original_unit + + arr1 = xr.DataArray( + name="a", + data=array1, + coords={"x": x, "y": y, "u": ("x", u), "v": ("y", v)}, + dims=("x", "y"), + ) + arr2 = xr.DataArray( + name="b", + data=array2, + coords={ + "x": np.arange(2, 4) * dim_unit, + "z": z, + "u": ("x", np.linspace(20, 30, 2) * coord_unit), + "w": ("z", w), + }, + dims=("x", "z"), + ) + arr3 = xr.DataArray( + name="c", + data=array3, + coords={ + "y": np.arange(3, 6) * dim_unit, + "z": np.arange(4, 8) * dim_unit, + "v": ("y", np.linspace(10, 20, 3) * coord_unit), + "w": ("z", np.linspace(10, 20, 4) * coord_unit), + }, + dims=("y", "z"), + ) + + func = function(xr.merge) + if error is not None: + with pytest.raises(error): + func([arr1, arr2, arr3]) + + return + + units = {name: original_unit for name in list("abcuvwxyz")} + convert_and_strip = lambda arr: strip_units(convert_units(arr, units)) + expected = attach_units( + func([strip_units(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]), + units, + ) + result = func([arr1, arr2, arr3]) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_merge_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + + x = np.arange(11, 14) * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit), + "b": (("y", "x"), np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(3) * dim_unit, + "y": np.arange(2, 4) * dim_unit, + "z": ("x", np.arange(-3, 0) * coord_unit), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit), + "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit), + }, + coords={ + "x": np.arange(3, 6) * dim_unit, + "y": np.arange(4, 6) * dim_unit, + "z": ("x", np.arange(3, 6) * coord_unit), + }, + ) + + func = function(xr.merge) + if error is not None: + with pytest.raises(error): + func([ds1, ds2, ds3]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func([strip_units(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), units + ) + result = func([ds1, ds2, ds3]) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) -def test_replication(func, dtype): +def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s data_array = xr.DataArray(data=array, dims="x") @@ -330,8 +1002,33 @@ def test_replication(func, dtype): assert_equal_with_units(expected, result) +@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) +def test_replication_dataset(func, dtype): + array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s + array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa + x = np.arange(20).astype(dtype) * unit_registry.m + y = np.arange(10).astype(dtype) * unit_registry.m + z = y.to(unit_registry.mm) + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "z": ("y", z)}, + ) + + numpy_func = getattr(np, func.__name__) + expected = ds.copy( + data={name: numpy_func(array.data) for name, array in ds.data_vars.items()} + ) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail( - reason="np.full_like on Variable strips the unit and pint does not allow mixed args" + reason=( + "pint is undecided on how `full_like` should work, so incorrect errors " + "may be expected: hgrecco/pint#882" + ) ) @pytest.mark.parametrize( "unit,error", @@ -344,8 +1041,9 @@ def test_replication(func, dtype): pytest.param(unit_registry.ms, None, id="compatible_unit"), pytest.param(unit_registry.s, None, id="identical_unit"), ), + ids=repr, ) -def test_replication_full_like(unit, error, dtype): +def test_replication_full_like_dataarray(unit, error, dtype): array = np.linspace(0, 5, 10) * unit_registry.s data_array = xr.DataArray(data=array, dims="x") @@ -360,6 +1058,163 @@ def test_replication_full_like(unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail( + reason=( + "pint is undecided on how `full_like` should work, so incorrect errors " + "may be expected: hgrecco/pint#882" + ) +) +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.ms, None, id="compatible_unit"), + pytest.param(unit_registry.s, None, id="identical_unit"), + ), + ids=repr, +) +def test_replication_full_like_dataset(unit, error, dtype): + array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s + array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa + x = np.arange(20).astype(dtype) * unit_registry.m + y = np.arange(10).astype(dtype) * unit_registry.m + z = y.to(unit_registry.mm) + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "z": ("y", z)}, + ) + + fill_value = -1 * unit + if error is not None: + with pytest.raises(error): + xr.full_like(ds, fill_value=fill_value) + + return + + expected = ds.copy( + data={ + name: np.full_like(array, fill_value=fill_value) + for name, array in ds.data_vars.items() + } + ) + result = xr.full_like(ds, fill_value=fill_value) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`where` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataarray(fill_value, unit, error, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + + x = xr.DataArray(data=array, dims="x") + cond = x < 5 * unit_registry.m + # FIXME: this should work without wrapping in array() + fill_value = np.array(fill_value) * unit + + if error is not None: + with pytest.raises(error): + xr.where(cond, x, fill_value) + + return + + fill_value_ = ( + fill_value.to(unit_registry.m) + if isinstance(fill_value, unit_registry.Quantity) + and fill_value.check(unit_registry.m) + else fill_value + ) + expected = attach_units( + xr.where(cond, strip_units(x), strip_units(fill_value_)), extract_units(x) + ) + result = xr.where(cond, x, fill_value) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`where` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataset(fill_value, unit, error, dtype): + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m + x = np.arange(10) * unit_registry.s + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + cond = ds.x < 5 * unit_registry.s + # FIXME: this should work without wrapping in array() + fill_value = np.array(fill_value) * unit + + if error is not None: + with pytest.raises(error): + xr.where(cond, ds, fill_value) + + return + + fill_value_ = ( + fill_value.to(unit_registry.m) + if isinstance(fill_value, unit_registry.Quantity) + and fill_value.check(unit_registry.m) + else fill_value + ) + expected = attach_units( + xr.where(cond, strip_units(ds), strip_units(fill_value_)), extract_units(ds) + ) + result = xr.where(cond, ds, fill_value) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="pint does not implement `np.einsum`") +def test_dot_dataarray(dtype): + array1 = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) + * unit_registry.m + / unit_registry.s + ) + array2 = ( + np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s + ) + + arr1 = xr.DataArray(data=array1, dims=("x", "y")) + arr2 = xr.DataArray(data=array2, dims=("y", "z")) + + expected = array1.dot(array2) + result = xr.dot(arr1, arr2) + + assert_equal_with_units(expected, result) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize(