Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some CF support for lon and lat #73

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions xesmf/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Decode longitude and latitude of an `xarray.Dataset` or a `xarray.DataArray`
according to CF conventions.
"""

import warnings

CF_SPECS = {
'lon': {
'name': ['lon', 'longitude'],
'units': ['degrees_east', 'degree_east', 'degree_E', 'degrees_E',
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking the units and other attributes seems a bit too heavy for xesmf? A "CF-checker" is better for a standalone package or pushed to upstream like xarray.decode_cf

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that a CF checker for xarray is globally missing.
However, xarray is intended to remain of general purpose, and not to introduce notions of geographic and physical spatialisation.

As for the fact that these attributes are checked is the strict application of the CF conventions for longitude and latitude.
This prevent focusing on names.

'degreeE', 'degreesE'],
'standard_name': ['longitude'],
'axis': 'X'
},
'lat': {
'name': ['lat', 'latitude'],
'units': ['degrees_north', 'degree_north', 'degree_N', 'degrees_N',
'degreeN', 'degreesN'],
'standard_name': ['latitude'],
'axis': 'Y'
}
}


def get_coord_name_from_specs(ds, specs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is certainly a more extensive search than what #74 provides. The question is do we want this ? Also, from a user perspective, all the different ways coordinates can be identified might create some confusion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it is the strict application of the CF conventions, which must be considered to be the reference for formatting dataset, and which are increasingly being followed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Checking with units first since they are mandatory. And then standard_name, just in case.

'''
Get the name of a `xarray.DataArray` according to search specifications

Parameters
----------
ds: xarray.DataArray or xarray.Dataset
specs: dict

Returns
-------
str: Name of the xarray.DataArray
'''
# Targets
ds_names = []
if hasattr(ds, 'coords'):
ds_names.extend(list(ds.coords))
if hasattr(ds, 'variables'):
ds_names.extend(list(ds.variables))

# Search in names first
if 'name' in specs:
for specs_name in specs['name']:
if specs_name in ds_names:
return specs_name

# Search in units attributes
if 'units' in specs:
for ds_name in ds_names:
if hasattr(ds[ds_name], 'units'):
for specs_units in specs['units']:
if ds[ds_name].units == specs_units:
return ds_name

# Search in standard_name attributes
if 'standard_name' in specs:
for ds_name in ds_names:
if hasattr(ds[ds_name], 'standard_name'):
for specs_sn in specs['standard_name']:
if (ds[ds_name].standard_name == specs_sn or
ds[ds_name].standard_name.startswith(
specs_sn+'_at_')):
return ds_name

# Search in axis attributes
if 'axis' in specs:
for ds_name in ds_names:
if (hasattr(ds[ds_name], 'axis') and
ds[ds_name].axis.lower() == specs['axis'].lower()):
return ds_name


def get_lon_name(ds):
'''
Get the longitude name in a `xarray.Dataset` or a `xarray.DataArray`

Parameters
----------
ds: xarray.DataArray or xarray.Dataset

Returns
-------
str: Name of the xarray.DataArray
'''
lon_name = get_coord_name_from_specs(ds, CF_SPECS['lon'])
if lon_name is not None:
return lon_name
warnings.warn('longitude not found in dataset')


def get_lat_name(ds):
'''
Get the latitude name in a `xarray.Dataset` or a `xarray.DataArray`

Parameters
----------
ds: xarray.DataArray or xarray.Dataset

Returns
-------
str or None: Name of the xarray.DataArray
'''
lat_name = get_coord_name_from_specs(ds, CF_SPECS['lat'])
if lat_name is not None:
return lat_name
warnings.warn('latitude not found in dataset')


def get_bounds_name_from_coord(ds, coord_name,
suffixes=['_b', '_bnds', '_bounds']):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could get really confusing, because this function could either return netCDF bounds or the grid corners (_b). I suggest the code clearly distinguishes bounds from corners.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep probably. I made no difference in this PR regarding the bounds/corners problem.

'''
Get the name of the bounds array from the coord array

It first searches for the 'bounds' attributes, then search
for names built from the suffixed coord name.

Parameters
----------
ds: xarray.DataArray or xarray.Dataset
coord_name: str
Name of coord DataArray.
suffixes: list of str
Prefixes appended to `coord_name` to search for the bounds array name.

Returns
-------
str or None: Name of the xarray.DataArray
'''

# Inits
coord = ds[coord_name]

# From bounds attribute (CF)
if 'bounds' in coord.attrs:
bounds_name = coord.attrs['bounds'].strip()
if bounds_name in ds:
return bounds_name
warnings.warn('invalid bounds name: ' + bounds_name)

# From suffixed names
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove the identification by suffix.

for suffix in suffixes:
if coord_name+suffix in ds:
return coord_name + suffix


def decode_cf(ds, mapping=None):
'''
Search for longitude and latitude coordinates and bounds and rename them

Parameters
----------
ds: xarray.DataArray or xarray.Dataset
mapping: None or dict
When a `dict` is provided, it is filled with keys that are the new
names and values that are the old names, so that the output dataset
can have its coordinates be renamed back with
:meth:`~xarray.Dataset.rename`.

Returns
-------
ds: xarray.DataArray or xarray.Dataset
'''

# Longitude
lon_name = get_lon_name(ds)
if lon_name is not None:

# Search for bounds and rename
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have to rename coordinates.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, renaming is for internal use only, just to make the code clearer. It is possible to not rename :)

lon_b_name = get_bounds_name_from_coord(ds, lon_name)
if lon_b_name is not None and lon_b_name != 'lon_b':
ds = ds.rename({lon_b_name: 'lon_b'})
if isinstance(mapping, dict):
mapping['lon_b'] = lon_b_name

# Rename coordinates
if lon_name != 'lon':
ds = ds.rename({lon_name: 'lon'})
if isinstance(mapping, dict):
mapping['lon'] = lon_name

# Latitude
lat_name = get_lat_name(ds)
if lat_name is not None:

# Search for bounds and rename
lat_b_name = get_bounds_name_from_coord(ds, lat_name)
if lat_b_name is not None and lat_b_name != 'lat_b':
ds = ds.rename({lat_b_name: 'lat_b'})
if isinstance(mapping, dict):
mapping['lat_b'] = lat_b_name

# Rename coordinates
if lat_name != 'lat':
ds = ds.rename({lat_name: 'lat'})
if isinstance(mapping, dict):
mapping['lat'] = lat_name

return ds
40 changes: 35 additions & 5 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from . backend import (esmf_grid, add_corner,
esmf_regrid_build, esmf_regrid_finalize)

from .cf import decode_cf as decode_cf_

from . smm import read_weights, apply_weights

try:
Expand All @@ -30,7 +32,8 @@ def as_2d_mesh(lon, lat):
return lon, lat


def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None,
decode_cf=True):
'''
Convert xarray DataSet or dictionary to ESMF.Grid object.

Expand All @@ -49,12 +52,22 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
periodic : bool, optional
Periodic in longitude?

decode_cf: bool, dict, optional
Search for lon and lat according to CF conventions and rename them.
When a dict is provided, it is filled so it can be used to rename
lon and lat back.

Returns
-------
grid : ESMF.Grid object

'''

# rename lon and lat?
if decode_cf or isinstance(decode_cf, dict):
mapping = decode_cf if isinstance(decode_cf, dict) else None
ds = decode_cf_(ds, mapping)

# use np.asarray(dr) instead of dr.values, so it also works for dictionary
lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])
Expand All @@ -74,7 +87,8 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):

class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False, ignore_degenerate=None):
filename=None, reuse_weights=False, ignore_degenerate=None,
decode_cf=True):
"""
Make xESMF regridder

