From b8f9d8c442e7023fe4117b1fe623cc44eed1dded Mon Sep 17 00:00:00 2001 From: Stephane Raynaud Date: Mon, 28 Oct 2019 16:36:34 +0100 Subject: [PATCH 1/3] Add some CF support for lon and lat --- xesmf/cf.py | 143 +++++++++++++++++++++++++++++++++++++++++ xesmf/frontend.py | 25 +++++-- xesmf/tests/test_cf.py | 63 ++++++++++++++++++ 3 files changed, 227 insertions(+), 4 deletions(-) create mode 100755 xesmf/cf.py create mode 100755 xesmf/tests/test_cf.py diff --git a/xesmf/cf.py b/xesmf/cf.py new file mode 100755 index 0000000..f415637 --- /dev/null +++ b/xesmf/cf.py @@ -0,0 +1,143 @@ +""" +Decode longitude and latitude of an `xarray.Dataset` or a `xarray.DataArray` +according to CF conventions. +""" + +import warnings + +CF_SPECS = { + 'lon': { + 'name': ['lon', 'longitude'], + 'units': ['degrees_east', 'degree_east', 'degree_E', 'degrees_E', + 'degreeE', 'degreesE'], + 'standard_name': ['longitude'], + 'axis': 'X' + }, + 'lat': { + 'name': ['lat', 'latitude'], + 'units': ['degrees_north', 'degree_north', 'degree_N', 'degrees_N', + 'degreeN', 'degreesN'], + 'standard_name': ['latitude'], + 'axis': 'Y' + } + } + + + +def get_coord_name_from_specs(ds, specs): + ''' + Get the name of a `xarray.DataArray` according to search specifications + + Parameters + ---------- + ds: xarray.DataArray or xarray.Dataset + specs: dict + + Returns + ------- + str: Name of the xarray.DataArray + ''' + # Targets + ds_names = [] + if hasattr(ds, 'coords'): + ds_names.extend(list(ds.coords)) + if hasattr(ds, 'variables'): + ds_names.extend(list(ds.variables)) + + # Search in names first + if 'name' in specs: + for specs_name in specs['name']: + if specs_name in ds_names: + return specs_name + + # Search in units attributes + if 'units' in specs: + for ds_name in ds_names: + if hasattr(ds[ds_name], 'units'): + for specs_units in specs['units']: + if ds[ds_name].units == specs_units: + return ds_name + + # Search in standard_name attributes + if 'standard_name' in specs: + for ds_name in ds_names: + if hasattr(ds[ds_name], 'standard_name'): + for specs_sn in specs['standard_name']: + if (ds[ds_name].standard_name == specs_sn or + ds[ds_name].standard_name.startswith( + specs_sn+'_at_')): + return ds_name + + # Search in axis attributes + if 'axis' in specs: + for ds_name in ds_names: + if (hasattr(ds[ds_name], 'axis') and + ds[ds_name].axis.lower() == specs['axis'].lower()): + return ds_name + + +def get_lon_name(ds): + ''' + Get the longitude name in a `xarray.Dataset` or a `xarray.DataArray` + + Parameters + ---------- + ds: xarray.DataArray or xarray.Dataset + + Returns + ------- + str: Name of the xarray.DataArray + ''' + lon_name = get_coord_name_from_specs(ds, CF_SPECS['lon']) + if lon_name is not None: + return lon_name + warnings.warn('longitude not found in dataset') + + +def get_lat_name(ds): + ''' + Get the latitude name in a `xarray.Dataset` or a `xarray.DataArray` + + Parameters + ---------- + ds: xarray.DataArray or xarray.Dataset + + Returns + ------- + str: Name of the xarray.DataArray + ''' + lat_name = get_coord_name_from_specs(ds, CF_SPECS['lat']) + if lat_name is not None: + return lat_name + warnings.warn('latitude not found in dataset') + + +def decode_cf(ds): + ''' + Search for longitude and latitude and rename them + + Parameters + ---------- + ds: xarray.DataArray or xarray.Dataset + + Returns + ------- + ds: xarray.DataArray or xarray.Dataset + ''' + # Longitude + lon_name = get_lon_name(ds) + if lon_name is not None: + ds = ds.rename({lon_name: 'lon'}) + for suffix in ('_b', '_bounds'): + if lon_name+suffix in ds: + ds = ds.rename({lon_name+suffix: 'lon_b'}) + + # Latitude + lat_name = get_lat_name(ds) + if lat_name is not None: + ds = ds.rename({lat_name: 'lat'}) + for suffix in ('_b', '_bounds'): + if lat_name+suffix in ds: + ds = ds.rename({lat_name+suffix: 'lat_b'}) + + return ds diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 586cc6a..c3a6191 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -10,6 +10,8 @@ from . backend import (esmf_grid, add_corner, esmf_regrid_build, esmf_regrid_finalize) +from .cf import decode_cf as decode_cf_ + from . smm import read_weights, apply_weights try: @@ -30,7 +32,8 @@ def as_2d_mesh(lon, lat): return lon, lat -def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): +def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None, + decode_cf=True): ''' Convert xarray DataSet or dictionary to ESMF.Grid object. @@ -49,12 +52,19 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): periodic : bool, optional Periodic in longitude? + decode_cf: bool, optional + Search for lon and lat according to CF conventions and rename them + Returns ------- grid : ESMF.Grid object ''' + # rename lon and lat? + if decode_cf: + ds = decode_cf_(ds) + # use np.asarray(dr) instead of dr.values, so it also works for dictionary lon = np.asarray(ds['lon']) lat = np.asarray(ds['lat']) @@ -74,7 +84,8 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): class Regridder(object): def __init__(self, ds_in, ds_out, method, periodic=False, - filename=None, reuse_weights=False, ignore_degenerate=None): + filename=None, reuse_weights=False, ignore_degenerate=None, + decode_cf=True): """ Make xESMF regridder @@ -118,6 +129,10 @@ def __init__(self, ds_in, ds_out, method, periodic=False, If False (default), raise error if grids contain degenerated cells (i.e. triangles or lines, instead of quadrilaterals) + + decode_cf: bool, optional + Search for lon and lat according to CF conventions and rename them + Returns ------- regridder : xESMF regridder object @@ -139,10 +154,12 @@ def __init__(self, ds_in, ds_out, method, periodic=False, # construct ESMF grid, with some shape checking self._grid_in, shape_in = ds_to_ESMFgrid(ds_in, need_bounds=self.need_bounds, - periodic=periodic + periodic=periodic, + decode_cf=decode_cf ) self._grid_out, shape_out = ds_to_ESMFgrid(ds_out, - need_bounds=self.need_bounds + need_bounds=self.need_bounds, + decode_cf=decode_cf ) # record output grid and metadata diff --git a/xesmf/tests/test_cf.py b/xesmf/tests/test_cf.py new file mode 100755 index 0000000..5530e78 --- /dev/null +++ b/xesmf/tests/test_cf.py @@ -0,0 +1,63 @@ +import pytest +import numpy as np +import xarray as xr +import xesmf.cf as xecf + + +coord_base = xr.DataArray(np.arange(2), dims='d') +da = xr.DataArray(np.ones(2), dims='d', name='temp') + + +def test_cf_get_lon_name_warns(): + with pytest.warns(UserWarning): + xecf.get_lon_name(da.to_dataset()) + + +@pytest.mark.parametrize("name,coord_specs", [ + ('lon', coord_base), + ('longitude', coord_base), + ('x', ('d', coord_base, {'units': 'degrees_east'})), + ('x', ('d', coord_base, {'standard_name': 'longitude'})), + ('xu', ('d', coord_base, {'standard_name': 'longitude_at_U_location'})), + ('x', ('d', coord_base, {'axis': 'X'})), + ('x', ('d', coord_base, {'units': 'degrees_east', 'axis': 'Y'})), + ]) +def test_cf_get_lon_name(name, coord_specs): + assert xecf.get_lon_name(da.assign_coords(**{name: coord_specs})) == name + + +def test_cf_get_lat_name_warns(): + with pytest.warns(UserWarning): + xecf.get_lat_name(da.to_dataset()) + + +@pytest.mark.parametrize("name,coord_specs", [ + ('lat', coord_base), + ('latitude', coord_base), + ('y', ('d', coord_base, {'units': 'degrees_north'})), + ('y', ('d', coord_base, {'standard_name': 'latitude'})), + ('yu', ('d', coord_base, {'standard_name': 'latitude_at_V_location'})), + ('y', ('d', coord_base, {'axis': 'Y'})), + ('y', ('d', coord_base, {'units': 'degrees_north', 'axis': 'X'})), + ]) +def test_cf_get_lat_name(name, coord_specs): + assert xecf.get_lat_name(da.assign_coords(**{name: coord_specs})) == name + + +def test_cf_decode_cf(): + + yy, xx = np.mgrid[:3, :4].astype('d') + yyb, xxb = np.mgrid[:4, :5] - .5 + + xx = xr.DataArray(xx, dims=['ny', 'nx'], attrs={'units': "degrees_east"}) + xxb = xr.DataArray(xxb, dims=['nyb', 'nxb']) + yy = xr.DataArray(yy, dims=['ny', 'nx']) + yyb = xr.DataArray(yyb, dims=['nyb', 'nxb']) + + ds = xr.Dataset({ + 'xx': xx, + 'xx_b': xxb, + 'latitude': yy, + 'latitude_bounds': yyb}) + ds_decoded = xecf.decode_cf(ds) + assert sorted(list(ds_decoded)) == ['lat', 'lat_b', 'lon', 'lon_b'] From afd5aac7c4102f2430ed01dc93ac5db7db116470 Mon Sep 17 00:00:00 2001 From: Stephane Raynaud Date: Tue, 29 Oct 2019 11:00:09 +0100 Subject: [PATCH 2/3] Add name restoration to CF renaming for outputs --- xesmf/cf.py | 27 +++++++++++++++++++++------ xesmf/frontend.py | 25 +++++++++++++++++++------ xesmf/tests/test_frontend.py | 26 ++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/xesmf/cf.py b/xesmf/cf.py index f415637..c7fe94f 100755 --- a/xesmf/cf.py +++ b/xesmf/cf.py @@ -23,7 +23,6 @@ } - def get_coord_name_from_specs(ds, specs): ''' Get the name of a `xarray.DataArray` according to search specifications @@ -112,32 +111,48 @@ def get_lat_name(ds): warnings.warn('latitude not found in dataset') -def decode_cf(ds): +def decode_cf(ds, mapping=None): ''' Search for longitude and latitude and rename them Parameters ---------- ds: xarray.DataArray or xarray.Dataset + mapping: None or dict + When a `dict` is provided, it is filled with keys that are the new + names and values that are the old names, so that the output dataset + can have its coordinates be renamed back with + :meth:`~xarray.Dataset.rename`. Returns ------- ds: xarray.DataArray or xarray.Dataset ''' + # Suffixes for bounds + bsuffixes = ('_b', '_bounds', '_bnds') + # Longitude lon_name = get_lon_name(ds) - if lon_name is not None: + if lon_name is not None and lon_name != 'lon': ds = ds.rename({lon_name: 'lon'}) - for suffix in ('_b', '_bounds'): + if isinstance(mapping, dict): + mapping['lon'] = lon_name + for suffix in bsuffixes: if lon_name+suffix in ds: ds = ds.rename({lon_name+suffix: 'lon_b'}) + if isinstance(mapping, dict): + mapping['lon_b'] = lon_name + suffix # Latitude lat_name = get_lat_name(ds) - if lat_name is not None: + if lat_name is not None and lat_name != 'lat': ds = ds.rename({lat_name: 'lat'}) - for suffix in ('_b', '_bounds'): + if isinstance(mapping, dict): + mapping['lat'] = lat_name + for suffix in bsuffixes: if lat_name+suffix in ds: ds = ds.rename({lat_name+suffix: 'lat_b'}) + if isinstance(mapping, dict): + mapping['lat_b'] = lat_name + suffix return ds diff --git a/xesmf/frontend.py b/xesmf/frontend.py index c3a6191..0388daa 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -52,8 +52,10 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None, periodic : bool, optional Periodic in longitude? - decode_cf: bool, optional - Search for lon and lat according to CF conventions and rename them + decode_cf: bool, dict, optional + Search for lon and lat according to CF conventions and rename them. + When a dict is provided, it is filled so it can be used to rename + lon and lat back. Returns ------- @@ -62,8 +64,9 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None, ''' # rename lon and lat? - if decode_cf: - ds = decode_cf_(ds) + if decode_cf or isinstance(decode_cf, dict): + mapping = decode_cf if isinstance(decode_cf, dict) else None + ds = decode_cf_(ds, mapping) # use np.asarray(dr) instead of dr.values, so it also works for dictionary lon = np.asarray(ds['lon']) @@ -157,11 +160,13 @@ def __init__(self, ds_in, ds_out, method, periodic=False, periodic=periodic, decode_cf=decode_cf ) + self._cf_mapping = {} + if decode_cf: + ds_out = decode_cf_(ds_out, self._cf_mapping) self._grid_out, shape_out = ds_to_ESMFgrid(ds_out, need_bounds=self.need_bounds, - decode_cf=decode_cf + decode_cf=False, ) - # record output grid and metadata self._lon_out = np.asarray(ds_out['lon']) self._lat_out = np.asarray(ds_out['lat']) @@ -384,6 +389,10 @@ def regrid_dataarray(self, dr_in, keep_attrs=False): dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim) dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim) + # rename coords back to original output grid + if self._cf_mapping: + dr_out = dr_out.rename(self._cf_mapping) + dr_out.attrs['regrid_method'] = self.method return dr_out @@ -431,6 +440,10 @@ def regrid_dataset(self, ds_in, keep_attrs=False): ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim) ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim) + # rename coords back to original output grid + if self._cf_mapping: + ds_out = ds_out.rename(self._cf_mapping) + ds_out.attrs['regrid_method'] = self.method return ds_out diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 7029c02..e1d9a10 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -260,3 +260,29 @@ def test_regrid_dataset(): # clean-up regridder.clean_weight_file() + + +def test_regrid_dataset_renamed_back(): + + # renaming + ds_out2 = ds_out.rename(lon='X', lat='Y', lon_b='X_b', lat_b='Y_bnds') + ds_out2.X.attrs['units'] = 'degrees_east' + ds_out2.Y.attrs['standard_name'] = 'latitude' + print(ds_out2.X.units) + + regridder = xe.Regridder(ds_in, ds_out2, 'conservative') + + ds_result = regridder(ds_in) + + # compare with analytical solution + rel_err = (ds_out2['data_ref'] - ds_result['data'])/ds_out2['data_ref'] + assert np.max(np.abs(rel_err)) < 0.05 + + # check that name are restored + assert 'X' in ds_result.coords + assert 'Y' in ds_result.coords + assert 'X_b' in ds_result.variables + assert 'Y_bnds' in ds_result.variables + + # clean-up + regridder.clean_weight_file() From 9094213e1bc86045810b3ea22f1802b915fcaf50 Mon Sep 17 00:00:00 2001 From: Stephane Raynaud Date: Tue, 26 Nov 2019 14:06:54 +0100 Subject: [PATCH 3/3] Add CF support to search for bounds It first look ate the "bounds" attribute, then search for names using usual prefixes. --- xesmf/cf.py | 89 +++++++++++++++++++++++++++++++----------- xesmf/tests/test_cf.py | 23 ++++++++++- 2 files changed, 88 insertions(+), 24 deletions(-) diff --git a/xesmf/cf.py b/xesmf/cf.py index c7fe94f..fd1694b 100755 --- a/xesmf/cf.py +++ b/xesmf/cf.py @@ -103,7 +103,7 @@ def get_lat_name(ds): Returns ------- - str: Name of the xarray.DataArray + str or None: Name of the xarray.DataArray ''' lat_name = get_coord_name_from_specs(ds, CF_SPECS['lat']) if lat_name is not None: @@ -111,9 +111,46 @@ def get_lat_name(ds): warnings.warn('latitude not found in dataset') +def get_bounds_name_from_coord(ds, coord_name, + suffixes=['_b', '_bnds', '_bounds']): + ''' + Get the name of the bounds array from the coord array + + It first searches for the 'bounds' attributes, then search + for names built from the suffixed coord name. + + Parameters + ---------- + ds: xarray.DataArray or xarray.Dataset + coord_name: str + Name of coord DataArray. + suffixes: list of str + Prefixes appended to `coord_name` to search for the bounds array name. + + Returns + ------- + str or None: Name of the xarray.DataArray + ''' + + # Inits + coord = ds[coord_name] + + # From bounds attribute (CF) + if 'bounds' in coord.attrs: + bounds_name = coord.attrs['bounds'].strip() + if bounds_name in ds: + return bounds_name + warnings.warn('invalid bounds name: ' + bounds_name) + + # From suffixed names + for suffix in suffixes: + if coord_name+suffix in ds: + return coord_name + suffix + + def decode_cf(ds, mapping=None): ''' - Search for longitude and latitude and rename them + Search for longitude and latitude coordinates and bounds and rename them Parameters ---------- @@ -128,31 +165,39 @@ def decode_cf(ds, mapping=None): ------- ds: xarray.DataArray or xarray.Dataset ''' - # Suffixes for bounds - bsuffixes = ('_b', '_bounds', '_bnds') # Longitude lon_name = get_lon_name(ds) - if lon_name is not None and lon_name != 'lon': - ds = ds.rename({lon_name: 'lon'}) - if isinstance(mapping, dict): - mapping['lon'] = lon_name - for suffix in bsuffixes: - if lon_name+suffix in ds: - ds = ds.rename({lon_name+suffix: 'lon_b'}) - if isinstance(mapping, dict): - mapping['lon_b'] = lon_name + suffix + if lon_name is not None: + + # Search for bounds and rename + lon_b_name = get_bounds_name_from_coord(ds, lon_name) + if lon_b_name is not None and lon_b_name != 'lon_b': + ds = ds.rename({lon_b_name: 'lon_b'}) + if isinstance(mapping, dict): + mapping['lon_b'] = lon_b_name + + # Rename coordinates + if lon_name != 'lon': + ds = ds.rename({lon_name: 'lon'}) + if isinstance(mapping, dict): + mapping['lon'] = lon_name # Latitude lat_name = get_lat_name(ds) - if lat_name is not None and lat_name != 'lat': - ds = ds.rename({lat_name: 'lat'}) - if isinstance(mapping, dict): - mapping['lat'] = lat_name - for suffix in bsuffixes: - if lat_name+suffix in ds: - ds = ds.rename({lat_name+suffix: 'lat_b'}) - if isinstance(mapping, dict): - mapping['lat_b'] = lat_name + suffix + if lat_name is not None: + + # Search for bounds and rename + lat_b_name = get_bounds_name_from_coord(ds, lat_name) + if lat_b_name is not None and lat_b_name != 'lat_b': + ds = ds.rename({lat_b_name: 'lat_b'}) + if isinstance(mapping, dict): + mapping['lat_b'] = lat_b_name + + # Rename coordinates + if lat_name != 'lat': + ds = ds.rename({lat_name: 'lat'}) + if isinstance(mapping, dict): + mapping['lat'] = lat_name return ds diff --git a/xesmf/tests/test_cf.py b/xesmf/tests/test_cf.py index 5530e78..de2810d 100755 --- a/xesmf/tests/test_cf.py +++ b/xesmf/tests/test_cf.py @@ -44,19 +44,38 @@ def test_cf_get_lat_name(name, coord_specs): assert xecf.get_lat_name(da.assign_coords(**{name: coord_specs})) == name +def test_get_bounds_name_from_coord_attribute(): + da = coord_base.copy() + dab = xr.DataArray(np.arange(da.size+1, dtype='d'), dims='xb') + da.attrs.update(bounds='xb') + ds = xr.Dataset({'x': da, 'xb': dab}).set_coords('x') + + assert xecf.get_bounds_name_from_coord(ds, 'x') == 'xb' + + +def test_get_bounds_name_from_coord_name(): + da = coord_base.copy() + dab = xr.DataArray(np.arange(da.size+1, dtype='d'), dims='x_b') + ds = xr.Dataset({'x': da, 'x_b': dab}).set_coords('x') + + assert xecf.get_bounds_name_from_coord(ds, 'x') == 'x_b' + + def test_cf_decode_cf(): yy, xx = np.mgrid[:3, :4].astype('d') yyb, xxb = np.mgrid[:4, :5] - .5 - xx = xr.DataArray(xx, dims=['ny', 'nx'], attrs={'units': "degrees_east"}) + xx = xr.DataArray(xx, dims=['ny', 'nx'], + attrs={'units': "degrees_east", + 'bounds': "xx_boonds"}) xxb = xr.DataArray(xxb, dims=['nyb', 'nxb']) yy = xr.DataArray(yy, dims=['ny', 'nx']) yyb = xr.DataArray(yyb, dims=['nyb', 'nxb']) ds = xr.Dataset({ 'xx': xx, - 'xx_b': xxb, + 'xx_boonds': xxb, 'latitude': yy, 'latitude_bounds': yyb}) ds_decoded = xecf.decode_cf(ds)