Skip to content

Commit

Permalink
adjust rechunking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Aug 14, 2024
1 parent 74a5abb commit 4510f79
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions tests/benchmarks/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,34 @@ def configure_rechunking(request, memory_multiplier):


@pytest.fixture(params=["8 MiB", "128 MiB"])
def configure_chunksize(request, memory_multiplier):
if memory_multiplier > 0.4 and parse_bytes(request.param) < parse_bytes("64 MiB"):
def chunk_size(request):
return request.param


@pytest.fixture
def configure_chunksize(chunk_size, memory_multiplier):
if memory_multiplier > 0.4 and parse_bytes(chunk_size) < parse_bytes("64 MiB"):
pytest.skip("too slow")

with dask.config.set({"array.chunk-size": request.param}):
with dask.config.set({"array.chunk-size": chunk_size}):
yield


@pytest.fixture
def input_chunk_size(chunk_size):
return chunk_size


@pytest.fixture
def output_chunk_size(chunk_size):
return chunk_size


def test_tiles_to_rows(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
input_chunk_size,
output_chunk_size,
configure_rechunking,
small_client,
):
Expand All @@ -58,15 +74,17 @@ def test_tiles_to_rows(
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", "x"))

a = da.random.random(shape, chunks="auto")
a = a.rechunk((-1, "auto")).sum()
a = da.random.random(shape, chunks=input_chunk_size)
a = a.rechunk((-1, output_chunk_size)).sum()
wait(a, small_client, timeout=600)


def test_swap_axes(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
input_chunk_size,
output_chunk_size,
configure_rechunking,
small_client,
):
Expand All @@ -77,15 +95,16 @@ def test_swap_axes(
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", "x"))

a = da.random.random(shape, chunks=(-1, "auto"))
a = a.rechunk(("auto", -1)).sum()
a = da.random.random(shape, chunks=(-1, input_chunk_size))
a = a.rechunk((output_chunk_size, -1)).sum()
wait(a, small_client, timeout=600)


def test_adjacent_groups(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
input_chunk_size,
output_chunk_size,
configure_rechunking,
small_client,
):
Expand All @@ -95,8 +114,8 @@ def test_adjacent_groups(
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", 10, 10_000))

a = da.random.random(shape, chunks=("auto", 2, 5_000))
a = a.rechunk(("auto", 5, 10_000)).sum()
a = da.random.random(shape, chunks=(input_chunk_size, 2, 5_000))
a = a.rechunk((output_chunk_size, 5, 10_000)).sum()
wait(a, small_client, timeout=600)


Expand Down

0 comments on commit 4510f79

Please sign in to comment.