Skip to content

Commit

Permalink
coords: retain str dtype (#4759)
Browse files Browse the repository at this point in the history
* coords: retain str dtype

* fix doctests

* update what's new

* fix multiindex repr

* rename function

* ensure minimum str dtype

* fix EOL spaces
  • Loading branch information
mathause authored Jan 13, 2021
1 parent f52a95c commit fb67358
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 12 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ Bug fixes
By `Anderson Banihirwe <https://github.com/andersy005>`_
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,
e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype
(:issue:`2658` and :issue:`4543`) by `Mathias Hauser <https://github.com/mathause>`_.
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`).
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import dtypes, utils
from .indexing import get_indexer_nd
from .utils import is_dict_like, is_full_slice
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str
from .variable import IndexVariable, Variable

if TYPE_CHECKING:
Expand Down Expand Up @@ -278,10 +278,12 @@ def align(
return (obj.copy(deep=copy),)

all_indexes = defaultdict(list)
all_coords = defaultdict(list)
unlabeled_dim_sizes = defaultdict(set)
for obj in objects:
for dim in obj.dims:
if dim not in exclude:
all_coords[dim].append(obj.coords[dim])
try:
index = obj.indexes[dim]
except KeyError:
Expand All @@ -306,7 +308,7 @@ def align(
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
joined_indexes[dim] = index
joined_indexes[dim] = indexes[dim]
else:
if (
any(
Expand All @@ -318,9 +320,11 @@ def align(
if join == "exact":
raise ValueError(f"indexes along dimension {dim!r} are not equal")
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
joined_indexes[dim] = index
else:
index = matching_indexes[0]
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
Expand Down Expand Up @@ -583,7 +587,7 @@ def reindex_variables(
args: tuple = (var.attrs, var.encoding)
else:
args = ()
reindexed[dim] = IndexVariable((dim,), target, *args)
reindexed[dim] = IndexVariable((dim,), indexers[dim], *args)

for dim in sizes:
if dim not in indexes and dim in indexers:
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def concat(
array([[0, 1, 2],
[3, 4, 5]])
Coordinates:
* x (x) object 'a' 'b'
* x (x) <U1 'a' 'b'
* y (y) int64 10 20 30
>>> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim")
Expand Down Expand Up @@ -503,7 +503,7 @@ def ensure_common_dims(vars):
for k in datasets[0].variables:
if k in concat_over:
try:
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
vars = ensure_common_dims([ds[k].variable for ds in datasets])
except KeyError:
raise ValueError("%r is not present in all datasets." % k)
combined = concat_vars(vars, dim, positions)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,8 +1325,8 @@ def broadcast_like(
[ 2.2408932 , 1.86755799, -0.97727788],
[ nan, nan, nan]])
Coordinates:
* x (x) object 'a' 'b' 'c'
* y (y) object 'a' 'b' 'c'
* x (x) <U1 'a' 'b' 'c'
* y (y) <U1 'a' 'b' 'c'
"""
if exclude is None:
exclude = set()
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,7 +2565,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 nan 12.06 nan
pressure (station) float64 211.8 nan 218.8 nan
Expand All @@ -2576,7 +2576,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 0.0 218.8 0.0
Expand All @@ -2589,7 +2589,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 100.0 218.8 100.0
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,11 @@ def dataset_update_method(
if coord_names:
other[key] = value.drop_vars(coord_names)

# use ds.coords and not ds.indexes, else str coords are cast to object
indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()}
return merge_core(
[dataset, other],
priority_arg=1,
indexes=dataset.indexes,
indexes=indexes,
combine_attrs="override",
)
19 changes: 19 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import numpy as np
import pandas as pd

from . import dtypes

K = TypeVar("K")
V = TypeVar("V")
T = TypeVar("T")
Expand Down Expand Up @@ -76,6 +78,23 @@ def maybe_cast_to_coords_dtype(label, coords_dtype):
return label


def maybe_coerce_to_str(index, original_coords):
"""maybe coerce a pandas Index back to a nunpy array of type str
pd.Index uses object-dtype to store str - try to avoid this for coords
"""

try:
result_type = dtypes.result_type(*original_coords)
except TypeError:
pass
else:
if result_type.kind in "SU":
index = np.asarray(index, dtype=result_type.type)

return index


def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ensure_us_time_resolution,
infix_dims,
is_duck_array,
maybe_coerce_to_str,
)

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
Expand Down Expand Up @@ -2523,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False):
indices = nputils.inverse_permutation(np.concatenate(positions))
data = data.take(indices)

# keep as str if possible as pandas.Index uses object (converts to numpy array)
data = maybe_coerce_to_str(data, variables)

attrs = dict(first_var.attrs)
if not shortcut:
for var in variables:
Expand Down
44 changes: 44 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,30 @@ def test_concat_fill_value(self, fill_value):
actual = concat(datasets, dim="t", fill_value=fill_value)
assert_identical(actual, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": [0, 1],
"x2": np.array(["a", "b"], dtype=dtype),
}
)
da2 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": np.array([1, 2]),
"x2": np.array(["c", "d"], dtype=dtype),
}
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


class TestConcatDataArray:
def test_concat(self):
Expand Down Expand Up @@ -525,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self):
actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs)
assert_identical(actual, expected[combine_attrs])

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)},
)
da2 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)},
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {}))
@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {}))
Expand Down
33 changes: 33 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,19 @@ def test_reindex_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):

data = DataArray(
[1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)}
)

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.dtype == expected.dtype

def test_rename(self):
renamed = self.dv.rename("bar")
assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"}))
Expand Down Expand Up @@ -3435,6 +3448,26 @@ def test_align_without_indexes_errors(self):
DataArray([1, 2], coords=[("x", [0, 1])]),
)

def test_align_str_dtype(self):

a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]})
b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]})

expected_a = DataArray(
[0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]}
)
expected_b = DataArray(
[np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]}
)

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_arrays(self):
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down
34 changes: 34 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,16 @@ def test_reindex_like_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):
data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)})

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.x.dtype == expected.x.dtype

@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}])
def test_align_fill_value(self, fill_value):
x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})})
Expand Down Expand Up @@ -2134,6 +2144,22 @@ def test_align_non_unique(self):
with raises_regex(ValueError, "cannot reindex or align"):
align(x, y)

def test_align_str_dtype(self):

a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]})
b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]})

expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]})
expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]})

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast(self):
ds = Dataset(
{"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}
Expand Down Expand Up @@ -3420,6 +3446,14 @@ def test_setitem_align_new_indexes(self):
)
assert_identical(ds, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_setitem_str_dtype(self, dtype):

ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)})
ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"])

assert np.issubdtype(ds.x.dtype, dtype)

def test_assign(self):
ds = Dataset()
actual = ds.assign(x=[0, 1, 2], y=2)
Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ def test_safe_cast_to_index():
assert expected.dtype == actual.dtype


@pytest.mark.parametrize(
"a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]]
)
def test_maybe_coerce_to_str(a, b, expected):

a = np.array([a])
b = np.array([b])
index = pd.Index(a).append(pd.Index(b))

actual = utils.maybe_coerce_to_str(index, [a, b])

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


def test_maybe_coerce_to_str_minimal_str_dtype():

a = np.array(["a", "a_long_string"])
index = pd.Index(["a"])

actual = utils.maybe_coerce_to_str(index, [a])
expected = np.array("a")

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


@requires_cftime
def test_safe_cast_to_index_cftimeindex():
date_types = _all_cftime_date_types()
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,17 @@ def test_concat_multiindex(self):
assert_identical(actual, expected)
assert isinstance(actual.to_index(), pd.MultiIndex)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_concat_str_dtype(self, dtype):

a = IndexVariable("x", np.array(["a"], dtype=dtype))
b = IndexVariable("x", np.array(["b"], dtype=dtype))
expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype))

actual = IndexVariable.concat([a, b])
assert actual.identical(expected)
assert np.issubdtype(actual.dtype, dtype)

def test_coordinate_alias(self):
with pytest.warns(Warning, match="deprecated"):
x = Coordinate("x", [1, 2, 3])
Expand Down

0 comments on commit fb67358

Please sign in to comment.