diff --git a/src/dask_awkward/layers/__init__.py b/src/dask_awkward/layers/__init__.py index d4ba4c5e..098bbf2e 100644 --- a/src/dask_awkward/layers/__init__.py +++ b/src/dask_awkward/layers/__init__.py @@ -6,6 +6,7 @@ ImplementsIOFunction, ImplementsProjection, IOFunctionWithMocking, + _dask_uses_tasks, io_func_implements_projection, ) @@ -18,4 +19,5 @@ "ImplementsIOFunction", "IOFunctionWithMocking", "io_func_implements_projection", + "_dask_uses_tasks", ) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index b925e35a..24b33c99 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -4,15 +4,9 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast -_dask_uses_tasks = True -try: - from dask._task_spec import convert_legacy_graph -except ModuleNotFoundError as _: +import dask - def convert_legacy_graph(dsk, all_keys=None): # type: ignore - return dsk - - _dask_uses_tasks = False +_dask_uses_tasks = dask.__version__ >= "2024.12.0" from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token from dask.highlevelgraph import MaterializedLayer @@ -170,21 +164,22 @@ def __init__( produces_tasks=self.produces_tasks, ) - super_kwargs = { + super_kwargs: dict[str, Any] = { "output": self.name, "output_indices": "i", - "dsk": {name: (self.io_func, blockwise_token(0))}, "indices": [(io_arg_map, "i")], "numblocks": {}, "annotations": None, } if _dask_uses_tasks: - task = convert_legacy_graph(super_kwargs["dsk"]) # type: ignore - super_kwargs["task"] = task - super_kwargs.pop("dsk") + from dask._task_spec import Task, TaskRef + + super_kwargs["task"] = Task(name, self.io_func, TaskRef(blockwise_token(0))) + else: + super_kwargs["dsk"] = {name: (self.io_func, blockwise_token(0))} - super().__init__(**super_kwargs) # type: ignore + super().__init__(**super_kwargs) def __repr__(self) -> str: return f"AwkwardInputLayer<{self.output}>" diff --git a/src/dask_awkward/lib/optimize.py b/src/dask_awkward/lib/optimize.py index 6ad2e132..d75d51e3 100644 --- a/src/dask_awkward/lib/optimize.py +++ b/src/dask_awkward/lib/optimize.py @@ -13,7 +13,11 @@ from dask.highlevelgraph import HighLevelGraph from dask.local import get_sync -from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer +from dask_awkward.layers import ( + AwkwardBlockwiseLayer, + AwkwardInputLayer, + _dask_uses_tasks, +) from dask_awkward.lib.utils import typetracer_nochecks from dask_awkward.utils import first @@ -340,7 +344,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG deps[outkey] = deps[chain[0]] [deps.pop(ch) for ch in chain[:-1]] - subgraph = layer0.dsk.copy() # mypy: ignore + if _dask_uses_tasks: + all_tasks = [layer0.task] + else: + subgraph = layer0.dsk.copy() indices = list(layer0.indices) parent = chain[0] @@ -349,14 +356,27 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG layer = dsk.layers[chain_member] for k in layer.io_deps: # mypy: ignore outlayer.io_deps[k] = layer.io_deps[k] - func, *args = layer.dsk[chain_member] # mypy: ignore - args2 = _recursive_replace(args, layer, parent, indices) - subgraph[chain_member] = (func,) + tuple(args2) + + if _dask_uses_tasks: + from dask._task_spec import Task + + func = layer.task.func + args = layer.task.dependencies + # how to do this with `.substitute(...)`? + args2 = _recursive_replace(args, layer, parent, indices) + all_tasks.append(Task(chain_member, func, *args2)) + else: + func, *args = layer.dsk[chain_member] # mypy: ignore + args2 = _recursive_replace(args, layer, parent, indices) + subgraph[chain_member] = (func,) + tuple(args2) parent = chain_member outlayer.numblocks = { i[0]: (numblocks,) for i in indices if i[1] is not None } # mypy: ignore - outlayer.dsk = subgraph # mypy: ignore + if _dask_uses_tasks: + outlayer.task = Task.fuse(*all_tasks) + else: + outlayer.dsk = subgraph # mypy: ignore if hasattr(outlayer, "_dims"): del outlayer._dims outlayer.indices = tuple( # mypy: ignore @@ -379,7 +399,12 @@ def _recursive_replace(args, layer, parent, indices): args2.append(layer.indices[ind][0]) elif layer.indices[ind][0] == parent: # arg refers to output of previous layer - args2.append(parent) + if _dask_uses_tasks: + from dask._task_spec import TaskRef + + args2.append(TaskRef(parent)) + else: + args2.append(parent) else: # arg refers to things defined in io_deps indices.append(layer.indices[ind])