Skip to content

Commit

Permalink
CF encoding should preserve vlen dtype for empty arrays (#7862)
Browse files Browse the repository at this point in the history
* CF encoding should preserve vlen dtype for empty arrays

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* preserve vlen string dtype in netcdf4 and zarr backends

* check for h5py-variant ("vlen") in coding.strings.check_vlen_dtype

* add test to check preserving vlen dtype for empty vlen string arrays

* ignore call_overload error for np.dtype("O", metadata={"vlen": str})

* use filter.codec_id instead of private filter._meta as suggested in review

* update comment and add whats-new.rst entry

* fix whats-new.rst

* fix whats-new.rst (missing dot)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kai Mühlbauer <kai.muehlbauer@uni-bonn.de>
Co-authored-by: Kai Mühlbauer <kmuehlbauer@wradlib.org>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
5 people authored Jun 16, 2023
1 parent 99f9559 commit 0c876e4
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Bug fixes
By `Mattia Almansi <https://github.com/malmans2>`_.
- Don't call ``CachingFileManager.__del__`` on interpreter shutdown (:issue:`7814`, :pull:`7880`).
By `Justus Magin <https://github.com/keewis>`_.
- Preserve vlen dtype for empty string arrays (:issue:`7328`, :pull:`7862`).
By `Tom White <https://github.com/tomwhite>`_ and `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Ensure dtype of reindex result matches dtype of the original DataArray (:issue:`7299`, :pull:`7917`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

Expand Down
10 changes: 6 additions & 4 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def __init__(self, variable_name, datastore):

dtype = array.dtype
if dtype is str:
# use object dtype because that's the only way in numpy to
# represent variable length strings; it also prevents automatic
# string concatenation via conventions.decode_cf_variable
dtype = np.dtype("O")
# use object dtype (with additional vlen string metadata) because that's
# the only way in numpy to represent variable length strings and to
# check vlen string dtype in further steps
# it also prevents automatic string concatenation via
# conventions.decode_cf_variable
dtype = coding.strings.create_vlen_dtype(str)
self.dtype = dtype

def __setitem__(self, key, value):
Expand Down
9 changes: 8 additions & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def __init__(self, variable_name, datastore):
array = self.get_array()
self.shape = array.shape

dtype = array.dtype
# preserve vlen string object dtype (GH 7328)
if array.filters is not None and any(
[filt.codec_id == "vlen-utf8" for filt in array.filters]
):
dtype = coding.strings.create_vlen_dtype(str)
else:
dtype = array.dtype

self.dtype = dtype

def get_array(self):
Expand Down
3 changes: 2 additions & 1 deletion xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def check_vlen_dtype(dtype):
if dtype.kind != "O" or dtype.metadata is None:
return None
else:
return dtype.metadata.get("element_type")
# check xarray (element_type) as well as h5py (vlen)
return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))


def is_unicode_dtype(dtype):
Expand Down
4 changes: 4 additions & 0 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
if var.dtype.kind == "O":
dims, data, attrs, encoding = _var_as_tuple(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
return var

if is_duck_dask_array(data):
warnings.warn(
"variable {} has data in the form of a dask array with "
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from xarray.backends.pydap_ import PydapDataStore
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype
from xarray.coding.variables import SerializationWarning
from xarray.conventions import encode_dataset_coordinates
from xarray.core import indexing
Expand Down Expand Up @@ -859,6 +860,20 @@ def test_roundtrip_string_with_fill_value_nchar(self) -> None:
with self.roundtrip(original) as actual:
assert_identical(expected, actual)

def test_roundtrip_empty_vlen_string_array(self) -> None:
# checks preserving vlen dtype for empty arrays GH7862
dtype = create_vlen_dtype(str)
original = Dataset({"a": np.array([], dtype=dtype)})
assert check_vlen_dtype(original["a"].dtype) == str
with self.roundtrip(original) as actual:
assert_identical(original, actual)
assert object == actual["a"].dtype
assert actual["a"].dtype == original["a"].dtype
# only check metadata for capable backends
# eg. NETCDF3 based backends do not roundtrip metadata
if actual["a"].dtype.metadata is not None:
assert check_vlen_dtype(actual["a"].dtype) == str

@pytest.mark.parametrize(
"decoded_fn, encoded_fn",
[
Expand Down
4 changes: 4 additions & 0 deletions xarray/tests/test_coding_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def test_vlen_dtype() -> None:
assert strings.is_bytes_dtype(dtype)
assert strings.check_vlen_dtype(dtype) is bytes

# check h5py variant ("vlen")
dtype = np.dtype("O", metadata={"vlen": str}) # type: ignore[call-overload]
assert strings.check_vlen_dtype(dtype) is str

assert strings.check_vlen_dtype(np.dtype(object)) is None


Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,18 @@ def test_decode_cf_error_includes_variable_name():
ds = Dataset({"invalid": ([], 1e36, {"units": "days since 2000-01-01"})})
with pytest.raises(ValueError, match="Failed to decode variable 'invalid'"):
decode_cf(ds)


def test_encode_cf_variable_with_vlen_dtype() -> None:
v = Variable(
["x"], np.array(["a", "b"], dtype=coding.strings.create_vlen_dtype(str))
)
encoded_v = conventions.encode_cf_variable(v)
assert encoded_v.data.dtype.kind == "O"
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str

# empty array
v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str)))
encoded_v = conventions.encode_cf_variable(v)
assert encoded_v.data.dtype.kind == "O"
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str

0 comments on commit 0c876e4

Please sign in to comment.