Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCTracks: improve hdf5 I/O #735

Merged
merged 8 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 60 additions & 109 deletions climada/hazard/tc_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
__all__ = ['CAT_NAMES', 'SAFFIR_SIM_CAT', 'TCTracks', 'set_category']

# standard libraries
import contextlib
import datetime as dt
import itertools
import logging
Expand Down Expand Up @@ -53,10 +52,6 @@
from sklearn.metrics import DistanceMetric
import statsmodels.api as sm
import xarray as xr
from xarray.backends import NetCDF4DataStore
from xarray.backends.api import dump_to_store
from xarray.backends.common import ArrayWriter
from xarray.backends.store import StoreBackendEntrypoint

# climada dependencies
from climada.util import ureg
Expand Down Expand Up @@ -1366,21 +1361,27 @@ def write_hdf5(self, file_name, complevel=5):
Specifies a compression level (0-9) for the zlib compression of the data.
A value of 0 or None disables compression. Default: 5
"""
# change dtype from bool to int to be NetCDF4-compliant, this is undone later
data = []
for track in self.data:
# convert "time" into a data variable and use a neutral name for the steps
track = track.rename(time="step")
track["time"] = ("step", track["step"].values)
track["step"] = np.arange(track.sizes["step"])
# change dtype from bool to int to be NetCDF4-compliant
track.attrs['orig_event_flag'] = int(track.attrs['orig_event_flag'])
try:
encoding = {
f'track{i}': {var: dict(zlib=True, complevel=complevel) for var in track.data_vars}
for i, track in enumerate(self.data)
}
ds_dict = {f'track{i}': track for i, track in enumerate(self.data)}
LOGGER.info('Writing %d tracks to %s', self.size, file_name)
_xr_to_netcdf_multi(file_name, ds_dict, encoding=encoding)
finally:
# ensure to undo the temporal change of dtype from above
for track in self.data:
track.attrs['orig_event_flag'] = bool(track.attrs['orig_event_flag'])
data.append(track)

# concatenate all data sets along new dimension "storm"
ds_combined = xr.combine_nested(data, concat_dim=["storm"])
ds_combined["storm"] = np.arange(ds_combined.sizes["storm"])

# convert attributes to data variables of combined dataset
df_attrs = pd.DataFrame([t.attrs for t in data], index=ds_combined["storm"].to_series())
ds_combined = xr.merge([ds_combined, df_attrs.to_xarray()])

encoding = {v: dict(zlib=True, complevel=complevel) for v in ds_combined.data_vars}
LOGGER.info('Writing %d tracks to %s', self.size, file_name)
ds_combined.to_netcdf(file_name, encoding=encoding)

@classmethod
def from_hdf5(cls, file_name):
Expand All @@ -1396,16 +1397,49 @@ def from_hdf5(cls, file_name):
tracks : TCTracks
TCTracks with data from the given HDF5 file.
"""
ds_dict = _xr_open_dataset_multi(file_name, prefix="track")
track_no = sorted(int(key[5:]) for key in ds_dict.keys())
ds_combined = xr.open_dataset(file_name)
if len(ds_combined.dims) == 0:
# this might be the legacy file format that is no longer supported, double-check:
try:
with xr.open_dataset(file_name, group="track0") as tr:
assert "time" in tr.dims and "max_sustained_wind" in tr.variables
is_legacy = True
except:
is_legacy = False
raise ValueError(
(
f"The file you try to read ({file_name}) is in a format that is no longer"
" supported by CLIMADA. Please store the data again using"
" TCTracks.write_hdf5. If you struggle to convert the data, please open an"
" issue on GitHub."
) if is_legacy else (
f"Unknown HDF5/NetCDF file format: {file_name}"
)
)

# when writing '<U*' and reading in again, xarray reads as dtype 'object'. undo this:
for varname in ds_combined.data_vars:
if ds_combined[varname].dtype == "object":
ds_combined[varname] = ds_combined[varname].astype(str)
data = []
for i in track_no:
track = ds_dict[f'track{i}']
for i in range(ds_combined.sizes["storm"]):
# extract a single storm and restrict to valid time steps
track = (
ds_combined
.isel(storm=i)
.dropna(dim="step", how="any", subset=["time", "lat", "lon"])
)
# convert the "time" variable to a coordinate
track = track.drop_vars(["storm", "step"]).rename(step="time")
track = track.assign_coords(time=track["time"]).compute()
# convert 0-dimensional variables to attributes:
attr_vars = [v for v in track.data_vars if track[v].ndim == 0]
track = (
track
.assign_attrs({v: track[v].item() for v in attr_vars})
.drop_vars(attr_vars)
)
track.attrs['orig_event_flag'] = bool(track.attrs['orig_event_flag'])
# when writing '<U*' and reading in again, xarray reads as dtype 'object'. undo this:
for varname in track.data_vars:
if track[varname].dtype == "object":
track[varname] = track[varname].astype(str)
data.append(track)
return cls(data)

