Skip to content

Commit

Permalink
move to top level
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Nov 10, 2022
1 parent 6d65dcd commit 5a37c18
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
2 changes: 1 addition & 1 deletion mesmer/xarray_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa

from mesmer.xarray_utils.grid import stack_lat_lon, unstack_lat_lon_and_align
from mesmer.xarray_utils import grid
50 changes: 25 additions & 25 deletions mesmer/xarray_utils/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def stack_lat_lon(
obj,
data,
*,
x_dim="lon",
y_dim="lat",
Expand All @@ -16,7 +16,7 @@ def stack_lat_lon(
Parameters
----------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array to convert to an 1D grid.
x_dim : str, default: "lon"
Name of the x-dimension.
Expand All @@ -31,35 +31,35 @@ def stack_lat_lon(
Returns
-------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array converted to an 1D grid.
"""

dims = {stack_dim: (y_dim, x_dim)}

obj = obj.stack(dims)
data = data.stack(dims)

if not multiindex:
# there is a bug in xarray v2022.06 (Index refactor)
if Version(xr.__version__) == Version("2022.6"):
raise TypeError("There is a bug in xarray v2022.06. Please update xarray.")

obj = obj.reset_index(stack_dim)
data = data.reset_index(stack_dim)

if dropna:
obj = obj.dropna(stack_dim)
data = data.dropna(stack_dim)

return obj
return data


def unstack_lat_lon_and_align(
obj, coords_orig, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"
data, coords_orig, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"
):
"""unstack an 1D grid to a regular lat-lon grid and align with orignal coords
Parameters
----------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array with 1D grid to unstack and align.
coords_orig : xr.Dataset | xr.DataArray
xarray object containing the original coordinates before it was converted to the
Expand All @@ -73,23 +73,23 @@ def unstack_lat_lon_and_align(
Returns
-------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array converted to a regular lat-lon grid.
"""

obj = unstack_lat_lon(obj, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim)
data = unstack_lat_lon(data, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim)

obj = align_to_coords(obj, coords_orig)
data = align_to_coords(data, coords_orig)

return obj
return data


def unstack_lat_lon(obj, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"):
def unstack_lat_lon(data, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"):
"""unstack an 1D grid to a regular lat-lon grid but do not align
Parameters
----------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array with 1D grid to unstack and align.
x_dim : str, default: "lon"
Name of the x-dimension.
Expand All @@ -100,36 +100,36 @@ def unstack_lat_lon(obj, *, x_dim="lon", y_dim="lat", stack_dim="gridcell"):
Returns
-------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array converted to a regular lat-lon grid (unaligned).
"""

# a MultiIndex is needed to unstack
if not isinstance(obj.indexes.get(stack_dim), pd.MultiIndex):
obj = obj.set_index({stack_dim: (y_dim, x_dim)})
if not isinstance(data.indexes.get(stack_dim), pd.MultiIndex):
data = data.set_index({stack_dim: (y_dim, x_dim)})

return obj.unstack(stack_dim)
return data.unstack(stack_dim)


def align_to_coords(obj, coords_orig):
"""align an unstacked lat-lon grid with it's orignal coords
def align_to_coords(data, coords_orig):
"""align an unstacked lat-lon grid with its orignal coords
Parameters
----------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Unstacked array with lat-lon to align.
coords_orig : xr.Dataset | xr.DataArray
xarray object containing the original coordinates before it was converted to the
1D grid.
Returns
-------
obj : xr.Dataset | xr.DataArray
data : xr.Dataset | xr.DataArray
Array aligned with original grid.
"""

# ensure we don't loose entire rows/ columns
obj = xr.align(obj, coords_orig, join="right")[0]
data = xr.align(data, coords_orig, join="right")[0]

# make sure non-dimension coords are correct
return obj.assign_coords(coords_orig.coords)
return data.assign_coords(coords_orig.coords)
30 changes: 15 additions & 15 deletions tests/unit/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def data_2D_coords(as_dataset):
def test_to_unstructured_defaults(as_dataset):
da, expected = data_1D_coords(as_dataset)

result = mxu.stack_lat_lon(da)
result = mxu.grid.stack_lat_lon(da)

xr.testing.assert_identical(result, expected)

Expand All @@ -92,7 +92,7 @@ def test_to_unstructured_defaults(as_dataset):
def test_to_unstructured_multiindex(as_dataset):
da, expected = data_1D_coords(as_dataset)

result = mxu.stack_lat_lon(da, multiindex=True)
result = mxu.grid.stack_lat_lon(da, multiindex=True)

expected = expected.set_index({"gridcell": ("lat", "lon")})

Expand All @@ -108,7 +108,7 @@ def test_to_unstructured(x_dim, y_dim, cell_dim, as_dataset):
as_dataset, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim
)

result = mxu.stack_lat_lon(da, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim)
result = mxu.grid.stack_lat_lon(da, x_dim=x_dim, y_dim=y_dim, stack_dim=cell_dim)

xr.testing.assert_identical(result, expected)

Expand All @@ -117,7 +117,7 @@ def test_to_unstructured(x_dim, y_dim, cell_dim, as_dataset):
def test_to_unstructured_2D_coords(as_dataset):
da, expected = data_2D_coords(as_dataset)

result = mxu.stack_lat_lon(da, x_dim="x", y_dim="y")
result = mxu.grid.stack_lat_lon(da, x_dim="x", y_dim="y")

xr.testing.assert_identical(result, expected)

Expand All @@ -142,7 +142,7 @@ def test_to_unstructured_dropna(dropna, coords, time_pos):
if dropna:
expected = expected.dropna("gridcell")

result = mxu.stack_lat_lon(da, dropna=dropna, **kwargs)
result = mxu.grid.stack_lat_lon(da, dropna=dropna, **kwargs)

xr.testing.assert_identical(result, expected)

Expand All @@ -162,8 +162,8 @@ def test_unstructured_roundtrip_dropna_row(coords):
da_structured[:, :, 0] = np.NaN
expected = da_structured

da_unstructured = mxu.stack_lat_lon(da_structured, **kwargs)
result = mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig, **kwargs)
da_unstructured = mxu.grid.stack_lat_lon(da_structured, **kwargs)
result = mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig, **kwargs)

# roundtripping adds x & y coords - not sure if there is something to be done about
if coords == "2D":
Expand All @@ -178,7 +178,7 @@ def test_from_unstructured_defaults(as_dataset):

coords_orig = expected.coords.to_dataset()[["lon", "lat"]]

result = mxu.unstack_lat_lon_and_align(da, coords_orig)
result = mxu.grid.unstack_lat_lon_and_align(da, coords_orig)

xr.testing.assert_identical(result, expected)

Expand All @@ -193,7 +193,7 @@ def test_from_unstructured(x_dim, y_dim, stack_dim, as_dataset):
)

