diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index f4dc261ab1..97e639f623 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -51,8 +51,8 @@ class ShardsBuffer(Generic[ShardType]): bytes_written: int bytes_read: int - _closed: bool - _done: bool + _accepts_input: bool + _inputs_done: bool _exception: None | Exception _tasks: list[asyncio.Task] _shards_available: asyncio.Condition @@ -64,12 +64,12 @@ def __init__( concurrency_limit: int = 2, max_message_size: int = -1, ) -> None: - self._closed = False + self._accepts_input = True self.shards = defaultdict(list) self.sizes = defaultdict(int) self._exception = None self.concurrency_limit = concurrency_limit - self._done = False + self._inputs_done = False self.memory_limiter = memory_limiter self.diagnostics: dict[str, float] = defaultdict(float) self._tasks = [ @@ -105,7 +105,7 @@ async def process(self, id: str, shards: list[pa.Table], size: int) -> None: except Exception as e: self._exception = e - self._done = True + self._inputs_done = True stop = time() self.diagnostics["avg_size"] = ( @@ -131,12 +131,12 @@ def empty(self) -> bool: async def _background_task(self) -> None: def _continue() -> bool: - return bool(self.shards or self._done) + return bool(self.shards or self._inputs_done) while True: async with self._shards_available: await self._shards_available.wait_for(_continue) - if self._done and not self.shards: + if self._inputs_done and not self.shards: break part_id = max(self.sizes, key=self.sizes.__getitem__) if self.max_message_size > 0: @@ -175,7 +175,7 @@ async def write(self, data: dict[str, list[ShardType]]) -> None: if self._exception: raise self._exception - if self._closed or self._done: + if not self._accepts_input or self._inputs_done: raise RuntimeError(f"Trying to put data in closed {self}.") if not data: @@ -215,13 +215,13 @@ async def flush(self) -> None: This closes the buffer such that no new writes are allowed """ async with self._flush_lock: - self._closed = True + self._accepts_input = False async with self._shards_available: self._shards_available.notify_all() await self._shards_available.wait_for( - lambda: not self.shards or self._exception or self._done + lambda: not self.shards or self._exception or self._inputs_done ) - self._done = True + self._inputs_done = True self._shards_available.notify_all() await asyncio.gather(*self._tasks) @@ -238,8 +238,8 @@ async def close(self) -> None: assert not self.bytes_memory, (type(self), self.bytes_memory) for t in self._tasks: t.cancel() - self._closed = True - self._done = True + self._accepts_input = False + self._inputs_done = True self.shards.clear() self.bytes_memory = 0 async with self._shards_available: diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index eef61ac932..71b9ae8b53 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -95,7 +95,7 @@ async def _process(self, id: str, shards: list[pa.Buffer]) -> None: def read(self, id: int | str) -> pa.Table: """Read a complete file back into memory""" self.raise_on_exception() - if not self._done: + if not self._inputs_done: raise RuntimeError("Tried to read from file before done.") parts = [] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 36c2e64853..e252003d4b 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -41,19 +41,28 @@ def shuffle_transfer( npartitions: int, column: str, ) -> None: - _get_worker_extension().add_partition( - input, id, npartitions=npartitions, column=column - ) + try: + _get_worker_extension().add_partition( + input, id, npartitions=npartitions, column=column + ) + except Exception: + raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") def shuffle_unpack( id: ShuffleId, output_partition: int, barrier: object ) -> pd.DataFrame: - return _get_worker_extension().get_output_partition(id, output_partition) + try: + return _get_worker_extension().get_output_partition(id, output_partition) + except Exception: + raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: - return _get_worker_extension().barrier(id) + try: + return _get_worker_extension().barrier(id) + except Exception: + raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") def rearrange_by_column_p2p( diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index e2339727b4..b810c84d9e 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import functools import logging import os import time @@ -15,6 +16,7 @@ from dask.utils import parse_bytes from distributed.core import PooledRPCCall +from distributed.diagnostics.plugin import SchedulerPlugin from distributed.protocol import to_serialize from distributed.shuffle._arrow import ( deserialize_schema, @@ -31,7 +33,7 @@ import pandas as pd import pyarrow as pa - from distributed.scheduler import Scheduler, WorkerState + from distributed.scheduler import Recs, Scheduler, TaskStateState, WorkerState from distributed.worker import Worker ShuffleId = NewType("ShuffleId", str) @@ -40,6 +42,10 @@ logger = logging.getLogger(__name__) +class ShuffleClosedError(RuntimeError): + pass + + class Shuffle: """State for a single active shuffle @@ -115,6 +121,7 @@ def __init__( partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) self.worker_for = pd.Series(worker_for, name="_workers").astype("category") + self.closed = False def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: return dump_batch(batch, file, self.schema) @@ -138,6 +145,7 @@ def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: self.total_recvd = 0 self.start_time = time.time() self._exception: Exception | None = None + self._closed_event = asyncio.Event() def __repr__(self) -> str: return f"" @@ -150,29 +158,20 @@ def time(self, name: str) -> Iterator[None]: self.diagnostics[name] += stop - start async def barrier(self) -> None: - # FIXME: This should restrict communication to only workers - # participating in this specific shuffle. This will not only reduce the - # number of workers we need to contact but will also simplify error - # handling, e.g. if a non-participating worker is not reachable in time + self.raise_if_closed() # TODO: Consider broadcast pinging once when the shuffle starts to warm # up the comm pool on scheduler side - out = await self.broadcast( - msg={"op": "shuffle_inputs_done", "shuffle_id": self.id} - ) - if not self.output_workers.issubset(set(out)): - raise ValueError( - "Some critical workers have left", - set(self.output_workers) - set(out), - ) - # TODO handle errors from workers and scheduler, and cancellation. + await self.broadcast(msg={"op": "shuffle_inputs_done", "shuffle_id": self.id}) async def send(self, address: str, shards: list[bytes]) -> None: + self.raise_if_closed() return await self.rpc(address).shuffle_receive( data=to_serialize(shards), shuffle_id=self.id, ) async def offload(self, func: Callable[..., T], *args: Any) -> T: + self.raise_if_closed() with self.time("cpu"): return await asyncio.get_running_loop().run_in_executor( self.executor, @@ -194,37 +193,40 @@ async def receive(self, data: list[bytes]) -> None: await self._receive(data) async def _receive(self, data: list[bytes]) -> None: - if self._exception: - raise self._exception + self.raise_if_closed() try: self.total_recvd += sum(map(len, data)) - # TODO: Is it actually a good idea to dispatch multiple times instead of - # only once? - # An ugly way of turning these batches back into an arrow table - data = await self.offload( - list_of_buffers_to_table, - data, - self.schema, - ) + groups = await self.offload(self._repartition_buffers, data) + await self._write_to_disk(groups) + except Exception as e: + self._exception = e + raise - groups = await self.offload(split_by_partition, data, self.column) + def _repartition_buffers(self, data: list[bytes]) -> dict[str, list[bytes]]: + table = list_of_buffers_to_table(data, self.schema) + groups = split_by_partition(table, self.column) + assert len(table) == sum(map(len, groups.values())) + del data + return { + k: [batch.serialize() for batch in v.to_batches()] + for k, v in groups.items() + } - assert len(data) == sum(map(len, groups.values())) - del data + async def _write_to_disk(self, data: dict[str, list[bytes]]) -> None: + self.raise_if_closed() + await self._disk_buffer.write(data) - groups = await self.offload( - lambda: { - k: [batch.serialize() for batch in v.to_batches()] - for k, v in groups.items() - } + def raise_if_closed(self) -> None: + if self.closed: + if self._exception: + raise self._exception + raise ShuffleClosedError( + f"Shuffle {self.id} has been closed on {self.local_address}" ) - await self._disk_buffer.write(groups) - except Exception as e: - self._exception = e - raise async def add_partition(self, data: pd.DataFrame) -> None: + self.raise_if_closed() if self.transferred: raise RuntimeError(f"Cannot add more partitions to shuffle {self}") @@ -241,9 +243,14 @@ def _() -> dict[str, list[bytes]]: return out out = await self.offload(_) - await self._comm_buffer.write(out) + await self._write_to_comm(out) + + async def _write_to_comm(self, data: dict[str, list[bytes]]) -> None: + self.raise_if_closed() + await self._comm_buffer.write(data) async def get_output_partition(self, i: int) -> pd.DataFrame: + self.raise_if_closed() assert self.transferred, "`get_output_partition` called before barrier task" assert self.worker_for[i] == self.local_address, ( @@ -258,7 +265,7 @@ async def get_output_partition(self, i: int) -> pd.DataFrame: ), f"No outputs remaining, but requested output partition {i} on {self.local_address}." await self.flush_receive() try: - df = self._disk_buffer.read(i) + df = self._read_from_disk(i) with self.time("cpu"): out = df.to_pandas() except KeyError: @@ -266,31 +273,49 @@ async def get_output_partition(self, i: int) -> pd.DataFrame: self.output_partitions_left -= 1 return out + def _read_from_disk(self, id: int | str) -> pa.Table: + self.raise_if_closed() + return self._disk_buffer.read(id) + async def inputs_done(self) -> None: + self.raise_if_closed() assert not self.transferred, "`inputs_done` called multiple times" self.transferred = True - await self._comm_buffer.flush() + await self._flush_comm() try: self._comm_buffer.raise_on_exception() except Exception as e: self._exception = e raise + async def _flush_comm(self) -> None: + self.raise_if_closed() + await self._comm_buffer.flush() + def done(self) -> bool: return self.transferred and self.output_partitions_left == 0 async def flush_receive(self) -> None: - if self._exception: - raise self._exception + self.raise_if_closed() await self._disk_buffer.flush() async def close(self) -> None: + if self.closed: + await self._closed_event.wait() + return + + self.closed = True await self._comm_buffer.close() await self._disk_buffer.close() try: self.executor.shutdown(cancel_futures=True) except Exception: self.executor.shutdown() + self._closed_event.set() + + def fail(self, exception: Exception) -> None: + if not self.closed: + self._exception = exception class ShuffleWorkerExtension: @@ -305,17 +330,26 @@ class ShuffleWorkerExtension: - collecting instrumentation of ongoing shuffles and route to scheduler/worker """ + worker: Worker + shuffles: dict[ShuffleId, Shuffle] + memory_limiter_comms: ResourceLimiter + memory_limiter_disk: ResourceLimiter + closed: bool + def __init__(self, worker: Worker) -> None: # Attach to worker worker.handlers["shuffle_receive"] = self.shuffle_receive worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done + worker.handlers["shuffle_fail"] = self.shuffle_fail + worker.stream_handlers["shuffle-fail"] = self.shuffle_fail worker.extensions["shuffle"] = self # Initialize - self.worker: Worker = worker - self.shuffles: dict[ShuffleId, Shuffle] = {} - self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) + self.worker = worker + self.shuffles = {} self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) + self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) + self.closed = False # Handlers ########## @@ -349,9 +383,19 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId) -> None: # `get_output_partition` will never be called. # This happens when there are fewer output partitions than workers. assert shuffle._disk_buffer.empty - del self.shuffles[shuffle_id] - logger.critical(f"Shuffle inputs done {shuffle}") + logger.info(f"Shuffle inputs done {shuffle}") await self._register_complete(shuffle) + del self.shuffles[shuffle_id] + + async def shuffle_fail(self, shuffle_id: ShuffleId, message: str) -> None: + try: + shuffle = self.shuffles[shuffle_id] + except KeyError: + return + exception = RuntimeError(message) + shuffle.fail(exception) + await shuffle.close() + del self.shuffles[shuffle_id] def add_partition( self, @@ -379,6 +423,8 @@ async def _barrier(self, shuffle_id: ShuffleId) -> None: async def _register_complete(self, shuffle: Shuffle) -> None: await shuffle.close() + # All the relevant work has already succeeded if we reached this point, + # so we do not need to check if the extension is closed. await self.worker.scheduler.shuffle_register_complete( id=shuffle.id, worker=self.worker.address, @@ -412,7 +458,7 @@ async def _get_shuffle( import pyarrow as pa try: - return self.shuffles[shuffle_id] + shuffle = self.shuffles[shuffle_id] except KeyError: try: result = await self.worker.scheduler.shuffle_get( @@ -422,7 +468,11 @@ async def _get_shuffle( else None, npartitions=npartitions, column=column, + worker=self.worker.address, ) + if result["status"] == "ERROR": + raise RuntimeError(result["message"]) + assert result["status"] == "OK" except KeyError: # Even the scheduler doesn't know about this shuffle # Let's hand this back to the scheduler and let it figure @@ -434,6 +484,10 @@ async def _get_shuffle( raise Reschedule() else: + if self.closed: + raise ShuffleClosedError( + f"{self.__class__.__name__} already closed on {self.worker.address}" + ) if shuffle_id not in self.shuffles: shuffle = Shuffle( column=result["column"], @@ -447,14 +501,31 @@ async def _get_shuffle( nthreads=self.worker.state.nthreads, local_address=self.worker.address, rpc=self.worker.rpc, - broadcast=self.worker.scheduler.broadcast, + broadcast=functools.partial( + self._broadcast_to_participants, shuffle_id + ), memory_limiter_disk=self.memory_limiter_disk, memory_limiter_comms=self.memory_limiter_comms, ) self.shuffles[shuffle_id] = shuffle return self.shuffles[shuffle_id] + else: + if shuffle._exception: + raise shuffle._exception + return shuffle + + async def _broadcast_to_participants(self, id: ShuffleId, msg: dict) -> dict: + participating_workers = ( + await self.worker.scheduler.shuffle_get_participating_workers(id=id) + ) + return await self.worker.scheduler.broadcast( + msg=msg, workers=participating_workers + ) async def close(self) -> None: + assert not self.closed + + self.closed = True while self.shuffles: _, shuffle = self.shuffles.popitem() await shuffle.close() @@ -507,8 +578,7 @@ def get_output_partition( Calling this for a ``shuffle_id`` which is unknown or incomplete is an error. """ - assert shuffle_id in self.shuffles, "Shuffle worker restrictions misbehaving" - shuffle = self.shuffles[shuffle_id] + shuffle = self.get_shuffle(shuffle_id) output = sync(self.worker.loop, shuffle.get_output_partition, output_partition) # key missing if another thread got to it first if shuffle.done() and shuffle_id in self.shuffles: @@ -517,7 +587,7 @@ def get_output_partition( return output -class ShuffleSchedulerExtension: +class ShuffleSchedulerExtension(SchedulerPlugin): """ Shuffle extension for the scheduler @@ -536,12 +606,17 @@ class ShuffleSchedulerExtension: columns: dict[ShuffleId, str] output_workers: dict[ShuffleId, set[str]] completed_workers: dict[ShuffleId, set[str]] + participating_workers: dict[ShuffleId, set[str]] + tombstones: set[ShuffleId] + erred_shuffles: dict[ShuffleId, Exception] + barriers: dict[ShuffleId, str] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler self.scheduler.handlers.update( { "shuffle_get": self.get, + "shuffle_get_participating_workers": self.get_participating_workers, "shuffle_register_complete": self.register_complete, } ) @@ -551,10 +626,28 @@ def __init__(self, scheduler: Scheduler): self.columns = {} self.output_workers = {} self.completed_workers = {} + self.participating_workers = {} + self.tombstones = set() + self.erred_shuffles = {} + self.barriers = {} + self.scheduler.add_plugin(self) + + def shuffle_ids(self) -> set[ShuffleId]: + return set(self.worker_for) def heartbeat(self, ws: WorkerState, data: dict) -> None: for shuffle_id, d in data.items(): - self.heartbeats[shuffle_id][ws.address].update(d) + if shuffle_id in self.output_workers: + self.heartbeats[shuffle_id][ws.address].update(d) + + @classmethod + def barrier_key(cls, shuffle_id: ShuffleId) -> str: + return "shuffle-barrier-" + shuffle_id + + @classmethod + def id_from_key(cls, key: str) -> ShuffleId: + assert "shuffle-barrier-" in key + return ShuffleId(key.replace("shuffle-barrier-", "")) def get( self, @@ -562,7 +655,17 @@ def get( schema: bytes | None, column: str | None, npartitions: int | None, + worker: str, ) -> dict: + + if id in self.tombstones: + return { + "status": "ERROR", + "message": f"Shuffle {id} has already been forgotten", + } + if exception := self.erred_shuffles.get(id): + return {"status": "ERROR", "message": str(exception)} + if id not in self.worker_for: assert schema is not None assert column is not None @@ -570,47 +673,142 @@ def get( workers = list(self.scheduler.workers) output_workers = set() - name = "shuffle-barrier-" + id # TODO single-source task name + name = self.barrier_key(id) + self.barriers[id] = name mapping = {} for ts in self.scheduler.tasks[name].dependents: part = ts.annotations["shuffle"] if ts.worker_restrictions: - worker = list(ts.worker_restrictions)[0] + output_worker = list(ts.worker_restrictions)[0] else: - worker = get_worker_for(part, workers, npartitions) - mapping[part] = worker - output_workers.add(worker) - self.scheduler.set_restrictions({ts.key: {worker}}) + output_worker = get_worker_for(part, workers, npartitions) + mapping[part] = output_worker + output_workers.add(output_worker) + self.scheduler.set_restrictions({ts.key: {output_worker}}) self.worker_for[id] = mapping self.schemas[id] = schema self.columns[id] = column self.output_workers[id] = output_workers self.completed_workers[id] = set() + self.participating_workers[id] = output_workers.copy() + self.participating_workers[id].add(worker) return { + "status": "OK", "worker_for": self.worker_for[id], "column": self.columns[id], "schema": self.schemas[id], "output_workers": self.output_workers[id], } + def get_participating_workers(self, id: ShuffleId) -> list[str]: + return list(self.participating_workers[id]) + + async def remove_worker(self, scheduler: Scheduler, worker: str) -> None: + affected_shuffles = set() + broadcasts = [] + from time import time + + recs: Recs = {} + stimulus_id = f"shuffle-failed-worker-left-{time()}" + barriers = [] + for shuffle_id, shuffle_workers in self.participating_workers.items(): + if worker not in shuffle_workers: + continue + exception = RuntimeError( + f"Worker {worker} left during active shuffle {shuffle_id}" + ) + self.erred_shuffles[shuffle_id] = exception + contact_workers = shuffle_workers.copy() + contact_workers.discard(worker) + affected_shuffles.add(shuffle_id) + name = self.barriers[shuffle_id] + barrier_task = self.scheduler.tasks.get(name) + if barrier_task: + barriers.append(barrier_task) + broadcasts.append( + scheduler.broadcast( + msg={ + "op": "shuffle_fail", + "message": str(exception), + "shuffle_id": shuffle_id, + }, + workers=list(contact_workers), + ) + ) + + results = await asyncio.gather(*broadcasts, return_exceptions=True) + for barrier_task in barriers: + if barrier_task.state == "memory": + for dt in barrier_task.dependents: + if worker not in dt.worker_restrictions: + continue + dt.worker_restrictions.clear() + recs.update({dt.key: "waiting"}) + # TODO: Do we need to handle other states? + self.scheduler.transitions(recs, stimulus_id=stimulus_id) + + # Assumption: No new shuffle tasks scheduled on the worker + # + no existing tasks anymore + # All task-finished/task-errer are queued up in batched stream + + exceptions = [result for result in results if isinstance(result, Exception)] + if exceptions: + # TODO: Do we need to handle errors here? + raise RuntimeError(exceptions) + + def transition( + self, + key: str, + start: TaskStateState, + finish: TaskStateState, + *args: Any, + **kwargs: Any, + ) -> None: + if finish != "forgotten": + return + if key not in self.barriers.values(): + + return + + shuffle_id = ShuffleSchedulerExtension.id_from_key(key) + participating_workers = self.participating_workers[shuffle_id] + worker_msgs = { + worker: [ + { + "op": "shuffle-fail", + "shuffle_id": shuffle_id, + "message": f"Shuffle {shuffle_id} forgotten", + } + ] + for worker in participating_workers + } + self._clean_on_scheduler(shuffle_id) + self.scheduler.send_all({}, worker_msgs) + def register_complete(self, id: ShuffleId, worker: str) -> None: """Learn from a worker that it has completed all reads of a shuffle""" + if exception := self.erred_shuffles.get(id): + raise exception if id not in self.completed_workers: logger.info("Worker shuffle reported complete after shuffle was removed") return self.completed_workers[id].add(worker) - if self.output_workers[id].issubset(self.completed_workers[id]): - del self.worker_for[id] - del self.schemas[id] - del self.columns[id] - del self.output_workers[id] - del self.completed_workers[id] - with contextlib.suppress(KeyError): - del self.heartbeats[id] + def _clean_on_scheduler(self, id: ShuffleId) -> None: + self.tombstones.add(id) + del self.worker_for[id] + del self.schemas[id] + del self.columns[id] + del self.output_workers[id] + del self.completed_workers[id] + del self.participating_workers[id] + self.erred_shuffles.pop(id, None) + del self.barriers[id] + with contextlib.suppress(KeyError): + del self.heartbeats[id] def get_worker_for(output_partition: int, workers: list[str], npartitions: int) -> str: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 96e70f5efd..7f397c8c1f 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -6,21 +6,25 @@ import random import shutil from collections import defaultdict -from typing import Any +from typing import Any, Mapping +from unittest import mock import pandas as pd import pytest import dask import dask.dataframe as dd -from dask.distributed import Worker +from dask.distributed import Event, Nanny, Worker from dask.utils import stringify from distributed.core import PooledRPCCall +from distributed.scheduler import Scheduler +from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._shuffle_extension import ( Shuffle, ShuffleId, + ShuffleWorkerExtension, dump_batch, get_worker_for, list_of_buffers_to_table, @@ -28,28 +32,43 @@ split_by_partition, split_by_worker, ) -from distributed.utils_test import gen_cluster, gen_test +from distributed.utils import Deadline +from distributed.utils_test import gen_cluster, gen_test, wait_for_state +from distributed.worker_state_machine import TaskState as WorkerTaskState pa = pytest.importorskip("pyarrow") -def clean_worker(worker): +async def clean_worker( + worker: Worker, interval: float = 0.01, timeout: int | None = None +) -> None: """Assert that the worker has no shuffle state""" - assert not worker.extensions["shuffle"].shuffles + deadline = Deadline.after(timeout) + extension = worker.extensions["shuffle"] + + while extension.shuffles and not deadline.expired: + await asyncio.sleep(interval) for dirpath, dirnames, filenames in os.walk(worker.local_directory): assert "shuffle" not in dirpath for fn in dirnames + filenames: assert "shuffle" not in fn -def clean_scheduler(scheduler): +async def clean_scheduler( + scheduler: Scheduler, interval: float = 0.01, timeout: int | None = None +) -> None: """Assert that the scheduler has no shuffle state""" - assert not scheduler.extensions["shuffle"].worker_for - assert not scheduler.extensions["shuffle"].heartbeats - assert not scheduler.extensions["shuffle"].schemas - assert not scheduler.extensions["shuffle"].columns - assert not scheduler.extensions["shuffle"].output_workers - assert not scheduler.extensions["shuffle"].completed_workers + deadline = Deadline.after(timeout) + extension = scheduler.extensions["shuffle"] + while extension.output_workers and not deadline.expired: + await asyncio.sleep(interval) + assert not extension.worker_for + assert not extension.heartbeats + assert not extension.schemas + assert not extension.columns + assert not extension.output_workers + assert not extension.completed_workers + assert not extension.participating_workers @gen_cluster(client=True) @@ -66,9 +85,9 @@ async def test_basic_integration(c, s, a, b): y = await y assert x == y - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) @gen_cluster(client=True) @@ -86,9 +105,9 @@ async def test_concurrent(c, s, a, b): y = await y assert x == y - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) @gen_cluster(client=True) @@ -102,6 +121,7 @@ async def test_bad_disk(c, s, a, b): ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") out = out.persist() + shuffle_id = await get_shuffle_id(s) while not a.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) shutil.rmtree(a.local_directory) @@ -109,23 +129,246 @@ async def test_bad_disk(c, s, a, b): while not b.extensions["shuffle"].shuffles: await asyncio.sleep(0.01) shutil.rmtree(b.local_directory) - with pytest.raises(FileNotFoundError) as e: + with pytest.raises(RuntimeError, match=f"shuffle_transfer failed .* {shuffle_id}"): out = await c.compute(out) - assert os.path.split(a.local_directory)[-1] in str(e.value) or os.path.split( - b.local_directory - )[-1] in str(e.value) + await c.close() + # await clean_worker(a) + # await clean_worker(b) + # await clean_scheduler(s) - # clean_worker(a) # TODO: clean up on exception - # clean_worker(b) # TODO: clean up on exception - # clean_scheduler(s) + +async def wait_until_worker_has_tasks( + prefix: str, worker: str, count: int, scheduler: Scheduler, interval: float = 0.01 +) -> None: + ws = scheduler.workers[worker] + while ( + len( + [ + key + for key, ts in scheduler.tasks.items() + if prefix in key and ts.state == "memory" and ws in ts.who_has + ] + ) + < count + ): + await asyncio.sleep(interval) + + +async def wait_for_tasks_in_state( + prefix: str, + state: str, + count: int, + dask_worker: Worker | Scheduler, + interval: float = 0.01, +) -> None: + tasks: Mapping[str, SchedulerTaskState | WorkerTaskState] + + if isinstance(dask_worker, Worker): + tasks = dask_worker.state.tasks + elif isinstance(dask_worker, Scheduler): + tasks = dask_worker.tasks + else: + raise TypeError(dask_worker) + + while ( + len([key for key, ts in tasks.items() if prefix in key and ts.state == state]) + < count + ): + await asyncio.sleep(interval) + + +async def get_shuffle_id(scheduler: Scheduler) -> ShuffleId: + scheduler_extension = scheduler.extensions["shuffle"] + while not scheduler_extension.shuffle_ids(): + await asyncio.sleep(0.01) + shuffle_ids = scheduler_extension.shuffle_ids() + assert len(shuffle_ids) == 1 + return next(iter(shuffle_ids)) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_closed_worker_during_transfer(c, s, a, b): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + await b.close() + + with pytest.raises(RuntimeError): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) -@pytest.mark.skip @pytest.mark.slow -@gen_cluster(client=True) -async def test_crashed_worker(c, s, a, b): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_crashed_worker_during_transfer(c, s, a): + async with Nanny(s.address, nthreads=1) as n: + killed_worker_address = n.worker_address + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_until_worker_has_tasks( + "shuffle-transfer", killed_worker_address, 1, s + ) + await n.process.process.kill() + + with pytest.raises(RuntimeError): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +# TODO: Deduplicate instead of failing: distributed#7324 +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_closed_input_only_worker_during_transfer(c, s, a, b): + def mock_get_worker_for( + output_partition: int, workers: list[str], npartitions: int + ) -> str: + return a.address + + with mock.patch( + "distributed.shuffle._shuffle_extension.get_worker_for", mock_get_worker_for + ): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-05-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) + await b.close() + + with pytest.raises(RuntimeError): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + +# TODO: Deduplicate instead of failing: distributed#7324 +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)], clean_kwargs={"processes": False}) +async def test_crashed_input_only_worker_during_transfer(c, s, a): + def mock_get_worker_for( + output_partition: int, workers: list[str], npartitions: int + ) -> str: + return a.address + + with mock.patch( + "distributed.shuffle._shuffle_extension.get_worker_for", mock_get_worker_for + ): + async with Nanny(s.address, nthreads=1) as n: + killed_worker_address = n.worker_address + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_until_worker_has_tasks( + "shuffle-transfer", n.worker_address, 1, s + ) + await n.process.process.kill() + + with pytest.raises(RuntimeError): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): + with dask.annotate(workers=[w1.address, w2.address], allow_other_workers=False): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-02-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w1) + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w2) + await w3.close() + + await c.compute(out) + del out + + await clean_worker(w1) + await clean_worker(w2) + await clean_worker(w3) + await clean_scheduler(s) + + +class BlockedInputsDoneShuffle(Shuffle): + def __init__( + self, + worker_for, + output_workers, + column, + schema, + id, + local_address, + directory, + nthreads, + rpc, + broadcast, + memory_limiter_disk, + memory_limiter_comms, + ): + super().__init__( + worker_for, + output_workers, + column, + schema, + id, + local_address, + directory, + nthreads, + rpc, + broadcast, + memory_limiter_disk, + memory_limiter_comms, + ) + self.in_inputs_done = asyncio.Event() + self.block_inputs_done = asyncio.Event() + + async def inputs_done(self) -> None: + self.in_inputs_done.set() + await self.block_inputs_done.wait() + await super().inputs_done() + + +@mock.patch("distributed.shuffle._shuffle_extension.Shuffle", BlockedInputsDoneShuffle) +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_closed_worker_during_barrier(c, s, a, b): df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -134,34 +377,259 @@ async def test_crashed_worker(c, s, a, b): ) out = dd.shuffle.shuffle(df, "x", shuffle="p2p") out = out.persist() + shuffle_id = await get_shuffle_id(s) + barrier_key = s.extensions["shuffle"].barrier_key(shuffle_id) + await wait_for_state(barrier_key, "processing", s) + shuffleA = a.extensions["shuffle"].shuffles[shuffle_id] + shuffleB = b.extensions["shuffle"].shuffles[shuffle_id] + await shuffleA.in_inputs_done.wait() + await shuffleB.in_inputs_done.wait() + + ts = s.tasks[barrier_key] + processing_worker = a if ts.processing_on.address == a.address else b + if processing_worker == a: + close_worker = a + alive_shuffle = shuffleB - while ( - len( - [ - ts - for ts in s.tasks.values() - if "shuffle_transfer" in ts.key and ts.state == "memory" - ] + else: + close_worker, alive_worker = b, a + alive_shuffle = shuffleA + await close_worker.close() + + alive_shuffle.block_inputs_done.set() + + with pytest.raises(RuntimeError): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + + +@mock.patch("distributed.shuffle._shuffle_extension.Shuffle", BlockedInputsDoneShuffle) +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_closed_other_worker_during_barrier(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + shuffle_id = await get_shuffle_id(s) + + barrier_key = s.extensions["shuffle"].barrier_key(shuffle_id) + await wait_for_state(barrier_key, "processing", s, interval=0) + + shuffleA = a.extensions["shuffle"].shuffles[shuffle_id] + shuffleB = b.extensions["shuffle"].shuffles[shuffle_id] + await shuffleA.in_inputs_done.wait() + await shuffleB.in_inputs_done.wait() + + ts = s.tasks[barrier_key] + processing_worker = a if ts.processing_on.address == a.address else b + if processing_worker == a: + close_worker = b + alive_shuffle = shuffleA + + else: + close_worker = a + alive_shuffle = shuffleB + await close_worker.close() + + alive_shuffle.block_inputs_done.set() + + with pytest.raises(RuntimeError, match="shuffle_barrier failed"): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + + +@pytest.mark.slow +@mock.patch("distributed.shuffle._shuffle_extension.Shuffle", BlockedInputsDoneShuffle) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_crashed_other_worker_during_barrier(c, s, a): + async with Nanny(s.address, nthreads=1) as n: + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", ) - < 3 - ): - await asyncio.sleep(0.01) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + shuffle_id = await get_shuffle_id(s) + barrier_key = s.extensions["shuffle"].barrier_key(shuffle_id) + # Ensure that barrier is not executed on the nanny + s.set_restrictions({barrier_key: {a.address}}) + await wait_for_state(barrier_key, "processing", s, interval=0) + shuffle = a.extensions["shuffle"].shuffles[shuffle_id] + await shuffle.in_inputs_done.wait() + await n.process.process.kill() + shuffle.block_inputs_done.set() + + with pytest.raises(RuntimeError, match="shuffle"): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_closed_worker_during_unpack(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_for_tasks_in_state("shuffle-p2p", "memory", 1, b) await b.close() - with pytest.raises(Exception) as e: + with pytest.raises(RuntimeError): out = await c.compute(out) - assert b.address in str(e.value) + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_crashed_worker_during_unpack(c, s, a): + async with Nanny(s.address, nthreads=2) as n: + killed_worker_address = n.worker_address + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_until_worker_has_tasks("shuffle-p2p", killed_worker_address, 1, s) + await n.process.process.kill() + with pytest.raises( + RuntimeError, + ): + out = await c.compute(out) + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +class BlockedRegisterCompleteShuffleWorkerExtension(ShuffleWorkerExtension): + def __init__(self, worker: Worker) -> None: + super().__init__(worker) + self.in_register_complete = asyncio.Event() + self.block_register_complete = asyncio.Event() + + async def _register_complete(self, shuffle: Shuffle) -> None: + self.in_register_complete.set() + await super()._register_complete(shuffle) + await self.block_register_complete.wait() + + +@pytest.mark.parametrize("kill_barrier", [True, False]) +@gen_cluster( + client=True, + worker_kwargs={ + "extensions": {"shuffle": BlockedRegisterCompleteShuffleWorkerExtension} + }, + nthreads=[("", 1)] * 2, +) +async def test_closed_worker_during_final_register_complete(c, s, a, b, kill_barrier): + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + shuffle_ext_a = a.extensions["shuffle"] + shuffle_ext_b = b.extensions["shuffle"] + await shuffle_ext_a.in_register_complete.wait() + await shuffle_ext_b.in_register_complete.wait() + + shuffle_id = await get_shuffle_id(s) + barrier_key = s.extensions["shuffle"].barrier_key(shuffle_id) + # TODO: properly parametrize over kill_barrier + if barrier_key in b.state.tasks: + shuffle_ext_a.block_register_complete.set() + while a.state.executing: + await asyncio.sleep(0.01) + b.batched_stream.abort() + else: + shuffle_ext_b.block_register_complete.set() + while b.state.executing: + await asyncio.sleep(0.01) + a.batched_stream.abort() + + with pytest.raises(RuntimeError, match="shuffle_unpack failed"): + out = await c.compute(out) + + shuffle_ext_b.block_register_complete.set() + + # something is holding on to refs of out s.t. we cannot release the futures. + # The shuffle will only be cleaned up once the tasks area released + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) + + +@gen_cluster( + client=True, + worker_kwargs={ + "extensions": {"shuffle": BlockedRegisterCompleteShuffleWorkerExtension} + }, + nthreads=[("", 1)] * 2, +) +async def test_closed_other_worker_during_final_register_complete(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + shuffle_ext_a = a.extensions["shuffle"] + shuffle_ext_b = b.extensions["shuffle"] + await shuffle_ext_a.in_register_complete.wait() + await shuffle_ext_b.in_register_complete.wait() + + shuffle_ext_b.block_register_complete.set() + while b.state.executing: + await asyncio.sleep(0.01) + await b.close() + + shuffle_ext_a.block_register_complete.set() + with pytest.raises(RuntimeError): + out = await c.compute(out) - # clean_worker(a) # TODO: clean up on exception - # clean_worker(b) - # clean_scheduler(s) + await c.close() + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() - clean_scheduler(s) + await clean_scheduler(s) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -178,9 +646,10 @@ async def test_heartbeat(c, s, a, b): assert s.extensions["shuffle"].heartbeats.values() await out - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + del out + await clean_scheduler(s) def test_processing_chain(): @@ -272,9 +741,10 @@ async def test_head(c, s, a, b): assert list(os.walk(a.local_directory)) == a_files # cleaned up files? assert list(os.walk(b.local_directory)) == b_files - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + del out + await clean_scheduler(s) def test_split_by_worker(): @@ -286,6 +756,24 @@ def test_split_by_worker(): s = pd.Series(worker_for, name="_worker").astype("category") +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_clean_after_forgotten_early(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + del out + await clean_worker(a, timeout=2) + await clean_worker(b, timeout=2) + await clean_scheduler(s, timeout=2) + + @gen_cluster(client=True) async def test_tail(c, s, a, b): df = dask.datasets.timeseries( @@ -305,11 +793,12 @@ async def test_tail(c, s, a, b): assert len(s.tasks) < ntasks_full del partial - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) +@pytest.mark.xfail(reason="Tombstone prohibits multiple calls to head") @gen_cluster(client=True, nthreads=[("127.0.0.1", 4)] * 2) async def test_repeat(c, s, a, b): df = dask.datasets.timeseries( @@ -321,21 +810,113 @@ async def test_repeat(c, s, a, b): out = dd.shuffle.shuffle(df, "x", shuffle="p2p") await c.compute(out.head(compute=False)) - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a, timeout=2) + await clean_worker(b, timeout=2) + await clean_scheduler(s, timeout=2) await c.compute(out.tail(compute=False)) - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a, timeout=2) + await clean_worker(b, timeout=2) + await clean_scheduler(s, timeout=2) await c.compute(out.head(compute=False)) - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a, timeout=2) + await clean_worker(b, timeout=2) + await clean_scheduler(s, timeout=2) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_crashed_worker_after_shuffle(c, s, a): + in_event = Event() + block_event = Event() + + @dask.delayed + def block(df, in_event, block_event): + in_event.set() + block_event.wait() + return df + + async with Nanny(s.address, nthreads=1) as n: + df = df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="100 s", + seed=42, + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + in_event = Event() + block_event = Event() + with dask.annotate(workers=[n.worker_address], allow_other_workers=True): + out = block(out, in_event, block_event) + fut = c.compute(out) + + await in_event.wait() + await n.process.process.kill() + block_event.set() + with pytest.raises(RuntimeError): + await fut + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_crashed_worker_after_shuffle_persisted(c, s, a): + async with Nanny(s.address, nthreads=1) as n: + df = df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + seed=42, + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + out = out.persist() + await out + + await n.process.process.kill() + + with pytest.raises(RuntimeError): + await c.compute(out.sum()) + + await c.close() + await clean_worker(a) + await clean_scheduler(s) + + +@pytest.mark.xfail(reason="Tombstone prohibits multiple calls to head") +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_closed_worker_between_repeats(c, s, w1, w2, w3): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="100 s", + seed=42, + ) + out = dd.shuffle.shuffle(df, "x", shuffle="p2p") + await c.compute(out.head(compute=False)) + + await clean_worker(w1) + await clean_worker(w2) + await clean_worker(w3) + await clean_scheduler(s) + + await w3.close() + await c.compute(out.tail(compute=False)) + + await clean_worker(w1) + await clean_worker(w2) + await clean_scheduler(s) + + await w2.close() + await c.compute(out.head(compute=False)) + await clean_worker(w1) + await clean_scheduler(s) @gen_cluster(client=True) @@ -353,12 +934,13 @@ async def test_new_worker(c, s, a, b): async with Worker(s.address) as w: - out = await c.compute(persisted) + await c.compute(persisted) - clean_worker(a) - clean_worker(b) - clean_worker(w) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + await clean_worker(w) + del persisted + await clean_scheduler(s) @gen_cluster(client=True) @@ -382,9 +964,9 @@ async def test_multi(c, s, a, b): out = await c.compute(out.size) assert out - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) @gen_cluster(client=True) @@ -411,10 +993,9 @@ async def test_restrictions(c, s, a, b): assert all(stringify(key) in a.data for key in y.__dask_keys__()) -@pytest.mark.xfail(reason="Don't clean up forgotten shuffles") +@pytest.mark.skip(reason="Fails on CI with unknown cause") @gen_cluster(client=True) async def test_delete_some_results(c, s, a, b): - # FIXME: This works but not reliably. It fails every ~25% of runs df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -428,10 +1009,10 @@ async def test_delete_some_results(c, s, a, b): x = x.partitions[: x.npartitions // 2].persist() await c.compute(x.size) - - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + del x + await clean_worker(a) + await clean_worker(b) + await clean_scheduler(s) @gen_cluster(client=True) @@ -452,9 +1033,11 @@ async def test_add_some_results(c, s, a, b): await c.compute(x.size) - clean_worker(a) - clean_worker(b) - clean_scheduler(s) + await clean_worker(a) + await clean_worker(b) + del x + del y + await clean_scheduler(s) @pytest.mark.slow @@ -472,10 +1055,7 @@ async def test_clean_after_close(c, s, a, b): await asyncio.sleep(0.01) await a.close() - clean_worker(a) - - # clean_scheduler(s) # TODO - # clean_worker(b) # TODO + await clean_worker(a) class PooledRPCShuffle(PooledRPCCall): diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index ae7d93d039..050ca2560e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1477,6 +1477,10 @@ def _purge_state(self, ts: TaskState) -> None: ts.next = None ts.done = False ts.coming_from = None + ts.exception = None + ts.traceback = None + ts.traceback_text = "" + ts.traceback_text = "" self.missing_dep_flight.discard(ts) self.ready.discard(ts)