diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/shuffle.py index 95af3a9ee9..ca3fe0c3c0 100644 --- a/distributed/shuffle/shuffle.py +++ b/distributed/shuffle/shuffle.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph @@ -91,15 +91,15 @@ def rearrange_by_column_p2p( class P2PShuffleLayer(SimpleShuffleLayer): def __init__( self, - name, - column, - npartitions, - npartitions_input, - ignore_index, - name_input, - meta_input, - parts_out=None, - annotations=None, + name: str, + column: str, + npartitions: int, + npartitions_input: int, + ignore_index: bool, + name_input: str, + meta_input: pd.DataFrame, + parts_out: list | None = None, + annotations: dict | None = None, ): annotations = annotations or {} annotations.update({"shuffle": lambda key: key[1]}) @@ -115,16 +115,16 @@ def __init__( annotations=annotations, ) - def get_split_keys(self): + def get_split_keys(self) -> list: # TODO: This is doing some funky stuff to set priorities but we don't need this return [] - def __repr__(self): + def __repr__(self) -> str: return ( f"{type(self).__name__}" ) - def _cull(self, parts_out): + def _cull(self, parts_out: list) -> P2PShuffleLayer: return P2PShuffleLayer( self.name, self.column, @@ -136,9 +136,9 @@ def _cull(self, parts_out): parts_out=parts_out, ) - def _construct_graph(self, deserializing=None): + def _construct_graph(self, deserializing: Any = None) -> dict[tuple | str, tuple]: token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out) - dsk = {} + dsk: dict[tuple | str, tuple] = {} barrier_key = "shuffle-barrier-" + token name = "shuffle-transfer-" + token tranfer_keys = list()