diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 420b6c55d56..59f910911f1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -50,6 +50,8 @@ Bug fixes By `Mattia Almansi `_. - Don't call ``CachingFileManager.__del__`` on interpreter shutdown (:issue:`7814`, :pull:`7880`). By `Justus Magin `_. +- Preserve vlen dtype for empty string arrays (:issue:`7328`, :pull:`7862`). + By `Tom White `_ and `Kai Mühlbauer `_. - Ensure dtype of reindex result matches dtype of the original DataArray (:issue:`7299`, :pull:`7917`) By `Anderson Banihirwe `_. diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index d3866e90de6..8a5d48c8c1e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -65,10 +65,12 @@ def __init__(self, variable_name, datastore): dtype = array.dtype if dtype is str: - # use object dtype because that's the only way in numpy to - # represent variable length strings; it also prevents automatic - # string concatenation via conventions.decode_cf_variable - dtype = np.dtype("O") + # use object dtype (with additional vlen string metadata) because that's + # the only way in numpy to represent variable length strings and to + # check vlen string dtype in further steps + # it also prevents automatic string concatenation via + # conventions.decode_cf_variable + dtype = coding.strings.create_vlen_dtype(str) self.dtype = dtype def __setitem__(self, key, value): diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index a4012a8a733..5c3d5781e35 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -70,7 +70,14 @@ def __init__(self, variable_name, datastore): array = self.get_array() self.shape = array.shape - dtype = array.dtype + # preserve vlen string object dtype (GH 7328) + if array.filters is not None and any( + [filt.codec_id == "vlen-utf8" for filt in array.filters] + ): + dtype = coding.strings.create_vlen_dtype(str) + else: + dtype = array.dtype + self.dtype = dtype def get_array(self): diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index ffe1b1a8d50..d0bfb1a7a63 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -29,7 +29,8 @@ def check_vlen_dtype(dtype): if dtype.kind != "O" or dtype.metadata is None: return None else: - return dtype.metadata.get("element_type") + # check xarray (element_type) as well as h5py (vlen) + return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) def is_unicode_dtype(dtype): diff --git a/xarray/conventions.py b/xarray/conventions.py index 1506efc31e8..053863ace2a 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -108,6 +108,10 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: if var.dtype.kind == "O": dims, data, attrs, encoding = _var_as_tuple(var) + # leave vlen dtypes unchanged + if strings.check_vlen_dtype(data.dtype) is not None: + return var + if is_duck_dask_array(data): warnings.warn( "variable {} has data in the form of a dask array with " diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 190adab3d19..e0a2262a339 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -46,6 +46,7 @@ ) from xarray.backends.pydap_ import PydapDataStore from xarray.backends.scipy_ import ScipyBackendEntrypoint +from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates from xarray.core import indexing @@ -859,6 +860,20 @@ def test_roundtrip_string_with_fill_value_nchar(self) -> None: with self.roundtrip(original) as actual: assert_identical(expected, actual) + def test_roundtrip_empty_vlen_string_array(self) -> None: + # checks preserving vlen dtype for empty arrays GH7862 + dtype = create_vlen_dtype(str) + original = Dataset({"a": np.array([], dtype=dtype)}) + assert check_vlen_dtype(original["a"].dtype) == str + with self.roundtrip(original) as actual: + assert_identical(original, actual) + assert object == actual["a"].dtype + assert actual["a"].dtype == original["a"].dtype + # only check metadata for capable backends + # eg. NETCDF3 based backends do not roundtrip metadata + if actual["a"].dtype.metadata is not None: + assert check_vlen_dtype(actual["a"].dtype) == str + @pytest.mark.parametrize( "decoded_fn, encoded_fn", [ diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index cb9595f4a64..0c9f67e77ad 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -32,6 +32,10 @@ def test_vlen_dtype() -> None: assert strings.is_bytes_dtype(dtype) assert strings.check_vlen_dtype(dtype) is bytes + # check h5py variant ("vlen") + dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload] + assert strings.check_vlen_dtype(dtype) is str + assert strings.check_vlen_dtype(np.dtype(object)) is None diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index acdf9c8846e..424b7db5ac4 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -487,3 +487,18 @@ def test_decode_cf_error_includes_variable_name(): ds = Dataset({"invalid": ([], 1e36, {"units": "days since 2000-01-01"})}) with pytest.raises(ValueError, match="Failed to decode variable 'invalid'"): decode_cf(ds) + + +def test_encode_cf_variable_with_vlen_dtype() -> None: + v = Variable( + ["x"], np.array(["a", "b"], dtype=coding.strings.create_vlen_dtype(str)) + ) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str + + # empty array + v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str))) + encoded_v = conventions.encode_cf_variable(v) + assert encoded_v.data.dtype.kind == "O" + assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str