Skip to content

Commit

Permalink
Add Layer to public API
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 26, 2022
1 parent 63a1115 commit 5ce0f35
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
5 changes: 3 additions & 2 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from distributed.shuffle._shuffle import rearrange_by_column_p2p
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._shuffle_extension import (
ShuffleSchedulerExtension,
ShuffleWorkerExtension,
)

__all__ = [
"P2PShuffleLayer",
"rearrange_by_column_p2p",
"ShuffleSchedulerExtension",
"ShuffleWorkerExtension",
"rearrange_by_column_p2p",
]
30 changes: 15 additions & 15 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph
Expand Down Expand Up @@ -93,15 +93,15 @@ def rearrange_by_column_p2p(
class P2PShuffleLayer(SimpleShuffleLayer):
def __init__(
self,
name,
column,
npartitions,
npartitions_input,
ignore_index,
name_input,
meta_input,
parts_out=None,
annotations=None,
name: str,
column: str,
npartitions: int,
npartitions_input: int,
ignore_index: bool,
name_input: str,
meta_input: pd.DataFrame,
parts_out: list | None = None,
annotations: dict | None = None,
):
annotations = annotations or {}
annotations.update({"shuffle": lambda key: key[1]})
Expand All @@ -117,16 +117,16 @@ def __init__(
annotations=annotations,
)

def get_split_keys(self):
def get_split_keys(self) -> list:
# TODO: This is doing some funky stuff to set priorities but we don't need this
return []

def __repr__(self):
def __repr__(self) -> str:
return (
f"{type(self).__name__}<name='{self.name}', npartitions={self.npartitions}>"
)

def _cull(self, parts_out):
def _cull(self, parts_out: list) -> P2PShuffleLayer:
return P2PShuffleLayer(
self.name,
self.column,
Expand All @@ -138,9 +138,9 @@ def _cull(self, parts_out):
parts_out=parts_out,
)

def _construct_graph(self, deserializing=None):
def _construct_graph(self, deserializing: Any = None) -> dict[tuple, tuple]:
token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out)
dsk = {}
dsk: dict[tuple, tuple] = {}
barrier_key = "shuffle-barrier-" + token
name = "shuffle-transfer-" + token
tranfer_keys = list()
Expand Down

0 comments on commit 5ce0f35

Please sign in to comment.