Skip to content

Commit

Permalink
BUG: Maintain column order with groupby.nth (pandas-dev#22811)
Browse files Browse the repository at this point in the history
  • Loading branch information
reidy-p authored and Pingviinituutti committed Feb 28, 2019
1 parent 9a7daf6 commit 01ea768
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 84 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ Other Enhancements
- Added :meth:`Interval.overlaps`, :meth:`IntervalArray.overlaps`, and :meth:`IntervalIndex.overlaps` for determining overlaps between interval-like objects (:issue:`21998`)
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`)
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
- :meth:`DataFrame.to_stata` and :class:` pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
Expand Down Expand Up @@ -1417,6 +1418,7 @@ Groupby/Resample/Rolling
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`)

Reshaping
^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ def _set_group_selection(self):

if len(groupers):
# GH12839 clear selected obj cache when group selection changes
self._group_selection = ax.difference(Index(groupers)).tolist()
self._group_selection = ax.difference(Index(groupers),
sort=False).tolist()
self._reset_cache('_selected_obj')

def _set_result_index_ordered(self, result):
Expand Down
20 changes: 13 additions & 7 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,17 +2944,20 @@ def intersection(self, other):
taken.name = None
return taken

def difference(self, other):
def difference(self, other, sort=True):
"""
Return a new Index with elements from the index that are not in
`other`.
This is the set difference of two Index objects.
It's sorted if sorting is possible.
Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible
.. versionadded:: 0.24.0
Returns
-------
Expand All @@ -2963,10 +2966,12 @@ def difference(self, other):
Examples
--------
>>> idx1 = pd.Index([1, 2, 3, 4])
>>> idx1 = pd.Index([2, 1, 3, 4])
>>> idx2 = pd.Index([3, 4, 5, 6])
>>> idx1.difference(idx2)
Int64Index([1, 2], dtype='int64')
>>> idx1.difference(idx2, sort=False)
Int64Index([2, 1], dtype='int64')
"""
self._assert_can_do_setop(other)
Expand All @@ -2985,10 +2990,11 @@ def difference(self, other):
label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
the_diff = this.values.take(label_diff)
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass
if sort:
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass

return this._shallow_copy(the_diff, name=result_name, freq=None)

Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def overlaps(self, other):
return self._data.overlaps(other)

def _setop(op_name):
def func(self, other):
def func(self, other, sort=True):
other = self._as_like_interval_index(other)

# GH 19016: ensure set op will not return a prohibited dtype
Expand All @@ -1048,7 +1048,11 @@ def func(self, other):
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

result = getattr(self._multiindex, op_name)(other._multiindex)
if op_name == 'difference':
result = getattr(self._multiindex, op_name)(other._multiindex,
sort)
else:
result = getattr(self._multiindex, op_name)(other._multiindex)
result_name = get_op_result_name(self, other)

# GH 19101: ensure empty results have correct dtype
Expand Down
22 changes: 19 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2798,10 +2798,18 @@ def intersection(self, other):
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def difference(self, other):
def difference(self, other, sort=True):
"""
Compute sorted set difference of two MultiIndex objects
Parameters
----------
other : MultiIndex
sort : bool, default True
Sort the resulting MultiIndex if possible
.. versionadded:: 0.24.0
Returns
-------
diff : MultiIndex
Expand All @@ -2817,8 +2825,16 @@ def difference(self, other):
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)

difference = sorted(set(self._ndarray_values) -
set(other._ndarray_values))
this = self._get_unique_index()

indexer = this.get_indexer(other)
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer,
assume_unique=True)
difference = this.values.take(label_diff)
if sort:
difference = sorted(difference)

if len(difference) == 0:
return MultiIndex(levels=[[]] * self.nlevels,
Expand Down
24 changes: 24 additions & 0 deletions pandas/tests/groupby/test_nth.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,27 @@ def test_nth_empty():
names=['a', 'b']),
columns=['c'])
assert_frame_equal(result, expected)


def test_nth_column_order():
# GH 20760
# Check that nth preserves column order
df = DataFrame([[1, 'b', 100],
[1, 'a', 50],
[1, 'a', np.nan],
[2, 'c', 200],
[2, 'd', 150]],
columns=['A', 'C', 'B'])
result = df.groupby('A').nth(0)
expected = DataFrame([['b', 100.0],
['c', 200.0]],
columns=['C', 'B'],
index=Index([1, 2], name='A'))
assert_frame_equal(result, expected)

result = df.groupby('A').nth(-1, dropna='any')
expected = DataFrame([['a', 50.0],
['d', 150.0]],
columns=['C', 'B'],
index=Index([1, 2], name='A'))
assert_frame_equal(result, expected)
11 changes: 6 additions & 5 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,13 @@ def test_union_base(self):
with pytest.raises(TypeError, match=msg):
first.union([1, 2, 3])

def test_difference_base(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_base(self, sort):
for name, idx in compat.iteritems(self.indices):
first = idx[2:]
second = idx[:4]
answer = idx[4:]
result = first.difference(second)
result = first.difference(second, sort)

if isinstance(idx, CategoricalIndex):
pass
Expand All @@ -685,21 +686,21 @@ def test_difference_base(self):
if isinstance(idx, PeriodIndex):
msg = "can only call with other PeriodIndex-ed objects"
with pytest.raises(ValueError, match=msg):
first.difference(case)
first.difference(case, sort)
elif isinstance(idx, CategoricalIndex):
pass
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
assert result.__class__ == answer.__class__
tm.assert_numpy_array_equal(result.sort_values().asi8,
answer.sort_values().asi8)
else:
result = first.difference(case)
result = first.difference(case, sort)
assert tm.equalContents(result, answer)

if isinstance(idx, MultiIndex):
msg = "other must be a MultiIndex or a list of tuples"
with pytest.raises(TypeError, match=msg):
first.difference([1, 2, 3])
first.difference([1, 2, 3], sort)

def test_symmetric_difference(self):
for name, idx in compat.iteritems(self.indices):
Expand Down
34 changes: 21 additions & 13 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,47 +209,55 @@ def test_intersection_bug_1708(self):
assert len(result) == 0

@pytest.mark.parametrize("tz", tz)
def test_difference(self, tz):
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, tz, sort):
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
'1/5/2000']

rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)

rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)

rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
other3 = pd.DatetimeIndex([], tz=tz)
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)

for rng, other, expected in [(rng1, other1, expected1),
(rng2, other2, expected2),
(rng3, other3, expected3)]:
result_diff = rng.difference(other)
result_diff = rng.difference(other, sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result_diff, expected)

def test_difference_freq(self):
@pytest.mark.parametrize("sort", [True, False])
def test_difference_freq(self, sort):
# GH14323: difference of DatetimeIndex should not preserve frequency

index = date_range("20160920", "20160925", freq="D")
other = date_range("20160921", "20160924", freq="D")
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

other = date_range("20160922", "20160925", freq="D")
idx_diff = index.difference(other)
idx_diff = index.difference(other, sort)
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
tm.assert_index_equal(idx_diff, expected)
tm.assert_attr_equal('freq', idx_diff, expected)

def test_datetimeindex_diff(self):
@pytest.mark.parametrize("sort", [True, False])
def test_datetimeindex_diff(self, sort):
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=100)
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
periods=98)
assert len(dti1.difference(dti2)) == 2
assert len(dti1.difference(dti2, sort)) == 2

def test_datetimeindex_union_join_empty(self):
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')
Expand Down
17 changes: 12 additions & 5 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,19 +801,26 @@ def test_intersection(self, closed):
result = index.intersection(other)
tm.assert_index_equal(result, expected)

def test_difference(self, closed):
index = self.create_index(closed=closed)
tm.assert_index_equal(index.difference(index[:1]), index[1:])
@pytest.mark.parametrize("sort", [True, False])
def test_difference(self, closed, sort):
index = IntervalIndex.from_arrays([1, 0, 3, 2],
[1, 2, 3, 4],
closed=closed)
result = index.difference(index[:1], sort)
expected = index[1:]
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result, expected)

# GH 19101: empty result, same dtype
result = index.difference(index)
result = index.difference(index, sort)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = IntervalIndex.from_arrays(index.left.astype('float64'),
index.right, closed=closed)
result = index.difference(other)
result = index.difference(other, sort)
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self, closed):
Expand Down
Loading

0 comments on commit 01ea768

Please sign in to comment.