Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ESMF.LocStream capabilities to xESMF #81

Merged
merged 13 commits into from
Mar 6, 2020
32 changes: 32 additions & 0 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def esmf_grid(lon, lat, periodic=False):
return grid


def esmf_locstream(lon, lat):
'''
Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid

Parameters
----------
lon, lat : 1D numpy array
Longitute/Latitude of cell centers.

Returns
-------
locstream : ESMF.LocStream object
'''

if len(lon.shape) > 1:
raise ValueError("lon can only be 1d")
if len(lat.shape) > 1:
raise ValueError("lat can only be 1d")

assert lon.shape == lat.shape

location_count = len(lon)

locstream = ESMF.LocStream(location_count,
coord_sys=ESMF.CoordSys.SPH_DEG)

locstream["ESMF:Lon"] = lon.astype(np.dtype('f8'))
locstream["ESMF:Lat"] = lat.astype(np.dtype('f8'))

return locstream


def add_corner(grid, lon_b, lat_b):
'''
Add corner information to ESMF.Grid for conservative regridding.
Expand Down
141 changes: 114 additions & 27 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings

from . backend import (esmf_grid, add_corner,
from . backend import (esmf_grid, esmf_locstream, add_corner,
esmf_regrid_build, esmf_regrid_finalize)

from . smm import read_weights, apply_weights
Expand Down Expand Up @@ -72,9 +72,40 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
return grid, lon.shape


def ds_to_ESMFlocstream(ds):
'''
Convert xarray DataSet or dictionary to ESMF.LocStream object.

Parameters
----------
ds : xarray DataSet or dictionary
Contains variables ``lon``, ``lat``.

Returns
-------
locstream : ESMF.LocStream object

'''

lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])

if len(lon.shape) > 1:
raise ValueError("lon can only be 1d")
if len(lat.shape) > 1:
raise ValueError("lat can only be 1d")

assert lon.shape == lat.shape

locstream = esmf_locstream(lon, lat)

return locstream, (1,) + lon.shape


class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False, ignore_degenerate=None):
filename=None, reuse_weights=False, ignore_degenerate=None,
locstream_in=False, locstream_out=False):
"""
Make xESMF regridder

Expand Down Expand Up @@ -118,6 +149,12 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)

locstream_in: bool, optional
input is a LocStream (list of locations)

locstream_out: bool, optional
output is a LocStream (list of locations)

Returns
-------
regridder : xESMF regridder object
Expand All @@ -135,15 +172,31 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
self.periodic = periodic
self.reuse_weights = reuse_weights
self.ignore_degenerate = ignore_degenerate
self.locstream_in = locstream_in
self.locstream_out = locstream_out

methods_avail_ls_in = ['nearest_s2d', 'nearest_d2s']
methods_avail_ls_out = ['bilinear', 'patch'] + methods_avail_ls_in

if locstream_in and self.method not in methods_avail_ls_in:
raise ValueError(f'locstream input is only available for method in {methods_avail_ls_in}')
if locstream_out and self.method not in methods_avail_ls_out:
raise ValueError(f'locstream output is only available for method in {methods_avail_ls_out}')

# construct ESMF grid, with some shape checking
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
)
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
)
if locstream_in:
self._grid_in, shape_in = ds_to_ESMFlocstream(ds_in)
else:
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
)
if locstream_out:
self._grid_out, shape_out = ds_to_ESMFlocstream(ds_out)
else:
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
)

# record output grid and metadata
self._lon_out = np.asarray(ds_out['lon'])
Expand Down Expand Up @@ -312,6 +365,9 @@ def __call__(self, indata, keep_attrs=False):
def regrid_numpy(self, indata):
"""See __call__()."""

if self.locstream_in:
indata = indata[np.newaxis, :]

outdata = apply_weights(self.weights, indata,
self.shape_in, self.shape_out)
return outdata
Expand Down Expand Up @@ -341,7 +397,14 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
# apply_ufunc needs a different name for output_core_dims
# example: ('lat', 'lon') -> ('lat_new', 'lon_new')
# https://github.com/pydata/xarray/issues/1931#issuecomment-367417542
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]
if self.locstream_out:
temp_horiz_dims = ['dummy', 'locations']
else:
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]

if self.locstream_in:
temp_horiz_dims = ['dummy_new'] + temp_horiz_dims


dr_out = xr.apply_ufunc(
self.regrid_numpy, dr_in,
Expand All @@ -355,20 +418,28 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
keep_attrs=keep_attrs
)

# rename dimension name to match output grid
dr_out = dr_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)
if not self.locstream_out:
# rename dimension name to match output grid
dr_out = dr_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)
if self.locstream_out:
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=('locations',))
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=('locations',))
else:
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

dr_out.attrs['regrid_method'] = self.method

if self.locstream_out:
dr_out = dr_out.squeeze(dim='dummy')

return dr_out

def regrid_dataset(self, ds_in, keep_attrs=False):
Expand All @@ -382,7 +453,14 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
name, dr_in = next(iter(ds_in.items()))

input_horiz_dims = dr_in.dims[-2:]
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]

if self.locstream_out:
temp_horiz_dims = ['dummy', 'locations']
else:
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]

if self.locstream_in:
temp_horiz_dims = ['dummy_new'] + temp_horiz_dims

# help user debugging invalid horizontal dimensions
print('using dimensions {} from data variable {} '
Expand All @@ -402,18 +480,27 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
keep_attrs=keep_attrs
)

# rename dimension name to match output grid
ds_out = ds_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)
if not self.locstream_out:
# rename dimension name to match output grid
ds_out = ds_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)
if self.locstream_out:
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=('locations',))
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=('locations',))
else:
if not self.locstream_in:
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

ds_out.attrs['regrid_method'] = self.method

if self.locstream_out:
ds_out = ds_out.squeeze(dim='dummy')

return ds_out
18 changes: 17 additions & 1 deletion xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ESMF
import xesmf as xe
from xesmf.backend import (warn_f_contiguous, warn_lat_range,
esmf_grid, add_corner,
esmf_grid, add_corner, esmf_locstream,
esmf_regrid_build, esmf_regrid_apply,
esmf_regrid_finalize)
from xesmf.smm import read_weights, apply_weights
Expand Down Expand Up @@ -207,3 +207,19 @@ def test_regrid_periodic_correct():
assert np.max(np.abs(rel_err)) < 0.065
# clean-up
esmf_regrid_finalize(regrid)


def test_esmf_locstream():
lon = np.arange(5)
lat = np.arange(5)

ls = esmf_locstream(lon, lat)
assert isinstance(ls, ESMF.LocStream)

lon2d, lat2d = np.meshgrid(lon, lat)
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat2d)
with pytest.raises(ValueError):
ls = esmf_locstream(lon, lat2d)
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat)
Loading