Skip to content

Commit

Permalink
Add NWM zonal averaging workflow (#1547)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Sep 18, 2024
1 parent 98ed1d2 commit 624979c
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 65 deletions.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies:
- bokeh ==3.5.1
- gilknocker ==0.4.1
- openssl >1.1.0g
- rioxarray ==0.17.0

########################################################
# PLEASE READ:
Expand Down
56 changes: 16 additions & 40 deletions tests/geospatial/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import uuid
from typing import Any, Literal

import coiled
import pytest
Expand All @@ -20,50 +19,27 @@ def scale(request):
return request.config.getoption("scale")


def get_cluster_spec(scale: Literal["small", "large"]) -> dict[str, Any]:
everywhere = dict(
workspace="dask-engineering-gcp",
region="us-central1",
wait_for_workers=True,
spot_policy="on-demand",
)

if scale == "small":
return {
"n_workers": 10,
**everywhere,
}
elif scale == "large":
return {
"n_workers": 100,
**everywhere,
}


@pytest.fixture(scope="module")
def cluster_name(request, scale):
module = os.path.basename(request.fspath).split(".")[0]
module = module.replace("test_", "")
return f"geospatial-{module}-{scale}-{uuid.uuid4().hex[:8]}"


@pytest.fixture(scope="module")
def cluster(
cluster_name,
scale,
github_cluster_tags,
):
kwargs = dict(
name=cluster_name,
tags=github_cluster_tags,
**get_cluster_spec(scale),
)
with coiled.Cluster(**kwargs) as cluster:
yield cluster


@pytest.fixture()
def client(cluster, benchmark_all):
with cluster.get_client() as client:
with benchmark_all(client):
yield client
def client_factory(cluster_name, github_cluster_tags, benchmark_all):
import contextlib

@contextlib.contextmanager
def _(n_workers, **cluster_kwargs):
with coiled.Cluster(
name=cluster_name,
tags=github_cluster_tags,
n_workers=n_workers,
**cluster_kwargs,
) as cluster:
with cluster.get_client() as client:
with benchmark_all(client):
yield client

return _
65 changes: 40 additions & 25 deletions tests/geospatial/test_rechunking.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,47 @@
import pytest
import xarray as xr
from coiled.credentials.google import CoiledShippedCredentials


@pytest.mark.client("era5_rechunking")
def test_era5_rechunking(client, gcs_url, scale):
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr",
).drop_encoding()
def test_era5_rechunking(
gcs_url,
scale,
client_factory,
cluster_kwargs={
"workspace": "dask-engineering-gcp",
"region": "us-central1",
"wait_for_workers": True,
},
scale_kwargs={
"small": {"n_workers": 10},
"medium": {"n_workers": 100},
"large": {"n_workers": 100},
},
):
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr",
).drop_encoding()

if scale == "small":
# 101.83 GiB (small)
time_range = slice("2020-01-01", "2023-01-01")
variables = ["sea_surface_temperature"]
elif scale == "medium":
# 2.12 TiB (medium)
time_range = slice(None)
variables = ["sea_surface_temperature"]
else:
# 4.24 TiB (large)
# This currently doesn't complete successfully.
time_range = slice(None)
variables = ["sea_surface_temperature", "snow_depth"]
subset = ds[variables].sel(time=time_range)
if scale == "small":
# 101.83 GiB (small)
time_range = slice("2020-01-01", "2023-01-01")
variables = ["sea_surface_temperature"]
elif scale == "medium":
# 2.12 TiB (medium)
time_range = slice(None)
variables = ["sea_surface_temperature"]
else:
# 4.24 TiB (large)
# This currently doesn't complete successfully.
time_range = slice(None)
variables = ["sea_surface_temperature", "snow_depth"]
subset = ds[variables].sel(time=time_range)

# Rechunk
result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"})
# Rechunk
result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"})

# Write result to cloud storage
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
# Write result to cloud storage
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
58 changes: 58 additions & 0 deletions tests/geospatial/test_zonal_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
This example was adapted from https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb
"""

import flox.xarray
import numpy as np
import rioxarray
import xarray as xr


def test_nwm(
s3,
scale,
client_factory,
cluster_kwargs={
"workspace": "dask-engineering",
"region": "us-east-1",
"wait_for_workers": True,
},
scale_kwargs={
"small": {"n_workers": 10},
"large": {"n_workers": 200, "scheduler_memory": "32 GiB"},
},
):
with client_factory(
**scale_kwargs[scale], **cluster_kwargs
) as client: # noqa: F841
ds = xr.open_zarr(
"s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", consolidated=True
)

if scale == "small":
# 6.03 TiB
time_range = slice("2020-01-01", "2020-12-31")
else:
# 252.30 TiB
time_range = slice("1979-02-01", "2020-12-31")
subset = ds.zwattablrt.sel(time=time_range)

counties = rioxarray.open_rasterio(
s3.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"),
chunks="auto",
).squeeze()

# Remove any small floating point error in coordinate locations
_, counties_aligned = xr.align(subset, counties, join="override")
counties_aligned = counties_aligned.persist()

county_id = np.unique(counties_aligned.data).compute()
county_id = county_id[county_id != 0]
county_mean = flox.xarray.xarray_reduce(
subset,
counties_aligned.rename("county"),
func="mean",
expected_groups=(county_id,),
)

county_mean.compute()

0 comments on commit 624979c

Please sign in to comment.