Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some CF support for lon and lat #73

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions xesmf/cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Decode longitude and latitude of an `xarray.Dataset` or a `xarray.DataArray`
according to CF conventions.
"""

import warnings

CF_SPECS = {
'lon': {
'name': ['lon', 'longitude'],
'units': ['degrees_east', 'degree_east', 'degree_E', 'degrees_E',
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking the units and other attributes seems a bit too heavy for xesmf? A "CF-checker" is better for a standalone package or pushed to upstream like xarray.decode_cf

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that a CF checker for xarray is globally missing.
However, xarray is intended to remain of general purpose, and not to introduce notions of geographic and physical spatialisation.

As for the fact that these attributes are checked is the strict application of the CF conventions for longitude and latitude.
This prevent focusing on names.

'degreeE', 'degreesE'],
'standard_name': ['longitude'],
'axis': 'X'
},
'lat': {
'name': ['lat', 'latitude'],
'units': ['degrees_north', 'degree_north', 'degree_N', 'degrees_N',
'degreeN', 'degreesN'],
'standard_name': ['latitude'],
'axis': 'Y'
}
}


def get_coord_name_from_specs(ds, specs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is certainly a more extensive search than what #74 provides. The question is do we want this ? Also, from a user perspective, all the different ways coordinates can be identified might create some confusion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it is the strict application of the CF conventions, which must be considered to be the reference for formatting dataset, and which are increasingly being followed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Checking with units first since they are mandatory. And then standard_name, just in case.

'''
Get the name of a `xarray.DataArray` according to search specifications

Parameters
----------
ds: xarray.DataArray or xarray.Dataset
specs: dict

Returns
-------
str: Name of the xarray.DataArray
'''
# Targets
ds_names = []
if hasattr(ds, 'coords'):
ds_names.extend(list(ds.coords))
if hasattr(ds, 'variables'):
ds_names.extend(list(ds.variables))

# Search in names first
if 'name' in specs:
for specs_name in specs['name']:
if specs_name in ds_names:
return specs_name

# Search in units attributes
if 'units' in specs:
for ds_name in ds_names:
if hasattr(ds[ds_name], 'units'):
for specs_units in specs['units']:
if ds[ds_name].units == specs_units:
return ds_name

# Search in standard_name attributes
if 'standard_name' in specs:
for ds_name in ds_names:
if hasattr(ds[ds_name], 'standard_name'):
for specs_sn in specs['standard_name']:
if (ds[ds_name].standard_name == specs_sn or
ds[ds_name].standard_name.startswith(
specs_sn+'_at_')):
return ds_name

# Search in axis attributes
if 'axis' in specs:
for ds_name in ds_names:
if (hasattr(ds[ds_name], 'axis') and
ds[ds_name].axis.lower() == specs['axis'].lower()):
return ds_name


def get_lon_name(ds):
'''
Get the longitude name in a `xarray.Dataset` or a `xarray.DataArray`

Parameters
----------
ds: xarray.DataArray or xarray.Dataset

Returns
-------
str: Name of the xarray.DataArray
'''
lon_name = get_coord_name_from_specs(ds, CF_SPECS['lon'])
if lon_name is not None:
return lon_name
warnings.warn('longitude not found in dataset')


def get_lat_name(ds):
'''
Get the latitude name in a `xarray.Dataset` or a `xarray.DataArray`

Parameters
----------
ds: xarray.DataArray or xarray.Dataset

Returns
-------
str: Name of the xarray.DataArray
'''
lat_name = get_coord_name_from_specs(ds, CF_SPECS['lat'])
if lat_name is not None:
return lat_name
warnings.warn('latitude not found in dataset')


def decode_cf(ds, mapping=None):
'''
Search for longitude and latitude and rename them

Parameters
----------
ds: xarray.DataArray or xarray.Dataset
mapping: None or dict
When a `dict` is provided, it is filled with keys that are the new
names and values that are the old names, so that the output dataset
can have its coordinates be renamed back with
:meth:`~xarray.Dataset.rename`.

Returns
-------
ds: xarray.DataArray or xarray.Dataset
'''
# Suffixes for bounds
bsuffixes = ('_b', '_bounds', '_bnds')
Copy link
Owner

@JiaweiZhuang JiaweiZhuang Nov 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here assumes a priority of '_b' > '_bounds' > '_bnds'. If more than one of them exist in the dataset, could it lead to a name conflict error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main caveat of this PR because it rely on names and not on attributes.

Don't blame me : a solution is to push forward the application of CF conventions and to use the bounds attribute as explained here. This makes unique link between coordinates and bounds.


# Longitude
lon_name = get_lon_name(ds)
if lon_name is not None and lon_name != 'lon':
ds = ds.rename({lon_name: 'lon'})
if isinstance(mapping, dict):
mapping['lon'] = lon_name
for suffix in bsuffixes:
if lon_name+suffix in ds:
ds = ds.rename({lon_name+suffix: 'lon_b'})
if isinstance(mapping, dict):
mapping['lon_b'] = lon_name + suffix

# Latitude
lat_name = get_lat_name(ds)
if lat_name is not None and lat_name != 'lat':
ds = ds.rename({lat_name: 'lat'})
if isinstance(mapping, dict):
mapping['lat'] = lat_name
for suffix in bsuffixes:
if lat_name+suffix in ds:
ds = ds.rename({lat_name+suffix: 'lat_b'})
if isinstance(mapping, dict):
mapping['lat_b'] = lat_name + suffix

return ds
40 changes: 35 additions & 5 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from . backend import (esmf_grid, add_corner,
esmf_regrid_build, esmf_regrid_finalize)

from .cf import decode_cf as decode_cf_

from . smm import read_weights, apply_weights

try:
Expand All @@ -30,7 +32,8 @@ def as_2d_mesh(lon, lat):
return lon, lat


def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None,
decode_cf=True):
'''
Convert xarray DataSet or dictionary to ESMF.Grid object.

Expand All @@ -49,12 +52,22 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
periodic : bool, optional
Periodic in longitude?

decode_cf: bool, dict, optional
Search for lon and lat according to CF conventions and rename them.
When a dict is provided, it is filled so it can be used to rename
lon and lat back.

Returns
-------
grid : ESMF.Grid object

'''

# rename lon and lat?
if decode_cf or isinstance(decode_cf, dict):
mapping = decode_cf if isinstance(decode_cf, dict) else None
ds = decode_cf_(ds, mapping)

# use np.asarray(dr) instead of dr.values, so it also works for dictionary
lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])
Expand All @@ -74,7 +87,8 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):

class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False, ignore_degenerate=None):
filename=None, reuse_weights=False, ignore_degenerate=None,
decode_cf=True):
"""
Make xESMF regridder

