diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index 178b69cb9b5..64c7fe012b4 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -1,13 +1,14 @@ from __future__ import annotations -from distributed.shuffle._shuffle import rearrange_by_column_p2p +from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p from distributed.shuffle._shuffle_extension import ( ShuffleSchedulerExtension, ShuffleWorkerExtension, ) __all__ = [ + "P2PShuffleLayer", + "rearrange_by_column_p2p", "ShuffleSchedulerExtension", "ShuffleWorkerExtension", - "rearrange_by_column_p2p", ] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 972bef71b3c..0435984087f 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph @@ -93,15 +93,15 @@ def rearrange_by_column_p2p( class P2PShuffleLayer(SimpleShuffleLayer): def __init__( self, - name, - column, - npartitions, - npartitions_input, - ignore_index, - name_input, - meta_input, - parts_out=None, - annotations=None, + name: str, + column: str, + npartitions: int, + npartitions_input: int, + ignore_index: bool, + name_input: str, + meta_input: pd.DataFrame, + parts_out: list | None = None, + annotations: dict | None = None, ): annotations = annotations or {} annotations.update({"shuffle": lambda key: key[1]}) @@ -117,16 +117,16 @@ def __init__( annotations=annotations, ) - def get_split_keys(self): + def get_split_keys(self) -> list: # TODO: This is doing some funky stuff to set priorities but we don't need this return [] - def __repr__(self): + def __repr__(self) -> str: return ( f"{type(self).__name__}" ) - def _cull(self, parts_out): + def _cull(self, parts_out: list) -> P2PShuffleLayer: return P2PShuffleLayer( self.name, self.column, @@ -138,9 +138,9 @@ def _cull(self, parts_out): parts_out=parts_out, ) - def _construct_graph(self, deserializing=None): + def _construct_graph(self, deserializing: Any = None) -> dict[tuple, tuple]: token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out) - dsk = {} + dsk: dict[tuple, tuple] = {} barrier_key = "shuffle-barrier-" + token name = "shuffle-transfer-" + token tranfer_keys = list()