diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index 125b774cf935..1ca7e085bfa3 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -263,7 +263,7 @@ pub(super) fn check_expand_literals( if duplicate_check && !names.insert(name) { let msg = format!( - "the name: '{}' is duplicate\n\n\ + "the name '{}' is duplicate\n\n\ It's possible that multiple expressions are returning the same default column \ name. If this is the case, try renaming the columns with \ `.alias(\"new_name\")` to avoid duplicate column names.", diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0dd79141ea02..d596f8f95c15 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -103,7 +103,7 @@ pub(super) use self::rolling_by::RollingFunctionBy; #[cfg(feature = "strings")] pub use self::strings::StringFunction; #[cfg(feature = "dtype-struct")] -pub(crate) use self::struct_::StructFunction; +pub use self::struct_::StructFunction; #[cfg(feature = "trigonometry")] pub(super) use self::trigonometry::TrigonometricFunction; use super::*; diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index cd273a6c6ddc..a05417dbfd24 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -707,7 +707,7 @@ fn resolve_with_columns( if !output_names.insert(field.name().clone()) { let msg = format!( - "the name: '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ + "the name '{}' passed to `LazyFrame.with_columns` is duplicate\n\n\ It's possible that multiple expressions are returning the same default column name. \ If this is the case, try renaming the columns with `.alias(\"new_name\")` to avoid \ duplicate column names.", diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 83dcd5b98a3c..0c8f883daf50 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" arrow = { workspace = true } polars-core = { workspace = true, features = ["rows"] } polars-error = { workspace = true } -polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } +polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "dtype-struct", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } polars-ops = { workspace = true } polars-plan = { workspace = true } polars-time = { workspace = true } diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 496b0d0070d0..ea86de06f4e5 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,22 +1,25 @@ use std::cell::RefCell; +use std::ops::Deref; use polars_core::frame::row::Row; use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; +use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, - JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, - SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, - Value as SQLValue, Values, WildcardAdditionalOptions, + Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, + JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, RenameSelectItem, Select, + SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, + TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; use crate::sql_expr::{ - parse_sql_array, parse_sql_expr, process_join_constraint, to_sql_interface_err, + parse_sql_array, parse_sql_expr, process_join_constraint, resolve_compound_identifier, + to_sql_interface_err, }; use crate::table_functions::PolarsTableFunctions; @@ -200,8 +203,9 @@ impl SQLContext { fn expr_or_ordinal( &mut self, e: &SQLExpr, - schema: Option<&Schema>, exprs: &[Expr], + selected: Option<&[Expr]>, + schema: Option<&Schema>, clause: &str, ) -> PolarsResult { match e { @@ -230,7 +234,14 @@ impl SQLContext { idx ) })?; - Ok(exprs + // note: "selected" cols represent final projection order, so we use those for + // ordinal resolution. "exprs" may include cols that are subsequently dropped. + let cols = if let Some(cols) = selected { + cols + } else { + exprs + }; + Ok(cols .get(idx - 1) .ok_or_else(|| { polars_err!( @@ -579,57 +590,72 @@ impl SQLContext { /// Execute the 'SELECT' part of the query. fn execute_select(&mut self, select_stmt: &Select, query: &Query) -> PolarsResult { - // Determine involved dataframes. - // Note: implicit joins require more work in query parsing, - // explicit joins are preferred for now (ref: #16662) - let mut lf = if select_stmt.from.is_empty() { DataFrame::empty().lazy() } else { + // Note: implicit joins need more work to support properly, + // explicit joins are preferred for now (ref: #16662) let from = select_stmt.clone().from; if from.len() > 1 { polars_bail!(SQLInterface: "multiple tables in FROM clause are not currently supported (found {}); use explicit JOIN syntax instead", from.len()) } self.execute_from_statement(from.first().unwrap())? }; - let mut contains_wildcard = false; - let mut contains_wildcard_exclude = false; - // Filter expression. - let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + // Filter expression (WHERE clause) + let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; lf = self.process_where(lf, &select_stmt.selection)?; - // Column projections. - let projections: Vec<_> = select_stmt + // 'SELECT *' modifiers + let mut excluded_cols = vec![]; + let mut replace_exprs = vec![]; + let mut rename_cols = (&mut vec![], &mut vec![]); + + // Column projections (SELECT clause) + let projections: Vec = select_stmt .projection .iter() .map(|select_item| { Ok(match select_item { - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?, + SelectItem::UnnamedExpr(expr) => { + vec![parse_sql_expr(expr, self, Some(schema.deref()))?] + }, SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, self, schema.as_deref())?; - expr.alias(&alias.value) + let expr = parse_sql_expr(expr, self, Some(schema.deref()))?; + vec![expr.alias(&alias.value)] }, - SelectItem::QualifiedWildcard(oname, wildcard_options) => self + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self .process_qualified_wildcard( - oname, + obj_name, wildcard_options, - &mut contains_wildcard_exclude, + &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, + Some(schema.deref()), )?, SelectItem::Wildcard(wildcard_options) => { - contains_wildcard = true; - let e = col("*"); + let cols = schema + .iter_names() + .map(|name| col(name)) + .collect::>(); + self.process_wildcard_additional_options( - e, + cols, wildcard_options, - &mut contains_wildcard_exclude, + &mut excluded_cols, + &mut rename_cols, + &mut replace_exprs, + Some(schema.deref()), )? }, }) }) - .collect::>()?; + .collect::>>>()? + .into_iter() + .flatten() + .collect(); - // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). + // Check for "GROUP BY ..." (after determining projections) let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) @@ -637,7 +663,15 @@ impl SQLContext { // translate the group expressions, allowing ordinal values group_by_keys = group_by_exprs .iter() - .map(|e| self.expr_or_ordinal(e, schema.as_deref(), &projections, "GROUP BY")) + .map(|e| { + self.expr_or_ordinal( + e, + &projections, + None, + Some(schema.deref()), + "GROUP BY", + ) + }) .collect::>()? }, // "GROUP BY ALL" syntax; automatically adds expressions that do not contain @@ -669,73 +703,38 @@ impl SQLContext { }; lf = if group_by_keys.is_empty() { - if query.order_by.is_empty() { + lf = if query.order_by.is_empty() { + // No sort, select cols as given lf.select(projections) - } else if !contains_wildcard { - let mut retained_names = PlIndexSet::with_capacity(projections.len()); - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + } else { + // Add projections to the base frame as any of the + // original columns may be required for the sort + lf = lf.with_columns(projections.clone()); - projections.iter().for_each(|expr| match expr { - Expr::Alias(_, name) => { - retained_names.insert(name.clone()); - }, - Expr::Column(name) => { - retained_names.insert(name.clone()); - }, - Expr::Columns(names) => names.iter().for_each(|name| { - retained_names.insert(name.clone()); - }), - Expr::Exclude(inner_expr, excludes) => { - if let Expr::Columns(names) = (*inner_expr).as_ref() { - names.iter().for_each(|name| { - retained_names.insert(name.clone()); - }) - } - excludes.iter().for_each(|excluded| { - if let Excluded::Name(name) = excluded { - retained_names.shift_remove(name); - } - }) - }, - _ => { - let field = expr.to_field(&schema, Context::Default).unwrap(); - retained_names.insert(ColumnName::from(field.name.as_str())); - }, - }); - let retained_columns: Vec<_> = - retained_names.into_iter().map(|name| col(&name)).collect(); - lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by)?; - lf.select(&retained_columns) - } else if contains_wildcard_exclude { - let mut dropped_names = Vec::with_capacity(projections.len()); - let exclude_expr = projections.iter().find(|expr| { - if let Expr::Exclude(_, excludes) = expr { - for excluded in excludes.iter() { - if let Excluded::Name(name) = excluded { - dropped_names.push(name.to_string()); - } - } - true - } else { - false - } - }); - if exclude_expr.is_some() { - lf = lf.with_columns(projections); - lf = self.process_order_by(lf, &query.order_by)?; - lf.drop(dropped_names) - } else { - lf = lf.select(projections); - self.process_order_by(lf, &query.order_by)? - } + // Final/selected cols (also ensures accurate ordinal position refs) + let retained_cols = projections + .iter() + .map(|e| { + col(e + .to_field(schema.deref(), Context::Default) + .unwrap() + .name + .as_str()) + }) + .collect::>(); + + lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))?; + lf.select(retained_cols) + }; + // Discard any excluded cols + if !excluded_cols.is_empty() { + lf.drop(excluded_cols) } else { - lf = lf.select(projections); - self.process_order_by(lf, &query.order_by)? + lf } } else { - lf = self.process_group_by(lf, contains_wildcard, &group_by_keys, &projections)?; - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_group_by(lf, &group_by_keys, &projections)?; + lf = self.process_order_by(lf, &query.order_by, None)?; // Apply optional 'having' clause, post-aggregation. let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); @@ -745,7 +744,7 @@ impl SQLContext { } }; - // Apply optional 'distinct' clause. + // Apply optional DISTINCT clause. lf = match &select_stmt.distinct { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { @@ -763,15 +762,22 @@ impl SQLContext { }) .collect::>>()?; - // DISTINCT ON applies the ORDER BY before the operation. + // DISTINCT ON has to apply the ORDER BY before the operation. if !query.order_by.is_empty() { - lf = self.process_order_by(lf, &query.order_by)?; + lf = self.process_order_by(lf, &query.order_by, None)?; } return Ok(lf.unique_stable(Some(cols), UniqueKeepStrategy::First)); }, None => lf, }; + // Apply final 'SELECT *' modifiers + if !replace_exprs.is_empty() { + lf = lf.with_columns(replace_exprs); + } + if !rename_cols.0.is_empty() { + lf = lf.rename(rename_cols.0, rename_cols.1); + } Ok(lf) } @@ -994,13 +1000,14 @@ impl SQLContext { &mut self, mut lf: LazyFrame, order_by: &[OrderByExpr], + selected: Option<&[Expr]>, ) -> PolarsResult { let mut by = Vec::with_capacity(order_by.len()); let mut descending = Vec::with_capacity(order_by.len()); let mut nulls_last = Vec::with_capacity(order_by.len()); let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); - let column_names = schema + let columns = schema .clone() .unwrap() .iter_names() @@ -1015,7 +1022,13 @@ impl SQLContext { descending.push(desc_order); // translate order expression, allowing ordinal values - by.push(self.expr_or_ordinal(&ob.expr, schema.as_deref(), &column_names, "ORDER BY")?) + by.push(self.expr_or_ordinal( + &ob.expr, + &columns, + selected, + schema.as_deref(), + "ORDER BY", + )?) } Ok(lf.sort_by_exprs( &by, @@ -1029,20 +1042,16 @@ impl SQLContext { fn process_group_by( &mut self, mut lf: LazyFrame, - contains_wildcard: bool, group_by_keys: &[Expr], projections: &[Expr], ) -> PolarsResult { - polars_ensure!( - !contains_wildcard, - SQLSyntax: "GROUP BY error (cannot process wildcard in group_by)" - ); let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; let group_by_keys_schema = expressions_to_schema(group_by_keys, &schema_before, Context::Default)?; // Remove the group_by keys as polars adds those implicitly. let mut aggregation_projection = Vec::with_capacity(projections.len()); + let mut projection_overrides = PlHashMap::with_capacity(projections.len()); let mut projection_aliases = PlHashSet::new(); let mut group_key_aliases = PlHashSet::new(); @@ -1057,6 +1066,12 @@ impl SQLContext { if e.clone().meta().is_simple_projection() { group_key_aliases.insert(alias.as_ref()); e = expr + } else if let Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByName(name)), + .. + } = expr.deref() + { + projection_overrides.insert(alias.as_ref(), col(name).alias(alias)); } else if !is_agg_or_window && !group_by_keys_schema.contains(alias) { projection_aliases.insert(alias.as_ref()); } @@ -1072,7 +1087,12 @@ impl SQLContext { } } aggregation_projection.push(e); - } else if let Expr::Column(_) = e { + } else if let Expr::Column(_) + | Expr::Function { + function: FunctionExpr::StructExpr(StructFunction::FieldByName(_)), + .. + } = e + { // Non-aggregated columns must be part of the GROUP BY clause if !group_by_keys_schema.contains(&field.name) { polars_bail!(SQLSyntax: "'{}' should participate in the GROUP BY clause or an aggregate function", &field.name); @@ -1088,7 +1108,9 @@ impl SQLContext { .iter_names() .zip(projections) .map(|(name, projection_expr)| { - if group_by_keys_schema.get(name).is_some() + if let Some(expr) = projection_overrides.get(name.as_str()) { + expr.clone() + } else if group_by_keys_schema.get(name).is_some() || projection_aliases.contains(name.as_str()) || group_key_aliases.contains(name.as_str()) { @@ -1151,48 +1173,73 @@ impl SQLContext { &mut self, ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, - contains_wildcard_exclude: &mut bool, - ) -> PolarsResult { - let idents = idents.as_slice(); - let e = match idents { - [tbl_name] => { - let lf = self.table_map.get_mut(&tbl_name.value).ok_or_else(|| { - polars_err!( - SQLInterface: "no table named '{}' found", - tbl_name - ) - })?; - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - cols(schema.iter_names()) - }, - e => polars_bail!( - SQLSyntax: "invalid wildcard expression ({:?})", - e - ), - }; - self.process_wildcard_additional_options(e, options, contains_wildcard_exclude) + excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, + schema: Option<&Schema>, + ) -> PolarsResult> { + let mut new_idents = idents.clone(); + new_idents.push(Ident::new("*")); + + let expr = resolve_compound_identifier(self, new_idents.deref(), schema); + self.process_wildcard_additional_options( + expr?, + options, + excluded_cols, + rename_cols, + replace_exprs, + schema, + ) } fn process_wildcard_additional_options( &mut self, - expr: Expr, + exprs: Vec, options: &WildcardAdditionalOptions, - contains_wildcard_exclude: &mut bool, - ) -> PolarsResult { + excluded_cols: &mut Vec, + rename_cols: &mut (&mut Vec, &mut Vec), + replace_exprs: &mut Vec, + schema: Option<&Schema>, + ) -> PolarsResult> { + // bail on (currently) unsupported wildcard options if options.opt_except.is_some() { - polars_bail!(SQLSyntax: "EXCEPT not supported (use EXCLUDE instead)") + polars_bail!(SQLInterface: "EXCEPT wildcard option is unsupported (use EXCLUDE instead)") + } else if options.opt_ilike.is_some() { + polars_bail!(SQLInterface: "ILIKE wildcard option is currently unsupported") + } else if options.opt_rename.is_some() && options.opt_replace.is_some() { + // pending an upstream fix: https://github.com/sqlparser-rs/sqlparser-rs/pull/1321 + polars_bail!(SQLInterface: "RENAME and REPLACE wildcard options cannot (yet) be used simultaneously") } - Ok(match &options.opt_exclude { - Some(ExcludeSelectItem::Single(ident)) => { - *contains_wildcard_exclude = true; - expr.exclude(vec![&ident.value]) - }, - Some(ExcludeSelectItem::Multiple(idents)) => { - *contains_wildcard_exclude = true; - expr.exclude(idents.iter().map(|i| &i.value)) - }, - _ => expr, - }) + + if let Some(items) = &options.opt_exclude { + *excluded_cols = match items { + ExcludeSelectItem::Single(ident) => vec![ident.value.clone()], + ExcludeSelectItem::Multiple(idents) => { + idents.iter().map(|i| i.value.clone()).collect() + }, + }; + } + if let Some(items) = &options.opt_rename { + match items { + RenameSelectItem::Single(rename) => { + rename_cols.0.push(rename.ident.value.clone()); + rename_cols.1.push(rename.alias.value.clone()); + }, + RenameSelectItem::Multiple(renames) => { + for rn in renames { + rename_cols.0.push(rn.ident.value.clone()); + rename_cols.1.push(rn.alias.value.clone()); + } + }, + } + } + if let Some(replacements) = &options.opt_replace { + for rp in &replacements.items { + let replacement_expr = parse_sql_expr(&rp.expr, self, schema); + replace_exprs.push(replacement_expr?.alias(rp.column_name.value.as_str())); + } + } + Ok(exprs) } fn rename_columns_from_table_alias( diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 447f0486baab..bcf95d75f8d4 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -375,42 +375,9 @@ impl SQLExprVisitor<'_> { /// Visit a compound SQL identifier /// - /// e.g. df.column or "df"."column" + /// e.g. tbl.column, struct.field, tbl.struct.field (inc. nested struct fields) fn visit_compound_identifier(&mut self, idents: &[Ident]) -> PolarsResult { - match idents { - [tbl_name, column_name] => { - let mut lf = self - .ctx - .get_table_from_current_scope(&tbl_name.value) - .ok_or_else(|| { - polars_err!( - SQLInterface: "no table or alias named '{}' found", - tbl_name - ) - })?; - - let schema = - lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena)?; - if let Some((_, name, _)) = schema.get_full(&column_name.value) { - let resolved = &self.ctx.resolve_name(&tbl_name.value, &column_name.value); - Ok(if name != resolved { - col(resolved).alias(name) - } else { - col(name) - }) - } else { - polars_bail!( - SQLInterface: "no column named '{}' found in table '{}'", - column_name, - tbl_name - ) - } - }, - _ => polars_bail!( - SQLInterface: "invalid identifier {:?}", - idents - ), - } + Ok(resolve_compound_identifier(self.ctx, idents, self.active_schema)?[0].clone()) } fn visit_interval(&self, interval: &Interval) -> PolarsResult { @@ -1240,3 +1207,83 @@ fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), })) } + +pub(crate) fn resolve_compound_identifier( + ctx: &mut SQLContext, + idents: &[Ident], + active_schema: Option<&Schema>, +) -> PolarsResult> { + // inference priority: table > struct > column + let ident_root = &idents[0]; + let mut remaining_idents = idents.iter().skip(1); + let mut lf = ctx.get_table_from_current_scope(&ident_root.value); + + let schema = if let Some(ref mut lf) = lf { + lf.schema_with_arenas(&mut ctx.lp_arena, &mut ctx.expr_arena) + } else { + Ok(Arc::new(if let Some(active_schema) = active_schema { + active_schema.clone() + } else { + Schema::new() + })) + }?; + + let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() { + Ok((col(&ident_root.value), None)) + } else { + let name = &remaining_idents.next().unwrap().value; + if lf.is_some() && name == "*" { + return Ok(schema + .iter_names() + .map(|name| col(name)) + .collect::>()); + } else if let Some((_, name, dtype)) = schema.get_full(name) { + let resolved = &ctx.resolve_name(&ident_root.value, name); + Ok(( + if name != resolved { + col(resolved).alias(name) + } else { + col(name) + }, + Some(dtype), + )) + } else if lf.is_none() { + remaining_idents = idents.iter().skip(1); + Ok((col(&ident_root.value), schema.get(&ident_root.value))) + } else { + polars_bail!( + SQLInterface: "no column named '{}' found in table '{}'", + name, + ident_root + ) + } + }; + + // additional ident levels index into struct fields + let (mut column, mut dtype) = col_dtype?; + for ident in remaining_idents { + let name = ident.value.as_str(); + match dtype { + Some(DataType::Struct(fields)) if name == "*" => { + return Ok(fields + .iter() + .map(|fld| column.clone().struct_().field_by_name(&fld.name)) + .collect()) + }, + Some(DataType::Struct(fields)) => { + dtype = fields + .iter() + .find(|fld| fld.name == name) + .map(|fld| &fld.dtype); + }, + Some(dtype) if name == "*" => { + polars_bail!(SQLSyntax: "cannot expand '*' on non-Struct dtype; found {:?}", dtype) + }, + _ => { + dtype = None; + }, + } + column = column.struct_().field_by_name(name); + } + Ok(vec![column]) +} diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 0a60b9dc7aca..274eb1a7af6f 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -3,10 +3,29 @@ use polars_lazy::prelude::*; use polars_sql::*; use polars_time::Duration; -fn create_sample_df() -> PolarsResult { +fn create_sample_df() -> DataFrame { let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::>()); let b = Series::new("b", 1..10000i64); - DataFrame::new(vec![a, b]) + DataFrame::new(vec![a, b]).unwrap() +} + +fn create_struct_df() -> (DataFrame, DataFrame) { + let struct_cols = vec![col("num"), col("str"), col("val")]; + let df = df! { + "num" => [100, 250, 300, 350], + "str" => ["b", "a", "b", "a"], + "val" => [4.0, 3.5, 2.0, 1.5], + } + .unwrap(); + + ( + df.clone() + .lazy() + .select([as_struct(struct_cols).alias("json_msg")]) + .collect() + .unwrap(), + df, + ) } fn assert_sql_to_polars(df: &DataFrame, sql: &str, f: impl FnOnce(LazyFrame) -> LazyFrame) { @@ -19,7 +38,7 @@ fn assert_sql_to_polars(df: &DataFrame, sql: &str, f: impl FnOnce(LazyFrame) -> #[test] fn test_simple_select() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -44,7 +63,7 @@ fn test_simple_select() -> PolarsResult<()> { #[test] fn test_nested_expr() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -57,7 +76,7 @@ fn test_nested_expr() -> PolarsResult<()> { #[test] fn test_group_by_simple() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let df_sql = context @@ -134,7 +153,7 @@ fn test_group_by_expression_key() -> PolarsResult<()> { #[test] fn test_cast_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -164,7 +183,7 @@ fn test_cast_exprs() { #[test] fn test_literal_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -225,7 +244,7 @@ fn test_implicit_date_string() { #[test] fn test_prefixed_column_names() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -244,7 +263,7 @@ fn test_prefixed_column_names() { #[test] fn test_prefixed_column_names_2() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -263,7 +282,7 @@ fn test_prefixed_column_names_2() { #[test] fn test_null_exprs() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -319,28 +338,28 @@ fn test_null_exprs_in_where() { #[test] fn test_binary_functions() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" SELECT a, b, - a + b as add, - a - b as sub, - a * b as mul, - a / b as div, - a % b as rem, - a <> b as neq, - a = b as eq, - a > b as gt, - a < b as lt, - a >= b as gte, - a <= b as lte, - a and b as and, - a or b as or, - a xor b as xor, - a || b as concat + a + b AS add, + a - b AS sub, + a * b AS mul, + a / b AS div, + a % b AS rem, + a <> b AS neq, + a = b AS eq, + a > b AS gt, + a < b AS lt, + a >= b AS gte, + a <= b AS lte, + a and b AS and, + a or b AS or, + a xor b AS xor, + a || b AS concat FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df.lazy().select(&[ @@ -369,23 +388,23 @@ fn test_binary_functions() { #[test] #[ignore = "TODO: non deterministic"] fn test_agg_functions() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" SELECT - sum(a) as sum_a, - first(a) as first_a, - last(a) as last_a, - avg(a) as avg_a, - max(a) as max_a, - min(a) as min_a, - atan(a) as atan_a, - stddev(a) as stddev_a, - variance(a) as variance_a, - count(a) as count_a, - count(distinct a) as count_distinct_a, - count(*) as count_all + sum(a) AS sum_a, + first(a) AS first_a, + last(a) AS last_a, + avg(a) AS avg_a, + max(a) AS max_a, + min(a) AS min_a, + atan(a) AS atan_a, + stddev(a) AS stddev_a, + variance(a) AS variance_a, + count(a) AS count_a, + count(distinct a) AS count_distinct_a, + count(*) AS count_all FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df @@ -411,9 +430,10 @@ fn test_agg_functions() { #[test] fn test_create_table() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" CREATE TABLE df2 AS SELECT a @@ -423,14 +443,15 @@ fn test_create_table() { "Response" => ["CREATE TABLE"] } .unwrap(); + assert!(df_sql.equals(&create_tbl_res)); let df_2 = context .execute(r#"SELECT a FROM df2"#) .unwrap() .collect() .unwrap(); - let expected = df.lazy().select(&[col("a")]).collect().unwrap(); + let expected = df.lazy().select(&[col("a")]).collect().unwrap(); assert!(df_2.equals(&expected)); } @@ -450,6 +471,7 @@ fn test_unary_minus_0() { .filter(col("value").lt(lit(-1))) .collect() .unwrap(); + assert!(df_sql.equals(&df_pl)); } @@ -471,14 +493,14 @@ fn test_unary_minus_1() { #[test] fn test_arr_agg() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let exprs = vec![ ( "SELECT ARRAY_AGG(a) AS a FROM df", vec![col("a").implode().alias("a")], ), ( - "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) as b FROM df", + "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) AS b FROM df", vec![col("a").implode().alias("a"), col("b").implode().alias("b")], ), ( @@ -512,7 +534,7 @@ fn test_arr_agg() { #[test] fn test_ctes() -> PolarsResult<()> { - let df = create_sample_df()?; + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.lazy()); @@ -530,6 +552,23 @@ fn test_ctes() -> PolarsResult<()> { Ok(()) } +#[test] +fn test_cte_values() -> PolarsResult<()> { + let sql = r#" + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + "#; + let mut context = SQLContext::new(); + assert!(context.execute(sql).is_ok()); + + Ok(()) +} + #[test] #[cfg(feature = "ipc")] fn test_group_by_2() -> PolarsResult<()> { @@ -543,7 +582,7 @@ fn test_group_by_2() -> PolarsResult<()> { let sql = r#" SELECT category, - count(category) as count, + count(category) AS count, max(calories), min(fats_g) FROM foods @@ -566,6 +605,7 @@ fn test_group_by_2() -> PolarsResult<()> { SortMultipleOptions::default().with_order_descending_multi([false, true]), ) .limit(2); + let expected = expected.collect()?; assert!(df_sql.equals(&expected)); Ok(()) @@ -573,7 +613,7 @@ fn test_group_by_2() -> PolarsResult<()> { #[test] fn test_case_expr() { - let df = create_sample_df().unwrap().head(Some(10)); + let df = create_sample_df().head(Some(10)); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); let sql = r#" @@ -591,15 +631,17 @@ fn test_case_expr() { .then(lit("lteq_5")) .otherwise(lit("no match")) .alias("sign"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } #[test] fn test_case_expr_with_expression() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" SELECT CASE b%2 @@ -615,13 +657,14 @@ fn test_case_expr_with_expression() { .then(lit("odd")) .otherwise(lit("No?")) .alias("parity"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } #[test] fn test_sql_expr() { - let df = create_sample_df().unwrap(); + let df = create_sample_df(); let expr = sql_expr("MIN(a)").unwrap(); let actual = df.clone().lazy().select(&[expr]).collect().unwrap(); let expected = df.lazy().select(&[col("a").min()]).collect().unwrap(); @@ -630,17 +673,76 @@ fn test_sql_expr() { #[test] fn test_iss_9471() { - let sql = r#" - SELECT - ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") as "abs", - FROM df"#; let df = df! { "a" => [-4, -3, -2, -1, 0, 1, 2, 3, 4], } .unwrap() .lazy(); + let mut context = SQLContext::new(); context.register("df", df); + + let sql = r#" + SELECT + ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") AS "abs", + FROM df"#; let res = context.execute(sql); + assert!(res.is_err()) } + +#[test] +fn test_order_by_excluded_column() { + let df = df! { + "x" => [0, 1, 2, 3], + "y" => [4, 2, 0, 8], + } + .unwrap() + .lazy(); + + let mut context = SQLContext::new(); + context.register("df", df); + + for sql in [ + "SELECT * EXCLUDE y FROM df ORDER BY y", + "SELECT df.* EXCLUDE y FROM df ORDER BY y", + ] { + let df_sorted = context.execute(sql).unwrap().collect().unwrap(); + let expected = df! {"x" => [2, 1, 0, 3],}.unwrap(); + assert!(df_sorted.equals(&expected)); + } +} + +#[test] +fn test_struct_field_selection() { + let (df_struct, df_original) = create_struct_df(); + + let mut context = SQLContext::new(); + context.register("df", df_struct.clone().lazy()); + + for sql in [ + r#"SELECT json_msg.* FROM df ORDER BY 1"#, + r#"SELECT df.json_msg.* FROM df ORDER BY 3 DESC"#, + r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#, + r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.val DESC"#, + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + assert!(df_sql.equals(&df_original)); + } + + let sql = r#" + SELECT + json_msg.str AS id, + SUM(json_msg.num) AS sum_n + FROM df + GROUP BY json_msg.str + ORDER BY 1 + "#; + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + let df_expected = df! { + "id" => ["a", "b"], + "sum_n" => [600, 400], + } + .unwrap(); + assert!(df_sql.equals(&df_expected)); +} diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 0af4ae64fa86..dd1f89027c46 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -419,7 +419,7 @@ fn test_resolve_join_column_select_13618() { } #[test] -fn test_compound_join_nested_and_with_brackets() { +fn test_compound_join_and_select_exclude_rename_replace() { let df1 = df! { "a" => [1, 2, 3, 4, 5], "b" => [1, 2, 3, 4, 5], @@ -442,10 +442,13 @@ fn test_compound_join_nested_and_with_brackets() { ctx.register("df2", df2.lazy()); let sql = r#" - SELECT df1.* EXCLUDE "e", df2.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * RENAME ("ee" AS "e") + FROM ( + SELECT df1.* EXCLUDE "e", df2.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); let expected = df! { @@ -465,10 +468,13 @@ fn test_compound_join_nested_and_with_brackets() { ); let sql = r#" - SELECT * EXCLUDE ("e", "e:df2"), df1.e - FROM df1 - INNER JOIN df2 ON df1.a = df2.a AND - ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + SELECT * REPLACE ("ee" || "ee" AS "ee") + FROM ( + SELECT * EXCLUDE ("e", "e:df2"), df1.e AS "ee" + FROM df1 + INNER JOIN df2 ON df1.a = df2.a AND + ((df1.b = df2.b AND df1.c = df2.c) AND df1.d = df2.d) + ) tbl "#; let actual = ctx.execute(sql).unwrap().collect().unwrap(); @@ -481,7 +487,7 @@ fn test_compound_join_nested_and_with_brackets() { "b:df2" => [1, 3], "c:df2" => [0, 4], "d:df2" => [0, 4], - "e" => ["a", "c"], + "ee" => ["aa", "cc"], } .unwrap(); diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py index 691d6895be7b..8fb470508f31 100644 --- a/py-polars/tests/unit/sql/test_order_by.py +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -27,17 +27,19 @@ def test_order_by_basic(foods_ipc_path: Path) -> None: "category": ["vegetables", "seafood", "meat", "fruit"] } - order_by_group_by_res = foods.sql( - """ - SELECT category - FROM self - GROUP BY category - ORDER BY category DESC - """ - ).collect() - assert order_by_group_by_res.to_dict(as_series=False) == { - "category": ["vegetables", "seafood", "meat", "fruit"] - } + for category in ("category", "category AS cat"): + category_col = category.split(" ")[-1] + order_by_group_by_res = foods.sql( + f""" + SELECT {category} + FROM self + GROUP BY category + ORDER BY {category_col} DESC + """ + ).collect() + assert order_by_group_by_res.to_dict(as_series=False) == { + category_col: ["vegetables", "seafood", "meat", "fruit"] + } order_by_constructed_group_by_res = foods.sql( """ @@ -108,8 +110,8 @@ def test_order_by_misc_selection() -> None: assert res.to_dict(as_series=False) == {"x": [1, None, 3, 2]} # order by expression - res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY x % y") - assert res.to_dict(as_series=False) == {"xmy": [1, 3, None, None]} + res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY -(x % y)") + assert res.to_dict(as_series=False) == {"xmy": [3, 1, None, None]} res = df.sql("SELECT (x % y) AS xmy FROM self ORDER BY x % y NULLS FIRST") assert res.to_dict(as_series=False) == {"xmy": [None, None, 1, 3]} diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py new file mode 100644 index 000000000000..db965efcd86d --- /dev/null +++ b/py-polars/tests/unit/sql/test_structs.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import SQLSyntaxError, StructFieldNotFoundError +from polars.testing import assert_frame_equal + + +@pytest.fixture() +def df_struct() -> pl.DataFrame: + return pl.DataFrame( + { + "id": [200, 300, 400], + "name": ["Bob", "David", "Zoe"], + "age": [45, 19, 45], + "other": [{"n": 1.5}, {"n": None}, {"n": -0.5}], + } + ).select(pl.struct(pl.all()).alias("json_msg")) + + +def test_struct_field_selection(df_struct: pl.DataFrame) -> None: + res = df_struct.sql( + """ + SELECT + -- validate table alias resolution + frame.json_msg.id AS ID, + self.json_msg.name AS NAME, + json_msg.age AS AGE + FROM + self AS frame + WHERE + json_msg.age > 20 AND + json_msg.other.n IS NOT NULL -- note: nested struct field + ORDER BY + json_msg.name DESC + """ + ) + expected = pl.DataFrame({"ID": [400, 200], "NAME": ["Zoe", "Bob"], "AGE": [45, 45]}) + assert_frame_equal(expected, res) + + +def test_struct_field_group_by(df_struct: pl.DataFrame) -> None: + res = pl.sql( + """ + SELECT + COUNT(json_msg.age) AS n, + ARRAY_AGG(json_msg.name) AS names + FROM df_struct + GROUP BY json_msg.age + ORDER BY 1 DESC + """ + ).collect() + + expected = pl.DataFrame( + data={"n": [2, 1], "names": [["Bob", "Zoe"], ["David"]]}, + schema_overrides={"n": pl.UInt32}, + ) + assert_frame_equal(expected, res) + + +def test_struct_field_group_by_errors(df_struct: pl.DataFrame) -> None: + with pytest.raises( + SQLSyntaxError, + match="'name' should participate in the GROUP BY clause or an aggregate function", + ): + pl.sql( + """ + SELECT + json_msg.name, + SUM(json_msg.age) AS sum_age + FROM df_struct + GROUP BY json_msg.age + """ + ).collect() + + +@pytest.mark.parametrize( + ("fields", "excluding"), + [ + ("json_msg.*", ""), + ("self.json_msg.*", ""), + ("json_msg.other.*", ""), + ("self.json_msg.other.*", ""), + ], +) +def test_struct_field_wildcard_selection( + fields: str, + excluding: str, + df_struct: pl.DataFrame, +) -> None: + query = f"SELECT {fields} {excluding} FROM df_struct ORDER BY json_msg.id" + print(query) + res = pl.sql(query).collect() + + expected = df_struct.unnest("json_msg") + if fields.endswith(".other.*"): + expected = expected["other"].struct.unnest() + if excluding: + expected = expected.drop(excluding.split(",")) + + assert_frame_equal(expected, res) + + +@pytest.mark.parametrize( + "invalid_column", + [ + "json_msg.invalid_column", + "json_msg.other.invalid_column", + "self.json_msg.other.invalid_column", + ], +) +def test_struct_field_selection_errors( + invalid_column: str, df_struct: pl.DataFrame +) -> None: + with pytest.raises(StructFieldNotFoundError, match="invalid_column"): + df_struct.sql(f"SELECT {invalid_column} FROM self") diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py new file mode 100644 index 000000000000..ad17a215f7da --- /dev/null +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import DuplicateError + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame({"num": [999, 666], "str": ["b", "a"], "val": [2.0, 0.5]}) + + +@pytest.mark.parametrize( + ("excluded", "expected"), + [ + ("num", ["str", "val"]), + ("(val, num)", ["str"]), + ("(str, num)", ["val"]), + ("(str, val, num)", []), + ], +) +def test_select_exclude( + excluded: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * EXCLUDE {excluded} FROM self").columns == expected + + +def test_select_exclude_error(df: pl.DataFrame) -> None: + with pytest.raises(DuplicateError, match="the name 'num' is duplicate"): + # note: missing "()" around the exclude option results in dupe col + assert df.sql("SELECT * EXCLUDE val, num FROM self") + + +@pytest.mark.parametrize( + ("renames", "expected"), + [ + ("val AS value", ["num", "str", "value"]), + ("(num AS flt)", ["flt", "str", "val"]), + ("(val AS value, num AS flt)", ["flt", "str", "value"]), + ], +) +def test_select_rename( + renames: str, + expected: list[str], + df: pl.DataFrame, +) -> None: + assert df.sql(f"SELECT * RENAME {renames} FROM self").columns == expected + + +@pytest.mark.parametrize( + ("replacements", "check_cols", "expected"), + [ + ( + "(num // 3 AS num)", + ["num"], + [(333,), (222,)], + ), + ( + "((str || str) AS str, num / 3 AS num)", + ["num", "str"], + [(333, "bb"), (222, "aa")], + ), + ], +) +def test_select_replace( + replacements: str, + check_cols: list[str], + expected: list[tuple[Any]], + df: pl.DataFrame, +) -> None: + res = df.sql(f"SELECT * REPLACE {replacements} FROM self") + + assert res.select(check_cols).rows() == expected + assert res.columns == df.columns diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index ebd3ec4e73ff..dc89f0e43f2e 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -474,7 +474,7 @@ def test_with_column_duplicates() -> None: df = pl.DataFrame({"a": [0, None, 2, 3, None], "b": [None, 1, 2, 3, None]}) with pytest.raises( ComputeError, - match=r"the name: 'same' passed to `LazyFrame.with_columns` is duplicate.*", + match=r"the name 'same' passed to `LazyFrame.with_columns` is duplicate.*", ): assert df.with_columns([pl.all().alias("same")]).columns == ["a", "b", "same"]