Skip to content

Commit

Permalink
BUG: Maintain the order of the bins in group_quantile. Updated tests p…
Browse files Browse the repository at this point in the history
  • Loading branch information
mabelvj committed Apr 19, 2020
1 parent c8db9b9 commit 15a27ea
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ Groupby/resample/rolling
- Bug in :meth:`DataFrame.resample` where an ``AmbiguousTimeError`` would be raised when the resulting timezone aware :class:`DatetimeIndex` had a DST transition at midnight (:issue:`25758`)
- Bug in :meth:`DataFrame.groupby` where a ``ValueError`` would be raised when grouping by a categorical column with read-only categories and ``sort=False`` (:issue:`33410`)
- Bug in :meth:`GroupBy.first` and :meth:`GroupBy.last` where None is not preserved in object dtype (:issue:`32800`)
- Bug in :meth:`SeriesGroupBy.quantile` causes the quantiles to be shifted when the ``by`` axis contains ``NaN`` (:issue:`33200`)

Reshaping
^^^^^^^^^
Expand Down
11 changes: 7 additions & 4 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -779,9 +779,10 @@ def group_quantile(ndarray[float64_t] out,
non_na_counts[lab] += 1

# Get an index of values sorted by labels and then values
order = (values, labels)
sort_arr = np.lexsort(order).astype(np.int64, copy=False)

sort_arr = np.arange(len(labels), dtype=np.int64)
mask = labels != -1
order = (np.asarray(values)[mask], labels[mask])
sort_arr[mask] = np.lexsort(order).astype(np.int64, copy=False)
with nogil:
for i in range(ngroups):
# Figure out how many group elements there are
Expand Down Expand Up @@ -819,7 +820,9 @@ def group_quantile(ndarray[float64_t] out,

# Increment the index reference in sorted_arr for the next group
grp_start += grp_sz

print(out)
# out = np.roll(out, -(len(out) - np.sum(counts)))
print(out)

# ----------------------------------------------------------------------
# group_nth, group_last, group_rank
Expand Down
18 changes: 13 additions & 5 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,14 +1507,22 @@ def test_quantile_missing_group_values_no_segfaults():
grp.quantile()


def test_quantile_missing_group_values_correct_results():
# GH 28662
data = np.array([1.0, np.nan, 3.0, np.nan])
df = pd.DataFrame(dict(key=data, val=range(4)))
@pytest.mark.parametrize(
"key, val, expected_key, expected_val",
[
([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 1.0]),
(["a", "b", "b", np.nan], range(4), ["a", "b"], [0, 1.5]),
],
)
def test_quantile_missing_group_values_correct_results(
key, val, expected_key, expected_val
):
# GH 28662, GH 33200
df = pd.DataFrame({"key": key, "val": val})

result = df.groupby("key").quantile()
expected = pd.DataFrame(
[1.0, 3.0], index=pd.Index([1.0, 3.0], name="key"), columns=["val"]
expected_val, index=pd.Index(expected_key, name="key"), columns=["val"]
)
tm.assert_frame_equal(result, expected)

Expand Down

0 comments on commit 15a27ea

Please sign in to comment.