Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass metrics to xgcm.Grid by default. #133

Merged
merged 5 commits into from
Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions pop_tools/xgcm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,39 @@ def relabel_pop_dims(ds):
if coord in ds_new.coords:
ds_new = ds_new.drop_vars(coord)
if 'z_w_top' in ds_new.dims and 'z_w' in ds_new.dims:
ds_new = ds_new.drop('z_w_top').rename({'z_w': 'z_w_top'})
ds_new = ds_new.drop_vars('z_w_top').rename({'z_w': 'z_w_top'})
return ds_new


def to_xgcm_grid_dataset(ds, **kwargs):
def get_metrics(ds):
"""Finds metrics variables present in `ds`, returns a dict that can be passed to xgcm."""
metrics = {
('X',): ['DXU', 'DXT'], # X distances
('Y',): ['DYU', 'DYT'], # Y distances
('Z',): ['DZU', 'DZT'], # Z distances
('X', 'Y'): ['UAREA', 'TAREA'], # Areas
}
# filter to variables that are present
new_metrics = {}
for axis, names in metrics.items():
new_names = [name for name in names if name in ds]
if new_names:
new_metrics[axis] = new_names
return new_metrics


def to_xgcm_grid_dataset(ds, metrics='detect', **kwargs):
"""Modify POP model output to be compatible with xgcm.

Parameters
----------
ds : xarray.Dataset
An xarray Dataset
metrics : {"detect"} or dict, optional
Dictionary providing metrics to the `xgcm.Grid` contructor.
If ``"detect"``, will autodetect metrics that are present by searching for
variables named DXU, DXT, DYU, DYT, DZU, DZT, UAREA, TAREA.
If None, no metrics will be assigned.
kwargs:
Additional keyword arguments are passed through to `xgcm.Grid` class.

Expand Down Expand Up @@ -204,5 +226,7 @@ def to_xgcm_grid_dataset(ds, **kwargs):
"""to_xgcm_grid_dataset() function requires the `xgcm` package. \nYou can install it via PyPI or Conda"""
)
ds_new = relabel_pop_dims(ds)
grid = xgcm.Grid(ds_new, **kwargs)
if metrics == 'detect':
metrics = get_metrics(ds_new)
grid = xgcm.Grid(ds_new, metrics=metrics, **kwargs)
return grid, ds_new
52 changes: 50 additions & 2 deletions tests/test_xgcm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
],
)
def test_to_xgcm_grid_dataset(ds, old_spatial_coords, axes):
grid, ds_new = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
grid, ds_new = pop_tools.to_xgcm_grid_dataset(ds)
assert isinstance(grid, xgcm.Grid)
assert set(axes) == set(grid.axes.keys())
new_spatial_coords = ['nlon_u', 'nlat_u', 'nlon_t', 'nlat_t']
Expand All @@ -49,4 +49,52 @@ def test_to_xgcm_grid_dataset_missing_xgcm():
with mock.patch.dict(sys.modules, {'xgcm': None}):
filepath = DATASETS.fetch('tend_zint_100m_Fe.nc')
ds = xr.open_dataset(filepath)
_, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
_, _ = pop_tools.to_xgcm_grid_dataset(ds)


def test_set_metrics():
from pop_tools.xgcm_util import get_metrics

ds = xr.Dataset({'DXU': 1, 'DYT': 2, 'DZT': 3})
actual = get_metrics(ds)
expected = {('X',): ['DXU'], ('Y',): ['DYT'], ('Z',): ['DZT']}
assert actual == expected

assert not get_metrics(xr.Dataset({}))


def test_metrics_assignment_no_metrics():
grid, _ = pop_tools.to_xgcm_grid_dataset(ds_c)
assert not grid._metrics


def get_metrics(grid):
return {
tuple(sorted(key)): [metric.name for metric in metrics]
for key, metrics in grid._metrics.items()
}


@pytest.mark.parametrize('ds', [ds_a, ds_b])
def test_metrics_assignment(ds):
grid, _ = pop_tools.to_xgcm_grid_dataset(ds)
expected = {
('X',): ['DXU', 'DXT'], # X distances
('Y',): ['DYU', 'DYT'], # Y distances
('X', 'Y'): ['UAREA', 'TAREA'], # Areas
}

if 'S_FLUX_ROFF_VSF' in ds:
expected[('X', 'Y')] = ['TAREA']
expected[('X',)] = ['DXU']

actual = get_metrics(grid)
assert actual == expected

grid, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics=None)
assert not grid._metrics

expected = {('X',): ['DXU']}
grid, _ = pop_tools.to_xgcm_grid_dataset(ds, metrics={'X': ['DXU']})
actual = get_metrics(grid)
assert actual == expected