Skip to content

Commit

Permalink
Significantly faster cohorts detection. (#272)
Browse files Browse the repository at this point in the history
* Significantly faster cohorts detection.

* cleanup

* add benchmark

* update types

* fix

* single chunk optimization

* more optimization

* more optimization

* add comment

* fix benchmark
  • Loading branch information
dcherian authored Oct 11, 2023
1 parent 9f82e19 commit a897034
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
14 changes: 13 additions & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dask
import numpy as np
import pandas as pd
import xarray as xr

import flox
from flox.xarray import xarray_reduce


class Cohorts:
Expand All @@ -12,7 +14,7 @@ def setup(self, *args, **kwargs):
raise NotImplementedError

def time_find_group_cohorts(self):
flox.core.find_group_cohorts(self.by, self.array.chunks)
flox.core.find_group_cohorts(self.by, [self.array.chunks[ax] for ax in self.axis])
# The cache clear fails dependably in CI
# Not sure why
try:
Expand Down Expand Up @@ -125,3 +127,13 @@ class PerfectMonthlyRechunked(PerfectMonthly):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


def time_cohorts_era5_single():
TIME = 900 # 92044 in Google ARCO ERA5
da = xr.DataArray(
dask.array.ones((TIME, 721, 1440), chunks=(1, -1, -1)),
dims=("time", "lat", "lon"),
coords=dict(time=pd.date_range("1959-01-01", freq="6H", periods=TIME)),
)
xarray_reduce(da, da.time.dt.day, method="cohorts", func="any")
21 changes: 15 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,20 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
# 1. First subset the array appropriately
axis = range(-labels.ndim, 0)
# Easier to create a dask array and use the .blocks property
array = dask.array.ones(tuple(sum(c) for c in chunks), chunks=chunks)
array = dask.array.empty(tuple(sum(c) for c in chunks), chunks=chunks)
labels = np.broadcast_to(labels, array.shape[-labels.ndim :])

# Iterate over each block and create a new block of same shape with "chunk number"
shape = tuple(array.blocks.shape[ax] for ax in axis)
blocks = np.empty(math.prod(shape), dtype=object)
for idx, block in enumerate(array.blocks.ravel()):
blocks[idx] = np.full(tuple(block.shape[ax] for ax in axis), idx)
which_chunk = np.block(blocks.reshape(shape).tolist()).reshape(-1)
# Use a numpy object array to enable assignment in the loop
# TODO: is it possible to just use a nested list?
# That is what we need for `np.block`
blocks = np.empty(shape, dtype=object)
array_chunks = tuple(np.array(c) for c in array.chunks)
for idx, blockindex in enumerate(np.ndindex(array.numblocks)):
chunkshape = tuple(c[i] for c, i in zip(array_chunks, blockindex))
blocks[blockindex] = np.full(chunkshape, idx)
which_chunk = np.block(blocks.tolist()).reshape(-1)

raveled = labels.reshape(-1)
# these are chunks where a label is present
Expand All @@ -229,7 +234,11 @@ def invert(x) -> tuple[np.ndarray, ...]:

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

if merge:
# If our dataset has chunksize one along the axis,
# then no merging is possible.
single_chunks = all((ac == 1).all() for ac in array_chunks)

if merge and not single_chunks:
# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
Expand Down

0 comments on commit a897034

Please sign in to comment.