Expand Down Expand Up @@ -118,6 +132,10 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)


decode_cf: bool, optional
Search for lon and lat according to CF conventions and rename them

Returns
-------
regridder : xESMF regridder object
Expand All @@ -139,12 +157,16 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
# construct ESMF grid, with some shape checking
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
periodic=periodic,
decode_cf=decode_cf
)
self._cf_mapping = {}
if decode_cf:
ds_out = decode_cf_(ds_out, self._cf_mapping)
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
need_bounds=self.need_bounds,
decode_cf=False,
)

# record output grid and metadata
self._lon_out = np.asarray(ds_out['lon'])
self._lat_out = np.asarray(ds_out['lat'])
Expand Down Expand Up @@ -367,6 +389,10 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

# rename coords back to original output grid
if self._cf_mapping:
dr_out = dr_out.rename(self._cf_mapping)

dr_out.attrs['regrid_method'] = self.method

return dr_out
Expand Down Expand Up @@ -414,6 +440,10 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

# rename coords back to original output grid
if self._cf_mapping:
ds_out = ds_out.rename(self._cf_mapping)

ds_out.attrs['regrid_method'] = self.method

return ds_out
82 changes: 82 additions & 0 deletions xesmf/tests/test_cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
import numpy as np
import xarray as xr
import xesmf.cf as xecf


