Skip to content

Commit

Permalink
Implement Salem-style horizontal coordinates with dimension renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Feb 11, 2022
1 parent 7b3c4cc commit 2cb6a5d
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 10 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
netCDF4>=1.5.5
xarray>=0.18,!=0.20.0,!=0.20.1
donfig>=0.6.0
pyproj>=2.4.1
27 changes: 24 additions & 3 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,40 @@


@importorskip('cf_xarray')
@pytest.mark.parametrize('name', ['lambert_conformal'])
def test_postprocess(name):
@pytest.mark.parametrize(
'name, cf_grid_mapping_name', [('lambert_conformal', 'lambert_conformal_conic')]
)
def test_postprocess(name, cf_grid_mapping_name):

# Verify initial/raw state
raw_ds = xwrf.tutorial.open_dataset(name)
assert pd.api.types.is_string_dtype(raw_ds.Times.dtype)
assert pd.api.types.is_numeric_dtype(raw_ds.Time.dtype)
assert 'time' not in raw_ds.cf.coordinates
assert raw_ds.cf.standard_names == {}

# Postprocess
ds = raw_ds.xwrf.postprocess()

# Check for time coordinate handling
assert pd.api.types.is_datetime64_dtype(ds.Time.dtype)
assert 'time' in ds.cf.coordinates
standard_names = ds.cf.standard_names

# Check for projection handling
assert ds['wrf_projection'].attrs['grid_mapping_name'] == cf_grid_mapping_name

# Check for standard name and variable handling
standard_names = ds.cf.standard_names
assert 'x' in standard_names['projection_x_coordinate']
assert 'y' in standard_names['projection_y_coordinate']
assert 'z' in standard_names['atmosphere_hybrid_sigma_pressure_coordinate']
assert standard_names['time'] == ['Time']
assert standard_names['humidity_mixing_ratio'] == ['Q2', 'QVAPOR']
assert standard_names['air_temperature'] == ['T2']

# Check for time dimension reduction
assert ds['z'].shape == (39,)
assert ds['z_stag'].shape == (40,)
assert ds['XLAT'].shape == ds['XLONG'].shape == (29, 31)
assert ds['XLAT_U'].shape == ds['XLONG_U'].shape == (29, 32)
assert ds['XLAT_V'].shape == ds['XLONG_V'].shape == (30, 31)
82 changes: 82 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import pyproj
import pytest

import xwrf
from xwrf.grid import _wrf_grid_from_dataset, wgs84


@pytest.fixture(scope='session', params=['dummy'])
def dummy_dataset(request):
return xwrf.tutorial.open_dataset(request.param)


@pytest.fixture(scope='session', params=['dummy_salem_parsed'])
def dummy_salem(request):
return xwrf.tutorial.open_dataset(request.param)


@pytest.fixture(scope='session')
def test_grid(request):
return xwrf.tutorial.open_dataset(request.param)


@pytest.fixture(scope='session')
def cf_grid_mapping_name(request):
"""A no-op to allow parallel indirect use with open dataset fixture."""
return request.param


def test_grid_construction_against_salem(dummy_dataset, dummy_salem):
grid_params = _wrf_grid_from_dataset(dummy_dataset)

# Projection coordinate values
np.testing.assert_array_almost_equal(
grid_params['south_north'], dummy_salem['south_north'].values
)
np.testing.assert_array_almost_equal(grid_params['west_east'], dummy_salem['west_east'].values)

# Projection CRS
assert grid_params['crs'] == pyproj.CRS(dummy_salem['Q2'].attrs['pyproj_srs'])


@pytest.mark.parametrize(
'test_grid, cf_grid_mapping_name',
[
('polar_stereographic_1', 'polar_stereographic'),
('polar_stereographic_2', 'polar_stereographic'),
('lambert_conformal', 'lambert_conformal_conic'),
('mercator', 'mercator')
],
indirect=True,
)
def test_grid_construction_against_own_latlon(test_grid, cf_grid_mapping_name):
grid_params = _wrf_grid_from_dataset(test_grid)
trf = pyproj.Transformer.from_crs(grid_params['crs'], wgs84, always_xy=True)
recalculated = {}
recalculated['XLONG'], recalculated['XLAT'] = trf.transform(
*np.meshgrid(grid_params['west_east'], grid_params['south_north'])
)
recalculated['XLONG_U'], recalculated['XLAT_U'] = trf.transform(
*np.meshgrid(grid_params['west_east_stag'], grid_params['south_north'])
)
recalculated['XLONG_V'], recalculated['XLAT_V'] = trf.transform(
*np.meshgrid(grid_params['west_east'], grid_params['south_north_stag'])
)

