Skip to content

Commit

Permalink
update rewrite_layer_chains with new dask Tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 6, 2024
1 parent 01b0a45 commit 536dfc4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/dask_awkward/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ImplementsIOFunction,
ImplementsProjection,
IOFunctionWithMocking,
_dask_uses_tasks,
io_func_implements_projection,
)

Expand All @@ -18,4 +19,5 @@
"ImplementsIOFunction",
"IOFunctionWithMocking",
"io_func_implements_projection",
"_dask_uses_tasks",
)
23 changes: 9 additions & 14 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast

_dask_uses_tasks = True
try:
from dask._task_spec import convert_legacy_graph
except ModuleNotFoundError as _:
import dask

def convert_legacy_graph(dsk, all_keys=None): # type: ignore
return dsk

_dask_uses_tasks = False
_dask_uses_tasks = dask.__version__ >= "2024.12.0"

from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
from dask.highlevelgraph import MaterializedLayer
Expand Down Expand Up @@ -170,21 +164,22 @@ def __init__(
produces_tasks=self.produces_tasks,
)

super_kwargs = {
super_kwargs: dict[str, Any] = {
"output": self.name,
"output_indices": "i",
"dsk": {name: (self.io_func, blockwise_token(0))},
"indices": [(io_arg_map, "i")],
"numblocks": {},
"annotations": None,
}

if _dask_uses_tasks:
task = convert_legacy_graph(super_kwargs["dsk"]) # type: ignore
super_kwargs["task"] = task
super_kwargs.pop("dsk")
from dask._task_spec import Task, TaskRef

super_kwargs["task"] = Task(name, self.io_func, TaskRef(blockwise_token(0)))
else:
super_kwargs["dsk"] = {name: (self.io_func, blockwise_token(0))}

super().__init__(**super_kwargs) # type: ignore
super().__init__(**super_kwargs)

def __repr__(self) -> str:
return f"AwkwardInputLayer<{self.output}>"
Expand Down
39 changes: 32 additions & 7 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync

from dask_awkward.layers import AwkwardBlockwiseLayer, AwkwardInputLayer
from dask_awkward.layers import (
AwkwardBlockwiseLayer,
AwkwardInputLayer,
_dask_uses_tasks,
)
from dask_awkward.lib.utils import typetracer_nochecks
from dask_awkward.utils import first

Expand Down Expand Up @@ -340,7 +344,10 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
deps[outkey] = deps[chain[0]]
[deps.pop(ch) for ch in chain[:-1]]

subgraph = layer0.dsk.copy() # mypy: ignore
if _dask_uses_tasks:
all_tasks = [layer0.task]
else:
subgraph = layer0.dsk.copy()
indices = list(layer0.indices)
parent = chain[0]

Expand All @@ -349,14 +356,27 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
layer = dsk.layers[chain_member]
for k in layer.io_deps: # mypy: ignore
outlayer.io_deps[k] = layer.io_deps[k]
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)

if _dask_uses_tasks:
from dask._task_spec import Task

func = layer.task.func
args = layer.task.dependencies
# how to do this with `.substitute(...)`?
args2 = _recursive_replace(args, layer, parent, indices)
all_tasks.append(Task(chain_member, func, *args2))
else:
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
parent = chain_member
outlayer.numblocks = {
i[0]: (numblocks,) for i in indices if i[1] is not None
} # mypy: ignore
outlayer.dsk = subgraph # mypy: ignore
if _dask_uses_tasks:
outlayer.task = Task.fuse(*all_tasks)
else:
outlayer.dsk = subgraph # mypy: ignore
if hasattr(outlayer, "_dims"):
del outlayer._dims
outlayer.indices = tuple( # mypy: ignore
Expand All @@ -379,7 +399,12 @@ def _recursive_replace(args, layer, parent, indices):
args2.append(layer.indices[ind][0])
elif layer.indices[ind][0] == parent:
# arg refers to output of previous layer
args2.append(parent)
if _dask_uses_tasks:
from dask._task_spec import TaskRef

args2.append(TaskRef(parent))
else:
args2.append(parent)
else:
# arg refers to things defined in io_deps
indices.append(layer.indices[ind])
Expand Down

0 comments on commit 536dfc4

Please sign in to comment.