Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update get_bounds() to support mappable non-CF axes using "bounds" attr #708

Merged
merged 8 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ docs: ## generate Sphinx HTML documentation, including API docs
# Build
# ----------------------
install: clean ## install the package to the active Python's site-packages
python setup.py install
python -m pip install .
15 changes: 8 additions & 7 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Below is a list of top-level API functions that are available in ``xcdat``.
compare_datasets
get_dim_coords
get_dim_keys
create_bounds
create_axis
create_gaussian_grid
create_global_mean_grid
Expand Down Expand Up @@ -83,13 +84,13 @@ Classes
.. autosummary::
:toctree: generated/

xcdat.bounds.BoundsAccessor
xcdat.spatial.SpatialAccessor
xcdat.temporal.TemporalAccessor
xcdat.regridder.accessor.RegridderAccessor
xcdat.regridder.regrid2.Regrid2Regridder
xcdat.regridder.xesmf.XESMFRegridder
xcdat.regridder.xgcm.XGCMRegridder
bounds.BoundsAccessor
spatial.SpatialAccessor
temporal.TemporalAccessor
regridder.accessor.RegridderAccessor
regridder.regrid2.Regrid2Regridder
regridder.xesmf.XESMFRegridder
regridder.xgcm.XGCMRegridder

.. currentmodule:: xarray

Expand Down
55 changes: 55 additions & 0 deletions tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,27 @@ def test_returns_single_dataset_axis_bounds_as_a_dataarray_object(self):

assert result.identical(expected)

def test_returns_single_dataset_axis_bounds_as_a_dataarray_object_for_non_cf_axis(
self,
):
ds = xr.Dataset(
coords={
"lat": xr.DataArray(
data=np.ones(3),
dims="lat",
attrs={"bounds": "lat_bnds"},
)
},
data_vars={
"lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"])
},
)

result = ds.bounds.get_bounds("Y")
expected = ds.lat_bnds

assert result.identical(expected)

def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self):
ds = xr.Dataset(
coords={
Expand Down Expand Up @@ -321,6 +342,40 @@ def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object(self):

assert result.identical(expected)

def test_returns_multiple_dataset_axis_bounds_as_a_dataset_object_for_non_cf_axis(
self,
):
ds = xr.Dataset(
coords={
"lat": xr.DataArray(
data=np.ones(3),
dims="lat",
attrs={
"bounds": "lat_bnds",
},
),
"latitude": xr.DataArray(
data=np.ones(3),
dims="latitude",
attrs={
"bounds": "latitude_bnds",
},
),
},
data_vars={
"var": xr.DataArray(data=np.ones(3), dims=["lat"]),
"lat_bnds": xr.DataArray(data=np.ones((3, 3)), dims=["lat", "bnds"]),
"latitude_bnds": xr.DataArray(
data=np.ones((3, 3)), dims=["latitude", "bnds"]
),
},
)

result = ds.bounds.get_bounds("Y")
expected = ds.drop_vars("var")

assert result.identical(expected)


class TestAddBounds:
@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion xcdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
get_dim_keys,
swap_lon_axis,
)
from xcdat.bounds import BoundsAccessor # noqa: F401
from xcdat.bounds import BoundsAccessor, create_bounds # noqa: F401
from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401
from xcdat.regridder.accessor import RegridderAccessor # noqa: F401
from xcdat.regridder.grid import ( # noqa: F401
Expand Down
58 changes: 46 additions & 12 deletions xcdat/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,7 @@ def get_bounds(
else:
# Get the obj in the Dataset using the key.
obj = _get_data_var(self._dataset, key=var_key)

# Check if the object is a data variable or a coordinate variable.
# If it is a data variable, derive the axis coordinate variable.
if obj.name in list(self._dataset.data_vars):
coord = get_dim_coords(obj, axis)
elif obj.name in list(self._dataset.coords):
coord = obj

try:
bounds_keys = [coord.attrs["bounds"]]
except KeyError:
bounds_keys = []
bounds_keys = self._get_bounds_from_attr(obj, axis)

if len(bounds_keys) == 0:
raise KeyError(
Expand Down Expand Up @@ -505,8 +494,53 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]:
except KeyError:
pass

keys_from_attr = self._get_bounds_from_attr(self._dataset, axis)
keys = keys + keys_from_attr

return list(set(keys))

def _get_bounds_from_attr(
self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey
) -> List[str]:
"""Retrieve bounds attribute keys from the given xarray object.

This method extracts the "bounds" attribute keys from the coordinates
of the specified axis in the provided xarray DataArray or Dataset.

Parameters:
-----------
obj : xr.DataArray | xr.Dataset
The xarray object from which to retrieve the bounds attribute keys.
axis : CFAxisKey
The CF axis key ("X", "Y", "T", or "Z").

Returns:
--------
List[str]
A list of bounds attribute keys found in the coordinates of the
specified axis. Otherwise, an empty list is returned.
"""
coords_obj = get_dim_coords(obj, axis)
bounds_keys: List[str] = []

if isinstance(coords_obj, xr.DataArray):
bounds_keys = self._extract_bounds_key(coords_obj, bounds_keys)
elif isinstance(coords_obj, xr.Dataset):
for coord in coords_obj.coords.values():
bounds_keys = self._extract_bounds_key(coord, bounds_keys)

return bounds_keys

def _extract_bounds_key(
self, coords_obj: xr.DataArray, bounds_keys: List[str]
) -> List[str]:
bnds_key = coords_obj.attrs.get("bounds")

if bnds_key is not None:
bounds_keys.append(bnds_key)

return bounds_keys

def _create_time_bounds( # noqa: C901
self,
time: xr.DataArray,
Expand Down
Loading