Skip to content

Commit

Permalink
Merge pull request #81 from raphaeldussin/locstream
Browse files Browse the repository at this point in the history
Adding ESMF.LocStream capabilities
  • Loading branch information
JiaweiZhuang authored Mar 6, 2020
2 parents bbaa210 + b84f2d7 commit 60cd934
Show file tree
Hide file tree
Showing 6 changed files with 804 additions and 34 deletions.
1 change: 1 addition & 0 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Contents
notebooks/Dask
notebooks/Compare_algorithms
notebooks/Reuse_regridder
notebooks/Using_LocStream

.. toctree::
:maxdepth: 1
Expand Down
489 changes: 489 additions & 0 deletions doc/notebooks/Using_LocStream.ipynb

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,38 @@ def esmf_grid(lon, lat, periodic=False):
return grid


def esmf_locstream(lon, lat):
'''
Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid
Parameters
----------
lon, lat : 1D numpy array
Longitute/Latitude of cell centers.
Returns
-------
locstream : ESMF.LocStream object
'''

if len(lon.shape) > 1:
raise ValueError("lon can only be 1d")
if len(lat.shape) > 1:
raise ValueError("lat can only be 1d")

assert lon.shape == lat.shape

location_count = len(lon)

locstream = ESMF.LocStream(location_count,
coord_sys=ESMF.CoordSys.SPH_DEG)

locstream["ESMF:Lon"] = lon.astype(np.dtype('f8'))
locstream["ESMF:Lat"] = lat.astype(np.dtype('f8'))

return locstream