coords_orig = expected.coords.to_dataset()[[x_dim, y_dim]]
result = mxu.unstack_lat_lon_and_align(
result = mxu.grid.unstack_lat_lon_and_align(
da, coords_orig, x_dim=x_dim, y_dim=y_dim, stack_dim=stack_dim
)

Expand All @@ -207,13 +207,13 @@ def test_unstructured_roundtrip_1D_coords(as_dataset):

coords_orig = da_structured.coords.to_dataset()[["lon", "lat"]]

result = mxu.unstack_lat_lon_and_align(
mxu.stack_lat_lon(da_structured), coords_orig
result = mxu.grid.unstack_lat_lon_and_align(
mxu.grid.stack_lat_lon(da_structured), coords_orig
)
xr.testing.assert_identical(result, da_structured)

result = mxu.stack_lat_lon(
mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig)
result = mxu.grid.stack_lat_lon(
mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig)
)
xr.testing.assert_identical(result, da_unstructured)

Expand All @@ -228,7 +228,7 @@ def test_unstructured_roundtrip_2D_coords(as_dataset):
coords_orig = da_structured.coords.to_dataset()[["x", "y"]]
print(coords_orig)

result = mxu.stack_lat_lon(
mxu.unstack_lat_lon_and_align(da_unstructured, coords_orig, **dims), **dims
result = mxu.grid.stack_lat_lon(
mxu.grid.unstack_lat_lon_and_align(da_unstructured, coords_orig, **dims), **dims
)
xr.testing.assert_identical(result, da_unstructured)

0 comments on commit 5a37c18

Please sign in to comment.