Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: tighten Axis #48612

Merged
merged 3 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@

NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")

Axis = Union[str, int]
AxisInt = int
Axis = Union[AxisInt, Literal["index", "columns", "rows"]]
IndexLabel = Union[Hashable, Sequence[Hashable]]
Level = Hashable
Shape = Tuple[int, ...]
Expand Down
7 changes: 5 additions & 2 deletions pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
is_bool,
is_integer,
)
from pandas._typing import Axis
from pandas._typing import (
Axis,
AxisInt,
)
from pandas.errors import UnsupportedFunctionCall
from pandas.util._validators import (
validate_args,
Expand Down Expand Up @@ -413,7 +416,7 @@ def validate_resampler_func(method: str, args, kwargs) -> None:
raise TypeError("too many arguments passed in")


def validate_minmax_axis(axis: int | None, ndim: int = 1) -> None:
def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
"""
Ensure that the axis argument passed to min, max, argmin, or argmax is zero
or None, as otherwise it will be incorrectly ignored.
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pandas._typing import (
AnyArrayLike,
ArrayLike,
AxisInt,
DtypeObj,
IndexLabel,
TakeIndexer,
Expand Down Expand Up @@ -1105,7 +1106,7 @@ def mode(

def rank(
values: ArrayLike,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down Expand Up @@ -1483,7 +1484,7 @@ def get_indexer(current_indexer, other_indexer):
def take(
arr,
indices: TakeIndexer,
axis: int = 0,
axis: AxisInt = 0,
allow_fill: bool = False,
fill_value=None,
):
Expand Down Expand Up @@ -1675,7 +1676,7 @@ def searchsorted(
_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"}


def diff(arr, n: int, axis: int = 0):
def diff(arr, n: int, axis: AxisInt = 0):
"""
difference of n between self,
analogous to s-s.shift(n)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AggFuncTypeDict,
AggObjType,
Axis,
AxisInt,
NDFrameT,
npt,
)
Expand Down Expand Up @@ -104,7 +105,7 @@ def frame_apply(


class Apply(metaclass=abc.ABCMeta):
axis: int
axis: AxisInt

def __init__(
self,
Expand Down
21 changes: 12 additions & 9 deletions pandas/core/array_algos/masked_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np

from pandas._libs import missing as libmissing
from pandas._typing import npt
from pandas._typing import (
AxisInt,
npt,
)

from pandas.core.nanops import check_below_min_count

Expand All @@ -21,7 +24,7 @@ def _reductions(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -62,7 +65,7 @@ def sum(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _reductions(
np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
Expand All @@ -75,7 +78,7 @@ def prod(
*,
skipna: bool = True,
min_count: int = 0,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _reductions(
np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
Expand All @@ -88,7 +91,7 @@ def _minmax(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
"""
Reduction for 1D masked array.
Expand Down Expand Up @@ -125,7 +128,7 @@ def min(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _minmax(np.min, values=values, mask=mask, skipna=skipna, axis=axis)

Expand All @@ -135,7 +138,7 @@ def max(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
return _minmax(np.max, values=values, mask=mask, skipna=skipna, axis=axis)

Expand All @@ -145,7 +148,7 @@ def mean(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
):
if not values.size or mask.all():
return libmissing.NA
Expand All @@ -157,7 +160,7 @@ def var(
mask: npt.NDArray[np.bool_],
*,
skipna: bool = True,
axis: int | None = None,
axis: AxisInt | None = None,
ddof: int = 1,
):
if not values.size or mask.all():
Expand Down
19 changes: 12 additions & 7 deletions pandas/core/array_algos/take.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from pandas._typing import (
ArrayLike,
AxisInt,
npt,
)

Expand All @@ -36,7 +37,7 @@
def take_nd(
arr: np.ndarray,
indexer,
axis: int = ...,
axis: AxisInt = ...,
fill_value=...,
allow_fill: bool = ...,
) -> np.ndarray:
Expand All @@ -47,7 +48,7 @@ def take_nd(
def take_nd(
arr: ExtensionArray,
indexer,
axis: int = ...,
axis: AxisInt = ...,
fill_value=...,
allow_fill: bool = ...,
) -> ArrayLike:
Expand All @@ -57,7 +58,7 @@ def take_nd(
def take_nd(
arr: ArrayLike,
indexer,
axis: int = 0,
axis: AxisInt = 0,
fill_value=lib.no_default,
allow_fill: bool = True,
) -> ArrayLike:
Expand Down Expand Up @@ -120,7 +121,7 @@ def take_nd(
def _take_nd_ndarray(
arr: np.ndarray,
indexer: npt.NDArray[np.intp] | None,
axis: int,
axis: AxisInt,
fill_value,
allow_fill: bool,
) -> np.ndarray:
Expand Down Expand Up @@ -287,7 +288,7 @@ def take_2d_multi(

@functools.lru_cache(maxsize=128)
def _get_take_nd_function_cached(
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: AxisInt
):
"""
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
Expand Down Expand Up @@ -324,7 +325,11 @@ def _get_take_nd_function_cached(


def _get_take_nd_function(
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None
ndim: int,
arr_dtype: np.dtype,
out_dtype: np.dtype,
axis: AxisInt = 0,
mask_info=None,
):
"""
Get the appropriate "take" implementation for the given dimension, axis
Expand Down Expand Up @@ -503,7 +508,7 @@ def _take_nd_object(
arr: np.ndarray,
indexer: npt.NDArray[np.intp],
out: np.ndarray,
axis: int,
axis: AxisInt,
fill_value,
mask_info,
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/array_algos/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import numpy as np

from pandas._typing import AxisInt

def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray:

def shift(values: np.ndarray, periods: int, axis: AxisInt, fill_value) -> np.ndarray:
new_values = values

if periods == 0 or values.size == 0:
Expand Down
11 changes: 6 additions & 5 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pandas._libs.arrays import NDArrayBacked
from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
F,
PositionalIndexer2D,
Expand Down Expand Up @@ -157,7 +158,7 @@ def take(
*,
allow_fill: bool = False,
fill_value: Any = None,
axis: int = 0,
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
if allow_fill:
fill_value = self._validate_scalar(fill_value)
Expand Down Expand Up @@ -192,15 +193,15 @@ def _values_for_factorize(self):
return self._ndarray, self._internal_fill_value

# Signature of "argmin" incompatible with supertype "ExtensionArray"
def argmin(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
# override base class by adding axis keyword
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
raise NotImplementedError
return nargminmax(self, "argmin", axis=axis)

# Signature of "argmax" incompatible with supertype "ExtensionArray"
def argmax(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
# override base class by adding axis keyword
validate_bool_kwarg(skipna, "skipna")
if not skipna and self._hasna:
Expand All @@ -216,7 +217,7 @@ def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
def _concat_same_type(
cls: type[NDArrayBackedExtensionArrayT],
to_concat: Sequence[NDArrayBackedExtensionArrayT],
axis: int = 0,
axis: AxisInt = 0,
) -> NDArrayBackedExtensionArrayT:
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
Expand Down Expand Up @@ -351,7 +352,7 @@ def fillna(
# ------------------------------------------------------------------------
# Reductions

def _wrap_reduction_result(self, axis: int | None, result):
def _wrap_reduction_result(self, axis: AxisInt | None, result):
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pandas._typing import (
ArrayLike,
AstypeArg,
AxisInt,
Dtype,
FillnaOptions,
PositionalIndexer,
Expand Down Expand Up @@ -1137,7 +1138,7 @@ def factorize(
@Substitution(klass="ExtensionArray")
@Appender(_extension_array_shared_docs["repeat"])
def repeat(
self: ExtensionArrayT, repeats: int | Sequence[int], axis: int | None = None
self: ExtensionArrayT, repeats: int | Sequence[int], axis: AxisInt | None = None
) -> ExtensionArrayT:
nv.validate_repeat((), {"axis": axis})
ind = np.arange(len(self)).repeat(repeats)
Expand Down Expand Up @@ -1567,7 +1568,7 @@ def _fill_mask_inplace(
def _rank(
self,
*,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pandas._typing import (
ArrayLike,
AstypeArg,
AxisInt,
Dtype,
NpDtype,
Ordered,
Expand Down Expand Up @@ -1988,7 +1989,7 @@ def sort_values(
def _rank(
self,
*,
axis: int = 0,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
Expand Down Expand Up @@ -2464,7 +2465,7 @@ def equals(self, other: object) -> bool:

@classmethod
def _concat_same_type(
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: int = 0
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: AxisInt = 0
) -> CategoricalT:
from pandas.core.dtypes.concat import union_categoricals

Expand Down
Loading