Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coords: retain str dtype #4759

Merged
merged 11 commits into from
Jan 13, 2021
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ Bug fixes
By `Anderson Banihirwe <https://github.com/andersy005>`_
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
By `Alessandro Amici <https://github.com/alexamici>`_
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,
e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype
(:issue:`2658` and :issue:`4543`) by `Mathias Hauser <https://github.com/mathause>`_.
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`).
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import dtypes, utils
from .indexing import get_indexer_nd
from .utils import is_dict_like, is_full_slice
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str
from .variable import IndexVariable, Variable

if TYPE_CHECKING:
Expand Down Expand Up @@ -278,10 +278,12 @@ def align(
return (obj.copy(deep=copy),)

all_indexes = defaultdict(list)
all_coords = defaultdict(list)
unlabeled_dim_sizes = defaultdict(set)
for obj in objects:
for dim in obj.dims:
if dim not in exclude:
all_coords[dim].append(obj.coords[dim])
try:
index = obj.indexes[dim]
except KeyError:
Expand All @@ -306,7 +308,7 @@ def align(
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
joined_indexes[dim] = index
joined_indexes[dim] = indexes[dim]
else:
if (
any(
Expand All @@ -318,9 +320,11 @@ def align(
if join == "exact":
raise ValueError(f"indexes along dimension {dim!r} are not equal")
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
joined_indexes[dim] = index
else:
index = matching_indexes[0]
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
Expand Down Expand Up @@ -583,7 +587,7 @@ def reindex_variables(
args: tuple = (var.attrs, var.encoding)
else:
args = ()
reindexed[dim] = IndexVariable((dim,), target, *args)
reindexed[dim] = IndexVariable((dim,), indexers[dim], *args)

for dim in sizes:
if dim not in indexes and dim in indexers:
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def concat(
array([[0, 1, 2],
[3, 4, 5]])
Coordinates:
* x (x) object 'a' 'b'
* x (x) <U1 'a' 'b'
* y (y) int64 10 20 30

>>> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim")
Expand Down Expand Up @@ -503,7 +503,7 @@ def ensure_common_dims(vars):
for k in datasets[0].variables:
if k in concat_over:
try:
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
vars = ensure_common_dims([ds[k].variable for ds in datasets])
except KeyError:
raise ValueError("%r is not present in all datasets." % k)
combined = concat_vars(vars, dim, positions)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,8 +1325,8 @@ def broadcast_like(
[ 2.2408932 , 1.86755799, -0.97727788],
[ nan, nan, nan]])
Coordinates:
* x (x) object 'a' 'b' 'c'
* y (y) object 'a' 'b' 'c'
* x (x) <U1 'a' 'b' 'c'
* y (y) <U1 'a' 'b' 'c'
"""
if exclude is None:
exclude = set()
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,7 +2565,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 nan 12.06 nan
pressure (station) float64 211.8 nan 218.8 nan
Expand All @@ -2576,7 +2576,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 0.0 218.8 0.0
Expand All @@ -2589,7 +2589,7 @@ def reindex(
<xarray.Dataset>
Dimensions: (station: 4)
Coordinates:
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
Data variables:
temperature (station) float64 10.98 0.0 12.06 0.0
pressure (station) float64 211.8 100.0 218.8 100.0
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,11 @@ def dataset_update_method(
if coord_names:
other[key] = value.drop_vars(coord_names)

# use ds.coords and not ds.indexes, else str coords are cast to object
indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()}
return merge_core(
[dataset, other],
priority_arg=1,
indexes=dataset.indexes,
indexes=indexes,
combine_attrs="override",
)
19 changes: 19 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import numpy as np
import pandas as pd

from . import dtypes

K = TypeVar("K")
V = TypeVar("V")
T = TypeVar("T")
Expand Down Expand Up @@ -76,6 +78,23 @@ def maybe_cast_to_coords_dtype(label, coords_dtype):
return label


def maybe_coerce_to_str(index, original_coords):
"""maybe coerce a pandas Index back to a nunpy array of type str

pd.Index uses object-dtype to store str - try to avoid this for coords
"""

try:
result_type = dtypes.result_type(*original_coords)
except TypeError:
pass
else:
if result_type.kind in "SU":
index = np.asarray(index, dtype=result_type.type)

return index


def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.

Expand Down
4 changes: 4 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ensure_us_time_resolution,
infix_dims,
is_duck_array,
maybe_coerce_to_str,
)

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
Expand Down Expand Up @@ -2523,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False):
indices = nputils.inverse_permutation(np.concatenate(positions))
data = data.take(indices)

# keep as str if possible as pandas.Index uses object (converts to numpy array)
data = maybe_coerce_to_str(data, variables)

attrs = dict(first_var.attrs)
if not shortcut:
for var in variables:
Expand Down
44 changes: 44 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,30 @@ def test_concat_fill_value(self, fill_value):
actual = concat(datasets, dim="t", fill_value=fill_value)
assert_identical(actual, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": [0, 1],
"x2": np.array(["a", "b"], dtype=dtype),
}
)
da2 = Dataset(
{
"data": (["x1", "x2"], data),
"x1": np.array([1, 2]),
"x2": np.array(["c", "d"], dtype=dtype),
}
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


class TestConcatDataArray:
def test_concat(self):
Expand Down Expand Up @@ -525,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self):
actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs)
assert_identical(actual, expected[combine_attrs])

@pytest.mark.parametrize("dtype", [str, bytes])
@pytest.mark.parametrize("dim", ["x1", "x2"])
def test_concat_str_dtype(self, dtype, dim):

data = np.arange(4).reshape([2, 2])

da1 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)},
)
da2 = DataArray(
data=data,
dims=["x1", "x2"],
coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)},
)
actual = concat([da1, da2], dim=dim)

assert np.issubdtype(actual.x2.dtype, dtype)


@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {}))
@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {}))
Expand Down
33 changes: 33 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,19 @@ def test_reindex_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):

data = DataArray(
[1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)}
)

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.dtype == expected.dtype

def test_rename(self):
renamed = self.dv.rename("bar")
assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"}))
Expand Down Expand Up @@ -3435,6 +3448,26 @@ def test_align_without_indexes_errors(self):
DataArray([1, 2], coords=[("x", [0, 1])]),
)

def test_align_str_dtype(self):

a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]})
b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]})

expected_a = DataArray(
[0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]}
)
expected_b = DataArray(
[np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]}
)

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast_arrays(self):
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down
34 changes: 34 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,16 @@ def test_reindex_like_fill_value(self, fill_value):
)
assert_identical(expected, actual)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_reindex_str_dtype(self, dtype):
data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)})

actual = data.reindex(x=data.x)
expected = data

assert_identical(expected, actual)
assert actual.x.dtype == expected.x.dtype

@pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}])
def test_align_fill_value(self, fill_value):
x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})})
Expand Down Expand Up @@ -2134,6 +2144,22 @@ def test_align_non_unique(self):
with raises_regex(ValueError, "cannot reindex or align"):
align(x, y)

def test_align_str_dtype(self):

a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]})
b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]})

expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]})
expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]})

actual_a, actual_b = xr.align(a, b, join="outer")

assert_identical(expected_a, actual_a)
assert expected_a.x.dtype == actual_a.x.dtype

assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_broadcast(self):
ds = Dataset(
{"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}
Expand Down Expand Up @@ -3420,6 +3446,14 @@ def test_setitem_align_new_indexes(self):
)
assert_identical(ds, expected)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_setitem_str_dtype(self, dtype):

ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)})
ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"])

assert np.issubdtype(ds.x.dtype, dtype)

def test_assign(self):
ds = Dataset()
actual = ds.assign(x=[0, 1, 2], y=2)
Expand Down
27 changes: 27 additions & 0 deletions xarray/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ def test_safe_cast_to_index():
assert expected.dtype == actual.dtype


@pytest.mark.parametrize(
"a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]]
)
def test_maybe_coerce_to_str(a, b, expected):

a = np.array([a])
b = np.array([b])
index = pd.Index(a).append(pd.Index(b))

actual = utils.maybe_coerce_to_str(index, [a, b])

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


def test_maybe_coerce_to_str_minimal_str_dtype():

a = np.array(["a", "a_long_string"])
index = pd.Index(["a"])

actual = utils.maybe_coerce_to_str(index, [a])
expected = np.array("a")

assert_array_equal(expected, actual)
assert expected.dtype == actual.dtype


@requires_cftime
def test_safe_cast_to_index_cftimeindex():
date_types = _all_cftime_date_types()
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,17 @@ def test_concat_multiindex(self):
assert_identical(actual, expected)
assert isinstance(actual.to_index(), pd.MultiIndex)

@pytest.mark.parametrize("dtype", [str, bytes])
def test_concat_str_dtype(self, dtype):

a = IndexVariable("x", np.array(["a"], dtype=dtype))
b = IndexVariable("x", np.array(["b"], dtype=dtype))
expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype))

actual = IndexVariable.concat([a, b])
assert actual.identical(expected)
assert np.issubdtype(actual.dtype, dtype)

def test_coordinate_alias(self):
with pytest.warns(Warning, match="deprecated"):
x = Coordinate("x", [1, 2, 3])
Expand Down