Skip to content

Commit

Permalink
Make things work
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Sep 20, 2024
1 parent b01a554 commit 4cff7b5
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions tests/geospatial/test_climatology.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def compute_rolling_mean(ds: xr.Dataset, window_weights: xr.DataArray) -> xr.Dat
stacked = stacked.fillna(stacked.sel(dayofyear=365))
stacked = stacked.pad(pad_width={"dayofyear": half_window_size}, mode="wrap")
stacked = stacked.rolling(dayofyear=window_size, center=True).construct("window")
return stacked.weighted(window_weights).mean(dim=("window", "year"))
rolling = stacked.weighted(window_weights).mean(dim=("window", "year"))
return rolling.isel(dayofyear=slice(half_window_size, -half_window_size))


def create_window_weights(window_size: int) -> xr.DataArray:
Expand Down Expand Up @@ -78,7 +79,7 @@ def test_compute_climatology(client, gcs_url, scale):
# Load dataset
ds = xr.open_zarr(
"gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr",
) # .drop_encoding()
)

if scale == "small":
# 101.83 GiB (small)
Expand All @@ -96,22 +97,24 @@ def test_compute_climatology(client, gcs_url, scale):
ds = ds[variables].sel(time=time_range)

ds = ds.drop_vars([k for k, v in ds.items() if "time" not in v.dims])
input_chunks_without_time = {
dim: chunks for dim, chunks in ds.chunks.items() if dim != "time"
}
pencil_chunks = {"time": -1, "longitude": 4, "latitude": 4}
ds = ds.chunk(pencil_chunks)
pencil_chunks = {"time": -1, "longitude": "auto", "latitude": "auto"}

working = ds.chunk(pencil_chunks)
hours = xr.DataArray(range(0, 24, 6), dims=["hour"])
daysofyear = xr.DataArray(range(0, 367), dims=["dayofyear"])
daysofyear = xr.DataArray(range(1, 367), dims=["dayofyear"])
template = (
ds.isel(time=0)
working.isel(time=0)
.drop_vars("time")
.expand_dims(hour=hours, dayofyear=daysofyear)
.assign_coords(hour=hours, dayofyear=daysofyear)
)
ds = ds.map_blocks(compute_hourly_climatology, template=template)
working = working.map_blocks(compute_hourly_climatology, template=template)

pancake_chunks = {"hour": 1, "dayofyear": 1, **input_chunks_without_time}
ds = ds.chunk(pancake_chunks)
ds.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})
pancake_chunks = {
"hour": 1,
"dayofyear": 1,
"latitude": ds.chunks["latitude"],
"longitude": ds.chunks["longitude"],
}
result = working.chunk(pancake_chunks)
result.to_zarr(gcs_url, storage_options={"token": CoiledShippedCredentials()})

0 comments on commit 4cff7b5

Please sign in to comment.