Skip to content

Commit

Permalink
Support exceptions in MultiComm
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Mar 31, 2022
1 parent 9efb27c commit 69bed31
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
11 changes: 9 additions & 2 deletions distributed/shuffle/multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
self._loop = loop or asyncio.get_event_loop()

self._communicate_future = asyncio.create_task(self.communicate())
self._exception = None

@property
def queue(self):
Expand All @@ -89,6 +90,8 @@ def put(self, data: dict):
If we're out of space then we block in order to enforce backpressure.
"""
if self._exception:
raise self._exception
with self.lock:
for address, shards in data.items():
size = sum(map(len, shards))
Expand Down Expand Up @@ -164,8 +167,12 @@ async def process(self, address: str, shards: list, size: int):
# while (time.time() // 5 % 4) == 0:
# await asyncio.sleep(0.1)
start = time.time()
with self.time("send"):
await self.send(address, [b"".join(shards)])
try:
with self.time("send"):
await self.send(address, [b"".join(shards)])
except Exception as e:
self._exception = e
self._done = True
stop = time.time()
self.diagnostics["avg_size"] = (
0.95 * self.diagnostics["avg_size"] + 0.05 * size
Expand Down
36 changes: 36 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,42 @@ async def test_bad_disk(c, s, a, b):
assert a.local_directory in str(e.value) or b.local_directory in str(e.value)


@pytest.mark.slow
@gen_cluster(client=True)
async def test_crashed_worker(c, s, a, b):

df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
out = out.persist()
while not a.extensions["shuffle"].shuffles:
await asyncio.sleep(0.01)
while not b.extensions["shuffle"].shuffles:
await asyncio.sleep(0.01)

while (
len(
[
ts
for ts in s.tasks.values()
if "shuffle_transfer" in ts.key and ts.state == "memory"
]
)
< 3
):
await asyncio.sleep(0.01)
await b.close()

with pytest.raises(Exception) as e:
out = await c.compute(out)

assert a.address in str(e.value) or b.address in str(e.value)


@gen_cluster(client=True)
async def test_heartbeat(c, s, a, b):
await a.heartbeat()
Expand Down

0 comments on commit 69bed31

Please sign in to comment.