Skip to content

Commit

Permalink
fix mask_ocean for 2D grids (#314)
Browse files Browse the repository at this point in the history
* fix mask_ocean for 2D grids

* rename to coords

* CHANGELOG
  • Loading branch information
mathause authored Sep 29, 2023
1 parent 91b184b commit b30b6b1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ New Features
- Added functions to stack regular lat-lon grids to 1D grids and unstack them again (`#217
<https://github.com/MESMER-group/mesmer/pull/217>`_). By `Mathias Hauser`_.
- Added functions to mask the ocean and Antarctica (`#219
<https://github.com/MESMER-group/mesmer/pull/219>`_). By `Mathias Hauser`_.
- Added functions to mask the ocean and Antarctica (
`#219 <https://github.com/MESMER-group/mesmer/pull/219>`_ and
`#314 <https://github.com/MESMER-group/mesmer/pull/314>`_). By `Mathias Hauser`_.
- Added functions to calculate the weighted global mean
(`#220 <https://github.com/MESMER-group/mesmer/pull/220>`_ and
`#287 <https://github.com/MESMER-group/mesmer/pull/287>`_). By `Mathias Hauser`_.
Expand Down
10 changes: 5 additions & 5 deletions mesmer/core/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import mesmer


def _where_if_dim(obj, cond, dims):
def _where_if_coords(obj, cond, coords):

# xarray applies where to all data_vars - even if they do not have the corresponding
# dimensions - we don't want that https://github.com/pydata/xarray/issues/7027

def _where(da):
if all(dim in da.dims for dim in dims):
if all(coord in da.coords for coord in coords):
return da.where(cond)
return da

Expand Down Expand Up @@ -71,7 +71,7 @@ def mask_ocean_fraction(data, threshold, *, x_coords="lon", y_coords="lat"):
mask_bool = mask_fraction > threshold

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])
return _where_if_coords(data, mask_bool, [y_coords, x_coords])


def mask_ocean(data, *, x_coords="lon", y_coords="lat"):
Expand Down Expand Up @@ -106,7 +106,7 @@ def mask_ocean(data, *, x_coords="lon", y_coords="lat"):
mask_bool = mask_bool.squeeze(drop=True)

# only mask data_vars that have the coords
return _where_if_dim(data, mask_bool, [y_coords, x_coords])
return _where_if_coords(data, mask_bool, [y_coords, x_coords])


def mask_antarctica(data, *, y_coords="lat"):
Expand All @@ -132,4 +132,4 @@ def mask_antarctica(data, *, y_coords="lat"):
mask_bool = data[y_coords] >= -60

# only mask if data has y_coords
return _where_if_dim(data, mask_bool, [y_coords])
return _where_if_coords(data, mask_bool, [y_coords])
24 changes: 24 additions & 0 deletions tests/unit/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,27 @@ def test_mask_antarctiva_default(
def test_mask_antarctiva(as_dataset, y_coords):

_test_mask(mesmer.mask.mask_antarctica, as_dataset, y_coords=y_coords)


def test_mask_ocean_2D_grid():

lon = lat = np.arange(0, 30)
LON, LAT = np.meshgrid(lon, lat)

dims = ("rlat", "rlon")

data = np.random.randn(*LON.shape)

data_2D_grid = xr.Dataset(
{"data": (dims, data)}, coords={"lon": (dims, LON), "lat": (dims, LAT)}
)

data_1D_grid = xr.Dataset(
{"data": (("lat", "lon"), data)}, coords={"lon": lon, "lat": lat}
)

result = mesmer.mask.mask_ocean(data_2D_grid)
expected = mesmer.mask.mask_ocean(data_1D_grid)

# the Datasets don't have equal coords but their arrays should be the same
np.testing.assert_equal(result.data.values, expected.data.values)

0 comments on commit b30b6b1

Please sign in to comment.