Expand Down Expand Up @@ -118,6 +132,10 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)


decode_cf: bool, optional
Search for lon and lat according to CF conventions and rename them

Returns
-------
regridder : xESMF regridder object
Expand All @@ -139,12 +157,16 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
# construct ESMF grid, with some shape checking
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
periodic=periodic,
decode_cf=decode_cf
)
self._cf_mapping = {}
if decode_cf:
ds_out = decode_cf_(ds_out, self._cf_mapping)
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
need_bounds=self.need_bounds,
decode_cf=False,
)

# record output grid and metadata
self._lon_out = np.asarray(ds_out['lon'])
self._lat_out = np.asarray(ds_out['lat'])
Expand Down Expand Up @@ -367,6 +389,10 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

# rename coords back to original output grid
if self._cf_mapping:
dr_out = dr_out.rename(self._cf_mapping)

dr_out.attrs['regrid_method'] = self.method

return dr_out
Expand Down Expand Up @@ -414,6 +440,10 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

# rename coords back to original output grid
if self._cf_mapping:
ds_out = ds_out.rename(self._cf_mapping)

ds_out.attrs['regrid_method'] = self.method

return ds_out
63 changes: 63 additions & 0 deletions xesmf/tests/test_cf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
import numpy as np
import xarray as xr
import xesmf.cf as xecf


coord_base = xr.DataArray(np.arange(2), dims='d')
da = xr.DataArray(np.ones(2), dims='d', name='temp')


def test_cf_get_lon_name_warns():
with pytest.warns(UserWarning):
xecf.get_lon_name(da.to_dataset())


@pytest.mark.parametrize("name,coord_specs", [
('lon', coord_base),
('longitude', coord_base),
('x', ('d', coord_base, {'units': 'degrees_east'})),
('x', ('d', coord_base, {'standard_name': 'longitude'})),
('xu', ('d', coord_base, {'standard_name': 'longitude_at_U_location'})),
('x', ('d', coord_base, {'axis': 'X'})),
('x', ('d', coord_base, {'units': 'degrees_east', 'axis': 'Y'})),
])
def test_cf_get_lon_name(name, coord_specs):
assert xecf.get_lon_name(da.assign_coords(**{name: coord_specs})) == name


def test_cf_get_lat_name_warns():
with pytest.warns(UserWarning):
xecf.get_lat_name(da.to_dataset())


@pytest.mark.parametrize("name,coord_specs", [
('lat', coord_base),
('latitude', coord_base),
('y', ('d', coord_base, {'units': 'degrees_north'})),
('y', ('d', coord_base, {'standard_name': 'latitude'})),
('yu', ('d', coord_base, {'standard_name': 'latitude_at_V_location'})),
('y', ('d', coord_base, {'axis': 'Y'})),
('y', ('d', coord_base, {'units': 'degrees_north', 'axis': 'X'})),
])
def test_cf_get_lat_name(name, coord_specs):
assert xecf.get_lat_name(da.assign_coords(**{name: coord_specs})) == name


def test_cf_decode_cf():

yy, xx = np.mgrid[:3, :4].astype('d')
yyb, xxb = np.mgrid[:4, :5] - .5

xx = xr.DataArray(xx, dims=['ny', 'nx'], attrs={'units': "degrees_east"})
xxb = xr.DataArray(xxb, dims=['nyb', 'nxb'])
yy = xr.DataArray(yy, dims=['ny', 'nx'])
yyb = xr.DataArray(yyb, dims=['nyb', 'nxb'])

ds = xr.Dataset({
'xx': xx,
'xx_b': xxb,
'latitude': yy,
'latitude_bounds': yyb})
ds_decoded = xecf.decode_cf(ds)
assert sorted(list(ds_decoded)) == ['lat', 'lat_b', 'lon', 'lon_b']
26 changes: 26 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,29 @@ def test_regrid_dataset():

# clean-up
regridder.clean_weight_file()


def test_regrid_dataset_renamed_back():

# renaming
ds_out2 = ds_out.rename(lon='X', lat='Y', lon_b='X_b', lat_b='Y_bnds')
ds_out2.X.attrs['units'] = 'degrees_east'
ds_out2.Y.attrs['standard_name'] = 'latitude'
print(ds_out2.X.units)

regridder = xe.Regridder(ds_in, ds_out2, 'conservative')

ds_result = regridder(ds_in)

# compare with analytical solution
rel_err = (ds_out2['data_ref'] - ds_result['data'])/ds_out2['data_ref']
assert np.max(np.abs(rel_err)) < 0.05

# check that name are restored
assert 'X' in ds_result.coords
assert 'Y' in ds_result.coords
assert 'X_b' in ds_result.variables
assert 'Y_bnds' in ds_result.variables

# clean-up
regridder.clean_weight_file()