From a5dc30da73a421487525a00cdb26c2b2d3c5d03f Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 30 Aug 2024 14:50:15 +0200 Subject: [PATCH] fix: Various schema corrections (#18474) --- crates/polars-plan/src/plans/aexpr/schema.rs | 7 +- .../src/plans/conversion/expr_to_ir.rs | 74 ++-------- .../src/plans/conversion/functions.rs | 66 +++++++++ .../polars-plan/src/plans/conversion/mod.rs | 2 +- .../plans/conversion/type_coercion/binary.rs | 8 +- .../conversion/type_coercion/functions.rs | 83 +++++++++++ .../plans/conversion/type_coercion/is_in.rs | 97 +++++++++++++ .../src/plans/conversion/type_coercion/mod.rs | 132 +++--------------- py-polars/tests/unit/test_datatypes.py | 5 + 9 files changed, 301 insertions(+), 173 deletions(-) create mode 100644 crates/polars-plan/src/plans/conversion/functions.rs create mode 100644 crates/polars-plan/src/plans/conversion/type_coercion/functions.rs create mode 100644 crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 89eed6b70c01..00bd4b8d7c50 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -56,7 +56,12 @@ impl AExpr { *nested = 0; Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) }, - Window { function, .. } => { + Window { + function, options, .. + } => { + if let WindowType::Over(mapping) = options { + *nested += matches!(mapping, WindowMapping::Join) as u8; + } let e = arena.get(*function); e.to_field_impl(schema, arena, nested) }, diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index bcfacb7f0dc6..29654d03f596 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -1,4 +1,5 @@ use super::*; +use crate::plans::conversion::functions::convert_functions; pub fn to_expr_ir(expr: Expr, arena: &mut Arena) -> PolarsResult { let mut state = ConversionContext::new(); @@ -40,12 +41,12 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> PolarsResult { } #[derive(Default)] -struct ConversionContext { - output_name: OutputName, +pub(super) struct ConversionContext { + pub(super) output_name: OutputName, /// Remove alias from the expressions and set as [`OutputName`]. - prune_alias: bool, + pub(super) prune_alias: bool, /// If an `alias` is encountered prune and ignore it. - ignore_alias: bool, + pub(super) ignore_alias: bool, } impl ConversionContext { @@ -68,14 +69,17 @@ fn to_aexprs( .collect() } -fn set_function_output_name(e: &[ExprIR], state: &mut ConversionContext, function_fmt: F) -where - F: FnOnce() -> Cow<'static, str>, +pub(super) fn set_function_output_name( + e: &[ExprIR], + state: &mut ConversionContext, + function_fmt: F, +) where + F: FnOnce() -> PlSmallStr, { if state.output_name.is_none() { if e.is_empty() { let s = function_fmt(); - state.output_name = OutputName::LiteralLhs(PlSmallStr::from_str(s.as_ref())); + state.output_name = OutputName::LiteralLhs(s); } else { state.output_name = e[0].output_name_inner().clone(); } @@ -117,7 +121,7 @@ fn to_aexpr_impl_materialized_lit( /// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. #[recursive] -fn to_aexpr_impl( +pub(super) fn to_aexpr_impl( expr: Expr, arena: &mut Arena, state: &mut ConversionContext, @@ -281,7 +285,7 @@ fn to_aexpr_impl( options, } => { let e = to_expr_irs(input, arena)?; - set_function_output_name(&e, state, || Cow::Borrowed(options.fmt_str)); + set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str)); AExpr::AnonymousFunction { input: e, function, @@ -293,55 +297,7 @@ fn to_aexpr_impl( input, function, options, - } => { - match function { - // This can be created by col(*).is_null() on empty dataframes. - FunctionExpr::Boolean( - BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal, - ) if input.is_empty() => { - return to_aexpr_impl(lit(true), arena, state); - }, - // Convert to binary expression as the optimizer understands those. - // Don't exceed 128 expressions as we might stackoverflow. - FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => { - if input.len() < 128 { - let expr = input - .into_iter() - .reduce(|l, r| l.logical_and(r)) - .unwrap() - .cast(DataType::Boolean); - return to_aexpr_impl(expr, arena, state); - } - }, - FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => { - if input.len() < 128 { - let expr = input - .into_iter() - .reduce(|l, r| l.logical_or(r)) - .unwrap() - .cast(DataType::Boolean); - return to_aexpr_impl(expr, arena, state); - } - }, - _ => {}, - } - - let e = to_expr_irs(input, arena)?; - - if state.output_name.is_none() { - // Handles special case functions like `struct.field`. - if let Some(name) = function.output_name() { - state.output_name = name - } else { - set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function))); - } - } - AExpr::Function { - input: e, - function, - options, - } - }, + } => return convert_functions(input, function, options, arena, state), Expr::Window { function, partition_by, diff --git a/crates/polars-plan/src/plans/conversion/functions.rs b/crates/polars-plan/src/plans/conversion/functions.rs new file mode 100644 index 000000000000..409da958477d --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/functions.rs @@ -0,0 +1,66 @@ +use arrow::legacy::error::PolarsResult; +use polars_utils::arena::{Arena, Node}; +use polars_utils::format_pl_smallstr; + +use super::*; +use crate::dsl::{Expr, FunctionExpr}; +use crate::plans::AExpr; +use crate::prelude::FunctionOptions; + +pub(super) fn convert_functions( + input: Vec, + function: FunctionExpr, + options: FunctionOptions, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult { + match function { + // This can be created by col(*).is_null() on empty dataframes. + FunctionExpr::Boolean(BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal) + if input.is_empty() => + { + return to_aexpr_impl(lit(true), arena, state); + }, + // Convert to binary expression as the optimizer understands those. + // Don't exceed 128 expressions as we might stackoverflow. + FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => { + if input.len() < 128 { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_and(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + } + }, + FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => { + if input.len() < 128 { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_or(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + } + }, + _ => {}, + } + + let e = to_expr_irs(input, arena)?; + + if state.output_name.is_none() { + // Handles special case functions like `struct.field`. + if let Some(name) = function.output_name() { + state.output_name = name + } else { + set_function_output_name(&e, state, || format_pl_smallstr!("{}", &function)); + } + } + + let ae_function = AExpr::Function { + input: e, + function, + options, + }; + Ok(arena.add(ae_function)) +} diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 28c41039d4b8..e07f8bc2848e 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -12,7 +12,6 @@ mod ir_to_dsl; mod scans; mod stack_opt; -use std::borrow::Cow; use std::sync::{Arc, Mutex, RwLock}; pub use dsl_to_ir::*; @@ -21,6 +20,7 @@ pub use ir_to_dsl::*; use polars_core::prelude::*; use polars_utils::vec::ConvertVec; use recursive::recursive; +mod functions; pub(crate) mod type_coercion; pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection}; diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 4f8dd1ee0fb0..9bc2917c250d 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -296,7 +296,13 @@ pub(super) fn process_binary( st = String } - // only cast if the type is not already the super type. + // TODO! raise here? + // We should at least never cast to Unknown. + if matches!(st, DataType::Unknown(UnknownKind::Any)) { + return Ok(None); + } + + // Only cast if the type is not already the super type. // this can prevent an expensive flattening and subsequent aggregation // in a group_by context. To be able to cast the groups need to be // flattened diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs new file mode 100644 index 000000000000..c7b738722c55 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs @@ -0,0 +1,83 @@ +use either::Either; + +use super::*; + +pub(super) fn get_function_dtypes( + input: &[ExprIR], + expr_arena: &Arena, + input_schema: &Schema, + function: &FunctionExpr, + mut options: FunctionOptions, +) -> PolarsResult, AExpr>> { + let mut early_return = move || { + // Next iteration this will not hit anymore as options is updated. + options.cast_to_supertypes = None; + Ok(Either::Right(AExpr::Function { + function: function.clone(), + input: input.to_vec(), + options, + })) + }; + + let mut dtypes = Vec::with_capacity(input.len()); + let mut first = true; + for e in input { + let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else { + return early_return(); + }; + + if first { + check_namespace(function, &dtype)?; + first = false; + } + // Ignore Unknown in the inputs. + // We will raise if we cannot find the supertype later. + match dtype { + DataType::Unknown(UnknownKind::Any) => { + return early_return(); + }, + _ => dtypes.push(dtype), + } + } + + if dtypes.iter().all_equal() { + return early_return(); + } + Ok(Either::Left(dtypes)) +} + +// `str` namespace belongs to `String` +// `cat` namespace belongs to `Categorical` etc. +fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> { + match function { + #[cfg(feature = "strings")] + FunctionExpr::StringExpr(_) => { + polars_ensure!(first_dtype == &DataType::String, InvalidOperation: "expected String type, got: {}", first_dtype) + }, + FunctionExpr::BinaryExpr(_) => { + polars_ensure!(first_dtype == &DataType::Binary, InvalidOperation: "expected Binary type, got: {}", first_dtype) + }, + #[cfg(feature = "temporal")] + FunctionExpr::TemporalExpr(_) => { + polars_ensure!(first_dtype.is_temporal(), InvalidOperation: "expected Date(time)/Duration type, got: {}", first_dtype) + }, + FunctionExpr::ListExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::List(_)), InvalidOperation: "expected List type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Array(_, _)), InvalidOperation: "expected Array type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Struct(_)), InvalidOperation: "expected Struct type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-categorical")] + FunctionExpr::Categorical(_) => { + polars_ensure!(matches!(first_dtype, DataType::Categorical(_, _)), InvalidOperation: "expected Struct type, got: {}", first_dtype) + }, + _ => {}, + } + + Ok(()) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs new file mode 100644 index 000000000000..6b906eb14567 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs @@ -0,0 +1,97 @@ +use super::*; + +pub(super) fn resolve_is_in( + input: &[ExprIR], + expr_arena: &Arena, + lp_arena: &Arena, + lp_node: Node, +) -> PolarsResult> { + let input_schema = get_schema(lp_arena, lp_node); + let other_e = &input[1]; + let (_, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[0].node(), + &input_schema + )); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + other_e.node(), + &input_schema + )); + + unpack!(early_escape(&type_left, &type_other)); + + let casted_expr = match (&type_left, &type_other) { + // types are equal, do nothing + (a, b) if a == b => return Ok(None), + // all-null can represent anything (and/or empty list), so cast to target dtype + (_, DataType::Null) => AExpr::Cast { + expr: other_e.node(), + data_type: type_left, + options: CastOptions::NonStrict, + }, + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => return Ok(None), + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => return Ok(None), + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), dt) if dt.is_numeric() => AExpr::Cast { + expr: other_e.node(), + data_type: type_left, + options: CastOptions::NonStrict, + }, + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + // can't check for more granular time_unit in less-granular time_unit data, + // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) + (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + } + }, + (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) + } + }, + (_, DataType::List(other_inner)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) + }, + #[cfg(feature = "dtype-array")] + (_, DataType::Array(other_inner, _)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) + }, + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), + + // don't attempt to cast between obviously mismatched types, but + // allow integer/float comparison (will use their supertypes). + (a, b) => { + if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + }; + Ok(Some(casted_expr)) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index 230652e101f7..46caac208b23 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -1,9 +1,13 @@ mod binary; +mod functions; +#[cfg(feature = "is_in")] +mod is_in; use std::borrow::Cow; use arrow::temporal_conversions::{time_unit_multiple, SECONDS_IN_DAY}; use binary::process_binary; +use either::Either; use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int}; @@ -23,6 +27,7 @@ macro_rules! unpack { } }; } +pub(super) use unpack; /// determine if we use the supertype or not. For instance when we have a column Int64 and we compare with literal UInt32 /// it would be wasteful to cast the column instead of the literal. @@ -193,98 +198,12 @@ impl OptimizationRule for TypeCoercionRule { ref input, options, } => { - let input_schema = get_schema(lp_arena, lp_node); - let other_e = &input[1]; - let (_, type_left) = unpack!(get_aexpr_and_type( - expr_arena, - input[0].node(), - &input_schema - )); - let (_, type_other) = unpack!(get_aexpr_and_type( - expr_arena, - other_e.node(), - &input_schema - )); - - unpack!(early_escape(&type_left, &type_other)); - - let casted_expr = match (&type_left, &type_other) { - // types are equal, do nothing - (a, b) if a == b => return Ok(None), - // all-null can represent anything (and/or empty list), so cast to target dtype - (_, DataType::Null) => AExpr::Cast { - expr: other_e.node(), - data_type: type_left, - options: CastOptions::NonStrict, - }, - #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => { - return Ok(None) - }, - #[cfg(feature = "dtype-categorical")] - (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - return Ok(None) - }, - #[cfg(feature = "dtype-decimal")] - (DataType::Decimal(_, _), dt) if dt.is_numeric() => AExpr::Cast { - expr: other_e.node(), - data_type: type_left, - options: CastOptions::NonStrict, - }, - #[cfg(feature = "dtype-decimal")] - (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) - }, - // can't check for more granular time_unit in less-granular time_unit data, - // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) - (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { - if lhs_unit <= rhs_unit { - return Ok(None); - } else { - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) - } - }, - (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { - if lhs_unit <= rhs_unit { - return Ok(None); - } else { - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) - } - }, - (_, DataType::List(other_inner)) => { - if other_inner.as_ref() == &type_left - || (type_left == DataType::Null) - || (other_inner.as_ref() == &DataType::Null) - || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) - { - return Ok(None); - } - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) - }, - #[cfg(feature = "dtype-array")] - (_, DataType::Array(other_inner, _)) => { - if other_inner.as_ref() == &type_left - || (type_left == DataType::Null) - || (other_inner.as_ref() == &DataType::Null) - || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) - { - return Ok(None); - } - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) - }, - #[cfg(feature = "dtype-struct")] - (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), - - // don't attempt to cast between obviously mismatched types, but - // allow integer/float comparison (will use their supertypes). - (a, b) => { - if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { - return Ok(None); - } - polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) - }, + let Some(casted_expr) = is_in::resolve_is_in(input, expr_arena, lp_arena, lp_node)? + else { + return Ok(None); }; - let mut input = input.clone(); + + let mut input = input.to_vec(); let other_input = expr_arena.add(casted_expr); input[1].set_node(other_input); @@ -300,8 +219,6 @@ impl OptimizationRule for TypeCoercionRule { ref input, options, } => { - let mut input = input.clone(); - let input_schema = get_schema(lp_arena, lp_node); let left_node = input[0].node(); let fill_value_node = input[2].node(); @@ -319,6 +236,7 @@ impl OptimizationRule for TypeCoercionRule { let super_type = modify_supertype(super_type, left, fill_value, &type_left, &type_fill_value); + let mut input = input.clone(); let new_node_left = if type_left != super_type { expr_arena.add(AExpr::Cast { expr: left_node, @@ -356,25 +274,17 @@ impl OptimizationRule for TypeCoercionRule { mut options, } if options.cast_to_supertypes.is_some() => { let input_schema = get_schema(lp_arena, lp_node); - let mut dtypes = Vec::with_capacity(input.len()); - for e in input { - let (_, dtype) = - unpack!(get_aexpr_and_type(expr_arena, e.node(), &input_schema)); - // Ignore Unknown in the inputs. - // We will raise if we cannot find the supertype later. - match dtype { - DataType::Unknown(UnknownKind::Any) => { - options.cast_to_supertypes = None; - return Ok(None); - }, - _ => dtypes.push(dtype), - } - } - if dtypes.iter().all_equal() { - options.cast_to_supertypes = None; - return Ok(None); - } + let dtypes = match functions::get_function_dtypes( + input, + expr_arena, + &input_schema, + function, + options, + )? { + Either::Left(dtypes) => dtypes, + Either::Right(ae) => return Ok(Some(ae)), + }; // TODO! use args_to_supertype. let self_e = input[0].clone(); diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 8313c4203c1f..9bd545125f64 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -196,3 +196,8 @@ def test_struct_field_iter() -> None: ("b", List(Int64)), ("a", List(List(Int64))), ] + + +def test_raise_invalid_namespace() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(pl.lit(1.5).str.replace("1", "2"))