From 26b106176fd83bcbefb818d1aa1437fd72a79829 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 8 Nov 2024 13:06:38 +0100 Subject: [PATCH] Remove redundant methods in P2PBarrierTask (#8924) --- distributed/shuffle/_core.py | 25 ----------------------- distributed/tests/test_client.py | 35 +++++++++++++++----------------- 2 files changed, 16 insertions(+), 44 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 5382252b44..c782d56e38 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -27,7 +27,6 @@ import dask.config from dask._task_spec import Task, _inline_recursively from dask.core import flatten -from dask.sizeof import sizeof from dask.typing import Key from dask.utils import parse_bytes, parse_timedelta @@ -601,41 +600,17 @@ def __init__( super().__init__(key, func, *args, **kwargs) def copy(self) -> P2PBarrierTask: - self.unpack() - assert self.func is not None return P2PBarrierTask( self.key, self.func, *self.args, spec=self.spec, **self.kwargs ) - def __sizeof__(self) -> int: - return super().__sizeof__() + sizeof(self.spec) - def __repr__(self) -> str: return f"P2PBarrierTask({self.key!r})" def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask: - self.unpack() new_args = _inline_recursively(self.args, dsk) new_kwargs = _inline_recursively(self.kwargs, dsk) assert self.func is not None return P2PBarrierTask( self.key, self.func, *new_args, spec=self.spec, **new_kwargs ) - - def __getstate__(self) -> dict[str, Any]: - state = super().__getstate__() - state["spec"] = self.spec - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - super().__setstate__(state) - self.spec = state["spec"] - - def __eq__(self, value: object) -> bool: - if not isinstance(value, P2PBarrierTask): - return False - if not super().__eq__(value): - return False - if self.spec != value.spec: - return False - return True diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8f6f0bcc43..016600b09e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -42,7 +42,6 @@ import dask import dask.bag as db from dask import delayed -from dask._task_spec import no_function_cache from dask.optimization import SubgraphCallable from dask.tokenize import tokenize from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile @@ -4934,29 +4933,27 @@ def __setstate__(self, state): @gen_cluster(client=True) async def test_robust_undeserializable_function(c, s, a, b, monkeypatch): - with no_function_cache(): - - class Foo: - def __getstate__(self): - return 1 + class Foo: + def __getstate__(self): + return 1 - def __setstate__(self, state): - raise MyException("hello") + def __setstate__(self, state): + raise MyException("hello") - def __call__(self, *args): - return 1 + def __call__(self, *args): + return 1 - future = c.submit(Foo(), 1) - await wait(future) - assert future.status == "error" - with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): - await future + future = c.submit(Foo(), 1) + await wait(future) + assert future.status == "error" + with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"): + await future - futures = c.map(inc, range(10)) - results = await c.gather(futures) + futures = c.map(inc, range(10)) + results = await c.gather(futures) - assert results == list(map(inc, range(10))) - assert a.data and b.data + assert results == list(map(inc, range(10))) + assert a.data and b.data @gen_cluster(client=True)