diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f1137b7b2a2..398c332433f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -66,6 +66,9 @@ Bug fixes By `Anderson Banihirwe `_ - Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`). By `Alessandro Amici `_ +- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations, + e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype + (:issue:`2658` and :issue:`4543`) by `Mathias Hauser `_. - Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling `_. - Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo `_. - Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`). diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 21bda8ef8d7..debf3aad96a 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -19,7 +19,7 @@ from . import dtypes, utils from .indexing import get_indexer_nd -from .utils import is_dict_like, is_full_slice +from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str from .variable import IndexVariable, Variable if TYPE_CHECKING: @@ -278,10 +278,12 @@ def align( return (obj.copy(deep=copy),) all_indexes = defaultdict(list) + all_coords = defaultdict(list) unlabeled_dim_sizes = defaultdict(set) for obj in objects: for dim in obj.dims: if dim not in exclude: + all_coords[dim].append(obj.coords[dim]) try: index = obj.indexes[dim] except KeyError: @@ -306,7 +308,7 @@ def align( any(not index.equals(other) for other in matching_indexes) or dim in unlabeled_dim_sizes ): - joined_indexes[dim] = index + joined_indexes[dim] = indexes[dim] else: if ( any( @@ -318,9 +320,11 @@ def align( if join == "exact": raise ValueError(f"indexes along dimension {dim!r} are not equal") index = joiner(matching_indexes) + # make sure str coords are not cast to object + index = maybe_coerce_to_str(index, all_coords[dim]) joined_indexes[dim] = index else: - index = matching_indexes[0] + index = all_coords[dim][0] if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] @@ -583,7 +587,7 @@ def reindex_variables( args: tuple = (var.attrs, var.encoding) else: args = () - reindexed[dim] = IndexVariable((dim,), target, *args) + reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) for dim in sizes: if dim not in indexes and dim in indexers: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 1275d002cd3..5cda5aa903c 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -187,7 +187,7 @@ def concat( array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) object 'a' 'b' + * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") @@ -503,7 +503,7 @@ def ensure_common_dims(vars): for k in datasets[0].variables: if k in concat_over: try: - vars = ensure_common_dims([ds.variables[k] for ds in datasets]) + vars = ensure_common_dims([ds[k].variable for ds in datasets]) except KeyError: raise ValueError("%r is not present in all datasets." % k) combined = concat_vars(vars, dim, positions) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b3a545dec73..6fdda8fc418 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1325,8 +1325,8 @@ def broadcast_like( [ 2.2408932 , 1.86755799, -0.97727788], [ nan, nan, nan]]) Coordinates: - * x (x) object 'a' 'b' 'c' - * y (y) object 'a' 'b' 'c' + * x (x) Dimensions: (station: 4) Coordinates: - * station (station) object 'boston' 'austin' 'seattle' 'lincoln' + * station (station) Dimensions: (station: 4) Coordinates: - * station (station) object 'boston' 'austin' 'seattle' 'lincoln' + * station (station) Dimensions: (station: 4) Coordinates: - * station (station) object 'boston' 'austin' 'seattle' 'lincoln' + * station (station) pd.Index: """Given an array, safely cast it to a pandas.Index. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0a6eef44c90..797de65bbcf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -48,6 +48,7 @@ ensure_us_time_resolution, infix_dims, is_duck_array, + maybe_coerce_to_str, ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -2523,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): indices = nputils.inverse_permutation(np.concatenate(positions)) data = data.take(indices) + # keep as str if possible as pandas.Index uses object (converts to numpy array) + data = maybe_coerce_to_str(data, variables) + attrs = dict(first_var.attrs) if not shortcut: for var in variables: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0d5507b6879..7416cab13ed 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -376,6 +376,30 @@ def test_concat_fill_value(self, fill_value): actual = concat(datasets, dim="t", fill_value=fill_value) assert_identical(actual, expected) + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim): + + data = np.arange(4).reshape([2, 2]) + + da1 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": [0, 1], + "x2": np.array(["a", "b"], dtype=dtype), + } + ) + da2 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": np.array([1, 2]), + "x2": np.array(["c", "d"], dtype=dtype), + } + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + class TestConcatDataArray: def test_concat(self): @@ -525,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self): actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs) assert_identical(actual, expected[combine_attrs]) + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim): + + data = np.arange(4).reshape([2, 2]) + + da1 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)}, + ) + da2 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)}, + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + @pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) @pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8215a9ddaac..3ead427e22e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1568,6 +1568,19 @@ def test_reindex_fill_value(self, fill_value): ) assert_identical(expected, actual) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype): + + data = DataArray( + [1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)} + ) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.dtype == expected.dtype + def test_rename(self): renamed = self.dv.rename("bar") assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"})) @@ -3435,6 +3448,26 @@ def test_align_without_indexes_errors(self): DataArray([1, 2], coords=[("x", [0, 1])]), ) + def test_align_str_dtype(self): + + a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]}) + b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]}) + + expected_a = DataArray( + [0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + expected_b = DataArray( + [np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast_arrays(self): x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 204f08c2eec..bd1938455b1 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1949,6 +1949,16 @@ def test_reindex_like_fill_value(self, fill_value): ) assert_identical(expected, actual) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype): + data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)}) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.x.dtype == expected.x.dtype + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}]) def test_align_fill_value(self, fill_value): x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})}) @@ -2134,6 +2144,22 @@ def test_align_non_unique(self): with raises_regex(ValueError, "cannot reindex or align"): align(x, y) + def test_align_str_dtype(self): + + a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]}) + + expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]}) + expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]}) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast(self): ds = Dataset( {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])} @@ -3420,6 +3446,14 @@ def test_setitem_align_new_indexes(self): ) assert_identical(ds, expected) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_setitem_str_dtype(self, dtype): + + ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) + ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"]) + + assert np.issubdtype(ds.x.dtype, dtype) + def test_assign(self): ds = Dataset() actual = ds.assign(x=[0, 1, 2], y=2) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 5f8b1770bd3..193c45f01cd 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -39,6 +39,33 @@ def test_safe_cast_to_index(): assert expected.dtype == actual.dtype +@pytest.mark.parametrize( + "a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]] +) +def test_maybe_coerce_to_str(a, b, expected): + + a = np.array([a]) + b = np.array([b]) + index = pd.Index(a).append(pd.Index(b)) + + actual = utils.maybe_coerce_to_str(index, [a, b]) + + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + +def test_maybe_coerce_to_str_minimal_str_dtype(): + + a = np.array(["a", "a_long_string"]) + index = pd.Index(["a"]) + + actual = utils.maybe_coerce_to_str(index, [a]) + expected = np.array("a") + + assert_array_equal(expected, actual) + assert expected.dtype == actual.dtype + + @requires_cftime def test_safe_cast_to_index_cftimeindex(): date_types = _all_cftime_date_types() diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 41bf24c7f88..e1ae3e1f258 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2094,6 +2094,17 @@ def test_concat_multiindex(self): assert_identical(actual, expected) assert isinstance(actual.to_index(), pd.MultiIndex) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_concat_str_dtype(self, dtype): + + a = IndexVariable("x", np.array(["a"], dtype=dtype)) + b = IndexVariable("x", np.array(["b"], dtype=dtype)) + expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype)) + + actual = IndexVariable.concat([a, b]) + assert actual.identical(expected) + assert np.issubdtype(actual.dtype, dtype) + def test_coordinate_alias(self): with pytest.warns(Warning, match="deprecated"): x = Coordinate("x", [1, 2, 3])