From 69bed31502c3c8e24639d7ca91ff63f4e1d1cbf9 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 31 Mar 2022 15:00:22 -0500 Subject: [PATCH] Support exceptions in MultiComm --- distributed/shuffle/multi_comm.py | 11 +++++-- distributed/shuffle/tests/test_shuffle.py | 36 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/multi_comm.py index 60d3a4a386..b24774fda9 100644 --- a/distributed/shuffle/multi_comm.py +++ b/distributed/shuffle/multi_comm.py @@ -68,6 +68,7 @@ def __init__( self._loop = loop or asyncio.get_event_loop() self._communicate_future = asyncio.create_task(self.communicate()) + self._exception = None @property def queue(self): @@ -89,6 +90,8 @@ def put(self, data: dict): If we're out of space then we block in order to enforce backpressure. """ + if self._exception: + raise self._exception with self.lock: for address, shards in data.items(): size = sum(map(len, shards)) @@ -164,8 +167,12 @@ async def process(self, address: str, shards: list, size: int): # while (time.time() // 5 % 4) == 0: # await asyncio.sleep(0.1) start = time.time() - with self.time("send"): - await self.send(address, [b"".join(shards)]) + try: + with self.time("send"): + await self.send(address, [b"".join(shards)]) + except Exception as e: + self._exception = e + self._done = True stop = time.time() self.diagnostics["avg_size"] = ( 0.95 * self.diagnostics["avg_size"] + 0.05 * size diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 074c6fe895..eca57aab2b 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,6 +74,42 @@ async def test_bad_disk(c, s, a, b): assert a.local_directory in str(e.value) or b.local_directory in str(e.value) +@pytest.mark.slow +@gen_cluster(client=True) +async def test_crashed_worker(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + while not a.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + while not b.extensions["shuffle"].shuffles: + await asyncio.sleep(0.01) + + while ( + len( + [ + ts + for ts in s.tasks.values() + if "shuffle_transfer" in ts.key and ts.state == "memory" + ] + ) + < 3 + ): + await asyncio.sleep(0.01) + await b.close() + + with pytest.raises(Exception) as e: + out = await c.compute(out) + + assert a.address in str(e.value) or b.address in str(e.value) + + @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat()