diff --git a/docs/changelog.rst b/docs/changelog.rst index 4df4a69fa..8d333d8f1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -22,6 +22,7 @@ Fixed - Bug in `MeshModel.get_mesh` after xugrid update to 0.9.0. (#848) - Bug in `raster.clip_bbox` when bbox doesn't overlap with raster. (#860) - Allow for string format in zoom_level path, e.g. `{zoom_level:02d}` (#851) +- Fixed incorrect renaming of single variable raster datasets (#883) v0.9.4 (2024-02-26) =================== diff --git a/hydromt/data_adapter/rasterdataset.py b/hydromt/data_adapter/rasterdataset.py index 0202d0081..d82302c19 100644 --- a/hydromt/data_adapter/rasterdataset.py +++ b/hydromt/data_adapter/rasterdataset.py @@ -301,6 +301,7 @@ def get_data( """ try: # load data + variables = list([variables]) if isinstance(variables, str) else variables fns = self._resolve_paths( time_tuple, variables, zoom_level, geom, bbox, logger ) @@ -310,6 +311,7 @@ def get_data( geom, bbox, cache_root, + variables=variables, zoom_level=zoom_level, logger=logger, ) @@ -340,9 +342,10 @@ def get_data( ds = self._set_metadata(ds) # return array if single var and single_var_as_array return self._single_var_as_array(ds, single_var_as_array, variables) - except NoDataException: + except NoDataException as e: + postfix = f"({e.message})" if e.message else "" _exec_nodata_strat( - f"No data was read from source: {self.name}", + f"No data was read from source: {self.name} {postfix}", strategy=handle_nodata, logger=logger, ) @@ -350,7 +353,7 @@ def get_data( def _resolve_paths( self, time_tuple: Optional[TimeRange] = None, - variables: Optional[Variables] = None, + variables: Optional[List] = None, zoom_level: Optional[int] = 0, geom: Optional[Geom] = None, bbox: Optional[Bbox] = None, @@ -374,6 +377,7 @@ def _read_data( bbox: Optional[Bbox], cache_root: Optional[StrPath], zoom_level: Optional[int] = None, + variables: Optional[List] = None, logger: Logger = logger, ): kwargs = self.driver_kwargs.copy() @@ -422,6 +426,9 @@ def _read_data( # NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs kwargs.update(overview_level=zoom_level - 1) ds = io.open_mfraster(fns, logger=logger, **kwargs) + # rename ds with single band if single variable is requested + if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1: + ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]}) else: raise ValueError(f"RasterDataset: Driver {self.driver} unknown") @@ -470,7 +477,7 @@ def _set_crs(self, ds: Data, logger: Logger = logger) -> Data: @staticmethod def _slice_data( ds: Data, - variables: Optional[Variables] = None, + variables: Optional[List] = None, geom: Optional[Geom] = None, bbox: Optional[Bbox] = None, buffer: GeomBuffer = 0, @@ -511,12 +518,10 @@ def _slice_data( ds.name = "data" ds = ds.to_dataset() elif variables is not None: - variables = np.atleast_1d(variables).tolist() - if len(variables) > 1 or len(ds.data_vars) > 1: - mvars = [var not in ds.data_vars for var in variables] - if any(mvars): - raise ValueError(f"RasterDataset: variables not found {mvars}") - ds = ds[variables] + mvars = [var for var in variables if var not in ds.data_vars] + if len(mvars) > 0: + raise NoDataException(f"Variables {mvars} not found in data.") + ds = ds[variables] if time_tuple is not None: ds = RasterDatasetAdapter._slice_temporal_dimension( ds, diff --git a/tests/test_data_adapter.py b/tests/test_data_adapter.py index afded1754..f862c60f9 100644 --- a/tests/test_data_adapter.py +++ b/tests/test_data_adapter.py @@ -73,12 +73,21 @@ def test_rasterdataset(rioda, tmpdir): fn_tif = str(tmpdir.join("test.tif")) rioda_utm = rioda.raster.reproject(dst_crs="utm") rioda_utm.raster.to_raster(fn_tif) + fn_nc = str(tmpdir.join("test.nc")) + rioda_utm.to_netcdf(fn_nc) data_catalog = DataCatalog() + da1 = data_catalog.get_rasterdataset(fn_tif, bbox=rioda.raster.bounds) + assert da1.name == "test" # name is taken from file name assert np.all(da1 == rioda_utm) - geom = rioda.raster.box - da1 = data_catalog.get_rasterdataset("test.tif", geom=geom) + assert "test.tif" in data_catalog.sources + + da1 = data_catalog.get_rasterdataset("test.tif", geom=rioda.raster.box) assert np.all(da1 == rioda_utm) + + ds1 = data_catalog.get_rasterdataset(fn_tif, single_var_as_array=False) + assert isinstance(ds1, xr.Dataset) # test single_var_as_array=False + with pytest.raises(FileNotFoundError): data_catalog.get_rasterdataset("no_file.tif") with pytest.raises(NoDataException): @@ -90,9 +99,13 @@ def test_rasterdataset(rioda, tmpdir): bbox=[12.5, 12.6, 12.7, 12.8], handle_nodata=NoDataStrategy.IGNORE, ) - assert da1 is None + da1 = data_catalog.get_rasterdataset(fn_tif, variables=["temp"]) + assert da1.name == "temp" # tif is renamed to variable name + with pytest.raises(NoDataException): # nc variables are not renamed + da1 = data_catalog.get_rasterdataset(fn_nc, variables=["temp"]) + @pytest.mark.skipif(not compat.HAS_GCSFS, reason="GCSFS not installed.") def test_gcs_cmip6(tmpdir):