From 8a3b44ce91e363053eb42aef31fecd81d1d167bc Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 10 Mar 2023 13:11:03 +0100 Subject: [PATCH] Share thread pool among P2P shuffle runs (#7621) --- distributed/shuffle/_worker_extension.py | 33 ++-- distributed/shuffle/tests/test_rechunk.py | 116 ++++++----- distributed/shuffle/tests/test_shuffle.py | 230 +++++++++++----------- 3 files changed, 201 insertions(+), 178 deletions(-) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 9a219fd17a..751d69e1f5 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -63,7 +63,7 @@ def __init__( output_workers: set[str], local_address: str, directory: str, - nthreads: int, + executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -73,7 +73,7 @@ def __init__( self.run_id = run_id self.output_workers = output_workers self.local_address = local_address - self.executor = ThreadPoolExecutor(nthreads) + self.executor = executor self.rpc = rpc self.scheduler = scheduler self.closed = False @@ -191,10 +191,6 @@ async def close(self) -> None: self.closed = True await self._comm_buffer.close() await self._disk_buffer.close() - try: - self.executor.shutdown(cancel_futures=True) - except Exception: # pragma: no cover - self.executor.shutdown() self._closed_event.set() def fail(self, exception: Exception) -> None: @@ -255,8 +251,8 @@ class ArrayRechunkRun(ShuffleRun[ArrayRechunkShardID, NIndex, "np.ndarray"]): The local address this Shuffle can be contacted by using `rpc`. directory: The scratch directory to buffer data in. - nthreads: - How many background threads to use for compute. + executor: + Thread pool to use for offloading compute. loop: The event loop. rpc: @@ -280,7 +276,7 @@ def __init__( run_id: int, local_address: str, directory: str, - nthreads: int, + executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -294,7 +290,7 @@ def __init__( output_workers=output_workers, local_address=local_address, directory=directory, - nthreads=nthreads, + executor=executor, rpc=rpc, scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, @@ -410,8 +406,8 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]): The local address this Shuffle can be contacted by using `rpc`. directory: The scratch directory to buffer data in. - nthreads: - How many background threads to use for compute. + executor: + Thread pool to use for offloading compute. loop: The event loop. rpc: @@ -435,7 +431,7 @@ def __init__( run_id: int, local_address: str, directory: str, - nthreads: int, + executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -449,7 +445,7 @@ def __init__( output_workers=output_workers, local_address=local_address, directory=directory, - nthreads=nthreads, + executor=executor, rpc=rpc, scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, @@ -569,6 +565,7 @@ def __init__(self, worker: Worker) -> None: self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) self.closed = False + self._executor = ThreadPoolExecutor(self.worker.state.nthreads) # Handlers ########## @@ -815,7 +812,7 @@ async def _( self.worker.local_directory, f"shuffle-{shuffle_id}-{result['run_id']}", ), - nthreads=self.worker.state.nthreads, + executor=self._executor, local_address=self.worker.address, rpc=self.worker.rpc, scheduler=self.worker.scheduler, @@ -834,7 +831,7 @@ async def _( self.worker.local_directory, f"shuffle-{shuffle_id}-{result['run_id']}", ), - nthreads=self.worker.state.nthreads, + executor=self._executor, local_address=self.worker.address, rpc=self.worker.rpc, scheduler=self.worker.scheduler, @@ -855,6 +852,10 @@ async def close(self) -> None: _, shuffle = self.shuffles.popitem() await shuffle.close() self._runs.remove(shuffle) + try: + self._executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + self._executor.shutdown() ############################# # Methods for worker thread # diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index db7c85d304..1bde629d84 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -9,6 +9,8 @@ np = pytest.importorskip("numpy") da = pytest.importorskip("dask.array") +from concurrent.futures import ThreadPoolExecutor + import dask from dask.array.core import concatenate3 from dask.array.rechunk import normalize_chunks, rechunk @@ -26,6 +28,16 @@ class ArrayRechunkTestPool(AbstractShuffleTestPool): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._executor = ThreadPoolExecutor(2) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + self._executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + self._executor.shutdown() def new_shuffle( self, @@ -47,7 +59,7 @@ def new_shuffle( id=ShuffleId(name), run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), local_address=name, - nthreads=2, + executor=self._executor, rpc=self, scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), @@ -85,58 +97,60 @@ async def test_lowlevel_rechunk( assert len(set(worker_for_mapping.values())) == min(n_workers, len(new_indices)) - local_shuffle_pool = ArrayRechunkTestPool() - shuffles = [] - for i in range(n_workers): - shuffles.append( - local_shuffle_pool.new_shuffle( - name=workers[i], - worker_for_mapping=worker_for_mapping, - old=old, - new=new, - directory=tmp_path, - loop=loop_in_thread, + with ArrayRechunkTestPool() as local_shuffle_pool: + shuffles = [] + for i in range(n_workers): + shuffles.append( + local_shuffle_pool.new_shuffle( + name=workers[i], + worker_for_mapping=worker_for_mapping, + old=old, + new=new, + directory=tmp_path, + loop=loop_in_thread, + ) ) + random.seed(42) + if barrier_first_worker: + barrier_worker = shuffles[0] + else: + barrier_worker = random.sample(shuffles, k=1)[0] + + try: + for i, (idx, arr) in enumerate(old_chunks.items()): + s = shuffles[i % len(shuffles)] + await s.add_partition(arr, idx) + + await barrier_worker.barrier() + + total_bytes_sent = 0 + total_bytes_recvd = 0 + total_bytes_recvd_shuffle = 0 + for s in shuffles: + metrics = s.heartbeat() + assert metrics["comm"]["total"] == metrics["comm"]["written"] + total_bytes_sent += metrics["comm"]["written"] + total_bytes_recvd += metrics["disk"]["total"] + total_bytes_recvd_shuffle += s.total_recvd + + assert total_bytes_recvd_shuffle == total_bytes_sent + + all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O") + for ix, worker in worker_for_mapping.items(): + s = local_shuffle_pool.shuffles[worker] + all_chunks[ix] = await s.get_output_partition(ix) + + finally: + await asyncio.gather(*[s.close() for s in shuffles]) + + old_cs = np.empty(tuple(len(dim) for dim in old), dtype="O") + for ix, arr in old_chunks.items(): + old_cs[ix] = arr + np.testing.assert_array_equal( + concatenate3(old_cs.tolist()), + concatenate3(all_chunks.tolist()), + strict=True, ) - random.seed(42) - if barrier_first_worker: - barrier_worker = shuffles[0] - else: - barrier_worker = random.sample(shuffles, k=1)[0] - - try: - for i, (idx, arr) in enumerate(old_chunks.items()): - s = shuffles[i % len(shuffles)] - await s.add_partition(arr, idx) - - await barrier_worker.barrier() - - total_bytes_sent = 0 - total_bytes_recvd = 0 - total_bytes_recvd_shuffle = 0 - for s in shuffles: - metrics = s.heartbeat() - assert metrics["comm"]["total"] == metrics["comm"]["written"] - total_bytes_sent += metrics["comm"]["written"] - total_bytes_recvd += metrics["disk"]["total"] - total_bytes_recvd_shuffle += s.total_recvd - - assert total_bytes_recvd_shuffle == total_bytes_sent - - all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O") - for ix, worker in worker_for_mapping.items(): - s = local_shuffle_pool.shuffles[worker] - all_chunks[ix] = await s.get_output_partition(ix) - - finally: - await asyncio.gather(*[s.close() for s in shuffles]) - - old_cs = np.empty(tuple(len(dim) for dim in old), dtype="O") - for ix, arr in old_chunks.items(): - old_cs[ix] = arr - np.testing.assert_array_equal( - concatenate3(old_cs.tolist()), concatenate3(all_chunks.tolist()), strict=True - ) def test_raise_on_fuse_optimization(): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 9f04fa7b3a..02604f0638 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -7,6 +7,7 @@ import random import shutil from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from itertools import count from typing import Any, Mapping from unittest import mock @@ -1107,6 +1108,16 @@ class DataFrameShuffleTestPool(AbstractShuffleTestPool): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._executor = ThreadPoolExecutor(2) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + self._executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + self._executor.shutdown() def new_shuffle( self, @@ -1127,7 +1138,7 @@ def new_shuffle( id=ShuffleId(name), run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), local_address=name, - nthreads=2, + executor=self._executor, rpc=self, scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), @@ -1172,54 +1183,54 @@ async def test_basic_lowlevel_shuffle( assert len(set(worker_for_mapping.values())) == min(n_workers, npartitions) schema = pa.Schema.from_pandas(dfs[0]) - local_shuffle_pool = DataFrameShuffleTestPool() - shuffles = [] - for ix in range(n_workers): - shuffles.append( - local_shuffle_pool.new_shuffle( - name=workers[ix], - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, + with DataFrameShuffleTestPool() as local_shuffle_pool: + shuffles = [] + for ix in range(n_workers): + shuffles.append( + local_shuffle_pool.new_shuffle( + name=workers[ix], + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + ) ) - ) - random.seed(42) - if barrier_first_worker: - barrier_worker = shuffles[0] - else: - barrier_worker = random.sample(shuffles, k=1)[0] + random.seed(42) + if barrier_first_worker: + barrier_worker = shuffles[0] + else: + barrier_worker = random.sample(shuffles, k=1)[0] - try: - for ix, df in enumerate(dfs): - s = shuffles[ix % len(shuffles)] - await s.add_partition(df, ix) + try: + for ix, df in enumerate(dfs): + s = shuffles[ix % len(shuffles)] + await s.add_partition(df, ix) - await barrier_worker.barrier() + await barrier_worker.barrier() - total_bytes_sent = 0 - total_bytes_recvd = 0 - total_bytes_recvd_shuffle = 0 - for s in shuffles: - metrics = s.heartbeat() - assert metrics["comm"]["total"] == metrics["comm"]["written"] - total_bytes_sent += metrics["comm"]["written"] - total_bytes_recvd += metrics["disk"]["total"] - total_bytes_recvd_shuffle += s.total_recvd + total_bytes_sent = 0 + total_bytes_recvd = 0 + total_bytes_recvd_shuffle = 0 + for s in shuffles: + metrics = s.heartbeat() + assert metrics["comm"]["total"] == metrics["comm"]["written"] + total_bytes_sent += metrics["comm"]["written"] + total_bytes_recvd += metrics["disk"]["total"] + total_bytes_recvd_shuffle += s.total_recvd - assert total_bytes_recvd_shuffle == total_bytes_sent + assert total_bytes_recvd_shuffle == total_bytes_sent - all_parts = [] - for part, worker in worker_for_mapping.items(): - s = local_shuffle_pool.shuffles[worker] - all_parts.append(s.get_output_partition(part)) + all_parts = [] + for part, worker in worker_for_mapping.items(): + s = local_shuffle_pool.shuffles[worker] + all_parts.append(s.get_output_partition(part)) - all_parts = await asyncio.gather(*all_parts) + all_parts = await asyncio.gather(*all_parts) - df_after = pd.concat(all_parts) - finally: - await asyncio.gather(*[s.close() for s in shuffles]) - assert len(df_after) == len(pd.concat(dfs)) + df_after = pd.concat(all_parts) + finally: + await asyncio.gather(*[s.close() for s in shuffles]) + assert len(df_after) == len(pd.concat(dfs)) @gen_test() @@ -1247,34 +1258,33 @@ async def test_error_offload(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) schema = pa.Schema.from_pandas(dfs[0]) - local_shuffle_pool = DataFrameShuffleTestPool() - class ErrorOffload(DataFrameShuffleRun): async def offload(self, func, *args): raise RuntimeError("Error during deserialization") - sA = local_shuffle_pool.new_shuffle( - name="A", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - Shuffle=ErrorOffload, - ) - sB = local_shuffle_pool.new_shuffle( - name="B", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - ) - try: - await sB.add_partition(dfs[0], 0) - with pytest.raises(RuntimeError, match="Error during deserialization"): - await sB.add_partition(dfs[1], 1) - await sB.barrier() - finally: - await asyncio.gather(*[s.close() for s in [sA, sB]]) + with DataFrameShuffleTestPool() as local_shuffle_pool: + sA = local_shuffle_pool.new_shuffle( + name="A", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + Shuffle=ErrorOffload, + ) + sB = local_shuffle_pool.new_shuffle( + name="B", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + ) + try: + await sB.add_partition(dfs[0], 0) + with pytest.raises(RuntimeError, match="Error during deserialization"): + await sB.add_partition(dfs[1], 1) + await sB.barrier() + finally: + await asyncio.gather(*[s.close() for s in [sA, sB]]) @gen_test() @@ -1302,33 +1312,32 @@ async def test_error_send(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) schema = pa.Schema.from_pandas(dfs[0]) - local_shuffle_pool = DataFrameShuffleTestPool() - class ErrorSend(DataFrameShuffleRun): async def send(self, *args: Any, **kwargs: Any) -> None: raise RuntimeError("Error during send") - sA = local_shuffle_pool.new_shuffle( - name="A", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - Shuffle=ErrorSend, - ) - sB = local_shuffle_pool.new_shuffle( - name="B", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - ) - try: - await sA.add_partition(dfs[0], 0) - with pytest.raises(RuntimeError, match="Error during send"): - await sA.barrier() - finally: - await asyncio.gather(*[s.close() for s in [sA, sB]]) + with DataFrameShuffleTestPool() as local_shuffle_pool: + sA = local_shuffle_pool.new_shuffle( + name="A", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + Shuffle=ErrorSend, + ) + sB = local_shuffle_pool.new_shuffle( + name="B", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + ) + try: + await sA.add_partition(dfs[0], 0) + with pytest.raises(RuntimeError, match="Error during send"): + await sA.barrier() + finally: + await asyncio.gather(*[s.close() for s in [sA, sB]]) @gen_test() @@ -1356,33 +1365,32 @@ async def test_error_receive(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) schema = pa.Schema.from_pandas(dfs[0]) - local_shuffle_pool = DataFrameShuffleTestPool() - class ErrorReceive(DataFrameShuffleRun): async def receive(self, data: list[tuple[int, bytes]]) -> None: raise RuntimeError("Error during receive") - sA = local_shuffle_pool.new_shuffle( - name="A", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - Shuffle=ErrorReceive, - ) - sB = local_shuffle_pool.new_shuffle( - name="B", - worker_for_mapping=worker_for_mapping, - schema=schema, - directory=tmp_path, - loop=loop_in_thread, - ) - try: - await sB.add_partition(dfs[0], 0) - with pytest.raises(RuntimeError, match="Error during receive"): - await sB.barrier() - finally: - await asyncio.gather(*[s.close() for s in [sA, sB]]) + with DataFrameShuffleTestPool() as local_shuffle_pool: + sA = local_shuffle_pool.new_shuffle( + name="A", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + Shuffle=ErrorReceive, + ) + sB = local_shuffle_pool.new_shuffle( + name="B", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmp_path, + loop=loop_in_thread, + ) + try: + await sB.add_partition(dfs[0], 0) + with pytest.raises(RuntimeError, match="Error during receive"): + await sB.barrier() + finally: + await asyncio.gather(*[s.close() for s in [sA, sB]]) from distributed.worker import DEFAULT_EXTENSIONS