Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Augspurger committed Jun 17, 2021
1 parent 0bfa2b6 commit 6d190ff
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 107 deletions.
59 changes: 42 additions & 17 deletions pangeo_forge_recipes/recipes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,44 @@ def to_dask(self):
"""
Translate the recipe to a dask.Delayed object for parallel execution.
"""
# This manually builds a Dask task graph with each stage of the recipe.
# We use a few "checkpoints" to ensure that downstream tasks depend
# on upstream tasks being done before starting. We use a manual task
# graph rather than dask.delayed to avoid some expensive tokenization
# in dask.delayed
import dask
from dask.delayed import Delayed

# TODO: HighlevelGraph layers for each of these mapped inputs.
# Cache Input --------------------------------------------------------
dsk = {}
token = dask.base.tokenize(self)

if getattr(self, "cache_inputs", False): # TODO: formalize cache_inputs
for i, input_key in enumerate(self.iter_inputs()):
dsk[(f"cache_input-{token}", i)] = (self.cache_input, input_key)

# Prepare Target ------------------------------------------------------
dsk[f"checkpoint_0-{token}"] = (lambda *args: None, list(dsk))
dsk[f"prepare_target-{token}"] = (
_prepare_target,
f"checkpoint_0-{token}",
self.prepare_target,
)

tasks = []
if getattr(self, "cache_inputs", False):
f = dask.delayed(self.cache_input)
for input_key in self.iter_inputs():
tasks.append(f(input_key))
# Store Chunk --------------------------------------------------------
keys = []
for i, chunk_key in enumerate(self.iter_chunks()):
k = (f"store_chunk-{token}", i)
dsk[k] = (_store_chunk, f"prepare_target-{token}", self.store_chunk, chunk_key)
keys.append(k)

b0 = dask.delayed(_barrier)(*tasks)
b1 = dask.delayed(_wait_and_call)(self.prepare_target, b0)
tasks = []
for chunk_key in self.iter_chunks():
tasks.append(dask.delayed(_wait_and_call)(self.store_chunk, b1, chunk_key))
# Finalize Target -----------------------------------------------------
dsk[f"checkpoint_1-{token}"] = (lambda *args: None, keys)
key = f"finalize_target-{token}"
dsk[key] = (_finalize_target, f"checkpoint_1-{token}", self.finalize_target)

b2 = dask.delayed(_barrier)(*tasks)
b3 = dask.delayed(_wait_and_call)(self.finalize_target, b2)
return b3
return Delayed(key, dsk)

def to_prefect(self):
"""Compile the recipe to a Prefect.Flow object."""
Expand Down Expand Up @@ -189,9 +210,13 @@ def wrapped(*args, **kwargs):
return wrapped


def _barrier(*args):
pass
def _prepare_target(checkpoint, func):
return func()


def _store_chunk(checkpoint, func, input_key):
return func(input_key)


def _wait_and_call(func, b, *args):
return func(*args)
def _finalize_target(checkpoint, func):
return func()
95 changes: 5 additions & 90 deletions pangeo_forge_recipes/recipes/xarray_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
import xarray as xr
import zarr
from dask.delayed import Delayed

from ..patterns import FilePattern, prune_pattern
from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener
Expand All @@ -24,7 +23,7 @@
fix_scalar_attr_encoding,
lock_for_conflicts,
)
from .base import BaseRecipe, closure
from .base import BaseRecipe

# use this filename to store global recipe metadata in the metadata_cache
# it will be written once (by prepare_target) and read many times (by store_chunk)
Expand Down Expand Up @@ -693,7 +692,7 @@ def _set_target_chunks(self):
# to serialize.

@property
def _prepare_target(self):
def prepare_target(self):
return functools.partial(
prepare_target,
target=self.target,
Expand All @@ -716,13 +715,8 @@ def _prepare_target(self):
metadata_cache=self.metadata_cache,
)

@property # type: ignore
@closure
def prepare_target(self) -> None:
return self._prepare_target()

