diff --git a/docs/changelog.rst b/docs/changelog.rst index 65a332863..0373c4e99 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,7 @@ Fixed ----- - Bug in `raster.transform` with lazy coordinates. (#801) - Bug in `workflows.mesh.mesh2d_from_rasterdataset` with multi-dimensional coordinates. (#843) +- Bug in `MeshModel.get_mesh` after xugrid update to 0.9.0. (#848) v0.9.4 (2024-02-26) diff --git a/hydromt/models/model_mesh.py b/hydromt/models/model_mesh.py index e3ad83e2e..3eb4696c1 100644 --- a/hydromt/models/model_mesh.py +++ b/hydromt/models/model_mesh.py @@ -382,19 +382,29 @@ def get_mesh( if grid_name not in self.mesh_names: raise ValueError(f"Grid {grid_name} not found in mesh.") if include_data: - grid = self.mesh_grids[grid_name] - uds = xu.UgridDataset(grid.to_dataset(optional_attributes=True)) - uds.ugrid.grid.set_crs(grid.crs) # Look for data_vars that are defined on grid_name + variables = [] for var in self.mesh.data_vars: if hasattr(self.mesh[var], "ugrid"): - if self.mesh[var].ugrid.grid.name == grid_name: - uds[var] = self.mesh[var] - # additionnal topology properties - elif var.startswith(grid_name): - uds[var] = self.mesh[var] + if self.mesh[var].ugrid.grid.name != grid_name: + variables.append(var) + # additional topology properties + elif not var.startswith(grid_name): + variables.append(var) # else is global property (not grid specific) + if variables and len(variables) < len(self.mesh.data_vars): + uds = self.mesh.drop_vars(variables) + # Drop coords as well + drop_coords = [c for c in uds.coords if not c.startswith(grid_name)] + uds = uds.drop_vars(drop_coords) + elif variables and len(variables) == len(self.mesh.data_vars): + grid = self.mesh_grids[grid_name] + uds = xu.UgridDataset(grid.to_dataset(optional_attributes=True)) + uds.ugrid.grid.set_crs(grid.crs) + else: + uds = self.mesh.copy() + return uds else: diff --git a/tests/test_model.py b/tests/test_model.py index 790b1b3c9..975de27fc 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -840,5 +840,7 @@ def test_meshmodel_setup(griduda, world): resampling_method=["mode", "centroid"], grid_name="mesh2d", ) + ds_mesh2d = mod1.get_mesh("mesh2d", include_data=True) + assert "vito" in ds_mesh2d assert "roughness_manning" in mod1.mesh.data_vars assert np.all(mod1.mesh["landuse"].values == mod1.mesh["vito"].values)