Skip to content

Commit

Permalink
include #883
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaapel committed May 2, 2024
1 parent f3a703d commit 89e6374
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
17 changes: 9 additions & 8 deletions hydromt/data_adapter/rasterdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,20 @@ def _slice_data(
ds : xarray.Dataset
The sliced RasterDataset.
"""
if isinstance(ds, xr.DataArray):
if isinstance(ds, xr.DataArray): # xr.DataArray has no variables
if ds.name is None:
# dummy name, required to create dataset
# renamed to variable in _single_var_as_array
ds.name = "data"
ds = ds.to_dataset()
elif variables is not None:
variables = np.atleast_1d(variables).tolist()
if len(variables) > 1 or len(ds.data_vars) > 1:
mvars = [var not in ds.data_vars for var in variables]
if any(mvars):
raise ValueError(f"RasterDataset: variables not found {mvars}")
ds = ds[variables]
elif variables is not None: # xr.Dataset has variables
# cast variables to list
variables = list([variables]) if isinstance(variables, str) else variables
mvars = [var for var in variables if var not in ds.data_vars]
if len(mvars) > 0:
raise NoDataException(f"Variables {mvars} not found in data.")
ds = ds[variables]

if time_tuple is not None:
ds = RasterDatasetAdapter._slice_temporal_dimension(
ds,
Expand Down
10 changes: 8 additions & 2 deletions hydromt/drivers/raster_xarray_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import xarray as xr

from hydromt._typing import Geom, StrPath, TimeRange, ZoomLevel
from hydromt._typing import Geom, StrPath, TimeRange, Variables, ZoomLevel
from hydromt._typing.error import NoDataStrategy
from hydromt._utils.unused_kwargs import warn_on_unused_kwargs
from hydromt.drivers.preprocessing import PREPROCESSORS
Expand All @@ -26,6 +26,7 @@ def read_data(
*,
logger: Logger,
mask: Optional[Geom] = None,
variables: Optional[Variables] = None,
time_range: Optional[TimeRange] = None,
zoom_level: Optional[ZoomLevel] = None,
handle_nodata: NoDataStrategy = NoDataStrategy.RAISE,
Expand All @@ -38,7 +39,12 @@ def read_data(
"""
warn_on_unused_kwargs(
self.__class__.__name__,
{"mask": mask, "time_range": time_range, "zoom_level": zoom_level},
{
"mask": mask,
"time_range": time_range,
"variables": variables,
"zoom_level": zoom_level,
},
logger,
)
options = copy(self.options)
Expand Down
6 changes: 4 additions & 2 deletions hydromt/drivers/rasterdataset_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import xarray as xr

from hydromt._typing import Geom, StrPath, TimeRange, ZoomLevel
from hydromt._typing import Geom, StrPath, TimeRange, Variables, ZoomLevel
from hydromt._typing.error import NoDataStrategy

from .base_driver import BaseDriver
Expand All @@ -20,7 +20,7 @@ def read(
uri: str,
*,
mask: Optional[Geom] = None,
variables: Optional[List[str]] = None,
variables: Optional[Variables] = None,
time_range: Optional[TimeRange] = None,
zoom_level: Optional[ZoomLevel] = None,
logger: Optional[Logger] = None,
Expand Down Expand Up @@ -48,6 +48,7 @@ def read(
uris,
mask=mask,
time_range=time_range,
variables=variables,
zoom_level=zoom_level,
logger=logger,
handle_nodata=handle_nodata,
Expand All @@ -59,6 +60,7 @@ def read_data(
uris: List[str],
*,
mask: Optional[Geom] = None,
variables: Optional[Variables] = None,
time_range: Optional[TimeRange] = None,
zoom_level: Optional[ZoomLevel] = None,
logger: Logger,
Expand Down
10 changes: 9 additions & 1 deletion hydromt/drivers/rasterio_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import xarray as xr
from pyproj import CRS

from hydromt._typing import Geom, StrPath, TimeRange, ZoomLevel
from hydromt._typing import Geom, StrPath, TimeRange, Variables, ZoomLevel
from hydromt._typing.error import NoDataStrategy
from hydromt._utils.unused_kwargs import warn_on_unused_kwargs
from hydromt._utils.uris import strip_scheme
from hydromt.config import SETTINGS
from hydromt.data_adapter.caching import cache_vrt_tiles
Expand All @@ -33,12 +34,16 @@ def read_data(
*,
mask: Optional[Geom] = None,
time_range: Optional[TimeRange] = None,
variables: Optional[Variables] = None,
zoom_level: Optional[ZoomLevel] = None,
logger: Logger,
handle_nodata: NoDataStrategy = NoDataStrategy.RAISE,
) -> xr.Dataset:
"""Read data using rasterio."""
# build up kwargs for open_raster
warn_on_unused_kwargs(
self.__class__.__name__, {"time_range": time_range}, logger=logger
)
kwargs: Dict[str, Any] = {}

# get source-specific options
Expand Down Expand Up @@ -73,6 +78,9 @@ def read_data(
# NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs
kwargs.update(overview_level=zoom_level - 1)
ds = open_mfraster(uris, logger=logger, **kwargs)
# rename ds with single band if single variable is requested
if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1:
ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]})
return ds

def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/gis/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from affine import Affine
from shapely.geometry import LineString, Point, box

from hydromt.drivers.rasterio_driver import open_raster
from hydromt.gis import raster, utils
from hydromt.io import open_raster

# origin, rotation, res, shape, internal_bounds
# NOTE a rotated grid with a negative dx is not supported
Expand Down

0 comments on commit 89e6374

Please sign in to comment.