Skip to content

Commit

Permalink
Share thread pool among P2P shuffle runs (#7621)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Mar 10, 2023
1 parent ab38b4c commit 8a3b44c
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 178 deletions.
33 changes: 17 additions & 16 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
output_workers: set[str],
local_address: str,
directory: str,
nthreads: int,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
Expand All @@ -73,7 +73,7 @@ def __init__(
self.run_id = run_id
self.output_workers = output_workers
self.local_address = local_address
self.executor = ThreadPoolExecutor(nthreads)
self.executor = executor
self.rpc = rpc
self.scheduler = scheduler
self.closed = False
Expand Down Expand Up @@ -191,10 +191,6 @@ async def close(self) -> None:
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
try:
self.executor.shutdown(cancel_futures=True)
except Exception: # pragma: no cover
self.executor.shutdown()
self._closed_event.set()

def fail(self, exception: Exception) -> None:
Expand Down Expand Up @@ -255,8 +251,8 @@ class ArrayRechunkRun(ShuffleRun[ArrayRechunkShardID, NIndex, "np.ndarray"]):
The local address this Shuffle can be contacted by using `rpc`.
directory:
The scratch directory to buffer data in.
nthreads:
How many background threads to use for compute.
executor:
Thread pool to use for offloading compute.
loop:
The event loop.
rpc:
Expand All @@ -280,7 +276,7 @@ def __init__(
run_id: int,
local_address: str,
directory: str,
nthreads: int,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
Expand All @@ -294,7 +290,7 @@ def __init__(
output_workers=output_workers,
local_address=local_address,
directory=directory,
nthreads=nthreads,
executor=executor,
rpc=rpc,
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
Expand Down Expand Up @@ -410,8 +406,8 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]):
The local address this Shuffle can be contacted by using `rpc`.
directory:
The scratch directory to buffer data in.
nthreads:
How many background threads to use for compute.
executor:
Thread pool to use for offloading compute.
loop:
The event loop.
rpc:
Expand All @@ -435,7 +431,7 @@ def __init__(
run_id: int,
local_address: str,
directory: str,
nthreads: int,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
Expand All @@ -449,7 +445,7 @@ def __init__(
output_workers=output_workers,
local_address=local_address,
directory=directory,
nthreads=nthreads,
executor=executor,
rpc=rpc,
scheduler=scheduler,
memory_limiter_comms=memory_limiter_comms,
Expand Down Expand Up @@ -569,6 +565,7 @@ def __init__(self, worker: Worker) -> None:
self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB"))
self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB"))
self.closed = False
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)

# Handlers
##########
Expand Down Expand Up @@ -815,7 +812,7 @@ async def _(
self.worker.local_directory,
f"shuffle-{shuffle_id}-{result['run_id']}",
),
nthreads=self.worker.state.nthreads,
executor=self._executor,
local_address=self.worker.address,
rpc=self.worker.rpc,
scheduler=self.worker.scheduler,
Expand All @@ -834,7 +831,7 @@ async def _(
self.worker.local_directory,
f"shuffle-{shuffle_id}-{result['run_id']}",
),
nthreads=self.worker.state.nthreads,
executor=self._executor,
local_address=self.worker.address,
rpc=self.worker.rpc,
scheduler=self.worker.scheduler,
Expand All @@ -855,6 +852,10 @@ async def close(self) -> None:
_, shuffle = self.shuffles.popitem()
await shuffle.close()
self._runs.remove(shuffle)
try:
self._executor.shutdown(cancel_futures=True)
except Exception: # pragma: no cover
self._executor.shutdown()

#############################
# Methods for worker thread #
Expand Down
116 changes: 65 additions & 51 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

from concurrent.futures import ThreadPoolExecutor

import dask
from dask.array.core import concatenate3
from dask.array.rechunk import normalize_chunks, rechunk
Expand All @@ -26,6 +28,16 @@
class ArrayRechunkTestPool(AbstractShuffleTestPool):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._executor = ThreadPoolExecutor(2)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
try:
self._executor.shutdown(cancel_futures=True)
except Exception: # pragma: no cover
self._executor.shutdown()

def new_shuffle(
self,
Expand All @@ -47,7 +59,7 @@ def new_shuffle(
id=ShuffleId(name),
run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator),
local_address=name,
nthreads=2,
executor=self._executor,
rpc=self,
scheduler=self,
memory_limiter_disk=ResourceLimiter(10000000),
Expand Down Expand Up @@ -85,58 +97,60 @@ async def test_lowlevel_rechunk(

assert len(set(worker_for_mapping.values())) == min(n_workers, len(new_indices))

local_shuffle_pool = ArrayRechunkTestPool()
shuffles = []
for i in range(n_workers):
shuffles.append(
local_shuffle_pool.new_shuffle(
name=workers[i],
worker_for_mapping=worker_for_mapping,
old=old,
new=new,
directory=tmp_path,
loop=loop_in_thread,
with ArrayRechunkTestPool() as local_shuffle_pool:
shuffles = []
for i in range(n_workers):
shuffles.append(
local_shuffle_pool.new_shuffle(
name=workers[i],
worker_for_mapping=worker_for_mapping,
old=old,
new=new,
directory=tmp_path,
loop=loop_in_thread,
)
)
random.seed(42)
if barrier_first_worker:
barrier_worker = shuffles[0]
else:
barrier_worker = random.sample(shuffles, k=1)[0]

try:
for i, (idx, arr) in enumerate(old_chunks.items()):
s = shuffles[i % len(shuffles)]
await s.add_partition(arr, idx)

await barrier_worker.barrier()

total_bytes_sent = 0
total_bytes_recvd = 0
total_bytes_recvd_shuffle = 0
for s in shuffles:
metrics = s.heartbeat()
assert metrics["comm"]["total"] == metrics["comm"]["written"]
total_bytes_sent += metrics["comm"]["written"]
total_bytes_recvd += metrics["disk"]["total"]
total_bytes_recvd_shuffle += s.total_recvd

assert total_bytes_recvd_shuffle == total_bytes_sent

all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O")
for ix, worker in worker_for_mapping.items():
s = local_shuffle_pool.shuffles[worker]
all_chunks[ix] = await s.get_output_partition(ix)

finally:
await asyncio.gather(*[s.close() for s in shuffles])

old_cs = np.empty(tuple(len(dim) for dim in old), dtype="O")
for ix, arr in old_chunks.items():
old_cs[ix] = arr
np.testing.assert_array_equal(
concatenate3(old_cs.tolist()),
concatenate3(all_chunks.tolist()),
strict=True,
)
random.seed(42)
if barrier_first_worker:
barrier_worker = shuffles[0]
else:
barrier_worker = random.sample(shuffles, k=1)[0]

try:
for i, (idx, arr) in enumerate(old_chunks.items()):
s = shuffles[i % len(shuffles)]
await s.add_partition(arr, idx)

await barrier_worker.barrier()

total_bytes_sent = 0
total_bytes_recvd = 0
total_bytes_recvd_shuffle = 0
for s in shuffles:
metrics = s.heartbeat()
assert metrics["comm"]["total"] == metrics["comm"]["written"]
total_bytes_sent += metrics["comm"]["written"]
total_bytes_recvd += metrics["disk"]["total"]
total_bytes_recvd_shuffle += s.total_recvd

assert total_bytes_recvd_shuffle == total_bytes_sent

all_chunks = np.empty(tuple(len(dim) for dim in new), dtype="O")
for ix, worker in worker_for_mapping.items():
s = local_shuffle_pool.shuffles[worker]
all_chunks[ix] = await s.get_output_partition(ix)

finally:
await asyncio.gather(*[s.close() for s in shuffles])

old_cs = np.empty(tuple(len(dim) for dim in old), dtype="O")
for ix, arr in old_chunks.items():
old_cs[ix] = arr
np.testing.assert_array_equal(
concatenate3(old_cs.tolist()), concatenate3(all_chunks.tolist()), strict=True
)


def test_raise_on_fuse_optimization():
Expand Down
Loading

0 comments on commit 8a3b44c

Please sign in to comment.