Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Nov 26, 2024
1 parent f0d087d commit 1906d64
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
7 changes: 5 additions & 2 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ impl AExpr {

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*agg_list = false;
} else if matches!(ctx.ctx, Context::Aggregation) {
} else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
*agg_list = true;
}

Expand All @@ -369,9 +369,12 @@ impl AExpr {
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
let out = function.get_field(ctx.schema, ctx.ctx, &fields)?;

// Note: Here in schema resolution we use `is_elementwise()`. During execution the
// scalar returns are identified using `is_scalar_ae()`

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*agg_list = false;
} else if matches!(ctx.ctx, Context::Aggregation) {
} else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
*agg_list = true;
}

Expand Down
23 changes: 20 additions & 3 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,36 @@ def test_lazy_agg_lit_explode() -> None:
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
"var"
]) # fmt: skip
def test_lazy_agg_auto_agg_list_19752(expr_op: str) -> None:
@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")])
def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None:
op = getattr(pl.Expr, expr_op)

lf = pl.LazyFrame({"a": 1, "b": 1})

q = lf.group_by("a").agg(pl.col("b").reverse().pipe(op))
q = lf.group_by("a").agg(lhs.reverse().pipe(op))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").shuffle().reverse().pipe(op))
q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op))

assert q.collect_schema() == q.collect().collect_schema()


def test_lazy_agg_schema_after_elementwise_19984() -> None:
lf = pl.LazyFrame({"a": 1, "b": 1})

q = lf.group_by("a").agg(pl.col("b").first().fill_null(0))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").first().fill_null(0).fill_null(0))
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(pl.col("b").first() + 1)
assert q.collect_schema() == q.collect().collect_schema()

q = lf.group_by("a").agg(1 + pl.col("b").first())
assert q.collect_schema() == q.collect().collect_schema()


@pytest.mark.parametrize(
"expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()]
)
Expand Down

0 comments on commit 1906d64

Please sign in to comment.