Skip to content

Commit

Permalink
fix: Don't partition group-by with non-scalar literals in agg (#20704)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 14, 2025
1 parent 38dc65f commit 1cd72ff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
3 changes: 2 additions & 1 deletion crates/polars-plan/src/plans/aexpr/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema)
&& !has_aggregation(*falsy)
&& !has_aggregation(*predicate)
},
Column(_) | Len | Literal(_) | Cast { .. } => true,
Literal(lv) => lv.is_scalar(),
Column(_) | Len | Cast { .. } => true,
_ => false,
}
});
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from collections import OrderedDict
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -1182,3 +1183,14 @@ def test_group_by_map_groups_slice_pushdown_20002() -> None:
}
),
)


@typing.no_type_check
def test_group_by_lit_series(capfd: Any, monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_VERBOSE", "1")
n = 10
df = pl.DataFrame({"x": np.ones(2 * n), "y": n * list(range(2))})
a = np.ones(n, dtype=float)
df.lazy().group_by("y").agg(pl.col("x").dot(a)).collect()
captured = capfd.readouterr().err
assert "are not partitionable" in captured

0 comments on commit 1cd72ff

Please sign in to comment.