Skip to content

Commit

Permalink
Merge branch 'main' into improve_subclass_barrier_task
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 7, 2024
2 parents 4c78021 + c38c509 commit 0a8082c
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
4 changes: 4 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,10 @@ def id(self) -> ShuffleId:
def run_id(self) -> int:
return self.run_spec.run_id

@property
def archived(self) -> bool:
return self._archived_by is not None

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"

Expand Down
6 changes: 6 additions & 0 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
Expand Down
17 changes: 12 additions & 5 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,12 @@ async def test_server_comms_mark_active_handlers():
ensure this is properly reflected and released. The sentinel for
"open comm but no active handler" is `None`
"""
in_handler = asyncio.Event()
unblock_handler = asyncio.Event()

async def long_handler(comm):
await asyncio.sleep(0.2)
in_handler.set()
await unblock_handler.wait()
return "done"

async with Server({"wait": long_handler}) as server:
Expand All @@ -1063,9 +1066,9 @@ async def long_handler(comm):

comm = await connect(server.address)
await comm.write({"op": "wait"})
while not server._comms:
await asyncio.sleep(0.05)
await in_handler.wait()
assert set(server._comms.values()) == {"wait"}
unblock_handler.set()

assert server.incoming_comms_open == 1
assert server.incoming_comms_active == 1
Expand Down Expand Up @@ -1106,9 +1109,12 @@ def validate_dict(server):

async with Server({}) as server2:
rpc_ = server2.rpc(server.address)
in_handler.clear()
unblock_handler.clear()
task = asyncio.create_task(rpc_.wait())
while not server.incoming_comms_active:
await asyncio.sleep(0.1)

await in_handler.wait()

assert server.incoming_comms_active == 1
assert server.incoming_comms_open == 1
assert server.outgoing_comms_active == 0
Expand All @@ -1120,6 +1126,7 @@ def validate_dict(server):
assert server2.outgoing_comms_open == 1
validate_dict(server)

unblock_handler.set()
await task
assert server.incoming_comms_active == 0
assert server.incoming_comms_open == 1
Expand Down
4 changes: 3 additions & 1 deletion distributed/tests/test_jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_jupyter_server():


@pytest.mark.slow
def test_jupyter_cli(loop):
def test_jupyter_cli(loop, requires_default_ports):
port = open_port()
with popen(
[
Expand All @@ -56,6 +56,8 @@ def test_jupyter_cli(loop):
"--host",
f"127.0.0.1:{port}",
],
terminate_timeout=120,
kill_timeout=60,
):
with Client(f"127.0.0.1:{port}", loop=loop):
response = requests.get("http://127.0.0.1:8787/jupyter/api/status")
Expand Down

0 comments on commit 0a8082c

Please sign in to comment.