Expand Down Expand Up @@ -1541,89 +1575,6 @@ def _one_interp_data(track, time_step_h, land_geom=None):
track_land_params(track_int, land_geom)
return track_int

def _xr_to_netcdf_multi(path, ds_dict, encoding=None):
"""Write multiple xarray Datasets to separate groups in a single NetCDF4 file

Contrary to xarray's `to_netcdf` functionality, this only supports the "NETCDF4" format and the
"netcdf4" engine since the groups feature has been introduced by NetCDF version 4.

Parameters
----------
path : str or Path
Path of the target NetCDF file.
ds_dict : dict whose keys are group names and values are xr.Dataset
Each xr.Dataset in the dict is stored in the group identified by its key in the dict.
Note that an empty string ("") is a valid group name and refers to the root group.
encoding : dict whose keys are group names and values are dict, optional
For each dataset/group, one dict that is compliant with the format of the `encoding`
keyword parameter in `xr.Dataset.to_netcdf`. Default: None
"""
# pylint: disable=protected-access
path = str(pathlib.Path(path).expanduser().absolute())
with contextlib.closing(NetCDF4DataStore.open(path, "w", "NETCDF4", None)) as store:
writer = ArrayWriter()
for group, dataset in ds_dict.items():
store._group = group
unlimited_dims = dataset.encoding.get("unlimited_dims", None)
encoding = None if encoding is None or group not in encoding else encoding[group]
dump_to_store(dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims)

def _xr_open_dataset_multi(path, prefix=""):
"""Read multiple xarray Datasets from groups contained in a single NetCDF4 file

The data is loaded into memory

Contrary to xarray's `open_dataset` functionality, this only supports the "netcdf4" engine
since the groups feature has been introduced by NetCDF version 4.

Parameters
----------
path : str or Path
Path of the NetCDF file to read.
prefix : str, optional
If given, only read groups whose name starts with this prefix. Default: ""

Returns
-------
ds_dict : dict whose keys are group names and values are xr.Dataset
Each xr.Dataset in the dict is taken from the group identified by its key in the dict.
Note that an empty string ("") is a valid group name and refers to the root group.
"""
# pylint: disable=protected-access
path = str(pathlib.Path(path).expanduser().absolute())
ds_dict = {}
with contextlib.closing(NetCDF4DataStore.open(path, "r", "NETCDF4", None)) as store:
groups = [g for g in _xr_nc4_groups_from_store(store) if g.startswith(prefix)]
store_entrypoint = StoreBackendEntrypoint()
LOGGER.info('Reading %d datasets from %s', len(groups), path)
for group in groups:
store._group = group
ds = store_entrypoint.open_dataset(store)
ds.load()
ds_dict[group] = ds
return ds_dict

def _xr_nc4_groups_from_store(store):
"""List all groups contained in the given NetCDF4 data store

Parameters
----------
store : xarray.backend.NetCDF4DataStore

Returns
-------
list of str
"""
# pylint: disable=protected-access
def iter_groups(ds, prefix=""):
groups = [""]
for group_name, group_ds in ds.groups.items():
groups.extend([f"{prefix}{group_name}{subgroup}"
for subgroup in iter_groups(group_ds, prefix="/")])
return groups
with store._manager.acquire_context(False) as root:
return iter_groups(root)

def _read_one_gettelman(nc_data, i_track):
"""Read a single track from Andrew Gettelman's NetCDF dataset

Expand Down
Binary file added climada/hazard/test/data/tctracks_hdf5_legacy.nc
Binary file not shown.
6 changes: 6 additions & 0 deletions climada/hazard/test/test_tc_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TEST_TRACK_CHAZ = DATA_DIR.joinpath('chaz_test_tracks.nc')
TEST_TRACK_STORM = DATA_DIR.joinpath('storm_test_tracks.txt')
TEST_TRACKS_ANTIMERIDIAN = DATA_DIR.joinpath('tracks-antimeridian')
TEST_TRACKS_LEGACY_HDF5 = DATA_DIR.joinpath('tctracks_hdf5_legacy.nc')


class TestIbtracs(unittest.TestCase):
Expand Down Expand Up @@ -339,6 +340,11 @@ def test_hdf5_io(self):
np.testing.assert_array_equal(tr[v].values, tr_read[v].values)
self.assertEqual(tr.sid, tr_read.sid)

# attempting to read the legacy file format should fail gracefully
with self.assertRaises(ValueError) as cm:
tc.TCTracks.from_hdf5(TEST_TRACKS_LEGACY_HDF5)
self.assertIn("no longer supported by CLIMADA", str(cm.exception))

def test_from_processed_ibtracs_csv(self):
tc_track = tc.TCTracks.from_processed_ibtracs_csv(TEST_TRACK)

Expand Down