diff --git a/ci/environment.yml b/ci/environment.yml index 7b2072ff46..68f2e7177e 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -47,6 +47,7 @@ dependencies: - bokeh ==3.5.1 - gilknocker ==0.4.1 - openssl >1.1.0g + - rioxarray ==0.17.0 ######################################################## # PLEASE READ: diff --git a/tests/geospatial/conftest.py b/tests/geospatial/conftest.py index bc8fd9a0e9..99ae0718b5 100644 --- a/tests/geospatial/conftest.py +++ b/tests/geospatial/conftest.py @@ -1,6 +1,5 @@ import os import uuid -from typing import Any, Literal import coiled import pytest @@ -20,26 +19,6 @@ def scale(request): return request.config.getoption("scale") -def get_cluster_spec(scale: Literal["small", "large"]) -> dict[str, Any]: - everywhere = dict( - workspace="dask-engineering-gcp", - region="us-central1", - wait_for_workers=True, - spot_policy="on-demand", - ) - - if scale == "small": - return { - "n_workers": 10, - **everywhere, - } - elif scale == "large": - return { - "n_workers": 100, - **everywhere, - } - - @pytest.fixture(scope="module") def cluster_name(request, scale): module = os.path.basename(request.fspath).split(".")[0] @@ -47,23 +26,20 @@ def cluster_name(request, scale): return f"geospatial-{module}-{scale}-{uuid.uuid4().hex[:8]}" -@pytest.fixture(scope="module") -def cluster( - cluster_name, - scale, - github_cluster_tags, -): - kwargs = dict( - name=cluster_name, - tags=github_cluster_tags, - **get_cluster_spec(scale), - ) - with coiled.Cluster(**kwargs) as cluster: - yield cluster - - @pytest.fixture() -def client(cluster, benchmark_all): - with cluster.get_client() as client: - with benchmark_all(client): - yield client +def client_factory(cluster_name, github_cluster_tags, benchmark_all): + import contextlib + + @contextlib.contextmanager + def _(n_workers, **cluster_kwargs): + with coiled.Cluster( + name=cluster_name, + tags=github_cluster_tags, + n_workers=n_workers, + **cluster_kwargs, + ) as cluster: + with cluster.get_client() as client: + with benchmark_all(client): + yield client + + return _ diff --git a/tests/geospatial/test_rechunking.py b/tests/geospatial/test_rechunking.py index 13712e0851..3782d4e0bb 100644 --- a/tests/geospatial/test_rechunking.py +++ b/tests/geospatial/test_rechunking.py @@ -1,32 +1,47 @@ -import pytest import xarray as xr from coiled.credentials.google import CoiledShippedCredentials -@pytest.mark.client("era5_rechunking") -def test_era5_rechunking(client, gcs_url, scale): - # Load dataset - ds = xr.open_zarr( - "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", - ).drop_encoding() +def test_era5_rechunking( + gcs_url, + scale, + client_factory, + cluster_kwargs={ + "workspace": "dask-engineering-gcp", + "region": "us-central1", + "wait_for_workers": True, + }, + scale_kwargs={ + "small": {"n_workers": 10}, + "medium": {"n_workers": 100}, + "large": {"n_workers": 100}, + }, +): + with client_factory( + **scale_kwargs[scale], **cluster_kwargs + ) as client: # noqa: F841 + # Load dataset + ds = xr.open_zarr( + "gs://weatherbench2/datasets/era5/1959-2023_01_10-full_37-1h-0p25deg-chunk-1.zarr", + ).drop_encoding() - if scale == "small": - # 101.83 GiB (small) - time_range = slice("2020-01-01", "2023-01-01") - variables = ["sea_surface_temperature"] - elif scale == "medium": - # 2.12 TiB (medium) - time_range = slice(None) - variables = ["sea_surface_temperature"] - else: - # 4.24 TiB (large) - # This currently doesn't complete successfully. - time_range = slice(None) - variables = ["sea_surface_temperature", "snow_depth"] - subset = ds[variables].sel(time=time_range) + if scale == "small": + # 101.83 GiB (small) + time_range = slice("2020-01-01", "2023-01-01") + variables = ["sea_surface_temperature"] + elif scale == "medium": + # 2.12 TiB (medium) + time_range = slice(None) + variables = ["sea_surface_temperature"] + else: + # 4.24 TiB (large) + # This currently doesn't complete successfully. + time_range = slice(None) + variables = ["sea_surface_temperature", "snow_depth"] + subset = ds[variables].sel(time=time_range) - # Rechunk - result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"}) + # Rechunk + result = subset.chunk({"time": -1, "longitude": "auto", "latitude": "auto"}) - # Write result to cloud storage - result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + # Write result to cloud storage + result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) diff --git a/tests/geospatial/test_zonal_average.py b/tests/geospatial/test_zonal_average.py new file mode 100644 index 0000000000..94ede84c3b --- /dev/null +++ b/tests/geospatial/test_zonal_average.py @@ -0,0 +1,58 @@ +""" +This example was adapted from https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb +""" + +import flox.xarray +import numpy as np +import rioxarray +import xarray as xr + + +def test_nwm( + s3, + scale, + client_factory, + cluster_kwargs={ + "workspace": "dask-engineering", + "region": "us-east-1", + "wait_for_workers": True, + }, + scale_kwargs={ + "small": {"n_workers": 10}, + "large": {"n_workers": 200, "scheduler_memory": "32 GiB"}, + }, +): + with client_factory( + **scale_kwargs[scale], **cluster_kwargs + ) as client: # noqa: F841 + ds = xr.open_zarr( + "s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", consolidated=True + ) + + if scale == "small": + # 6.03 TiB + time_range = slice("2020-01-01", "2020-12-31") + else: + # 252.30 TiB + time_range = slice("1979-02-01", "2020-12-31") + subset = ds.zwattablrt.sel(time=time_range) + + counties = rioxarray.open_rasterio( + s3.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), + chunks="auto", + ).squeeze() + + # Remove any small floating point error in coordinate locations + _, counties_aligned = xr.align(subset, counties, join="override") + counties_aligned = counties_aligned.persist() + + county_id = np.unique(counties_aligned.data).compute() + county_id = county_id[county_id != 0] + county_mean = flox.xarray.xarray_reduce( + subset, + counties_aligned.rename("county"), + func="mean", + expected_groups=(county_id,), + ) + + county_mean.compute()