@property
def _cache_input(self):
def cache_input(self):
return functools.partial(
cache_input,
cache_inputs=self.cache_inputs,
Expand All @@ -737,13 +731,8 @@ def _cache_input(self):
metadata_cache=self.metadata_cache,
)

@property # type: ignore
@closure
def cache_input(self, input_key: InputKey) -> None: # type: ignore
return self._cache_input(input_key)

@property
def _store_chunk(self):
def store_chunk(self):
return functools.partial(
store_chunk,
target=self.target,
Expand All @@ -766,25 +755,12 @@ def _store_chunk(self):
metadata_cache=self.metadata_cache,
)

@property # type: ignore
@closure
def store_chunk(self, chunk_key: ChunkKey) -> None: # type: ignore
# TODO(TOM): Restore the cache lookup
# assert isinstance(self.target, CacheFSSpecTarget) # TODO(mypy): check optional
return self._store_chunk(chunk_key)

@property
def _finalize_target(self):
def finalize_target(self):
return functools.partial(
finalize_target, target=self.target, consolidate_zarr=self.consolidate_zarr
)

@property # type: ignore
@closure
def finalize_target(self) -> None:
# assert isinstance(self.finalize_target, CacheFSSpecTarget) # TODO(mypy): check optional
return self._finalize_target()

def iter_inputs(self):
for input in self._inputs_chunks:
yield input
Expand All @@ -796,64 +772,3 @@ def iter_chunks(self):
def inputs_for_chunk(self, chunk_key: ChunkKey) -> Tuple[InputKey]:
"""Convenience function for users to introspect recipe."""
return self._chunks_inputs[chunk_key]

def to_dask(self):
"""Convert the Recipe to a dask.delayed.Delayed object."""
# This manually builds a Dask task graph with each stage of the recipe.
# We use a few "checkpoints" to ensure that downstream tasks depend
# on upstream tasks being done before starting.

# TODO: HighlevelGraph layers for each of these mapped inputs.
# Cache Input --------------------------------------------------------
dsk = {}
token = dask.base.tokenize(self)

for i, input_key in enumerate(self.iter_inputs()):
dsk[(f"cache_input-{token}", i)] = (self._cache_input, input_key)
dsk[f"checkpoint_0-{token}"] = (lambda *args: None, list(dsk))

# Prepare Target -----------------------------------------------------
def prepare_target2(checkpoint):
return self._prepare_target()

dsk[f"prepare_target-{token}"] = (prepare_target2, f"checkpoint_0-{token}")

# Store Chunk --------------------------------------------------------
def store_chunk2(checkpoint, input_key):
return self._store_chunk(input_key)

keys = []
for i, chunk_key in enumerate(self.iter_chunks()):
k = (f"store_chunk-{token}", i)
dsk[k] = (store_chunk2, f"prepare_target-{token}", chunk_key)
keys.append(k)

dsk[f"checkpoint_1-{token}"] = (lambda *args: None, keys)

# Finalize Target ----------------------------------------------------
def finalize_target2(checkpoint):
return self._finalize_target()

key = f"finalize_target-{token}"
dsk[key] = (finalize_target2, f"checkpoint_1-{token}")

return Delayed(key, dsk)

def to_prefect(self):
"""Compile the recipe to a Prefect.Flow object."""
from prefect import Flow, task, unmapped

cache_input_task = task(self._cache_input, name="cache_input")
prepare_target_task = task(self._prepare_target, name="prepare_target")
store_chunk_task = task(self._store_chunk, name="store_chunk")
finalize_target_task = task(self._finalize_target, name="finalize_target")

with Flow("pangeo-forge-recipe") as flow:
cache_task = cache_input_task.map(input_key=list(self.iter_inputs()))
prepare_task = prepare_target_task(upstream_tasks=[cache_task])
store_task = store_chunk_task.map(
chunk_key=list(self.iter_chunks()), upstream_tasks=[unmapped(prepare_task)],
)
_ = finalize_target_task(upstream_tasks=[store_task])

return flow

0 comments on commit 6d190ff

Please sign in to comment.