From d99cba3f6248e1e169e5b001ccc57c881085c034 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 13 Nov 2024 16:13:50 +1100 Subject: [PATCH 1/4] c --- crates/polars-plan/src/plans/aexpr/schema.rs | 223 +++++++++++-------- py-polars/tests/unit/test_schema.py | 20 ++ 2 files changed, 147 insertions(+), 96 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 7105855636c5..ff290b814e11 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -32,50 +32,49 @@ impl AExpr { ctx: Context, arena: &Arena, ) -> PolarsResult { - // During aggregation a column that isn't aggregated gets an extra nesting level - // col(foo: i64) -> list[i64] - // But not if we do an aggregation: - // col(foo: i64).sum() -> i64 - // The `nested` keeps track of the nesting we need to add. - let mut nested = matches!(ctx, Context::Aggregation) as u8; - let mut field = self.to_field_impl(schema, ctx, arena, &mut nested)?; + // In some cases we need to implode the result column, this is indicated by `agg_list`. + // We initialize it to `true` if we are in an aggregation context. Functions that always + // return scalars should explicitly set it to `false` (e.g. aggregations). + let mut agg_list = matches!(ctx, Context::Aggregation); + let mut field = self.to_field_impl(schema, ctx, arena, &mut agg_list)?; - if nested >= 1 { + if agg_list { field.coerce(field.dtype().clone().implode()); } + Ok(field) } /// Get Field result of the expression. The schema is the input data. + /// + /// This is taken as `&mut bool` as for some expressions this is determined by the upper node + /// (e.g. `alias`, `cast`). #[recursive] pub fn to_field_impl( &self, schema: &Schema, ctx: Context, arena: &Arena, - nested: &mut u8, + agg_list: &mut bool, ) -> PolarsResult { use AExpr::*; use DataType::*; match self { Len => { - *nested = 0; + *agg_list = false; Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) }, - Window { - function, options, .. - } => { - if let WindowType::Over(mapping) = options { - *nested += matches!(mapping, WindowMapping::Join) as u8; - } + Window { function, .. } => { let e = arena.get(*function); - e.to_field_impl(schema, ctx, arena, nested) + e.to_field_impl(schema, ctx, arena, &mut false) }, Explode(expr) => { // `Explode` is a "flatten" operation, which is not the same as returning a scalar. // Namely, it should be auto-imploded in the aggregation context, so we don't update - // the `nested` state here. - let field = arena.get(*expr).to_field_impl(schema, ctx, arena, &mut 0)?; + // the `agg_list` state here. + let field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; if let List(inner) = field.dtype() { Ok(Field::new(field.name().clone(), *inner.clone())) @@ -87,14 +86,14 @@ impl AExpr { name.clone(), arena .get(*expr) - .to_field_impl(schema, ctx, arena, nested)? + .to_field_impl(schema, ctx, arena, agg_list)? .dtype, )), Column(name) => schema .get_field(name) .ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())), Literal(sv) => { - *nested = 0; + *agg_list = false; Ok(match sv { LiteralValue::Series(s) => s.field().into_owned(), _ => Field::new(sv.output_name().clone(), sv.get_datatype()), @@ -116,35 +115,42 @@ impl AExpr { | Operator::LogicalOr => { let out_field; let out_name = { - out_field = - arena.get(*left).to_field_impl(schema, ctx, arena, nested)?; + out_field = arena + .get(*left) + .to_field_impl(schema, ctx, arena, agg_list)?; out_field.name() }; Field::new(out_name.clone(), Boolean) }, Operator::TrueDivide => { - return get_truediv_field(*left, *right, arena, ctx, schema, nested) + return get_truediv_field(*left, *right, arena, ctx, schema, agg_list) }, _ => { - return get_arithmetic_field(*left, *right, arena, *op, ctx, schema, nested) + return get_arithmetic_field( + *left, *right, arena, *op, ctx, schema, agg_list, + ) }, }; Ok(field) }, - Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested), + Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, agg_list), Gather { expr, returns_scalar, .. } => { if *returns_scalar { - *nested = nested.saturating_sub(1); + *agg_list = false; } - arena.get(*expr).to_field_impl(schema, ctx, arena, nested) + arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false) }, - SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested), - Filter { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested), + SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, agg_list), + Filter { input, .. } => arena + .get(*input) + .to_field_impl(schema, ctx, arena, agg_list), Agg(agg) => { use IRAggExpr::*; match agg { @@ -152,13 +158,16 @@ impl AExpr { | Min { input: expr, .. } | First(expr) | Last(expr) => { - *nested = nested.saturating_sub(1); - arena.get(*expr).to_field_impl(schema, ctx, arena, nested) + *agg_list = false; + arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false) }, Sum(expr) => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; let dt = match field.dtype() { Boolean => Some(IDX_DTYPE), UInt8 | Int8 | Int16 | UInt16 => Some(Int64), @@ -170,9 +179,10 @@ impl AExpr { Ok(field) }, Median(expr) => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -180,9 +190,10 @@ impl AExpr { Ok(field) }, Mean(expr) => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -190,69 +201,80 @@ impl AExpr { Ok(field) }, Implode(expr) => { - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; field.coerce(DataType::List(field.dtype().clone().into())); Ok(field) }, Std(expr, _) => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; float_type(&mut field); Ok(field) }, Var(expr, _) => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; float_type(&mut field); Ok(field) }, NUnique(expr) => { - *nested = 0; - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; field.coerce(IDX_DTYPE); Ok(field) }, Count(expr, _) => { - *nested = 0; - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; field.coerce(IDX_DTYPE); Ok(field) }, AggGroups(expr) => { - *nested = 1; - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = true; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; field.coerce(List(IDX_DTYPE.into())); Ok(field) }, Quantile { expr, .. } => { - *nested = nested.saturating_sub(1); - let mut field = - arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let mut field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; float_type(&mut field); Ok(field) }, #[cfg(feature = "bitwise")] Bitwise(expr, _) => { - *nested = nested.saturating_sub(1); - let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + *agg_list = false; + let field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, &mut false)?; // @Q? Do we need to coerce here? Ok(field) }, } }, Cast { expr, dtype, .. } => { - let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; + let field = arena + .get(*expr) + .to_field_impl(schema, ctx, arena, agg_list)?; Ok(Field::new(field.name().clone(), dtype.clone())) }, Ternary { truthy, falsy, .. } => { - let mut nested_truthy = *nested; - let mut nested_falsy = *nested; + let mut agg_list_truthy = *agg_list; + let mut agg_list_falsy = *agg_list; // During aggregation: // left: col(foo): list nesting: 1 @@ -261,11 +283,11 @@ impl AExpr { let mut truthy = arena .get(*truthy) - .to_field_impl(schema, ctx, arena, &mut nested_truthy)?; + .to_field_impl(schema, ctx, arena, &mut agg_list_truthy)?; let falsy = arena .get(*falsy) - .to_field_impl(schema, ctx, arena, &mut nested_falsy)?; + .to_field_impl(schema, ctx, arena, &mut agg_list_falsy)?; let st = if let DataType::Null = *truthy.dtype() { falsy.dtype().clone() @@ -273,7 +295,7 @@ impl AExpr { try_get_supertype(truthy.dtype(), falsy.dtype())? }; - *nested = std::cmp::max(nested_truthy, nested_falsy); + *agg_list = agg_list_truthy | agg_list_falsy; truthy.coerce(st); Ok(truthy) @@ -284,14 +306,14 @@ impl AExpr { options, .. } => { - let fields = func_args_to_fields(input, ctx, schema, arena, nested)?; + let fields = func_args_to_fields(input, ctx, schema, arena, agg_list)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); let out = output_type.get_field(schema, ctx, &fields)?; if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { - *nested = 0; + *agg_list = false; } else if matches!(ctx, Context::Aggregation) { - *nested += 1; + *agg_list = true; } Ok(out) @@ -301,19 +323,21 @@ impl AExpr { input, options, } => { - let fields = func_args_to_fields(input, ctx, schema, arena, nested)?; + let fields = func_args_to_fields(input, ctx, schema, arena, agg_list)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); let out = function.get_field(schema, ctx, &fields)?; if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { - *nested = 0; + *agg_list = false; } else if matches!(ctx, Context::Aggregation) { - *nested += 1; + *agg_list = true; } Ok(out) }, - Slice { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested), + Slice { input, .. } => arena + .get(*input) + .to_field_impl(schema, ctx, arena, agg_list), } } } @@ -323,25 +347,28 @@ fn func_args_to_fields( ctx: Context, schema: &Schema, arena: &Arena, - nested: &mut u8, + agg_list: &mut bool, ) -> PolarsResult> { - let mut first = true; input .iter() + .enumerate() // Default context because `col()` would return a list in aggregation context - .map(|e| { - // Only mutate first nested as that is the dtype of the function. - let mut nested_tmp = *nested; - let nested = if first { - first = false; - &mut *nested - } else { - &mut nested_tmp - }; + .map(|(i, e)| { + let tmp = &mut false; arena .get(e.node()) - .to_field_impl(schema, ctx, arena, nested) + .to_field_impl( + schema, + ctx, + arena, + if i == 0 { + // Only mutate first agg_list as that is the dtype of the function. + agg_list + } else { + tmp + }, + ) .map(|mut field| { field.name = e.output_name().clone(); field @@ -357,7 +384,7 @@ fn get_arithmetic_field( op: Operator, ctx: Context, schema: &Schema, - nested: &mut u8, + agg_list: &mut bool, ) -> PolarsResult { use DataType::*; let left_ae = arena.get(left); @@ -371,11 +398,11 @@ fn get_arithmetic_field( // leading to quadratic behavior. # 4736 // // further right_type is only determined when needed. - let mut left_field = left_ae.to_field_impl(schema, ctx, arena, nested)?; + let mut left_field = left_ae.to_field_impl(schema, ctx, arena, agg_list)?; let super_type = match op { Operator::Minus => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] (Struct(_), Struct(_)) => { @@ -430,7 +457,7 @@ fn get_arithmetic_field( } }, Operator::Plus => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; match (&left_field.dtype, &right_type) { (Duration(_), Datetime(_, _)) | (Datetime(_, _), Duration(_)) @@ -472,7 +499,7 @@ fn get_arithmetic_field( } }, _ => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] @@ -558,10 +585,14 @@ fn get_truediv_field( arena: &Arena, ctx: Context, schema: &Schema, - nested: &mut u8, + agg_list: &mut bool, ) -> PolarsResult { - let mut left_field = arena.get(left).to_field_impl(schema, ctx, arena, nested)?; - let right_field = arena.get(right).to_field_impl(schema, ctx, arena, nested)?; + let mut left_field = arena + .get(left) + .to_field_impl(schema, ctx, arena, agg_list)?; + let right_field = arena + .get(right) + .to_field_impl(schema, ctx, arena, agg_list)?; use DataType::*; // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 78a277a3662f..3be7e17e5644 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -246,3 +246,23 @@ def test_lf_agg_lit_explode() -> None: schema = {"k": pl.Int64, "o": pl.List(pl.Int64)} assert q.collect_schema() == schema assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type] + + +@pytest.mark.parametrize("expr_op", [ + "approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or", + "bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis", + "last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max", + "nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound", + "var" +]) # fmt: skip +def test_lf_agg_auto_agg_list_19752(expr_op: str) -> None: + op = getattr(pl.Expr, expr_op) + + lf = pl.LazyFrame({"a": 1, "b": 1}) + + q = lf.group_by("a").agg(pl.col("b").reverse().pipe(op)) + assert q.collect_schema() == q.collect().collect_schema() + + q = lf.group_by("a").agg(pl.col("b").shuffle().reverse().pipe(op)) + + assert q.collect_schema() == q.collect().collect_schema() From 44f2a38d644fff9e991b01161225300cd8056581 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 13 Nov 2024 16:16:17 +1100 Subject: [PATCH 2/4] c --- crates/polars-plan/src/plans/aexpr/schema.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index ff290b814e11..d4b3226a132a 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -32,9 +32,9 @@ impl AExpr { ctx: Context, arena: &Arena, ) -> PolarsResult { - // In some cases we need to implode the result column, this is indicated by `agg_list`. - // We initialize it to `true` if we are in an aggregation context. Functions that always - // return scalars should explicitly set it to `false` (e.g. aggregations). + // Indicates whether we should auto-implode the result. This is initialized to true if we are + // in an aggregation context, so functions that return scalars should explicitly set this + // to false in `to_field_impl`. let mut agg_list = matches!(ctx, Context::Aggregation); let mut field = self.to_field_impl(schema, ctx, arena, &mut agg_list)?; From ceba39262d586117f575aaae18e3aa73013234fe Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 13 Nov 2024 16:27:25 +1100 Subject: [PATCH 3/4] c --- crates/polars-plan/src/plans/aexpr/schema.rs | 12 ++++++++++-- py-polars/tests/unit/test_schema.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index d4b3226a132a..6c1b675b2bd8 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -64,9 +64,17 @@ impl AExpr { *agg_list = false; Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) }, - Window { function, .. } => { + Window { + function, options, .. + } => { + if let WindowType::Over(WindowMapping::Join) = options { + // expr.over(..), defaults to agg-list unless explicitly unset + // by the `to_field_impl` of the `expr` + *agg_list = true; + } + let e = arena.get(*function); - e.to_field_impl(schema, ctx, arena, &mut false) + e.to_field_impl(schema, ctx, arena, agg_list) }, Explode(expr) => { // `Explode` is a "flatten" operation, which is not the same as returning a scalar. diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 3be7e17e5644..a8f9e43d84c0 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -266,3 +266,15 @@ def test_lf_agg_auto_agg_list_19752(expr_op: str) -> None: q = lf.group_by("a").agg(pl.col("b").shuffle().reverse().pipe(op)) assert q.collect_schema() == q.collect().collect_schema() + + +@pytest.mark.parametrize( + "expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()] +) +@pytest.mark.parametrize("mapping_strategy", ["explode", "join", "group_to_rows"]) +def test_lf_window_schema(expr: pl.Expr, mapping_strategy: str) -> None: + q = pl.LazyFrame({"a": 1, "b": 1}).select( + expr.over("a", mapping_strategy=mapping_strategy) # type: ignore[arg-type] + ) + + assert q.collect_schema() == q.collect().collect_schema() From 0ebd8aab2e9aeeaa303a00b84189e2e64aa7effd Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Wed, 13 Nov 2024 16:36:11 +1100 Subject: [PATCH 4/4] c --- .github/workflows/test-coverage.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index 232f79fb8947..531add1428e6 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -96,9 +96,13 @@ jobs: with: python-version: '3.12' - - name: Create virtual environment + - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh + echo "$HOME/.local/bin" >> "$GITHUB_PATH" + + - name: Create virtual environment + run: | uv venv echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV @@ -165,7 +169,7 @@ jobs: runs-on: ubuntu-latest steps: - # Needed to fetch the Codecov config file + # Needed to fetch the Codecov config file - uses: actions/checkout@v4 - name: Download coverage reports