Skip to content

Commit

Permalink
move io functionality for rasterio
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaapel committed May 2, 2024
1 parent 3289aa7 commit 2ba86a4
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 38 deletions.
219 changes: 209 additions & 10 deletions hydromt/drivers/rasterio_driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
"""Driver using rasterio for RasterDataset."""
from glob import glob
from io import IOBase
from logging import Logger, getLogger
from os.path import basename
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import dask
import rasterio
import rioxarray
import xarray as xr
from pyproj import CRS

from hydromt import io
from hydromt._typing import Geom, StrPath, TimeRange, ZoomLevel
from hydromt._typing.error import NoDataStrategy
from hydromt._utils.uris import strip_scheme
from hydromt.config import SETTINGS
from hydromt.data_adapter.caching import cache_vrt_tiles
from hydromt.drivers import RasterDatasetDriver
from hydromt.gis.merge import merge
from hydromt.gis.utils import cellres

logger: Logger = getLogger(__name__)
Expand Down Expand Up @@ -67,25 +72,19 @@ def read_data(
if isinstance(zoom_level, int) and zoom_level > 0:
# NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs
kwargs.update(overview_level=zoom_level - 1)
ds = io.open_mfraster(uris, logger=logger, **kwargs)
ds = open_mfraster(uris, logger=logger, **kwargs)
return ds

def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> None:
"""Write out a RasterDataset using rasterio."""
pass

def _get_zoom_levels_and_crs(
self, fn: Optional[StrPath] = None, logger=logger
) -> Tuple[int, int]:
def _get_zoom_levels_and_crs(self, uri: StrPath, logger=logger) -> Tuple[int, int]:
"""Get zoom levels and crs from adapter or detect from tif file if missing."""
if self.zoom_levels is not None and self.crs is not None:
return self.zoom_levels, self.crs
zoom_levels = {}
crs = None
if fn is None:
fn = self.path
try:
with rasterio.open(fn) as src:
with rasterio.open(uri) as src:
res = abs(src.res[0])
crs = src.crs
overviews = [src.overviews(i) for i in src.indexes]
Expand Down Expand Up @@ -184,3 +183,203 @@ def _parse_zoom_level(
raise TypeError(f"zoom_level not understood: {type(zoom_level)}")
logger.debug(f"Using zoom level {zl} ({dst_res:.2f})")
return zl


def open_raster(
uri: Union[StrPath, IOBase, rasterio.DatasetReader, rasterio.vrt.WarpedVRT],
mask_nodata: bool = False,
chunks: Union[int, Tuple[int, ...], Dict[str, int], None] = None,
logger: Logger = logger,
**kwargs,
) -> xr.DataArray:
"""Open a gdal-readable file with rasterio based on.
:py:meth:`rioxarray.open_rasterio`, but return squeezed DataArray.
Arguments
---------
filename : str, path, file-like, rasterio.DatasetReader, or rasterio.WarpedVRT
Path to the file to open. Or already open rasterio dataset.
mask_nodata : bool, optional
set nodata values to np.nan (xarray default nodata value)
chunks : int, tuple or dict, optional
Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or
``{'x': 5, 'y': 5}``. If chunks is provided, it used to load the new
DataArray into a dask array.
**kwargs:
key-word arguments are passed to :py:meth:`xarray.open_dataset` with
"rasterio" engine.
logger : logger object, optional
The logger object used for logging messages. If not provided, the default
logger will be used.
Returns
-------
data : DataArray
DataArray
"""
chunks = chunks or {}
kwargs.update(masked=mask_nodata, default_name="data", chunks=chunks)
if not mask_nodata: # if mask_and_scale by default True in xarray ?
kwargs.update(mask_and_scale=False)
if isinstance(uri, IOBase): # file-like does not handle chunks
logger.warning("Removing chunks to read and load remote data.")
kwargs.pop("chunks")
# keep only 2D DataArray
da = rioxarray.open_rasterio(uri, **kwargs).squeeze(drop=True)
# set missing _FillValue
# TODO: do this is in data adapter
# if mask_nodata:
# da.raster.set_nodata(np.nan)
# elif da.raster.nodata is None:
# if nodata is not None:
# da.raster.set_nodata(nodata)
# else:
# logger.warning(f"nodata value missing for {uri}")
# there is no option for scaling but not masking ...
scale_factor = da.attrs.pop("scale_factor", 1)
add_offset = da.attrs.pop("add_offset", 0)
if not mask_nodata and (scale_factor != 1 or add_offset != 0):
raise NotImplementedError(
"scale and offset in combination with mask_nodata==False is not supported."
)
return da


def open_mfraster(
uris: Union[str, List[StrPath]],
chunks: Union[int, Tuple[int, ...], Dict[str, int], None] = None,
concat: bool = False,
concat_dim: str = "dim0",
mosaic: bool = False,
mosaic_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> xr.Dataset:
"""Open multiple gdal-readable files as single Dataset with geospatial attributes.
Each raster is turned into a DataArray with its name inferred from the filename.
By default all DataArray are assumed to be on an identical grid and the output
dataset is a merge of the rasters.
If ``concat`` the DataArrays are concatenated along ``concat_dim`` returning a
Dataset with a single 3D DataArray.
If ``mosaic`` the DataArrays are concatenated along the the spatial dimensions
using :py:meth:`~hydromt.raster.merge`.
Arguments
---------
uris: str, list of str/Path/file-like
Paths to the rasterio/gdal files.
Paths can be provided as list of paths or a path pattern string which is
interpreted according to the rules used by the Unix shell. The variable name
is derived from the basename minus extension in case a list of paths:
``<name>.<extension>`` and based on the file basename minus pre-, postfix and
extension in a path pattern: ``<prefix><*name><postfix>.<extension>``
chunks: int, tuple or dict, optional
Chunk sizes along each dimension, e.g., 5, (5, 5) or {'x': 5, 'y': 5}.
If chunks is provided, it used to load the new DataArray into a dask array.
concat: bool, optional
If True, concatenate raster along ``concat_dim``. We destinguish the following
filenames from which the numerical index and variable name are inferred, where
the variable name is based on the first raster.
``<name>_<index>.<extension>``
``<name>*<postfix>.<index>`` (PCRaster style; requires path pattern)
``<name><index>.<extension>``
``<name>.<extension>`` (index based on order)
concat_dim: str, optional
Dimension name of concatenate index, by default 'dim0'
mosaic: bool, optional
If True create mosaic of several rasters. The variable is named based on
variable name infered from the first raster.
mosaic_kwargs: dict, optional
Mosaic key_word arguments to unify raster crs and/or resolution. See
:py:meth:`hydromt.merge.merge` for options.
**kwargs:
key-word arguments are passed to :py:meth:`hydromt.raster.open_raster`
Returns
-------
data : DataSet
The newly created DataSet.
"""
chunks = chunks or {}
mosaic_kwargs = mosaic_kwargs or {}
if concat and mosaic:
raise ValueError("Only one of 'mosaic' or 'concat' can be True.")
prefix, postfix = "", ""
if isinstance(uris, str):
if "*" in uris:
prefix, postfix = basename(uris).split(".")[0].split("*")
uris = [fn for fn in glob(uris) if not fn.endswith(".xml")]
else:
uris = [str(p) if isinstance(p, Path) else p for p in uris]
if len(uris) == 0:
raise OSError("no files to open")

da_lst, index_lst, fn_attrs = [], [], []
for i, uri in enumerate(uris):
# read file
da = open_raster(uri, chunks=chunks, **kwargs)

# get name, attrs and index (if concat)
if hasattr(uri, "path"): # file-like
bname = basename(uri.path)
else:
bname = basename(uri)
if concat:
# name based on basename until postfix or _
vname = bname.split(".")[0].replace(postfix, "").split("_")[0]
# index based on postfix behind "_"
if "_" in bname and bname.split(".")[0].split("_")[1].isdigit():
index = int(bname.split(".")[0].split("_")[1])
# index based on file extension (PCRaster style)
elif "." in bname and bname.split(".")[1].isdigit():
index = int(bname.split(".")[1])
# index based on postfix directly after prefix
elif prefix != "" and bname.split(".")[0].strip(prefix).isdigit():
index = int(bname.split(".")[0].strip(prefix))
# index based on file order
else:
index = i
index_lst.append(index)
else:
# name based on basename minus pre- & postfix
vname = bname.split(".")[0].replace(prefix, "").replace(postfix, "")
da.attrs.update(source_file=bname)
fn_attrs.append(bname)
da.name = vname

if i > 0:
if not mosaic:
# check if transform, shape and crs are close
if not da_lst[0].raster.identical_grid(da):
raise xr.MergeError("Geotransform and/or shape do not match")
# copy coordinates from first raster
da[da.raster.x_dim] = da_lst[0][da.raster.x_dim]
da[da.raster.y_dim] = da_lst[0][da.raster.y_dim]
if concat or mosaic:
# copy name from first raster
da.name = da_lst[0].name
da_lst.append(da)

if concat or mosaic:
if concat:
with dask.config.set(**{"array.slicing.split_large_chunks": False}):
da = xr.concat(da_lst, dim=concat_dim)
da.coords[concat_dim] = xr.IndexVariable(concat_dim, index_lst)
da = da.sortby(concat_dim).transpose(concat_dim, ...)
da.attrs.update(da_lst[0].attrs)
else:
da = merge(da_lst, **mosaic_kwargs) # spatial merge
da.attrs.update({"source_file": "; ".join(fn_attrs)})
ds = da.to_dataset() # dataset for consistency
else:
ds = xr.merge(
da_lst
) # seems that with rioxarray drops all datarrays atrributes not just ds
ds.attrs = {}

# update spatial attributes
if da_lst[0].rio.crs is not None:
ds.rio.write_crs(da_lst[0].rio.crs, inplace=True)
ds.rio.write_transform(inplace=True)
return ds
65 changes: 37 additions & 28 deletions hydromt/io/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from ast import literal_eval
from glob import glob
from logging import Logger
from os.path import abspath, basename, dirname, isfile, join, splitext
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
Expand All @@ -13,6 +14,7 @@
import numpy as np
import pandas as pd
import pyproj
import rasterio
import rioxarray
import xarray as xr
from pyogrio import read_dataframe
Expand All @@ -32,6 +34,7 @@

if TYPE_CHECKING:
from hydromt._validators.model_config import HydromtModelStep

logger = logging.getLogger(__name__)

__all__ = [
Expand All @@ -49,7 +52,12 @@


def open_raster(
filename, mask_nodata=False, chunks=None, nodata=None, logger=logger, **kwargs
uri: Union[StrPath, pyio.IOBase, rasterio.DatasetReader, rasterio.vrt.WarpedVRT],
mask_nodata: bool = False,
chunks: Union[int, Tuple[int, ...], Dict[str, int], None] = None,
nodata: Optional[Union[float, int]] = None,
logger: Logger = logger,
**kwargs,
) -> xr.DataArray:
"""Open a gdal-readable file with rasterio based on.
Expand Down Expand Up @@ -83,19 +91,20 @@ def open_raster(
kwargs.update(masked=mask_nodata, default_name="data", chunks=chunks)
if not mask_nodata: # if mask_and_scale by default True in xarray ?
kwargs.update(mask_and_scale=False)
if isinstance(filename, pyio.IOBase): # file-like does not handle chunks
if isinstance(uri, pyio.IOBase): # file-like does not handle chunks
logger.warning("Removing chunks to read and load remote data.")
kwargs.pop("chunks")
# keep only 2D DataArray
da = rioxarray.open_rasterio(filename, **kwargs).squeeze(drop=True)
da = rioxarray.open_rasterio(uri, **kwargs).squeeze(drop=True)
# set missing _FillValue
if mask_nodata:
da.raster.set_nodata(np.nan)
elif da.raster.nodata is None:
if nodata is not None:
da.raster.set_nodata(nodata)
else:
logger.warning(f"nodata value missing for {filename}")
# TODO: do this is in data adapter
# if mask_nodata:
# da.raster.set_nodata(np.nan)
# elif da.raster.nodata is None:
# if nodata is not None:
# da.raster.set_nodata(nodata)
# else:
# logger.warning(f"nodata value missing for {uri}")
# there is no option for scaling but not masking ...
scale_factor = da.attrs.pop("scale_factor", 1)
add_offset = da.attrs.pop("add_offset", 0)
Expand All @@ -107,12 +116,12 @@ def open_raster(


def open_mfraster(
paths,
chunks=None,
concat=False,
concat_dim="dim0",
mosaic=False,
mosaic_kwargs=None,
uris: Union[str, List[StrPath]],
chunks: Union[int, Tuple[int, ...], Dict[str, int], None] = None,
concat: bool = False,
concat_dim: str = "dim0",
mosaic: bool = False,
mosaic_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> xr.Dataset:
"""Open multiple gdal-readable files as single Dataset with geospatial attributes.
Expand All @@ -127,7 +136,7 @@ def open_mfraster(
Arguments
---------
paths: str, list of str/Path/file-like
uris: str, list of str/Path/file-like
Paths to the rasterio/gdal files.
Paths can be provided as list of paths or a path pattern string which is
interpreted according to the rules used by the Unix shell. The variable name
Expand Down Expand Up @@ -166,25 +175,25 @@ def open_mfraster(
if concat and mosaic:
raise ValueError("Only one of 'mosaic' or 'concat' can be True.")
prefix, postfix = "", ""
if isinstance(paths, str):
if "*" in paths:
prefix, postfix = basename(paths).split(".")[0].split("*")
paths = [fn for fn in glob(paths) if not fn.endswith(".xml")]
if isinstance(uris, str):
if "*" in uris:
prefix, postfix = basename(uris).split(".")[0].split("*")
uris = [fn for fn in glob(uris) if not fn.endswith(".xml")]
else:
paths = [str(p) if isinstance(p, Path) else p for p in paths]
if len(paths) == 0:
uris = [str(p) if isinstance(p, Path) else p for p in uris]
if len(uris) == 0:
raise OSError("no files to open")

da_lst, index_lst, fn_attrs = [], [], []
for i, fn in enumerate(paths):
for i, uri in enumerate(uris):
# read file
da = open_raster(fn, chunks=chunks, **kwargs)
da = open_raster(uri, chunks=chunks, **kwargs)

# get name, attrs and index (if concat)
if hasattr(fn, "path"): # file-like
bname = basename(fn.path)
if hasattr(uri, "path"): # file-like
bname = basename(uri.path)
else:
bname = basename(fn)
bname = basename(uri)
if concat:
# name based on basename until postfix or _
vname = bname.split(".")[0].replace(postfix, "").split("_")[0]
Expand Down

0 comments on commit 2ba86a4

Please sign in to comment.