Skip to content

Commit

Permalink
cleaner way of dealing with it
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray authored Dec 4, 2024
1 parent c25e1f6 commit 7c2174c
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

_dask_uses_tasks = True
try:
from dask._task_spec import Task
from dask._task_spec import convert_legacy_graph
except ModuleNotFoundError as _:
def convert_legacy_graph(_):
return _
_dask_uses_tasks = False

from dask.blockwise import Blockwise, BlockwiseDepDict, blockwise_token
Expand Down Expand Up @@ -166,24 +168,21 @@ def __init__(
produces_tasks=self.produces_tasks,
)

super_kwargs = {
"output": self.name,
"output_indices": "i",
"dsk": {name: (self.io_func, blockwise_token(0))},
"indices": [(io_arg_map, "i")],
"numblocks": {},
"annotations": None,
}

if _dask_uses_tasks:
super().__init__( # type: ignore
output=self.name,
output_indices="i",
task=Task(name, self.io_func, blockwise_token(0)),
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)
else:
super().__init__(
output=self.name,
output_indices="i",
dsk={name: (self.io_func, blockwise_token(0))},
indices=[(io_arg_map, "i")],
numblocks={},
annotations=None,
)
task = convert_legacy_graph(super_args["dsk"])
super_args["task"] = task
super_args.pop("dsk")

super().__init__(**super_kwargs)

def __repr__(self) -> str:
return f"AwkwardInputLayer<{self.output}>"
Expand Down

0 comments on commit 7c2174c

Please sign in to comment.