coord_base = xr.DataArray(np.arange(2), dims='d')
da = xr.DataArray(np.ones(2), dims='d', name='temp')


def test_cf_get_lon_name_warns():
with pytest.warns(UserWarning):
xecf.get_lon_name(da.to_dataset())


@pytest.mark.parametrize("name,coord_specs", [
('lon', coord_base),
('longitude', coord_base),
('x', ('d', coord_base, {'units': 'degrees_east'})),
('x', ('d', coord_base, {'standard_name': 'longitude'})),
('xu', ('d', coord_base, {'standard_name': 'longitude_at_U_location'})),
('x', ('d', coord_base, {'axis': 'X'})),
('x', ('d', coord_base, {'units': 'degrees_east', 'axis': 'Y'})),
])
def test_cf_get_lon_name(name, coord_specs):
assert xecf.get_lon_name(da.assign_coords(**{name: coord_specs})) == name


def test_cf_get_lat_name_warns():
with pytest.warns(UserWarning):
xecf.get_lat_name(da.to_dataset())


@pytest.mark.parametrize("name,coord_specs", [
('lat', coord_base),
('latitude', coord_base),
('y', ('d', coord_base, {'units': 'degrees_north'})),
('y', ('d', coord_base, {'standard_name': 'latitude'})),
('yu', ('d', coord_base, {'standard_name': 'latitude_at_V_location'})),
('y', ('d', coord_base, {'axis': 'Y'})),
('y', ('d', coord_base, {'units': 'degrees_north', 'axis': 'X'})),
])
def test_cf_get_lat_name(name, coord_specs):
assert xecf.get_lat_name(da.assign_coords(**{name: coord_specs})) == name


def test_get_bounds_name_from_coord_attribute():
da = coord_base.copy()
dab = xr.DataArray(np.arange(da.size+1, dtype='d'), dims='xb')
da.attrs.update(bounds='xb')
ds = xr.Dataset({'x': da, 'xb': dab}).set_coords('x')

assert xecf.get_bounds_name_from_coord(ds, 'x') == 'xb'


def test_get_bounds_name_from_coord_name():
da = coord_base.copy()
dab = xr.DataArray(np.arange(da.size+1, dtype='d'), dims='x_b')
ds = xr.Dataset({'x': da, 'x_b': dab}).set_coords('x')

assert xecf.get_bounds_name_from_coord(ds, 'x') == 'x_b'


def test_cf_decode_cf():

yy, xx = np.mgrid[:3, :4].astype('d')
yyb, xxb = np.mgrid[:4, :5] - .5

xx = xr.DataArray(xx, dims=['ny', 'nx'],
attrs={'units': "degrees_east",
'bounds': "xx_boonds"})
xxb = xr.DataArray(xxb, dims=['nyb', 'nxb'])
yy = xr.DataArray(yy, dims=['ny', 'nx'])
yyb = xr.DataArray(yyb, dims=['nyb', 'nxb'])

ds = xr.Dataset({
'xx': xx,
'xx_boonds': xxb,
'latitude': yy,
'latitude_bounds': yyb})
ds_decoded = xecf.decode_cf(ds)
assert sorted(list(ds_decoded)) == ['lat', 'lat_b', 'lon', 'lon_b']
Loading