assert grid_params['crs'].to_cf()['grid_mapping_name'] == cf_grid_mapping_name
for varname, recalculated_values in recalculated.items():
if varname in test_grid.data_vars:
np.testing.assert_array_almost_equal(
recalculated_values,
test_grid[varname].values[0],
decimal=2,
err_msg=f'Computed {varname} does not match with raw output',
)
elif '_' not in varname and varname + '_M' in test_grid.data_vars:
np.testing.assert_array_almost_equal(
recalculated_values,
test_grid[varname + '_M'].values[0],
decimal=2,
err_msg=f"Computed {varname + '_M'} does not match with raw output",
)
21 changes: 21 additions & 0 deletions tests/test_postprocess.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pyproj
import pytest

import xwrf
Expand Down Expand Up @@ -25,3 +26,23 @@ def test_cf_attrs_added(dummy_dataset, variable):
def test_remove_units_from_bool_arrays(dummy_attrs_only_dataset, variable):
dataset = xwrf.postprocess._remove_units_from_bool_arrays(dummy_attrs_only_dataset)
assert 'units' not in dataset[variable].attrs


def test_include_projection_coordinates(dummy_dataset):
dataset = xwrf.postprocess._include_projection_coordinates(dummy_dataset)
assert dataset['south_north'].attrs['axis'] == 'Y'
assert dataset['west_east'].attrs['axis'] == 'X'
assert isinstance(dataset['wrf_projection'].item(), pyproj.CRS)
assert dataset['Q2'].attrs['grid_mapping'] == 'wrf_projection'


def test_warning_on_projection_coordinate_failure(dummy_attrs_only_dataset):
with pytest.warns(UserWarning):
dataset = xwrf.postprocess._include_projection_coordinates(dummy_attrs_only_dataset)
assert dataset is dummy_attrs_only_dataset


def test_rename_dims(dummy_dataset):
dataset = xwrf.postprocess._rename_dims(dummy_dataset)
assert {'x', 'y'}.intersection(set(dataset.dims))
assert not {'south_north', 'west_east'}.intersection(set(dataset.dims))
7 changes: 6 additions & 1 deletion xwrf/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import xarray as xr

from .postprocess import (
_assign_coord_to_dim_of_different_name,
_collapse_time_dim,
_decode_times,
_include_projection_coordinates,
_modify_attrs_to_cf,
_remove_units_from_bool_arrays,
_rename_dims,
)


Expand Down Expand Up @@ -43,8 +46,10 @@ def postprocess(self, decode_times=True) -> xr.Dataset:
self.xarray_obj.pipe(_modify_attrs_to_cf)
.pipe(_remove_units_from_bool_arrays)
.pipe(_collapse_time_dim)
.pipe(_include_projection_coordinates)
.pipe(_assign_coord_to_dim_of_different_name)
)
if decode_times:
ds = ds.pipe(_decode_times)

return ds
return ds.pipe(_rename_dims)
76 changes: 72 additions & 4 deletions xwrf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
version: 0.1

horizontal_dims:
- south_north
- west_east
- south_north_stag
- west_east_stag

latitude_coords:
- CLAT
- XLAT
Expand All @@ -16,6 +22,10 @@ longitude_coords:
- XLONG_U
- XLONG_V

vertical_coords:
- ZNU
- ZNW

time_coords:
- XTIME
- Times
Expand All @@ -27,13 +37,47 @@ boolean_units_attrs:
- flag

