diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 2f946a4feb8..0e6fa368321 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -354,7 +354,7 @@ impl AExpr { if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { *agg_list = false; - } else if matches!(ctx.ctx, Context::Aggregation) { + } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) { *agg_list = true; } @@ -369,9 +369,12 @@ impl AExpr { polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); let out = function.get_field(ctx.schema, ctx.ctx, &fields)?; + // Note: Here in schema resolution we use `is_elementwise()`. During execution the + // scalar returns are identified using `is_scalar_ae()` + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { *agg_list = false; - } else if matches!(ctx.ctx, Context::Aggregation) { + } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) { *agg_list = true; } diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 021069a6108..44211da4185 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -255,19 +255,36 @@ def test_lazy_agg_lit_explode() -> None: "nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound", "var" ]) # fmt: skip -def test_lazy_agg_auto_agg_list_19752(expr_op: str) -> None: +@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")]) +def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None: op = getattr(pl.Expr, expr_op) lf = pl.LazyFrame({"a": 1, "b": 1}) - q = lf.group_by("a").agg(pl.col("b").reverse().pipe(op)) + q = lf.group_by("a").agg(lhs.reverse().pipe(op)) assert q.collect_schema() == q.collect().collect_schema() - q = lf.group_by("a").agg(pl.col("b").shuffle().reverse().pipe(op)) + q = lf.group_by("a").agg(lhs.shuffle().reverse().pipe(op)) assert q.collect_schema() == q.collect().collect_schema() +def test_lazy_agg_schema_after_elementwise_19984() -> None: + lf = pl.LazyFrame({"a": 1, "b": 1}) + + q = lf.group_by("a").agg(pl.col("b").first().fill_null(0)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(pl.col("b").first().fill_null(0).fill_null(0)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(pl.col("b").first() + 1) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(1 + pl.col("b").first()) + assert q.collect_schema() == q.collect().collect_schema() + + @pytest.mark.parametrize( "expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()] )