Skip to content

Commit

Permalink
BUG: Support rasterio complex_int16 (GDAL CInt16) dtype (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
scottyhq authored Jun 24, 2021
1 parent 6204661 commit 18af305
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ History

Latest
------
- BUG: support GDAL CInt16, rasterio complex_int16 (pull #353)

0.4.2
------
Expand Down
17 changes: 15 additions & 2 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def __init__(
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError("All bands should have the same dtype")

dtype = np.dtype(dtypes[0])
dtype = _rasterio_to_numpy_dtype(dtypes)

# handle unsigned case
if mask_and_scale and unsigned and dtype.kind == "i":
self._dtype = np.dtype("u%s" % dtype.itemsize)
Expand Down Expand Up @@ -334,6 +335,17 @@ def default(value):
return parsed_meta


def _rasterio_to_numpy_dtype(dtypes):
"""Numpy dtype from first entry of rasterio dataset.dtypes"""
# rasterio has some special dtype names (complex_int16 -> np.complex64)
if dtypes[0] == "complex_int16":
dtype = np.dtype("complex64")
else:
dtype = np.dtype(dtypes[0])

return dtype


def _to_numeric(value):
"""
Convert the value to a number
Expand Down Expand Up @@ -882,8 +894,9 @@ def open_rasterio(
encoding = {}
if mask_and_scale and "_Unsigned" in attrs:
unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true"

if masked:
encoding["dtype"] = str(riods.dtypes[0])
encoding["dtype"] = _rasterio_to_numpy_dtype(riods.dtypes)

da_name = attrs.pop("NETCDF_VARNAME", default_name)
data = indexing.LazilyOuterIndexedArray(
Expand Down
19 changes: 12 additions & 7 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,18 @@ def _ensure_nodata_dtype(original_nodata, new_dtype):
Convert the nodata to the new datatype and raise warning
if the value of the nodata value changed.
"""
original_nodata = float(original_nodata)
nodata = np.dtype(new_dtype).type(original_nodata)
if not np.isnan(nodata) and original_nodata != nodata:
warnings.warn(
f"The nodata value ({original_nodata}) has been automatically "
f"changed to ({nodata}) to match the dtype of the data."
)
# Complex-valued rasters can have real-valued nodata
if str(new_dtype).startswith("c"):
nodata = original_nodata
else:
original_nodata = float(original_nodata)
nodata = np.dtype(new_dtype).type(original_nodata)
if not np.isnan(nodata) and original_nodata != nodata:
warnings.warn(
f"The nodata value ({original_nodata}) has been automatically "
f"changed to ({nodata}) to match the dtype of the data."
)

return nodata


Expand Down
15 changes: 11 additions & 4 deletions rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,18 @@ def to_raster(self, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
Call ".compute()" on the Delayed object to compute the result
later. Call ``dask.compute(delayed1, delayed2)`` to save
multiple delayed files at once.
dtype: np.dtype
Numpy-compliant dtype used to save raster. If data is not already
represented by this dtype in memory it is recast. dtype='complex_int16'
is a special case to write in-memory np.complex64 to CInt16.
**kwargs
Keyword arguments to pass into writing the raster.
"""
dtype = kwargs["dtype"]
# generate initial output file
if str(kwargs["dtype"]) == "complex_int16":
numpy_dtype = "complex64"
else:
numpy_dtype = kwargs["dtype"]

with rasterio.open(self.raster_path, "w", **kwargs) as rds:
_write_metatata_to_raster(rds, xarray_dataarray, tags)
if not (lock and is_dask_collection(xarray_dataarray.data)):
Expand All @@ -170,15 +177,15 @@ def to_raster(self, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
out_data = xarray_dataarray.rio.isel_window(window)
else:
out_data = xarray_dataarray
data = encode_cf_variable(out_data).values.astype(dtype)
data = encode_cf_variable(out_data).values.astype(numpy_dtype)
if data.ndim == 2:
rds.write(data, 1, window=window)
else:
rds.write(data, window=window)

if lock and is_dask_collection(xarray_dataarray.data):
return dask.array.store(
encode_cf_variable(xarray_dataarray).data.astype(dtype),
encode_cf_variable(xarray_dataarray).data.astype(numpy_dtype),
self,
lock=lock,
compute=compute,
Expand Down
84 changes: 84 additions & 0 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,3 +1107,87 @@ def test_non_rectilinear__skip_parse_coordinates(open_rasterio):
assert xds.rio.shape == (10, 10)
with rasterio.open(test_file) as rds:
assert rds.transform == xds.rio.transform()


@pytest.mark.xfail(
rasterio.__version__ < "1.2.4",
reason="https://github.com/mapbox/rasterio/issues/2182",
)
def test_cint16_dtype(tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file)
assert xds.rio.shape == (100, 100)
assert xds.dtype == "complex64"

tmp_output = tmp_path / "tmp_cint16.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex_int16" in riofh.dtypes
assert data.dtype == "complex64"


@pytest.mark.xfail(
rasterio.__version__ < "1.2.5",
reason="https://github.com/mapbox/rasterio/issues/2206",
)
def test_cint16_dtype_nodata(tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file)
assert xds.rio.nodata == 0

tmp_output = tmp_path / "tmp_cint16.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
with rasterio.open(str(tmp_output)) as riofh:
assert riofh.nodata == 0

# Assign nodata=None
tmp_output = tmp_path / "tmp_cint16_nodata.tif"
xds.rio.write_nodata(None, inplace=True)
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
with rasterio.open(str(tmp_output)) as riofh:
assert riofh.nodata is None


def test_cint16_dtype_masked(tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file, masked=True)
assert xds.rio.shape == (100, 100)
assert xds.dtype == "complex64"
assert xds.rio.encoded_nodata == 0
assert np.isnan(xds.rio.nodata)

tmp_output = tmp_path / "tmp_cint16.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex_int16")
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex_int16" in riofh.dtypes
assert riofh.nodata == 0
assert data.dtype == "complex64"


def test_cint16_promote_dtype(tmp_path):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "cint16.tif")
xds = rioxarray.open_rasterio(test_file)

tmp_output = tmp_path / "tmp_cfloat64.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output))
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex64" in riofh.dtypes
assert riofh.nodata == 0
assert data.dtype == "complex64"

tmp_output = tmp_path / "tmp_cfloat128.tif"
with pytest.warns(NotGeoreferencedWarning):
xds.rio.to_raster(str(tmp_output), dtype="complex128")
with rasterio.open(str(tmp_output)) as riofh:
data = riofh.read()
assert "complex128" in riofh.dtypes
assert riofh.nodata == 0
assert data.dtype == "complex128"
Binary file added test/test_data/input/cint16.tif
Binary file not shown.

0 comments on commit 18af305

Please sign in to comment.