diff --git a/mesmer/xarray_utils/__init__.py b/mesmer/xarray_utils/__init__.py index 4c098d44..3febad11 100644 --- a/mesmer/xarray_utils/__init__.py +++ b/mesmer/xarray_utils/__init__.py @@ -1,3 +1,3 @@ # flake8: noqa -from mesmer.xarray_utils.grid import stack_lat_lon, unstack_lat_lon_and_align +from mesmer.xarray_utils import grid diff --git a/mesmer/xarray_utils/grid.py b/mesmer/xarray_utils/grid.py index b4b35919..ce333397 100644 --- a/mesmer/xarray_utils/grid.py +++ b/mesmer/xarray_utils/grid.py @@ -4,7 +4,7 @@ def stack_lat_lon( - obj, + data, *, x_dim="lon", y_dim="lat", @@ -16,7 +16,7 @@ def stack_lat_lon( Parameters ---------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array to convert to an 1D grid. x_dim : str, default: "lon" Name of the x-dimension. @@ -31,35 +31,35 @@ def stack_lat_lon( Returns ------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array converted to an 1D grid. """ dims = {stack_dim: (y_dim, x_dim)} - obj = obj.stack(dims) + data = data.stack(dims) if not multiindex: # there is a bug in xarray v2022.06 (Index refactor) if Version(xr.__version__) == Version("2022.6"): raise TypeError("There is a bug in xarray v2022.06. Please update xarray.") - obj = obj.reset_index(stack_dim) + data = data.reset_index(stack_dim) if dropna: - obj = obj.dropna(stack_dim) + data = data.dropna(stack_dim) - return obj + return data def unstack_lat_lon_and_align( - obj, coords_orig, *, x_dim="lon", y_dim="lat", stack_dim="gridcell" + data, coords_orig, *, x_dim="lon", y_dim="lat", stack_dim="gridcell" ): """unstack an 1D grid to a regular lat-lon grid and align with orignal coords Parameters ---------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array with 1D grid to unstack and align. coords_orig : xr.Dataset | xr.DataArray xarray object containing the original coordinates before it was converted to the @@ -73,23 +73,23 @@ def unstack_lat_lon_and_align( Returns ------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array converted to a regular lat-lon grid. """ - obj = unstack_lat_lon(obj, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim) + data = unstack_lat_lon(data, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim) - obj = align_to_coords(obj, coords_orig) + data = align_to_coords(data, coords_orig) - return obj + return data -def unstack_lat_lon(obj, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"): +def unstack_lat_lon(data, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"): """unstack an 1D grid to a regular lat-lon grid but do not align Parameters ---------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array with 1D grid to unstack and align. x_dim : str, default: "lon" Name of the x-dimension. @@ -100,23 +100,23 @@ def unstack_lat_lon(obj, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"): Returns ------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array converted to a regular lat-lon grid (unaligned). """ # a MultiIndex is needed to unstack - if not isinstance(obj.indexes.get(stack_dim), pd.MultiIndex): - obj = obj.set_index({stack_dim: (y_dim, x_dim)}) + if not isinstance(data.indexes.get(stack_dim), pd.MultiIndex): + data = data.set_index({stack_dim: (y_dim, x_dim)}) - return obj.unstack(stack_dim) + return data.unstack(stack_dim) -def align_to_coords(obj, coords_orig): - """align an unstacked lat-lon grid with it's orignal coords +def align_to_coords(data, coords_orig): + """align an unstacked lat-lon grid with its orignal coords Parameters ---------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Unstacked array with lat-lon to align. coords_orig : xr.Dataset | xr.DataArray xarray object containing the original coordinates before it was converted to the @@ -124,12 +124,12 @@ def align_to_coords(obj, coords_orig): Returns ------- - obj : xr.Dataset | xr.DataArray + data : xr.Dataset | xr.DataArray Array aligned with original grid. """ # ensure we don't loose entire rows/ columns - obj = xr.align(obj, coords_orig, join="right")[0] + data = xr.align(data, coords_orig, join="right")[0] # make sure non-dimension coords are correct - return obj.assign_coords(coords_orig.coords) + return data.assign_coords(coords_orig.coords) diff --git a/tests/unit/test_grid.py b/tests/unit/test_grid.py index c133147e..7a0d3df8 100644 --- a/tests/unit/test_grid.py +++ b/tests/unit/test_grid.py @@ -83,7 +83,7 @@ def data_2D_coords(as_dataset): def test_to_unstructured_defaults(as_dataset): da, expected = data_1D_coords(as_dataset) - result = mxu.stack_lat_lon(da) + result = mxu.grid.stack_lat_lon(da) xr.testing.assert_identical(result, expected) @@ -92,7 +92,7 @@ def test_to_unstructured_defaults(as_dataset): def test_to_unstructured_multiindex(as_dataset): da, expected = data_1D_coords(as_dataset) - result = mxu.stack_lat_lon(da, multiindex=True) + result = mxu.grid.stack_lat_lon(da, multiindex=True) expected = expected.set_index({"gridcell": ("lat", "lon")}) @@ -108,7 +108,7 @@ def test_to_unstructured(x_dim, y_dim, cell_dim, as_dataset): as_dataset, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim ) - result = mxu.stack_lat_lon(da, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim) + result = mxu.grid.stack_lat_lon(da, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim) xr.testing.assert_identical(result, expected) @@ -117,7 +117,7 @@ def test_to_unstructured(x_dim, y_dim, cell_dim, as_dataset): def test_to_unstructured_2D_coords(as_dataset): da, expected = data_2D_coords(as_dataset) - result = mxu.stack_lat_lon(da, x_dim="x", y_dim="y") + result = mxu.grid.stack_lat_lon(da, x_dim="x", y_dim="y") xr.testing.assert_identical(result, expected) @@ -142,7 +142,7 @@ def test_to_unstructured_dropna(dropna, coords, time_pos): if dropna: expected = expected.dropna("gridcell") - result = mxu.stack_lat_lon(da, dropna=dropna, **kwargs) + result = mxu.grid.stack_lat_lon(da, dropna=dropna, **kwargs) xr.testing.assert_identical(result, expected) @@ -162,8 +162,8 @@ def test_unstructured_roundtrip_dropna_row(coords): da_structured[:, :, 0] = np.NaN expected = da_structured - da_unstructured = mxu.stack_lat_lon(da_structured, **kwargs) - result = mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig, **kwargs) + da_unstructured = mxu.grid.stack_lat_lon(da_structured, **kwargs) + result = mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig, **kwargs) # roundtripping adds x & y coords - not sure if there is something to be done about if coords == "2D": @@ -178,7 +178,7 @@ def test_from_unstructured_defaults(as_dataset): coords_orig = expected.coords.to_dataset()[["lon", "lat"]] - result = mxu.unstack_lat_lon_and_align(da, coords_orig) + result = mxu.grid.unstack_lat_lon_and_align(da, coords_orig) xr.testing.assert_identical(result, expected) @@ -193,7 +193,7 @@ def test_from_unstructured(x_dim, y_dim, stack_dim, as_dataset): ) coords_orig = expected.coords.to_dataset()[[x_dim, y_dim]] - result = mxu.unstack_lat_lon_and_align( + result = mxu.grid.unstack_lat_lon_and_align( da, coords_orig, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim ) @@ -207,13 +207,13 @@ def test_unstructured_roundtrip_1D_coords(as_dataset): coords_orig = da_structured.coords.to_dataset()[["lon", "lat"]] - result = mxu.unstack_lat_lon_and_align( - mxu.stack_lat_lon(da_structured), coords_orig + result = mxu.grid.unstack_lat_lon_and_align( + mxu.grid.stack_lat_lon(da_structured), coords_orig ) xr.testing.assert_identical(result, da_structured) - result = mxu.stack_lat_lon( - mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig) + result = mxu.grid.stack_lat_lon( + mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig) ) xr.testing.assert_identical(result, da_unstructured) @@ -228,7 +228,7 @@ def test_unstructured_roundtrip_2D_coords(as_dataset): coords_orig = da_structured.coords.to_dataset()[["x", "y"]] print(coords_orig) - result = mxu.stack_lat_lon( - mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig, **dims), **dims + result = mxu.grid.stack_lat_lon( + mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig, **dims), **dims ) xr.testing.assert_identical(result, da_unstructured)