diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index a797090a83444..13340c7cf11d8 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -603,6 +603,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.resample` where an ``AmbiguousTimeError`` would be raised when the resulting timezone aware :class:`DatetimeIndex` had a DST transition at midnight (:issue:`25758`) - Bug in :meth:`DataFrame.groupby` where a ``ValueError`` would be raised when grouping by a categorical column with read-only categories and ``sort=False`` (:issue:`33410`) - Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`) +- Bug in :meth:`SeriesGroupBy.quantile` causes the quantiles to be shifted when the ``by`` axis contains ``NaN`` (:issue:`33200`) Reshaping ^^^^^^^^^ diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 53e66c4b8723d..31333b16d590a 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -779,9 +779,10 @@ def group_quantile(ndarray[float64_t] out, non_na_counts[lab] += 1 # Get an index of values sorted by labels and then values - order = (values, labels) - sort_arr = np.lexsort(order).astype(np.int64, copy=False) - + sort_arr = np.arange(len(labels), dtype=np.int64) + mask = labels != -1 + order = (np.asarray(values)[mask], labels[mask]) + sort_arr[mask] = np.lexsort(order).astype(np.int64, copy=False) with nogil: for i in range(ngroups): # Figure out how many group elements there are @@ -819,7 +820,9 @@ def group_quantile(ndarray[float64_t] out, # Increment the index reference in sorted_arr for the next group grp_start += grp_sz - + print(out) + # out = np.roll(out, -(len(out) - np.sum(counts))) + print(out) # ---------------------------------------------------------------------- # group_nth, group_last, group_rank diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 346de55f551df..e195b1d487614 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1507,14 +1507,22 @@ def test_quantile_missing_group_values_no_segfaults(): grp.quantile() -def test_quantile_missing_group_values_correct_results(): - # GH 28662 - data = np.array([1.0, np.nan, 3.0, np.nan]) - df = pd.DataFrame(dict(key=data, val=range(4))) +@pytest.mark.parametrize( + "key, val, expected_key, expected_val", + [ + ([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 1.0]), + (["a", "b", "b", np.nan], range(4), ["a", "b"], [0, 1.5]), + ], +) +def test_quantile_missing_group_values_correct_results( + key, val, expected_key, expected_val +): + # GH 28662, GH 33200 + df = pd.DataFrame({"key": key, "val": val}) result = df.groupby("key").quantile() expected = pd.DataFrame( - [1.0, 3.0], index=pd.Index([1.0, 3.0], name="key"), columns=["val"] + expected_val, index=pd.Index(expected_key, name="key"), columns=["val"] ) tm.assert_frame_equal(result, expected)