Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Task class instead of tuple #8797

Merged
merged 23 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jobs:
# Increase this value to reset cache if
# continuous_integration/environment-${{ matrix.environment }}.yaml has not
# changed. See also same variable in .pre-commit-config.yaml
CACHE_NUMBER: 2
CACHE_NUMBER: 0
id: cache

- name: Update environment
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-mindeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies:
# Distributed depends on the latest version of Dask
- pip
- pip:
- git+https://github.com/dask/dask
fjetter marked this conversation as resolved.
Show resolved Hide resolved
- git+https://github.com/dask/dask
# test dependencies
- pytest
- pytest-cov
Expand Down
5 changes: 3 additions & 2 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@

from tornado.ioloop import IOLoop

from dask._task_spec import TaskRef

from distributed.client import Future
from distributed.protocol import to_serialize
from distributed.utils import LateLoopEvent, iscoroutinefunction, sync, thread_state
from distributed.utils_comm import WrappedKey
from distributed.worker import get_client, get_worker

_T = TypeVar("_T")


class Actor(WrappedKey):
class Actor(TaskRef):
"""Controls an object on a remote worker

An actor allows remote control of a stateful object living on a remote
Expand Down
50 changes: 30 additions & 20 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
cast,
)

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from packaging.version import parse as parse_version
from tlz import first, groupby, merge, partition_all, valmap

Expand All @@ -52,7 +49,6 @@
from dask.tokenize import tokenize
from dask.typing import Key, NoDefault, no_default
from dask.utils import (
apply,
ensure_dict,
format_bytes,
funcname,
Expand All @@ -74,6 +70,8 @@
from tornado import gen
from tornado.ioloop import IOLoop

from dask._task_spec import DataNode, GraphNode, Task, TaskRef

import distributed.utils
from distributed import cluster_dump, preloading
from distributed import versions as version_module
Expand Down Expand Up @@ -123,7 +121,6 @@
thread_state,
)
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry_operation,
Expand All @@ -132,6 +129,9 @@
)
from distributed.worker import get_client, get_worker, secede

if TYPE_CHECKING:
from typing_extensions import TypeAlias

logger = logging.getLogger(__name__)

_global_clients: weakref.WeakValueDictionary[int, Client] = (
Expand Down Expand Up @@ -250,7 +250,7 @@
pass


class Future(WrappedKey):
class Future(TaskRef):
"""A remotely running computation

A Future is a local proxy to a result running on a remote worker. A user
Expand Down Expand Up @@ -598,6 +598,9 @@
except RuntimeError: # closed event loop
pass

def __str__(self):
return repr(self)

Check warning on line 602 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L602

Added line #L602 was not covered by tests

