Skip to content

Commit

Permalink
Backport PR pandas-dev#57102: ENH: Add skipna to groupby.first and gr…
Browse files Browse the repository at this point in the history
…oupby.last
  • Loading branch information
rhshadrach authored and meeseeksmachine committed Jan 30, 2024
1 parent 10b5873 commit f1beec4
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 28 deletions.
3 changes: 2 additions & 1 deletion doc/source/whatsnew/v2.2.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Bug fixes

Other
~~~~~
-
- Added the argument ``skipna`` to :meth:`DataFrameGroupBy.first`, :meth:`DataFrameGroupBy.last`, :meth:`SeriesGroupBy.first`, and :meth:`SeriesGroupBy.last`; achieving ``skipna=False`` used to be available via :meth:`DataFrameGroupBy.nth`, but the behavior was changed in pandas 2.0.0 (:issue:`57019`)
- Added the argument ``skipna`` to :meth:`Resampler.first`, :meth:`Resampler.last` (:issue:`57019`)

.. ---------------------------------------------------------------------------
.. _whatsnew_221.contributors:
Expand Down
2 changes: 2 additions & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def group_last(
result_mask: npt.NDArray[np.bool_] | None = ...,
min_count: int = ..., # Py_ssize_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_nth(
out: np.ndarray, # rank_t[:, ::1]
Expand All @@ -147,6 +148,7 @@ def group_nth(
min_count: int = ..., # int64_t
rank: int = ..., # int64_t
is_datetimelike: bool = ...,
skipna: bool = ...,
) -> None: ...
def group_rank(
out: np.ndarray, # float64_t[:, ::1]
Expand Down
41 changes: 26 additions & 15 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,7 @@ def group_last(
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=-1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1458,14 +1459,19 @@ def group_last(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
resx[lab, j] = val
nobs[lab, j] += 1
resx[lab, j] = val

if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand All @@ -1486,6 +1492,7 @@ def group_nth(
int64_t min_count=-1,
int64_t rank=1,
bint is_datetimelike=False,
bint skipna=True,
) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1520,15 +1527,19 @@ def group_nth(
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if skipna:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)
if isna_entry:
continue

if not isna_entry:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
if uses_mask and not skipna:
result_mask[lab, j] = mask[i, j]

_check_below_mincount(
out, uses_mask, result_mask, ncounts, K, nobs, min_count, resx
Expand Down
7 changes: 7 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,18 @@
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []
ALL_PYARROW_DTYPES = []
ALL_REAL_PYARROW_DTYPES_STR_REPR = []

ALL_REAL_NULLABLE_DTYPES = (
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
)

arithmetic_dunder_methods = [
"__add__",
Expand Down
32 changes: 32 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,38 @@ def any_numpy_dtype(request):
return request.param


@pytest.fixture(params=tm.ALL_REAL_NULLABLE_DTYPES)
def any_real_nullable_dtype(request):
"""
Parameterized fixture for all real dtypes that can hold NA.
* float
* 'float32'
* 'float64'
* 'Float32'
* 'Float64'
* 'UInt8'
* 'UInt16'
* 'UInt32'
* 'UInt64'
* 'Int8'
* 'Int16'
* 'Int32'
* 'Int64'
* 'uint8[pyarrow]'
* 'uint16[pyarrow]'
* 'uint32[pyarrow]'
* 'uint64[pyarrow]'
* 'int8[pyarrow]'
* 'int16[pyarrow]'
* 'int32[pyarrow]'
* 'int64[pyarrow]'
* 'float[pyarrow]'
* 'double[pyarrow]'
"""
return request.param


@pytest.fixture(params=tm.ALL_NUMERIC_DTYPES)
def any_numeric_dtype(request):
"""
Expand Down
36 changes: 28 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3335,22 +3335,31 @@ def max(
)

@final
def first(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def first(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the first non-null entry of each column.
Compute the first entry of each column within each group.
Defaults to skipping NA elements.
Parameters
----------
numeric_only : bool, default False
Include only float, int, boolean columns.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.
.. versionadded:: 2.2.1
Returns
-------
Series or DataFrame
First non-null of values within each group.
First values within each group.
See Also
--------
Expand Down Expand Up @@ -3402,12 +3411,17 @@ def first(x: Series):
min_count=min_count,
alias="first",
npfunc=first_compat,
skipna=skipna,
)

@final
def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
def last(
self, numeric_only: bool = False, min_count: int = -1, skipna: bool = True
) -> NDFrameT:
"""
Compute the last non-null entry of each column.
Compute the last entry of each column within each group.
Defaults to skipping NA elements.
Parameters
----------
Expand All @@ -3416,12 +3430,17 @@ def last(self, numeric_only: bool = False, min_count: int = -1) -> NDFrameT:
everything, then use only numeric data.
min_count : int, default -1
The required number of valid values to perform the operation. If fewer
than ``min_count`` non-NA values are present the result will be NA.
than ``min_count`` valid values are present the result will be NA.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is NA, the result
will be NA.
.. versionadded:: 2.2.1
Returns
-------
Series or DataFrame
Last non-null of values within each group.
Last of values within each group.
See Also
--------
Expand Down Expand Up @@ -3461,6 +3480,7 @@ def last(x: Series):
min_count=min_count,
alias="last",
npfunc=last_compat,
skipna=skipna,
)

@final
Expand Down
10 changes: 8 additions & 2 deletions pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,25 +1306,31 @@ def first(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "first", args, kwargs)
nv.validate_resampler_func("first", args, kwargs)
return self._downsample("first", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"first", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.last)
def last(
self,
numeric_only: bool = False,
min_count: int = 0,
skipna: bool = True,
*args,
**kwargs,
):
maybe_warn_args_and_kwargs(type(self), "last", args, kwargs)
nv.validate_resampler_func("last", args, kwargs)
return self._downsample("last", numeric_only=numeric_only, min_count=min_count)
return self._downsample(
"last", numeric_only=numeric_only, min_count=min_count, skipna=skipna
)

@final
@doc(GroupBy.median)
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/groupby/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from pandas._libs.tslibs import iNaT

from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.missing import na_value_for_dtype

import pandas as pd
from pandas import (
DataFrame,
Expand Down Expand Up @@ -327,6 +330,34 @@ def test_groupby_non_arithmetic_agg_int_like_precision(method, data):
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how):
# GH#57019
na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype))
df = DataFrame(
{
"a": [2, 1, 1, 2, 3, 3],
"b": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
"c": [na_value, 3.0, na_value, 4.0, np.nan, np.nan],
},
dtype=any_real_nullable_dtype,
)
gb = df.groupby("a", sort=sort)
method = getattr(gb, how)
result = method(skipna=skipna)

ilocs = {
("first", True): [3, 1, 4],
("first", False): [0, 1, 4],
("last", True): [3, 1, 5],
("last", False): [3, 2, 5],
}[how, skipna]
expected = df.iloc[ilocs].set_index("a")
if sort:
expected = expected.sort_index()
tm.assert_frame_equal(result, expected)


def test_idxmin_idxmax_axis1():
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]
Expand Down
29 changes: 29 additions & 0 deletions pandas/tests/resample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
import pytest

from pandas.core.dtypes.common import is_extension_array_dtype

import pandas as pd
from pandas import (
DataFrame,
DatetimeIndex,
Expand Down Expand Up @@ -429,3 +432,29 @@ def test_resample_quantile(series):
result = ser.resample(freq).quantile(q)
expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("how", ["first", "last"])
def test_first_last_skipna(any_real_nullable_dtype, skipna, how):
# GH#57019
if is_extension_array_dtype(any_real_nullable_dtype):
na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value
else:
na_value = np.nan
df = DataFrame(
{
"a": [2, 1, 1, 2],
"b": [na_value, 3.0, na_value, 4.0],
"c": [na_value, 3.0, na_value, 4.0],
},
index=date_range("2020-01-01", periods=4, freq="D"),
dtype=any_real_nullable_dtype,
)
rs = df.resample("ME")
method = getattr(rs, how)
result = method(skipna=skipna)

gb = df.groupby(df.shape[0] * [pd.to_datetime("2020-01-31")])
expected = getattr(gb, how)(skipna=skipna)
expected.index.freq = "ME"
tm.assert_frame_equal(result, expected)
4 changes: 2 additions & 2 deletions pandas/tests/resample/test_resample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,11 +1040,11 @@ def test_args_kwargs_depr(method, raises):
if raises:
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
with pytest.raises(UnsupportedFunctionCall, match=error_msg):
func(*args, 1, 2, 3)
func(*args, 1, 2, 3, 4)
else:
with tm.assert_produces_warning(FutureWarning, match=warn_msg):
with pytest.raises(TypeError, match=error_msg_type):
func(*args, 1, 2, 3)
func(*args, 1, 2, 3, 4)


def test_df_axis_param_depr():
Expand Down

0 comments on commit f1beec4

Please sign in to comment.