Skip to content

Commit

Permalink
please mkdocstrings by specifying of all variables (they won't help m…
Browse files Browse the repository at this point in the history
…uch with type checks, but appear in documentation)
  • Loading branch information
arogozhnikov committed Sep 17, 2024
1 parent 4c4070e commit af8b3b7
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Tensor = TypeVar("Tensor")
ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor]
Reduction = Union[str, ReductionCallable]
Size = typing.Any

_reductions = ("min", "max", "sum", "mean", "prod", "any", "all")

Expand Down Expand Up @@ -456,7 +457,7 @@ def _prepare_recipes_for_all_dims(
return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims}


def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: Size) -> Tensor:
"""
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
Expand Down Expand Up @@ -533,7 +534,7 @@ def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reducti
raise EinopsError(message + "\n {}".format(e))


def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
"""
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
Expand All @@ -549,6 +550,10 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
>>> rearrange(images, 'b h w c -> b h w c').shape
(32, 30, 40, 3)
# stacked and reordered axes to "b c h w" format
>>> rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
(960, 40, 3)
Expand All @@ -557,10 +562,6 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
>>> rearrange(images, 'b h w c -> h (b w) c').shape
(30, 1280, 3)
# reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
(32, 3600)
Expand Down Expand Up @@ -591,7 +592,7 @@ def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)


def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths: Size) -> Tensor:
"""
einops.repeat allows reordering elements and repeating them in arbitrary combinations.
This operation includes functionality of repeat, tile, broadcast functions.
Expand Down Expand Up @@ -641,7 +642,7 @@ def repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) ->
return reduce(tensor, pattern, reduction="repeat", **axes_lengths)


def parse_shape(x, pattern: str) -> dict:
def parse_shape(x: Tensor, pattern: str) -> dict:
"""
Parse a tensor shape to dictionary mapping axes names to their lengths.
Expand Down Expand Up @@ -730,7 +731,7 @@ def _enumerate_directions(x):
np_ndarray = Any


def asnumpy(tensor) -> np_ndarray:
def asnumpy(tensor: Tensor) -> np_ndarray:
"""
Convert a tensor of an imperative framework (i.e. numpy/cupy/torch/jax/etc.) to `numpy.ndarray`
Expand Down

0 comments on commit af8b3b7

Please sign in to comment.