cf_attribute_map:
ZNW:
standard_name: atmosphere_hybrid_sigma_pressure_coordinate
axis: Z
south_north:
units: m
standard_name: projection_y_coordinate
axis: Y
west_east:
units: m
standard_name: projection_x_coordinate
axis: X
south_north_stag:
units: m
standard_name: projection_y_coordinate
axis: Y
c_grid_axis_shift: 0.5
west_east_stag:
units: m
standard_name: projection_x_coordinate
axis: X
c_grid_axis_shift: 0.5
XLAT_M:
units: degree_north
standard_name: latitude
XLAT:
standard_name: latitude
XLAT_U:
standard_name: latitude
XLAT_V:
standard_name: latitude
XLONG_M:
units: degree_east
standard_name: longitude
XLONG:
standard_name: longitude
XLONG_U:
standard_name: longitude
XLONG_V:
standard_name: longitude
ZNU:
standard_name: atmosphere_hybrid_sigma_pressure_coordinate
axis: Z
ZNW:
axis: Z
c_grid_axis_shift: 0.5
XTIME:
axis: T
U10:
Expand Down Expand Up @@ -82,3 +126,27 @@ cf_attribute_map:
standard_name: integral_of_surface_upward_heat_flux_in_air_wrt_time
ACLHF:
standard_name: integral_of_surface_upward_latent_heat_flux_in_air_wrf_time

conditional_cf_attribute_map:
HYBRID_OPT==0:
ZNU:
standard_name: atmosphere_sigma_coordinate
ZNW:
standard_name: atmosphere_sigma_coordinate
HYBRID_OPT!=0:
ZNU:
standard_name: atmosphere_hybrid_sigma_pressure_coordinate
ZNW:
standard_name: atmosphere_hybrid_sigma_pressure_coordinate

assign_coord_to_dim_map:
ZNU: bottom_top
ZNW: bottom_top_stag

rename_dim_map:
south_north: y
west_east: x
south_north_stag: y_stag
west_east_stag: x_stag
bottom_top: z
bottom_top_stag: z_stag
76 changes: 76 additions & 0 deletions xwrf/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Provide grid-related functionality specific to WRF datasets.
This submodule contains code reused with modification from Salem under the terms of the BSD
3-Clause License. Salem is Copyright (c) 2014-2021, Fabien Maussion and Salem Development Team All
rights reserved.
"""
from __future__ import annotations # noqa: F401

from typing import Hashable, Mapping

import numpy as np
import pyproj
import xarray as xr

# Default CRS (lon/lat on WGS84, which is EPSG:4326)
wgs84 = pyproj.CRS(4326)


def _wrf_grid_from_dataset(ds: xr.Dataset) -> Mapping[Hashable, pyproj.CRS | np.ndarray]:
"""Get the WRF projection and dimension coordinates out of the file."""

# Use standards from a typical WRF file
cen_lon = ds.CEN_LON
cen_lat = ds.CEN_LAT
dx = ds.DX
dy = ds.DY
proj_id = ds.MAP_PROJ

pargs = {
'x_0': 0,
'y_0': 0,
'a': 6370000,
'b': 6370000,
'lat_1': ds.TRUELAT1,
'lat_2': getattr(ds, 'TRUELAT2', ds.TRUELAT1),
'lat_0': ds.MOAD_CEN_LAT,
'lon_0': ds.STAND_LON,
'center_lon': cen_lon,
}

if proj_id == 1:
# Lambert
pargs['proj'] = 'lcc'
del pargs['center_lon']
elif proj_id == 2:
# Polar stereo
pargs['proj'] = 'stere'
pargs['lat_ts'] = pargs['lat_1']
pargs['lat_0'] = 90.0
del pargs['lat_1'], pargs['lat_2'], pargs['center_lon']
elif proj_id == 3:
# Mercator
pargs['proj'] = 'merc'
pargs['lat_ts'] = pargs['lat_1']
pargs['lon_0'] = pargs['center_lon']
del pargs['lat_0'], pargs['lat_1'], pargs['lat_2'], pargs['center_lon']
else:
raise NotImplementedError(f'WRF proj not implemented yet: {proj_id}')

# Construct the pyproj CRS (letting errors fail through)
crs = pyproj.CRS(pargs)

# Get grid specifications
trf = pyproj.Transformer.from_crs(wgs84, crs, always_xy=True)
nx = ds.dims['west_east']
ny = ds.dims['south_north']
e, n = trf.transform(cen_lon, cen_lat)
x0 = -(nx - 1) / 2.0 * dx + e # DL corner
y0 = -(ny - 1) / 2.0 * dy + n # DL corner

return {
'crs': crs,
'south_north': y0 + np.arange(ny) * dy,
'west_east': x0 + np.arange(nx) * dx,
'south_north_stag': y0 + (np.arange(ny + 1) - 0.5) * dy,
'west_east_stag': x0 + (np.arange(nx + 1) - 0.5) * dx,
}
Loading

0 comments on commit 2cb6a5d

Please sign in to comment.