Skip to content

Commit

Permalink
Remove redundant methods in P2PBarrierTask (#8924)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 8, 2024
1 parent 9842ae9 commit 26b1061
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 44 deletions.
25 changes: 0 additions & 25 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import dask.config
from dask._task_spec import Task, _inline_recursively
from dask.core import flatten
from dask.sizeof import sizeof
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta

Expand Down Expand Up @@ -601,41 +600,17 @@ def __init__(
super().__init__(key, func, *args, **kwargs)

def copy(self) -> P2PBarrierTask:
self.unpack()
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *self.args, spec=self.spec, **self.kwargs
)

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.spec)

def __repr__(self) -> str:
return f"P2PBarrierTask({self.key!r})"

def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return P2PBarrierTask(
self.key, self.func, *new_args, spec=self.spec, **new_kwargs
)

def __getstate__(self) -> dict[str, Any]:
state = super().__getstate__()
state["spec"] = self.spec
return state

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.spec = state["spec"]

def __eq__(self, value: object) -> bool:
if not isinstance(value, P2PBarrierTask):
return False
if not super().__eq__(value):
return False
if self.spec != value.spec:
return False
return True
35 changes: 16 additions & 19 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import dask
import dask.bag as db
from dask import delayed
from dask._task_spec import no_function_cache
from dask.optimization import SubgraphCallable
from dask.tokenize import tokenize
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile
Expand Down Expand Up @@ -4934,29 +4933,27 @@ def __setstate__(self, state):

@gen_cluster(client=True)
async def test_robust_undeserializable_function(c, s, a, b, monkeypatch):
with no_function_cache():

class Foo:
def __getstate__(self):
return 1
class Foo:
def __getstate__(self):
return 1

def __setstate__(self, state):
raise MyException("hello")
def __setstate__(self, state):
raise MyException("hello")

def __call__(self, *args):
return 1
def __call__(self, *args):
return 1

future = c.submit(Foo(), 1)
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
await future
future = c.submit(Foo(), 1)
await wait(future)
assert future.status == "error"
with raises_with_cause(RuntimeError, "deserialization", MyException, "hello"):
await future

futures = c.map(inc, range(10))
results = await c.gather(futures)
futures = c.map(inc, range(10))
results = await c.gather(futures)

assert results == list(map(inc, range(10)))
assert a.data and b.data
assert results == list(map(inc, range(10)))
assert a.data and b.data


@gen_cluster(client=True)
Expand Down

0 comments on commit 26b1061

Please sign in to comment.