def add_corner(grid, lon_b, lat_b):
'''
Add corner information to ESMF.Grid for conservative regridding.
Expand Down
150 changes: 121 additions & 29 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings

from . backend import (esmf_grid, add_corner,
from . backend import (esmf_grid, esmf_locstream, add_corner,
esmf_regrid_build, esmf_regrid_finalize)

from . smm import read_weights, apply_weights
Expand Down Expand Up @@ -72,9 +72,40 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
return grid, lon.shape


def ds_to_ESMFlocstream(ds):
'''
Convert xarray DataSet or dictionary to ESMF.LocStream object.
Parameters
----------
ds : xarray DataSet or dictionary
Contains variables ``lon``, ``lat``.
Returns
-------
locstream : ESMF.LocStream object
'''

lon = np.asarray(ds['lon'])
lat = np.asarray(ds['lat'])

if len(lon.shape) > 1:
raise ValueError("lon can only be 1d")
if len(lat.shape) > 1:
raise ValueError("lat can only be 1d")

assert lon.shape == lat.shape

locstream = esmf_locstream(lon, lat)

return locstream, (1,) + lon.shape


class Regridder(object):
def __init__(self, ds_in, ds_out, method, periodic=False,
filename=None, reuse_weights=False, ignore_degenerate=None):
filename=None, reuse_weights=False, ignore_degenerate=None,
locstream_in=False, locstream_out=False):
"""
Make xESMF regridder
Expand Down Expand Up @@ -118,6 +149,12 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
If False (default), raise error if grids contain degenerated cells
(i.e. triangles or lines, instead of quadrilaterals)
locstream_in: bool, optional
input is a LocStream (list of locations)
locstream_out: bool, optional
output is a LocStream (list of locations)
Returns
-------
regridder : xESMF regridder object
Expand All @@ -135,15 +172,31 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
self.periodic = periodic
self.reuse_weights = reuse_weights
self.ignore_degenerate = ignore_degenerate
self.locstream_in = locstream_in
self.locstream_out = locstream_out

methods_avail_ls_in = ['nearest_s2d', 'nearest_d2s']
methods_avail_ls_out = ['bilinear', 'patch'] + methods_avail_ls_in

if locstream_in and self.method not in methods_avail_ls_in:
raise ValueError(f'locstream input is only available for method in {methods_avail_ls_in}')
if locstream_out and self.method not in methods_avail_ls_out:
raise ValueError(f'locstream output is only available for method in {methods_avail_ls_out}')

# construct ESMF grid, with some shape checking
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
)
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
)
if locstream_in:
self._grid_in, shape_in = ds_to_ESMFlocstream(ds_in)
else:
self._grid_in, shape_in = ds_to_ESMFgrid(ds_in,
need_bounds=self.need_bounds,
periodic=periodic
)
if locstream_out:
self._grid_out, shape_out = ds_to_ESMFlocstream(ds_out)
else:
self._grid_out, shape_out = ds_to_ESMFgrid(ds_out,
need_bounds=self.need_bounds
)

# record output grid and metadata
self._lon_out = np.asarray(ds_out['lon'])
Expand Down Expand Up @@ -312,6 +365,9 @@ def __call__(self, indata, keep_attrs=False):
def regrid_numpy(self, indata):
"""See __call__()."""

if self.locstream_in:
indata = np.expand_dims(indata, axis=-2)

outdata = apply_weights(self.weights, indata,
self.shape_in, self.shape_out)
return outdata
Expand All @@ -336,12 +392,22 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
"""See __call__()."""

# example: ('lat', 'lon') or ('y', 'x')
input_horiz_dims = dr_in.dims[-2:]
if self.locstream_in:
input_horiz_dims = dr_in.dims[-1:]
else:
input_horiz_dims = dr_in.dims[-2:]

# apply_ufunc needs a different name for output_core_dims
# example: ('lat', 'lon') -> ('lat_new', 'lon_new')
# https://github.com/pydata/xarray/issues/1931#issuecomment-367417542
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]
if self.locstream_out:
temp_horiz_dims = ['dummy', 'locations']
else:
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]

if self.locstream_in and not self.locstream_out:
temp_horiz_dims = ['dummy_new'] + temp_horiz_dims


dr_out = xr.apply_ufunc(
self.regrid_numpy, dr_in,
Expand All @@ -355,20 +421,28 @@ def regrid_dataarray(self, dr_in, keep_attrs=False):
keep_attrs=keep_attrs
)

# rename dimension name to match output grid
dr_out = dr_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)
if not self.locstream_out:
# rename dimension name to match output grid
dr_out = dr_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)
if self.locstream_out:
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=('locations',))
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=('locations',))
else:
dr_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
dr_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

dr_out.attrs['regrid_method'] = self.method

if self.locstream_out:
dr_out = dr_out.squeeze(dim='dummy')

return dr_out

def regrid_dataset(self, ds_in, keep_attrs=False):
Expand All @@ -381,8 +455,18 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
# get the first data variable to infer input_core_dims
name, dr_in = next(iter(ds_in.items()))

input_horiz_dims = dr_in.dims[-2:]
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]
if self.locstream_in:
input_horiz_dims = dr_in.dims[-1:]
else:
input_horiz_dims = dr_in.dims[-2:]

if self.locstream_out:
temp_horiz_dims = ['dummy', 'locations']
else:
temp_horiz_dims = [s + '_new' for s in input_horiz_dims]

if self.locstream_in and not self.locstream_out:
temp_horiz_dims = ['dummy_new'] + temp_horiz_dims

# help user debugging invalid horizontal dimensions
print('using dimensions {} from data variable {} '
Expand All @@ -402,18 +486,26 @@ def regrid_dataset(self, ds_in, keep_attrs=False):
keep_attrs=keep_attrs
)

# rename dimension name to match output grid
ds_out = ds_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)
if not self.locstream_out:
# rename dimension name to match output grid
ds_out = ds_out.rename(
{temp_horiz_dims[0]: self.out_horiz_dims[0],
temp_horiz_dims[1]: self.out_horiz_dims[1]
}
)

# append output horizontal coordinate values
# extra coordinates are automatically tracked by apply_ufunc
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)
if self.locstream_out:
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=('locations',))
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=('locations',))
else:
ds_out.coords['lon'] = xr.DataArray(self._lon_out, dims=self.lon_dim)
ds_out.coords['lat'] = xr.DataArray(self._lat_out, dims=self.lat_dim)

ds_out.attrs['regrid_method'] = self.method

if self.locstream_out:
ds_out = ds_out.squeeze(dim='dummy')

return ds_out
18 changes: 17 additions & 1 deletion xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ESMF
import xesmf as xe
from xesmf.backend import (warn_f_contiguous, warn_lat_range,
esmf_grid, add_corner,
esmf_grid, add_corner, esmf_locstream,
esmf_regrid_build, esmf_regrid_apply,
esmf_regrid_finalize)
from xesmf.smm import read_weights, apply_weights
Expand Down Expand Up @@ -207,3 +207,19 @@ def test_regrid_periodic_correct():
assert np.max(np.abs(rel_err)) < 0.065
# clean-up
esmf_regrid_finalize(regrid)


def test_esmf_locstream():
lon = np.arange(5)
lat = np.arange(5)

ls = esmf_locstream(lon, lat)
assert isinstance(ls, ESMF.LocStream)

lon2d, lat2d = np.meshgrid(lon, lat)
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat2d)
with pytest.raises(ValueError):
ls = esmf_locstream(lon, lat2d)
with pytest.raises(ValueError):
ls = esmf_locstream(lon2d, lat)
Loading

0 comments on commit 60cd934

Please sign in to comment.