Skip to content

Commit

Permalink
feat(python)!: Update group-by iteration to always return tuple keys (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jun 6, 2024
1 parent fd4c71e commit efac81c
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 92 deletions.
36 changes: 16 additions & 20 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7840,7 +7840,7 @@ def unstack(
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: str,
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: Literal[False] = ...,
Expand All @@ -7850,11 +7850,21 @@ def partition_by(
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: str,
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: Literal[True],
) -> dict[Any, Self]: ...
) -> dict[tuple[object, ...], Self]: ...

@overload
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: bool,
) -> list[Self] | dict[tuple[object, ...], Self]: ...

def partition_by(
self,
Expand All @@ -7863,7 +7873,7 @@ def partition_by(
maintain_order: bool = True,
include_key: bool = True,
as_dict: bool = False,
) -> list[Self] | dict[Any, Self]:
) -> list[Self] | dict[tuple[object, ...], Self]:
"""
Group by the given columns and return the groups as separate dataframes.
Expand Down Expand Up @@ -7999,27 +8009,13 @@ def partition_by(
]

if as_dict:
key_as_single_value = isinstance(by, str) and not more_by
if key_as_single_value:
issue_deprecation_warning(
"`partition_by(..., as_dict=True)` will change to always return tuples as dictionary keys."
f" Pass `by` as a list to silence this warning, e.g. `partition_by([{by!r}], as_dict=True)`.",
version="0.20.4",
)

if include_key:
if key_as_single_value:
names = [p.get_column(by)[0] for p in partitions] # type: ignore[arg-type]
else:
names = [p.select(by_parsed).row(0) for p in partitions]
names = [p.select(by_parsed).row(0) for p in partitions]
else:
if not maintain_order: # Group keys cannot be matched to partitions
msg = "cannot use `partition_by` with `maintain_order=False, include_key=False, as_dict=True`"
raise ValueError(msg)
if key_as_single_value:
names = self.get_column(by).unique(maintain_order=True).to_list() # type: ignore[arg-type]
else:
names = self.select(by_parsed).unique(maintain_order=True).rows()
names = self.select(by_parsed).unique(maintain_order=True).rows()

return dict(zip(names, partitions))

Expand Down
58 changes: 8 additions & 50 deletions py-polars/polars/dataframe/group_by.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Iterable, Iterator
from typing import TYPE_CHECKING, Callable, Iterable

from polars import functions as F
from polars._utils.convert import parse_as_duration_string
from polars._utils.deprecation import (
deprecate_renamed_function,
issue_deprecation_warning,
)
from polars._utils.deprecation import deprecate_renamed_function

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -105,30 +102,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
key_as_single_value = (
len(self.by) == 1 and isinstance(self.by[0], str) and not self.named_by
)
if key_as_single_value:
issue_deprecation_warning(
"`group_by` iteration will change to always return group identifiers as tuples."
f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by[0]!r}])`.",
version="0.20.4",
)
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down Expand Up @@ -817,24 +797,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

# When grouping by a single column, group name is a single value
# When grouping by multiple columns, group name is a tuple of values
self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
if self.group_by is None:
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down Expand Up @@ -974,24 +943,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

# When grouping by a single column, group name is a single value
# When grouping by multiple columns, group name is a tuple of values
self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
if self.group_by is None:
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down
14 changes: 2 additions & 12 deletions py-polars/tests/unit/dataframe/test_partition_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,18 @@ def test_partition_by_as_dict() -> None:
result_first = result[("one", 1)]
assert result_first.to_dict(as_series=False) == {"a": ["one"], "b": [1]}

result = df.partition_by(["a"], as_dict=True)
result = df.partition_by("a", as_dict=True)
result_first = result[("one",)]
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}

with pytest.deprecated_call():
result = df.partition_by("a", as_dict=True)
result_first = result["one"]
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}


def test_partition_by_as_dict_include_keys_false() -> None:
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})

result = df.partition_by(["a"], include_key=False, as_dict=True)
result = df.partition_by("a", include_key=False, as_dict=True)
result_first = result[("one",)]
assert result_first.to_dict(as_series=False) == {"b": [1, 3]}

with pytest.deprecated_call():
result = df.partition_by("a", include_key=False, as_dict=True)
result_first = result["one"]
assert result_first.to_dict(as_series=False) == {"b": [1, 3]}


def test_partition_by_as_dict_include_keys_false_maintain_order_false() -> None:
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_rolling_iter() -> None:

# Without 'by' argument
result1 = [
(name, data.shape)
(name[0], data.shape)
for name, data in df.rolling(index_column="date", period="2d")
]
expected1 = [
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,9 @@ def test_group_by_iteration() -> None:
[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],
[("c", 6, 1)],
]
with pytest.deprecated_call():
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
for i, (group, data) in gb_iter:
assert group == expected_names[i]
assert group == (expected_names[i],)
assert data.rows() == expected_rows[i]

# Grouped by ALL columns should give groups of a single row
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/operations/test_group_by_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,8 @@ def test_group_by_dynamic_iter(every: str | timedelta, tzinfo: ZoneInfo | None)
for name, data in df.group_by_dynamic("datetime", every=every, closed="left")
]
expected1 = [
(datetime(2020, 1, 1, 10, tzinfo=tzinfo), (2, 3)),
(datetime(2020, 1, 1, 11, tzinfo=tzinfo), (1, 3)),
((datetime(2020, 1, 1, 10, tzinfo=tzinfo),), (2, 3)),
((datetime(2020, 1, 1, 11, tzinfo=tzinfo),), (1, 3)),
]
assert result1 == expected1

Expand Down
6 changes: 2 additions & 4 deletions py-polars/tests/unit/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def test_rows_by_key() -> None:
"b": [("b", "q", 2.5, 8), ("b", "q", 3.0, 7)],
}
assert df.rows_by_key("w", include_key=True) == {
key[0]: grp.rows() # type: ignore[index]
for key, grp in df.group_by(["w"])
key[0]: grp.rows() for key, grp in df.group_by(["w"])
}
assert df.rows_by_key("w", include_key=True, unique=True) == {
"a": ("a", "k", 4.5, 6),
Expand Down Expand Up @@ -136,8 +135,7 @@ def test_rows_by_key() -> None:
],
}
assert df.rows_by_key("w", named=True, include_key=True) == {
key[0]: grp.rows(named=True) # type: ignore[index]
for key, grp in df.group_by(["w"])
key[0]: grp.rows(named=True) for key, grp in df.group_by(["w"])
}
assert df.rows_by_key("w", named=True, include_key=True, unique=True) == {
"a": {"w": "a", "x": "k", "y": 4.5, "z": 6},
Expand Down

0 comments on commit efac81c

Please sign in to comment.