Skip to content

Commit

Permalink
chore(python): split Expr.top_k and Expr.top_k_by into separate
Browse files Browse the repository at this point in the history
functions
  • Loading branch information
MarcoGorelli committed May 3, 2024
1 parent 19c46f5 commit 9607f2c
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 128 deletions.
2 changes: 2 additions & 0 deletions py-polars/docs/source/reference/expressions/modify_select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Manipulation/selection
Expr.arg_true
Expr.backward_fill
Expr.bottom_k
Expr.bottom_k_by
Expr.cast
Expr.ceil
Expr.clip
Expand Down Expand Up @@ -61,5 +62,6 @@ Manipulation/selection
Expr.take_every
Expr.to_physical
Expr.top_k
Expr.top_k_by
Expr.upper_bound
Expr.where
214 changes: 138 additions & 76 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,8 +2035,6 @@ def top_k(
self,
k: int | IntoExprColumn = 5,
*,
by: IntoExpr | Iterable[IntoExpr] | None = None,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
maintain_order: bool = False,
multithreaded: bool = True,
Expand All @@ -2046,19 +2044,12 @@ def top_k(
This has time complexity:
.. math:: O(n + k \log{}n - \frac{k}{2})
.. math:: O(n + k \log{n} - \frac{k}{2})
Parameters
----------
k
Number of elements to return.
by
Column(s) included in sort order. Accepts expression input.
Strings are parsed as column names.
If not provided, each column will be treated induvidually.
descending
Return the k smallest. Top-k by multiple columns can be specified per
column by passing a sequence of booleans.
nulls_last
Place null values last.
maintain_order
Expand All @@ -2068,7 +2059,9 @@ def top_k(
See Also
--------
top_k_by
bottom_k
bottom_k_by
Examples
--------
Expand Down Expand Up @@ -2097,15 +2090,61 @@ def top_k(
│ 3 ┆ 4 │
│ 2 ┆ 98 │
└───────┴──────────┘
"""
k = parse_as_expression(k)
return self._from_pyexpr(self._pyexpr.top_k(k, nulls_last, multithreaded))

def top_k_by(
self,
by: IntoExpr | Iterable[IntoExpr],
k: int | IntoExprColumn = 5,
*,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
maintain_order: bool = False,
multithreaded: bool = True,
) -> Self:
r"""
Return elements corresponding to the `k` largest elements of the `by` column(s).
This has time complexity:
.. math:: O(n + k \log{n} - \frac{k}{2})
>>> df2 = pl.DataFrame(
Parameters
----------
by
Column(s) included in sort order. Accepts expression input.
Strings are parsed as column names.
k
Number of elements to return.
descending
If `True`, consider the k smallest (instead of the k largest). Top-k by
multiple columns can be specified per column by passing a sequence of
booleans.
nulls_last
Place null values last.
maintain_order
Whether the order should be maintained if elements are equal.
multithreaded
Sort using multiple threads.
See Also
--------
top_k
bottom_k
bottom_k_by
Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [1, 2, 3, 4, 5, 6],
... "b": [6, 5, 4, 3, 2, 1],
... "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
... }
... )
>>> df2
>>> df
shape: (6, 3)
┌─────┬─────┬────────┐
│ a ┆ b ┆ c │
Expand All @@ -2122,9 +2161,9 @@ def top_k(
Get the top 2 rows by column `a` or `b`.
>>> df2.select(
... pl.all().top_k(2, by="a").name.suffix("_top_by_a"),
... pl.all().top_k(2, by="b").name.suffix("_top_by_b"),
>>> df.select(
... pl.all().top_k_by("a", 2).name.suffix("_top_by_a"),
... pl.all().top_k_by("b", 2).name.suffix("_top_by_b"),
... )
shape: (2, 6)
┌────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
Expand All @@ -2138,12 +2177,12 @@ def top_k(
Get the top 2 rows by multiple columns with given order.
>>> df2.select(
>>> df.select(
... pl.all()
... .top_k(2, by=["c", "a"], descending=[False, True])
... .top_k_by(["c", "a"], 2, descending=[False, True])
... .name.suffix("_by_ca"),
... pl.all()
... .top_k(2, by=["c", "b"], descending=[False, True])
... .top_k_by(["c", "b"], 2, descending=[False, True])
... .name.suffix("_by_cb"),
... )
shape: (2, 6)
Expand All @@ -2159,8 +2198,8 @@ def top_k(
Get the top 2 rows by column `a` in each group.
>>> (
... df2.group_by("c", maintain_order=True)
... .agg(pl.all().top_k(2, by="a"))
... df.group_by("c", maintain_order=True)
... .agg(pl.all().top_k_by("a", 2))
... .explode(pl.all().exclude("c"))
... )
shape: (5, 3)
Expand All @@ -2177,32 +2216,22 @@ def top_k(
└────────┴─────┴─────┘
"""
k = parse_as_expression(k)
if by is not None:
by = parse_as_list_of_expressions(by)
if isinstance(descending, bool):
descending = [descending]
elif len(by) != len(descending):
msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.top_k_by(
k, by, descending, nulls_last, maintain_order, multithreaded
)
)
else:
if not isinstance(descending, bool):
msg = "`descending` should be a boolean if no `by` is provided"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.top_k(k, descending, nulls_last, multithreaded)
by = parse_as_list_of_expressions(by)
if isinstance(descending, bool):
descending = [descending]
elif len(by) != len(descending):
msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.top_k_by(
k, by, descending, nulls_last, maintain_order, multithreaded
)
)

def bottom_k(
self,
k: int | IntoExprColumn = 5,
*,
by: IntoExpr | Iterable[IntoExpr] | None = None,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
maintain_order: bool = False,
multithreaded: bool = True,
Expand All @@ -2212,19 +2241,12 @@ def bottom_k(
This has time complexity:
.. math:: O(n + k \log{}n - \frac{k}{2})
.. math:: O(n + k \log{n} - \frac{k}{2})
Parameters
----------
k
Number of elements to return.
by
Column(s) included in sort order.
Accepts expression input. Strings are parsed as column names.
If not provided, each column will be treated induvidually.
descending
Return the k largest. Bottom-k by multiple columns can be specified per
column by passing a sequence of booleans.
nulls_last
Place null values last.
maintain_order
Expand All @@ -2235,6 +2257,8 @@ def bottom_k(
See Also
--------
top_k
top_k_by
bottom_k_by
Examples
--------
Expand All @@ -2261,15 +2285,61 @@ def bottom_k(
│ 3 ┆ 4 │
│ 2 ┆ 98 │
└───────┴──────────┘
"""
k = parse_as_expression(k)
return self._from_pyexpr(self._pyexpr.bottom_k(k, nulls_last, multithreaded))

def bottom_k_by(
self,
by: IntoExpr | Iterable[IntoExpr],
k: int | IntoExprColumn = 5,
*,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
maintain_order: bool = False,
multithreaded: bool = True,
) -> Self:
r"""
Return elements corresponding to the `k` smallest elements of `by` column(s).
This has time complexity:
.. math:: O(n + k \log{n} - \frac{k}{2})
>>> df2 = pl.DataFrame(
Parameters
----------
by
Column(s) included in sort order.
Accepts expression input. Strings are parsed as column names.
k
Number of elements to return.
descending
If `True`, consider the k largest (instead of the k smallest). Bottom-k by
multiple columns can be specified per column by passing a sequence of
booleans.
nulls_last
Place null values last.
maintain_order
Whether the order should be maintained if elements are equal.
multithreaded
Sort using multiple threads.
See Also
--------
top_k
top_k_by
bottom_k
Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [1, 2, 3, 4, 5, 6],
... "b": [6, 5, 4, 3, 2, 1],
... "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"],
... }
... )
>>> df2
>>> df
shape: (6, 3)
┌─────┬─────┬────────┐
│ a ┆ b ┆ c │
Expand All @@ -2286,9 +2356,9 @@ def bottom_k(
Get the bottom 2 rows by column `a` or `b`.
>>> df2.select(
... pl.all().bottom_k(2, by="a").name.suffix("_btm_by_a"),
... pl.all().bottom_k(2, by="b").name.suffix("_btm_by_b"),
>>> df.select(
... pl.all().bottom_k_by("a", 2).name.suffix("_btm_by_a"),
... pl.all().bottom_k_by("b", 2).name.suffix("_btm_by_b"),
... )
shape: (2, 6)
┌────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
Expand All @@ -2302,12 +2372,12 @@ def bottom_k(
Get the bottom 2 rows by multiple columns with given order.
>>> df2.select(
>>> df.select(
... pl.all()
... .bottom_k(2, by=["c", "a"], descending=[False, True])
... .bottom_k_by(["c", "a"], 2, descending=[False, True])
... .name.suffix("_by_ca"),
... pl.all()
... .bottom_k(2, by=["c", "b"], descending=[False, True])
... .bottom_k_by(["c", "b"], 2, descending=[False, True])
... .name.suffix("_by_cb"),
... )
shape: (2, 6)
Expand All @@ -2323,8 +2393,8 @@ def bottom_k(
Get the bottom 2 rows by column `a` in each group.
>>> (
... df2.group_by("c", maintain_order=True)
... .agg(pl.all().bottom_k(2, by="a"))
... df.group_by("c", maintain_order=True)
... .agg(pl.all().bottom_k_by("a", 2))
... .explode(pl.all().exclude("c"))
... )
shape: (5, 3)
Expand All @@ -2341,25 +2411,17 @@ def bottom_k(
└────────┴─────┴─────┘
"""
k = parse_as_expression(k)
if by is not None:
by = parse_as_list_of_expressions(by)
if isinstance(descending, bool):
descending = [descending]
elif len(by) != len(descending):
msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.bottom_k_by(
k, by, descending, nulls_last, maintain_order, multithreaded
)
)
else:
if not isinstance(descending, bool):
msg = "`descending` should be a boolean if no `by` is provided"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.bottom_k(k, descending, nulls_last, multithreaded)
by = parse_as_list_of_expressions(by)
if isinstance(descending, bool):
descending = [descending]
elif len(by) != len(descending):
msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})"
raise ValueError(msg)
return self._from_pyexpr(
self._pyexpr.bottom_k_by(
k, by, descending, nulls_last, maintain_order, multithreaded
)
)

def arg_sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self:
"""
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,7 +3402,7 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Series:
This has time complexity:
.. math:: O(n + k \log{}n - \frac{k}{2})
.. math:: O(n + k \log{n} - \frac{k}{2})
Parameters
----------
Expand Down Expand Up @@ -3432,7 +3432,7 @@ def bottom_k(self, k: int | IntoExprColumn = 5) -> Series:
This has time complexity:
.. math:: O(n + k \log{}n - \frac{k}{2})
.. math:: O(n + k \log{n} - \frac{k}{2})
Parameters
----------
Expand Down
Loading

0 comments on commit 9607f2c

Please sign in to comment.