def __repr__(self):
if self.type:
return (
Expand All @@ -616,6 +619,9 @@
def __await__(self):
return self.result().__await__()

def __hash__(self):
return hash(self._id)


class FutureState:
"""A Future's internal state.
Expand Down Expand Up @@ -813,7 +819,7 @@
client: dict[str, dict[str, Any]]


_T_LowLevelGraph: TypeAlias = dict[Key, tuple]
_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode]


def _is_nested(iterable):
Expand Down Expand Up @@ -905,7 +911,7 @@
def is_materialized(self) -> bool:
return hasattr(self, "_cached_dict")

def __getitem__(self, key: Key) -> tuple:
def __getitem__(self, key: Key) -> GraphNode:
return self._dict[key]

def __iter__(self) -> Iterator[Key]:
Expand All @@ -919,7 +925,7 @@

if not self.kwargs:
dsk = {
key: (self.func,) + args
key: Task(key, self.func, *args)
for key, args in zip(self._keys, zip(*self.iterables))
}

Expand All @@ -928,15 +934,15 @@
dsk = {}
for k, v in self.kwargs.items():
if sizeof(v) > 1e5:
vv = dask.delayed(v)
kwargs2[k] = vv._key
dsk.update(vv.dask)
vv = DataNode(k, v)
kwargs2[k] = vv.ref()
dsk[vv.key] = vv

Check warning on line 939 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L937-L939

Added lines #L937 - L939 were not covered by tests
else:
kwargs2[k] = v

dsk.update(
{
key: (apply, self.func, (tuple, list(args)), kwargs2)
key: Task(key, self.func, *args, **kwargs2)
for key, args in zip(self._keys, zip(*self.iterables))
}
)
Expand Down Expand Up @@ -2158,10 +2164,14 @@
if isinstance(workers, (str, Number)):
workers = [workers]

if kwargs:
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {key: (func,) + tuple(args)}
dsk = {
key: Task(
key,
func,
*args,
**kwargs,
)
}
futures = self._graph_to_futures(
dsk,
[key],
Expand Down Expand Up @@ -3374,7 +3384,7 @@
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(keys),
"keys": set(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -4460,7 +4470,7 @@
self,
filename: str = "dask-cluster-dump",
write_from_scheduler: bool | None = None,
exclude: Collection[str] = ("run_spec",),
exclude: Collection[str] = (),
format: Literal["msgpack", "yaml"] = "msgpack",
**storage_options,
):
Expand Down Expand Up @@ -6100,7 +6110,7 @@
stack.extend(x.values())
elif type(x) is SubgraphCallable:
stack.extend(x.dsk.values())
elif isinstance(x, WrappedKey):
elif isinstance(x, TaskRef):
if x not in seen:
seen.add(x)
futures.append(x)
Expand Down
4 changes: 3 additions & 1 deletion distributed/deploy/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ async def test_repr():

@gen_test()
async def test_cluster_wait_for_worker():
async with LocalCluster(n_workers=2, asynchronous=True) as cluster:
async with LocalCluster(
n_workers=2, asynchronous=True, dashboard_address=":0"
) as cluster:
assert len(cluster.scheduler.workers) == 2
cluster.scale(4)
await cluster.wait_for_workers(4)
Expand Down
24 changes: 19 additions & 5 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,11 @@ async def test_threads_per_worker_set_to_0():
Warning, match="Setting `threads_per_worker` to 0 has been deprecated."
):
async with LocalCluster(
n_workers=2, processes=False, threads_per_worker=0, asynchronous=True
n_workers=2,
processes=False,
threads_per_worker=0,
asynchronous=True,
dashboard_address=":0",
) as cluster:
assert len(cluster.workers) == 2
assert all(w.state.nthreads < CPU_COUNT for w in cluster.workers.values())
Expand Down Expand Up @@ -1170,7 +1174,10 @@ async def test_local_cluster_redundant_kwarg(nanny):
@gen_test()
async def test_cluster_info_sync():
async with LocalCluster(
processes=False, asynchronous=True, scheduler_sync_interval="1ms"
processes=False,
asynchronous=True,
scheduler_sync_interval="1ms",
dashboard_address=":0",
) as cluster:
assert cluster._cluster_info["name"] == cluster.name

Expand All @@ -1197,7 +1204,10 @@ async def test_cluster_info_sync():
@gen_test()
async def test_cluster_info_sync_is_robust_to_network_blips(monkeypatch):
async with LocalCluster(
processes=False, asynchronous=True, scheduler_sync_interval="1ms"
processes=False,
asynchronous=True,
scheduler_sync_interval="1ms",
dashboard_address=":0",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these dashboard changes are also unrelated. Appologies. If it actually helps I will factor it out but those tests are typically disjoint from actual changes so I hope the review process is not too difficult

This change allows the tests to run in parallel

) as cluster:
assert cluster._cluster_info["name"] == cluster.name

Expand Down Expand Up @@ -1235,7 +1245,9 @@ async def error(*args, **kwargs):
@gen_test()
async def test_cluster_host_used_throughout_cluster(host, use_nanny):
"""Ensure that the `host` kwarg is propagated through scheduler, nanny, and workers"""
async with LocalCluster(host=host, asynchronous=True) as cluster:
async with LocalCluster(
host=host, asynchronous=True, dashboard_address=":0"
) as cluster:
url = urlparse(cluster.scheduler_address)
assert url.hostname == "127.0.0.1"
for worker in cluster.workers.values():
Expand All @@ -1249,7 +1261,9 @@ async def test_cluster_host_used_throughout_cluster(host, use_nanny):

@gen_test()
async def test_connect_to_closed_cluster():
async with LocalCluster(processes=False, asynchronous=True) as cluster:
async with LocalCluster(
processes=False, asynchronous=True, dashboard_address=":0"
) as cluster:
async with Client(cluster, asynchronous=True) as c1:
assert await c1.submit(inc, 1) == 2

Expand Down
37 changes: 14 additions & 23 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from distributed.client import Future, futures_of, wait
from distributed.protocol.serialize import ToPickle
from distributed.utils import sync
from distributed.utils_comm import pack_data

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,11 +39,8 @@

def get_runspec(self, *args, key=None, **kwargs):
key = self._process_key(key)
ts = self.scheduler.tasks.get(key)
return {
"task": ToPickle(ts.run_spec),
"deps": [dts.key for dts in ts.dependencies],
}
ts = self.scheduler.tasks[key]
return ToPickle(ts.run_spec)

Check warning on line 43 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L42-L43

Added lines #L42 - L43 were not covered by tests


class ReplayTaskClient:
Expand All @@ -61,10 +57,8 @@
self.client = client
self.client.extensions["replay-tasks"] = self
# monkey patch
self.client._get_raw_components_from_future = (
self._get_raw_components_from_future
)
self.client._prepare_raw_components = self._prepare_raw_components
self.client._get_raw_components_from_future = self._get_task_runspec
self.client._prepare_raw_components = self._get_dependencies
self.client._get_components_from_future = self._get_components_from_future
self.client._get_errored_future = self._get_errored_future
self.client.recreate_task_locally = self.recreate_task_locally
Expand All @@ -74,7 +68,7 @@
def scheduler(self):
return self.client.scheduler

async def _get_raw_components_from_future(self, future):
async def _get_task_runspec(self, future):
"""
For a given future return the func, args and kwargs and future
deps that would be executed remotely.
Expand All @@ -85,28 +79,25 @@
else:
validate_key(future)
key = future
spec = await self.scheduler.get_runspec(key=key)
return (*spec["task"], spec["deps"])
run_spec = await self.scheduler.get_runspec(key=key)
return run_spec

Check warning on line 83 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L82-L83

Added lines #L82 - L83 were not covered by tests

async def _prepare_raw_components(self, raw_components):
async def _get_dependencies(self, dependencies):
"""
Take raw components and resolve future dependencies.
"""
function, args, kwargs, deps = raw_components
futures = self.client._graph_to_futures({}, deps, span_metadata={})
futures = self.client._graph_to_futures({}, dependencies, span_metadata={})

Check warning on line 89 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L89

Added line #L89 was not covered by tests
data = await self.client._gather(futures)
args = pack_data(args, data)
kwargs = pack_data(kwargs, data)
return (function, args, kwargs)
return data

Check warning on line 91 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L91

Added line #L91 was not covered by tests

async def _get_components_from_future(self, future):
"""
For a given future return the func, args and kwargs that would be
executed remotely. Any args/kwargs that are themselves futures will
be resolved to the return value of those futures.
"""
raw_components = await self._get_raw_components_from_future(future)
return await self._prepare_raw_components(raw_components)
runspec = await self._get_task_runspec(future)
return runspec, await self._get_dependencies(runspec.dependencies)

Check warning on line 100 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L99-L100

Added lines #L99 - L100 were not covered by tests

def recreate_task_locally(self, future):
"""
Expand Down Expand Up @@ -137,10 +128,10 @@
-------
Any; will return the result of the task future.
"""
func, args, kwargs = sync(
runspec, dependencies = sync(

Check warning on line 131 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L131

Added line #L131 was not covered by tests
self.client.loop, self._get_components_from_future, future
)
return func(*args, **kwargs)
return runspec(dependencies)

Check warning on line 134 in distributed/recreate_tasks.py

View check run for this annotation

Codecov / codecov/patch

distributed/recreate_tasks.py#L134

Added line #L134 was not covered by tests

async def _get_errored_future(self, future):
"""
Expand Down
Loading
Loading