diff --git a/tests/geospatial/test_climatology.py b/tests/geospatial/test_climatology.py index 9148e9cda5..1318943768 100644 --- a/tests/geospatial/test_climatology.py +++ b/tests/geospatial/test_climatology.py @@ -39,7 +39,8 @@ def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dat stacked = stacked.fillna(stacked.sel(dayofyear=365)) stacked = stacked.pad(pad_width={"dayofyear": half_window_size}, mode="wrap") stacked = stacked.rolling(dayofyear=window_size, center=True).construct("window") - return stacked.weighted(window_weights).mean(dim=("window", "year")) + rolling = stacked.weighted(window_weights).mean(dim=("window", "year")) + return rolling.isel(dayofyear=slice(half_window_size, -half_window_size)) def create_window_weights(window_size: int) -> xr.DataArray: @@ -78,7 +79,7 @@ def test_compute_climatology(client, gcs_url, scale): # Load dataset ds = xr.open_zarr( "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr", - ) # .drop_encoding() + ) if scale == "small": # 101.83 GiB (small) @@ -96,22 +97,24 @@ def test_compute_climatology(client, gcs_url, scale): ds = ds[variables].sel(time=time_range) ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims]) - input_chunks_without_time = { - dim: chunks for dim, chunks in ds.chunks.items() if dim != "time" - } - pencil_chunks = {"time": -1, "longitude": 4, "latitude": 4} - ds = ds.chunk(pencil_chunks) + pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"} + working = ds.chunk(pencil_chunks) hours = xr.DataArray(range(0, 24, 6), dims=["hour"]) - daysofyear = xr.DataArray(range(0, 367), dims=["dayofyear"]) + daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"]) template = ( - ds.isel(time=0) + working.isel(time=0) .drop_vars("time") .expand_dims(hour=hours, dayofyear=daysofyear) .assign_coords(hour=hours, dayofyear=daysofyear) ) - ds = ds.map_blocks(compute_hourly_climatology, template=template) + working = working.map_blocks(compute_hourly_climatology, template=template) - pancake_chunks = {"hour": 1, "dayofyear": 1, **input_chunks_without_time} - ds = ds.chunk(pancake_chunks) - ds.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()}) + pancake_chunks = { + "hour": 1, + "dayofyear": 1, + "latitude": ds.chunks["latitude"], + "longitude": ds.chunks["longitude"], + } + result = working.chunk(pancake_chunks) + result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})