Skip to content

Commit

Permalink
fix: Incorrect collect_schema() for fill_null() after an aggregat…
Browse files Browse the repository at this point in the history
…ion expression in group-by context (#19993)
  • Loading branch information
nameexhaustion authored Nov 26, 2024
1 parent 0b1d520 commit d9ea1d8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ impl AExpr {

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*agg_list = false;
} else if matches!(ctx.ctx, Context::Aggregation) {
} else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
*agg_list = true;
}

Expand All @@ -371,7 +371,7 @@ impl AExpr {

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*agg_list = false;
} else if matches!(ctx.ctx, Context::Aggregation) {
} else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
*agg_list = true;
}

Expand Down
23 changes: 20 additions & 3 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,36 @@ def test_lazy_agg_lit_explode() -> None:
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
"var"
]) # fmt: skip
def test_lazy_agg_auto_agg_list_19752(expr_op: str) -> None:
@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")])
def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None:
op = getattr(pl.Expr, expr_op)

lf = pl.LazyFrame({"a": 1, "b": 1})

q = lf.group_by("a").agg(pl.col("b").reverse().pipe(op))
q = lf.group_by("a").agg(lhs.reverse().pipe(op))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").shuffle().reverse().pipe(op))
q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op))

assert q.collect_schema() == q.collect().collect_schema()


def test_lazy_agg_schema_after_elementwise_19984() -> None:
lf = pl.LazyFrame({"a": 1, "b": 1})

q = lf.group_by("a").agg(pl.col("b").first().fill_null(0))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").first().fill_null(0).fill_null(0))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").first() + 1)
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(1 + pl.col("b").first())
assert q.collect_schema() == q.collect().collect_schema()


@pytest.mark.parametrize(
"expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()]
)
Expand Down

0 comments on commit d9ea1d8

Please sign in to comment.