diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index f6ea8bce1216..3b2c3c43c919 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -9,21 +9,28 @@ use polars_ops::frame::JoinCoalesce; use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, - JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, RenameSelectItem, Select, - SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, - TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, + BinaryOperator, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, + JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, + RenameSelectItem, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, + TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values, + WildcardAdditionalOptions, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; use crate::sql_expr::{ - parse_sql_array, parse_sql_expr, process_join_constraint, resolve_compound_identifier, - to_sql_interface_err, + parse_sql_array, parse_sql_expr, resolve_compound_identifier, to_sql_interface_err, }; use crate::table_functions::PolarsTableFunctions; +#[derive(Clone)] +pub struct TableInfo { + pub(crate) frame: LazyFrame, + pub(crate) name: String, + pub(crate) schema: Arc, +} + struct SelectModifiers { exclude: PlHashSet, // SELECT * EXCLUDE ilike: Option, // SELECT * ILIKE @@ -210,9 +217,13 @@ impl SQLContext { self.process_limit_offset(lf, &query.limit, &query.offset) } + pub(crate) fn get_frame_schema(&mut self, frame: &mut LazyFrame) -> PolarsResult { + frame.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena) + } + pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option { - let table_name = self.table_map.get(name).cloned(); - table_name + let table = self.table_map.get(name).cloned(); + table .or_else(|| self.cte_map.borrow().get(name).cloned()) .or_else(|| { self.table_aliases @@ -368,7 +379,7 @@ impl SQLContext { .how(join_type) .join_nulls(true); - let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let lf_schema = self.get_frame_schema(&mut lf)?; let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect(); let joined_tbl = match quantifier { SetQuantifier::ByName | SetQuantifier::AllByName => { @@ -376,7 +387,7 @@ impl SQLContext { join.on(lf_cols).finish() }, SetQuantifier::Distinct | SetQuantifier::None => { - let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let rf_schema = self.get_frame_schema(&mut rf)?; let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect(); if lf_cols.len() != rf_cols.len() { polars_bail!(SQLInterface: "{} requires equal number of columns in each table (use '{} BY NAME' to combine mismatched tables)", op_name, op_name) @@ -407,8 +418,8 @@ impl SQLContext { match quantifier { // UNION [ALL | DISTINCT] SetQuantifier::All | SetQuantifier::Distinct | SetQuantifier::None => { - let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let lf_schema = self.get_frame_schema(&mut lf)?; + let rf_schema = self.get_frame_schema(&mut rf)?; if lf_schema.len() != rf_schema.len() { polars_bail!(SQLInterface: "UNION requires equal number of columns in each table (use 'UNION BY NAME' to combine mismatched tables)") } @@ -543,51 +554,56 @@ impl SQLContext { fn execute_from_statement(&mut self, tbl_expr: &TableWithJoins) -> PolarsResult { let (l_name, mut lf) = self.get_table(&tbl_expr.relation)?; if !tbl_expr.joins.is_empty() { - for tbl in &tbl_expr.joins { - let (r_name, mut rf) = self.get_table(&tbl.relation)?; - let left_schema = - lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - let right_schema = - rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - - lf = match &tbl.join_operator { - JoinOperator::FullOuter(constraint) => { - self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)? - }, - JoinOperator::Inner(constraint) => { - self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? - }, - JoinOperator::LeftOuter(constraint) => { - self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)? - }, - #[cfg(feature = "semi_anti_join")] - JoinOperator::LeftAnti(constraint) => { - self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)? - }, - #[cfg(feature = "semi_anti_join")] - JoinOperator::LeftSemi(constraint) => { - self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)? - }, - #[cfg(feature = "semi_anti_join")] - JoinOperator::RightAnti(constraint) => { - self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)? - }, - #[cfg(feature = "semi_anti_join")] - JoinOperator::RightSemi(constraint) => { - self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)? + for join in &tbl_expr.joins { + let (r_name, mut rf) = self.get_table(&join.relation)?; + let left_schema = self.get_frame_schema(&mut lf)?; + let right_schema = self.get_frame_schema(&mut rf)?; + + lf = match &join.join_operator { + op @ (JoinOperator::FullOuter(constraint) + | JoinOperator::LeftOuter(constraint) + | JoinOperator::Inner(constraint) + | JoinOperator::LeftAnti(constraint) + | JoinOperator::LeftSemi(constraint) + | JoinOperator::RightAnti(constraint) + | JoinOperator::RightSemi(constraint)) => { + let (lf, rf) = match op { + JoinOperator::RightAnti(_) | JoinOperator::RightSemi(_) => (rf, lf), + _ => (lf, rf), + }; + self.process_join( + &TableInfo { + frame: lf, + name: l_name.clone(), + schema: left_schema.clone(), + }, + &TableInfo { + frame: rf, + name: r_name.clone(), + schema: right_schema.clone(), + }, + constraint, + match op { + JoinOperator::FullOuter(_) => JoinType::Full, + JoinOperator::LeftOuter(_) => JoinType::Left, + JoinOperator::Inner(_) => JoinType::Inner, + #[cfg(feature = "semi_anti_join")] + JoinOperator::LeftAnti(_) | JoinOperator::RightAnti(_) => JoinType::Anti, + #[cfg(feature = "semi_anti_join")] + JoinOperator::LeftSemi(_) | JoinOperator::RightSemi(_) => JoinType::Semi, + join_type => polars_bail!(SQLInterface: "join type '{:?}' not currently supported", join_type), + }, + )? }, JoinOperator::CrossJoin => lf.cross_join(rf, Some(format!(":{}", r_name))), join_type => { - polars_bail!( - SQLInterface: - "join type '{:?}' not currently supported", join_type - ); + polars_bail!(SQLInterface: "join type '{:?}' not currently supported", join_type) }, }; // track join-aliased columns so we can resolve them later - let joined_schema = - lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let joined_schema = self.get_frame_schema(&mut lf)?; + self.joined_aliases.borrow_mut().insert( r_name.to_string(), right_schema @@ -625,7 +641,7 @@ impl SQLContext { }; // Filter expression (WHERE clause) - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let schema = self.get_frame_schema(&mut lf)?; lf = self.process_where(lf, &select_stmt.selection)?; // 'SELECT *' modifiers @@ -636,45 +652,7 @@ impl SQLContext { replace: vec![], }; - // Column projections (SELECT clause) - let projections: Vec = select_stmt - .projection - .iter() - .map(|select_item| { - Ok(match select_item { - SelectItem::UnnamedExpr(expr) => { - vec![parse_sql_expr(expr, self, Some(schema.deref()))?] - }, - SelectItem::ExprWithAlias { expr, alias } => { - let expr = parse_sql_expr(expr, self, Some(schema.deref()))?; - vec![expr.alias(&alias.value)] - }, - SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self - .process_qualified_wildcard( - obj_name, - wildcard_options, - &mut select_modifiers, - Some(schema.deref()), - )?, - SelectItem::Wildcard(wildcard_options) => { - let cols = schema - .iter_names() - .map(|name| col(name)) - .collect::>(); - - self.process_wildcard_additional_options( - cols, - wildcard_options, - &mut select_modifiers, - Some(schema.deref()), - )? - }, - }) - }) - .collect::>>>()? - .into_iter() - .flatten() - .collect(); + let projections = self.column_projections(select_stmt, &schema, &mut select_modifiers)?; // Check for "GROUP BY ..." (after determining projections) let mut group_by_keys: Vec = Vec::new(); @@ -775,7 +753,7 @@ impl SQLContext { lf = self.process_order_by(lf, &query.order_by, None)?; // Apply optional 'having' clause, post-aggregation. - let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + let schema = Some(self.get_frame_schema(&mut lf)?); match select_stmt.having.as_ref() { Some(expr) => lf.filter(parse_sql_expr(expr, self, schema.as_deref())?), None => lf, @@ -787,7 +765,7 @@ impl SQLContext { Some(Distinct::Distinct) => lf.unique_stable(None, UniqueKeepStrategy::Any), Some(Distinct::On(exprs)) => { // TODO: support exprs in `unique` see https://github.com/pola-rs/polars/issues/5760 - let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + let schema = Some(self.get_frame_schema(&mut lf)?); let cols = exprs .iter() .map(|e| { @@ -811,14 +789,66 @@ impl SQLContext { Ok(lf) } + fn column_projections( + &mut self, + select_stmt: &Select, + schema: &SchemaRef, + select_modifiers: &mut SelectModifiers, + ) -> PolarsResult> { + let parsed_items: PolarsResult>> = select_stmt + .projection + .iter() + .map(|select_item| match select_item { + SelectItem::UnnamedExpr(expr) => { + Ok(vec![parse_sql_expr(expr, self, Some(schema))?]) + }, + SelectItem::ExprWithAlias { expr, alias } => { + let expr = parse_sql_expr(expr, self, Some(schema))?; + Ok(vec![expr.alias(&alias.value)]) + }, + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self + .process_qualified_wildcard( + obj_name, + wildcard_options, + select_modifiers, + Some(schema), + ), + SelectItem::Wildcard(wildcard_options) => { + let cols = schema + .iter_names() + .map(|name| col(name)) + .collect::>(); + + self.process_wildcard_additional_options( + cols, + wildcard_options, + select_modifiers, + Some(schema), + ) + }, + }) + .collect(); + + let flattened_exprs: Vec = parsed_items? + .into_iter() + .flatten() + .flat_map(|expr| expand_exprs(expr, schema)) + .collect(); + + Ok(flattened_exprs) + } + fn process_where( &mut self, mut lf: LazyFrame, expr: &Option, ) -> PolarsResult { if let Some(expr) = expr { - let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); + let schema = Some(self.get_frame_schema(&mut lf)?); let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?; + if filter_expression.clone().meta().has_multiple_outputs() { + filter_expression = all_horizontal([filter_expression])?; + } lf = self.process_subqueries(lf, vec![&mut filter_expression]); lf = lf.filter(filter_expression); } @@ -826,28 +856,27 @@ impl SQLContext { } pub(super) fn process_join( - &self, - left_tbl: LazyFrame, - right_tbl: LazyFrame, + &mut self, + tbl_left: &TableInfo, + tbl_right: &TableInfo, constraint: &JoinConstraint, - tbl_name: &str, - join_tbl_name: &str, join_type: JoinType, ) -> PolarsResult { - let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?; + let (left_on, right_on) = process_join_constraint(constraint, tbl_left, tbl_right)?; - let joined_tbl = left_tbl + let joined = tbl_left + .frame .clone() .join_builder() - .with(right_tbl.clone()) + .with(tbl_right.frame.clone()) .left_on(left_on) .right_on(right_on) .how(join_type) - .suffix(format!(":{}", join_tbl_name)) + .suffix(format!(":{}", tbl_right.name)) .coalesce(JoinCoalesce::KeepColumns) .finish(); - Ok(joined_tbl) + Ok(joined) } fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame { @@ -1032,7 +1061,7 @@ impl SQLContext { order_by: &[OrderByExpr], selected: Option<&[Expr]>, ) -> PolarsResult { - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let schema = self.get_frame_schema(&mut lf)?; let columns_iter = schema.iter_names().map(|e| col(e)); let mut descending = Vec::with_capacity(order_by.len()); @@ -1084,7 +1113,7 @@ impl SQLContext { group_by_keys: &[Expr], projections: &[Expr], ) -> PolarsResult { - let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let schema_before = self.get_frame_schema(&mut lf)?; let group_by_keys_schema = expressions_to_schema(group_by_keys, &schema_before, Context::Default)?; @@ -1297,13 +1326,13 @@ impl SQLContext { fn rename_columns_from_table_alias( &mut self, - mut frame: LazyFrame, + mut lf: LazyFrame, alias: &TableAlias, ) -> PolarsResult { if alias.columns.is_empty() { - Ok(frame) + Ok(lf) } else { - let schema = frame.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let schema = self.get_frame_schema(&mut lf)?; if alias.columns.len() != schema.len() { polars_bail!( SQLSyntax: "number of columns ({}) in alias '{}' does not match the number of columns in the table/query ({})", @@ -1312,7 +1341,7 @@ impl SQLContext { } else { let existing_columns: Vec<_> = schema.iter_names().collect(); let new_columns: Vec<_> = alias.columns.iter().map(|c| c.value.clone()).collect(); - Ok(frame.rename(existing_columns, new_columns)) + Ok(lf.rename(existing_columns, new_columns)) } } } @@ -1332,3 +1361,131 @@ impl SQLContext { } } } + +fn collect_compound_identifiers( + left: &[Ident], + right: &[Ident], + left_name: &str, + right_name: &str, +) -> PolarsResult<(Vec, Vec)> { + if left.len() == 2 && right.len() == 2 { + let (tbl_a, col_a) = (&left[0].value, &left[1].value); + let (tbl_b, col_b) = (&right[0].value, &right[1].value); + + // switch left/right operands if the caller has them in reverse + if left_name == tbl_b || right_name == tbl_a { + Ok((vec![col(col_b)], vec![col(col_a)])) + } else { + Ok((vec![col(col_a)], vec![col(col_b)])) + } + } else { + polars_bail!(SQLInterface: "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); + } +} + +fn expand_exprs(expr: Expr, schema: &SchemaRef) -> Vec { + match expr { + Expr::Wildcard => schema + .iter_names() + .map(|name| col(name)) + .collect::>(), + Expr::Column(nm) if is_regex_colname(nm.clone()) => { + let rx = regex::Regex::new(&nm).unwrap(); + schema + .iter_names() + .filter(|name| rx.is_match(name)) + .map(|name| col(name)) + .collect::>() + }, + Expr::Columns(names) => names.iter().map(|name| col(name)).collect::>(), + _ => vec![expr], + } +} + +fn is_regex_colname(nm: ColumnName) -> bool { + nm.starts_with('^') && nm.ends_with('$') +} + +fn process_join_on( + expression: &sqlparser::ast::Expr, + tbl_left: &TableInfo, + tbl_right: &TableInfo, +) -> PolarsResult<(Vec, Vec)> { + if let SQLExpr::BinaryOp { left, op, right } = expression { + match *op { + BinaryOperator::Eq => { + if let (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) = + (left.as_ref(), right.as_ref()) + { + collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name) + } else { + polars_bail!(SQLInterface: "JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); + } + }, + BinaryOperator::And => { + let (mut left_i, mut right_i) = process_join_on(left, tbl_left, tbl_right)?; + let (mut left_j, mut right_j) = process_join_on(right, tbl_left, tbl_right)?; + + left_i.append(&mut left_j); + right_i.append(&mut right_j); + Ok((left_i, right_i)) + }, + _ => { + polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); + }, + } + } else if let SQLExpr::Nested(expr) = expression { + process_join_on(expr, tbl_left, tbl_right) + } else { + polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); + } +} + +fn process_join_constraint( + constraint: &JoinConstraint, + tbl_left: &TableInfo, + tbl_right: &TableInfo, +) -> PolarsResult<(Vec, Vec)> { + if let JoinConstraint::On(SQLExpr::BinaryOp { left, op, right }) = constraint { + if op == &BinaryOperator::And { + let (mut left_on, mut right_on) = process_join_on(left, tbl_left, tbl_right)?; + let (left_on_, right_on_) = process_join_on(right, tbl_left, tbl_right)?; + left_on.extend(left_on_); + right_on.extend(right_on_); + return Ok((left_on, right_on)); + } + if op != &BinaryOperator::Eq { + polars_bail!(SQLInterface: + "only equi-join constraints are supported; found '{:?}' op in\n{:?}", op, constraint) + } + match (left.as_ref(), right.as_ref()) { + (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => { + return collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name); + }, + (SQLExpr::Identifier(left), SQLExpr::Identifier(right)) => { + return Ok((vec![col(&left.value)], vec![col(&right.value)])) + }, + _ => {}, + } + }; + if let JoinConstraint::Using(idents) = constraint { + if !idents.is_empty() { + let using: Vec = idents.iter().map(|id| col(&id.value)).collect(); + return Ok((using.clone(), using.clone())); + } + }; + if let JoinConstraint::Natural = constraint { + let left_names = tbl_left.schema.iter_names().collect::>(); + let right_names = tbl_right.schema.iter_names().collect::>(); + let on = left_names + .intersection(&right_names) + .map(|name| col(name)) + .collect::>(); + if on.is_empty() { + polars_bail!(SQLInterface: "no common columns found for NATURAL JOIN") + } + Ok((on.clone(), on)) + } else { + polars_bail!(SQLInterface: "unsupported SQL join constraint:\n{:?}", constraint); + } +} diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index c5034dca0966..80af7e15f616 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,14 +1,13 @@ use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; -use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, TimeUnit}; +use polars_core::export::regex; +use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit}; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] use polars_lazy::dsl::ListNameSpaceExtension; use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when}; use polars_plan::plans::{typed_lit, LiteralValue}; -#[cfg(feature = "list_eval")] -use polars_plan::prelude::col; use polars_plan::prelude::LiteralValue::Null; -use polars_plan::prelude::{lit, StrptimeOptions}; +use polars_plan::prelude::{col, cols, lit, StrptimeOptions}; use sqlparser::ast::{ DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident, @@ -21,6 +20,7 @@ use crate::SQLContext; pub(crate) struct SQLFunctionVisitor<'a> { pub(crate) func: &'a SQLFunction, pub(crate) ctx: &'a mut SQLContext, + pub(crate) active_schema: Option<&'a Schema>, } /// SQL functions that are supported by Polars @@ -602,6 +602,15 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT ARRAY_CONTAINS(column_1, 'foo') FROM df; /// ``` ArrayContains, + + // ---- + // Column selection + // ---- + Columns, + + // ---- + // User-defined + // ---- Udf(String), } @@ -635,6 +644,7 @@ impl PolarsSQLFunctions { "char_length", "character_length", "coalesce", + "columns", "concat", "concat_ws", "cos", @@ -685,11 +695,13 @@ impl PolarsSQLFunctions { "sind", "sqrt", "starts_with", - "stdev", "stddev", - "stdev_samp", "stddev_samp", + "stdev", + "stdev_samp", + "strftime", "strpos", + "strptime", "substr", "sum", "tan", @@ -697,8 +709,8 @@ impl PolarsSQLFunctions { "unnest", "upper", "var", - "variance", "var_samp", + "variance", ] } } @@ -824,6 +836,11 @@ impl PolarsSQLFunctions { "array_upper" => Self::ArrayMax, "unnest" => Self::Explode, + // ---- + // Column selection + // ---- + "columns" => Self::Columns, + other => { if ctx.function_registry.contains(other) { Self::Udf(other.to_string()) @@ -1249,6 +1266,49 @@ impl SQLFunctionVisitor<'_> { ArrayToString => self.visit_arr_to_string(), ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), + + // ---- + // Column selection + // ---- + Columns => { + let active_schema = self.active_schema; + self.try_visit_unary(|e: Expr| { + match e { + Expr::Literal(LiteralValue::String(pat)) => { + if "*" == pat { + polars_bail!(SQLSyntax: "COLUMNS('*') is not a valid regex; did you mean COLUMNS(*)?") + }; + let pat = match pat.as_str() { + _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(), + _ if pat.starts_with('^') => format!("{}.*$", pat), + _ if pat.ends_with('$') => format!("^.*{}", pat), + _ => format!("^.*{}.*$", pat), + }; + if let Some(active_schema) = &active_schema { + let rx = regex::Regex::new(&pat).unwrap(); + let col_names = active_schema + .iter_names() + .filter(|name| rx.is_match(name)) + .collect::>(); + + Ok(if col_names.len() == 1 { + col(col_names[0]) + } else { + cols(col_names) + }) + } else { + Ok(col(&pat)) + } + }, + Expr::Wildcard => Ok(col("*")), + _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e), + } + }) + }, + + // ---- + // User-defined + // ---- Udf(func_name) => self.visit_udf(&func_name), } } @@ -1258,7 +1318,7 @@ impl SQLFunctionVisitor<'_> { .into_iter() .map(|arg| { if let FunctionArgExpr::Expr(e) = arg { - parse_sql_expr(e, self.ctx, None) + parse_sql_expr(e, self.ctx, self.active_schema) } else { polars_bail!(SQLInterface: "only expressions are supported in UDFs") } @@ -1272,32 +1332,6 @@ impl SQLFunctionVisitor<'_> { .call(args) } - fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { - self.visit_unary_no_window(f) - .and_then(|e| self.apply_window_spec(e, &self.func.over)) - } - - /// Some functions have cumulative equivalents that can be applied to window specs - /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false) - /// visit_unary_with_cumulative_window will take in a function & a cumulative function - /// if there is a cumulative window spec, it will apply the cumulative function, - /// otherwise it will apply the function - fn visit_unary_with_opt_cumulative( - &mut self, - f: impl Fn(Expr) -> Expr, - cumulative_f: impl Fn(Expr, bool) -> Expr, - ) -> PolarsResult { - match self.func.over.as_ref() { - Some(WindowType::WindowSpec(spec)) => { - self.apply_cumulative_window(f, cumulative_f, spec) - }, - Some(WindowType::NamedWindow(named_window)) => polars_bail!( - SQLInterface: "Named windows are not currently supported; found {:?}", - named_window - ), - _ => self.visit_unary(f), - } - } /// Window specs without partition bys are essentially cumulative functions /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false) fn apply_cumulative_window( @@ -1314,7 +1348,7 @@ impl SQLFunctionVisitor<'_> { let (order_by, desc): (Vec, Vec) = order_by .iter() .map(|o| { - let expr = parse_sql_expr(&o.expr, self.ctx, None)?; + let expr = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?; Ok(match o.asc { Some(b) => (expr, !b), None => (expr, false), @@ -1337,11 +1371,53 @@ impl SQLFunctionVisitor<'_> { } } + fn visit_unary(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { + self.try_visit_unary(|e| Ok(f(e))) + } + + fn try_visit_unary(&mut self, f: impl Fn(Expr) -> PolarsResult) -> PolarsResult { + let args = extract_args(self.func)?; + match args.as_slice() { + [FunctionArgExpr::Expr(sql_expr)] => { + f(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?) + }, + [FunctionArgExpr::Wildcard] => f(parse_sql_expr( + &SQLExpr::Wildcard, + self.ctx, + self.active_schema, + )?), + _ => self.not_supported_error(), + } + .and_then(|e| self.apply_window_spec(e, &self.func.over)) + } + + /// Some functions have cumulative equivalents that can be applied to window specs + /// e.g. SUM(a) OVER (ORDER BY b DESC) -> CUMSUM(a, false) + /// visit_unary_with_cumulative_window will take in a function & a cumulative function + /// if there is a cumulative window spec, it will apply the cumulative function, + /// otherwise it will apply the function + fn visit_unary_with_opt_cumulative( + &mut self, + f: impl Fn(Expr) -> Expr, + cumulative_f: impl Fn(Expr, bool) -> Expr, + ) -> PolarsResult { + match self.func.over.as_ref() { + Some(WindowType::WindowSpec(spec)) => { + self.apply_cumulative_window(f, cumulative_f, spec) + }, + Some(WindowType::NamedWindow(named_window)) => polars_bail!( + SQLInterface: "Named windows are not currently supported; found {:?}", + named_window + ), + _ => self.visit_unary(f), + } + } + fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { let args = extract_args(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr)] => { - let expr = parse_sql_expr(sql_expr, self.ctx, None)?; + let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?; // apply the function on the inner expr -- e.g. SUM(a) -> SUM Ok(f(expr)) }, @@ -1363,7 +1439,7 @@ impl SQLFunctionVisitor<'_> { let args = extract_args(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; f(expr1, expr2) }, @@ -1383,7 +1459,7 @@ impl SQLFunctionVisitor<'_> { let mut expr_args = vec![]; for arg in args { if let FunctionArgExpr::Expr(sql_expr) = arg { - expr_args.push(parse_sql_expr(sql_expr, self.ctx, None)?); + expr_args.push(parse_sql_expr(sql_expr, self.ctx, self.active_schema)?); } else { return self.not_supported_error(); }; @@ -1399,7 +1475,7 @@ impl SQLFunctionVisitor<'_> { match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] => { - let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; + let expr1 = parse_sql_expr(sql_expr1, self.ctx, self.active_schema)?; let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?; f(expr1, expr2, expr3) @@ -1420,7 +1496,7 @@ impl SQLFunctionVisitor<'_> { let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr)] => { - let mut base = parse_sql_expr(sql_expr, self.ctx, None)?; + let mut base = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?; if is_distinct { base = base.unique_stable(); } @@ -1430,7 +1506,7 @@ impl SQLFunctionVisitor<'_> { base = self.apply_order_by(base, order_exprs.as_slice())?; }, FunctionArgumentClause::Limit(limit_expr) => { - let limit = parse_sql_expr(&limit_expr, self.ctx, None)?; + let limit = parse_sql_expr(&limit_expr, self.ctx, self.active_schema)?; match limit { Expr::Literal(LiteralValue::Int(n)) if n >= 0 => { base = base.head(Some(n as usize)) @@ -1489,13 +1565,13 @@ impl SQLFunctionVisitor<'_> { (false, [FunctionArgExpr::Wildcard] | []) => Ok(len()), // count(column_name) (false, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx, None)?; + let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.count()) }, // count(distinct column_name) (true, [FunctionArgExpr::Expr(sql_expr)]) => { - let expr = parse_sql_expr(sql_expr, self.ctx, None)?; + let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?; let expr = self.apply_window_spec(expr, &self.func.over)?; Ok(expr.n_unique()) }, @@ -1512,7 +1588,7 @@ impl SQLFunctionVisitor<'_> { // note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise // https://www.postgresql.org/docs/current/queries-order.html let desc_order = !ob.asc.unwrap_or(true); - by.push(parse_sql_expr(&ob.expr, self.ctx, None)?); + by.push(parse_sql_expr(&ob.expr, self.ctx, self.active_schema)?); nulls_last.push(!ob.nulls_first.unwrap_or(desc_order)); descending.push(desc_order); } @@ -1537,7 +1613,7 @@ impl SQLFunctionVisitor<'_> { .order_by .iter() .map(|o| { - let e = parse_sql_expr(&o.expr, self.ctx, None)?; + let e = parse_sql_expr(&o.expr, self.ctx, self.active_schema)?; Ok(o.asc.map_or(e.clone(), |b| { e.sort(SortOptions::default().with_order_descending(!b)) })) @@ -1549,7 +1625,7 @@ impl SQLFunctionVisitor<'_> { let partition_by = window_spec .partition_by .iter() - .map(|p| parse_sql_expr(p, self.ctx, None)) + .map(|p| parse_sql_expr(p, self.ctx, self.active_schema)) .collect::>>()?; expr.over(partition_by) } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 37711b6d6492..da6f24904d5d 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -17,8 +17,8 @@ use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, - TrimWhereField, UnaryOperator, Value as SQLValue, + Interval, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, + UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -398,6 +398,7 @@ impl SQLExprVisitor<'_> { }, SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr), SQLExpr::Value(value) => self.visit_literal(value), + SQLExpr::Wildcard => Ok(Expr::Wildcard), e @ SQLExpr::Case { .. } => self.visit_case_when_then(e), other => { polars_bail!(SQLInterface: "expression {:?} is not currently supported", other) @@ -414,7 +415,7 @@ impl SQLExprVisitor<'_> { polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause"); } let mut lf = self.ctx.execute_query_no_ctes(subquery)?; - let schema = lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena)?; + let schema = self.ctx.get_frame_schema(&mut lf)?; if restriction == SubqueryRestriction::SingleColumn { if schema.len() != 1 { @@ -767,6 +768,7 @@ impl SQLExprVisitor<'_> { let mut visitor = SQLFunctionVisitor { func: function, ctx: self.ctx, + active_schema: self.active_schema, }; visitor.visit_function() } @@ -1118,97 +1120,6 @@ impl SQLExprVisitor<'_> { } } -fn collect_compound_identifiers( - left: &[Ident], - right: &[Ident], - left_name: &str, - right_name: &str, -) -> PolarsResult<(Vec, Vec)> { - if left.len() == 2 && right.len() == 2 { - let (tbl_a, col_a) = (&left[0].value, &left[1].value); - let (tbl_b, col_b) = (&right[0].value, &right[1].value); - - // switch left/right operands if the caller has them in reverse - if left_name == tbl_b || right_name == tbl_a { - Ok((vec![col(col_b)], vec![col(col_a)])) - } else { - Ok((vec![col(col_a)], vec![col(col_b)])) - } - } else { - polars_bail!(SQLInterface: "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); - } -} - -fn process_join_on( - expression: &sqlparser::ast::Expr, - left_name: &str, - right_name: &str, -) -> PolarsResult<(Vec, Vec)> { - if let SQLExpr::BinaryOp { left, op, right } = expression { - match *op { - BinaryOperator::Eq => { - if let (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) = - (left.as_ref(), right.as_ref()) - { - collect_compound_identifiers(left, right, left_name, right_name) - } else { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); - } - }, - BinaryOperator::And => { - let (mut left_i, mut right_i) = process_join_on(left, left_name, right_name)?; - let (mut left_j, mut right_j) = process_join_on(right, left_name, right_name)?; - left_i.append(&mut left_j); - right_i.append(&mut right_j); - Ok((left_i, right_i)) - }, - _ => { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); - }, - } - } else if let SQLExpr::Nested(expr) = expression { - process_join_on(expr, left_name, right_name) - } else { - polars_bail!(SQLInterface: "JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); - } -} - -pub(super) fn process_join_constraint( - constraint: &JoinConstraint, - left_name: &str, - right_name: &str, -) -> PolarsResult<(Vec, Vec)> { - if let JoinConstraint::On(SQLExpr::BinaryOp { left, op, right }) = constraint { - if op == &BinaryOperator::And { - let (mut left_on, mut right_on) = process_join_on(left, left_name, right_name)?; - let (left_on_, right_on_) = process_join_on(right, left_name, right_name)?; - left_on.extend(left_on_); - right_on.extend(right_on_); - return Ok((left_on, right_on)); - } - if op != &BinaryOperator::Eq { - polars_bail!(SQLInterface: - "only equi-join constraints are supported; found '{:?}' op in\n{:?}", op, constraint) - } - match (left.as_ref(), right.as_ref()) { - (SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => { - return collect_compound_identifiers(left, right, left_name, right_name); - }, - (SQLExpr::Identifier(left), SQLExpr::Identifier(right)) => { - return Ok((vec![col(&left.value)], vec![col(&right.value)])) - }, - _ => {}, - } - } - if let JoinConstraint::Using(idents) = constraint { - if !idents.is_empty() { - let using: Vec = idents.iter().map(|id| col(&id.value)).collect(); - return Ok((using.clone(), using.clone())); - } - } - polars_bail!(SQLInterface: "unsupported SQL join constraint:\n{:?}", constraint); -} - /// parse a SQL expression to a polars expression /// # Example /// ```rust diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index dd1f89027c46..e5be8e598b60 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -525,6 +525,41 @@ fn test_join_on_different_keys() { ); } +#[test] +fn test_join_multi_consecutive() { + let df1 = df! { "a" => [1, 2, 3], "b" => [4, 8, 6] }.unwrap(); + let df2 = df! { "a" => [3, 2, 1], "b" => [6, 5, 4], "c" => ["x", "y", "z"] }.unwrap(); + let df3 = df! { "c" => ["w", "y", "z"], "d" => [10.5, -50.0, 25.5] }.unwrap(); + + let mut ctx = SQLContext::new(); + ctx.register("tbl_a", df1.lazy()); + ctx.register("tbl_b", df2.lazy()); + ctx.register("tbl_c", df3.lazy()); + + let sql = r#" + SELECT tbl_a.a, tbl_a.b, tbl_b.c, tbl_c.d FROM tbl_a + INNER JOIN tbl_b ON tbl_a.a = tbl_b.a AND tbl_a.b = tbl_b.b + INNER JOIN tbl_c ON tbl_a.c = tbl_c.c + ORDER BY a DESC + "#; + let actual = ctx.execute(sql).unwrap().collect().unwrap(); + + let expected = df! { + "a" => [1], + "b" => [4], + "c" => ["z"], + "d" => [25.5], + } + .unwrap(); + + assert!( + actual.equals(&expected), + "expected = {:?}\nactual={:?}", + expected, + actual + ); +} + #[test] fn test_join_utf8() { // (色) color and (野菜) vegetable diff --git a/py-polars/docs/source/reference/sql/clauses.rst b/py-polars/docs/source/reference/sql/clauses.rst index ab5353dab127..894b28ac3f2d 100644 --- a/py-polars/docs/source/reference/sql/clauses.rst +++ b/py-polars/docs/source/reference/sql/clauses.rst @@ -127,13 +127,11 @@ Combines rows from two or more tables based on a related column. **Join Types** * `CROSS JOIN` -* `FULL JOIN` -* `INNER JOIN` -* `LEFT JOIN` -* `[LEFT] ANTI JOIN` -* `[LEFT] SEMI JOIN` -* `RIGHT ANTI JOIN` -* `RIGHT SEMI JOIN` +* `[NATURAL] FULL JOIN` +* `[NATURAL] INNER JOIN` +* `[NATURAL] LEFT JOIN` +* `[LEFT | RIGHT] ANTI JOIN` +* `[LEFT | RIGHT] SEMI JOIN` **Example:** @@ -156,7 +154,6 @@ Combines rows from two or more tables based on a related column. FROM df1 FULL JOIN df2 USING (ham) """).collect() - # shape: (4, 3) # ┌──────┬───────┬─────┐ # │ foo ┆ apple ┆ ham │ @@ -169,6 +166,20 @@ Combines rows from two or more tables based on a related column. # │ 3 ┆ null ┆ c │ # └──────┴───────┴─────┘ + pl.sql(""" + SELECT COLUMNS('^\w+$') + FROM df1 NATURAL INNER JOIN df2 + """).collect() + # shape: (2, 3) + # ┌─────┬───────┬─────┐ + # │ foo ┆ apple ┆ ham │ + # │ --- ┆ --- ┆ --- │ + # │ i64 ┆ str ┆ str │ + # ╞═════╪═══════╪═════╡ + # │ 1 ┆ x ┆ a │ + # │ 2 ┆ y ┆ b │ + # └─────┴───────┴─────┘ + .. _where: WHERE diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index 10c7e2f71d5e..b278422973ae 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -6,7 +6,7 @@ import pytest import polars as pl -from polars.exceptions import SQLInterfaceError +from polars.exceptions import SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal @@ -323,3 +323,145 @@ def test_implicit_joins() -> None: WHERE t1.a = t2.b """ ) + + +def test_natural_joins_01() -> None: + df1 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 4], + "FirstName": ["Jernau Morat", "Cheradenine", "Byr", "Diziet"], + "LastName": ["Gurgeh", "Zakalwe", "Genar-Hofoen", "Sma"], + } + ) + df2 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 5], + "Role": ["Protagonist", "Protagonist", "Protagonist", "Antagonist"], + "Book": [ + "Player of Games", + "Use of Weapons", + "Excession", + "Consider Phlebas", + ], + } + ) + df3 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 4], + "Affiliation": ["Culture", "Culture", "Culture", "Shellworld"], + "Species": ["Pan-human", "Human", "Human", "Oct"], + } + ) + df4 = pl.DataFrame( + { + "CharacterID": [1, 2, 3, 6], + "Ship": [ + "Limiting Factor", + "Xenophobe", + "Grey Area", + "Falling Outside The Normal Moral Constraints", + ], + "Drone": ["Flere-Imsaho", "Skaffen-Amtiskaw", "Eccentric", "Psychopath"], + } + ) + with pl.SQLContext( + {"df1": df1, "df2": df2, "df3": df3, "df4": df4}, eager=True + ) as ctx: + # note: use of 'COLUMNS' is a neat way to drop + # all non-coalesced ":" cols + res = ctx.execute( + """ + SELECT COLUMNS('^[^:]*$') + FROM df1 + NATURAL LEFT JOIN df2 + NATURAL INNER JOIN df3 + NATURAL LEFT JOIN df4 + ORDER BY ALL + """ + ) + assert res.rows(named=True) == [ + { + "CharacterID": 1, + "FirstName": "Jernau Morat", + "LastName": "Gurgeh", + "Role": "Protagonist", + "Book": "Player of Games", + "Affiliation": "Culture", + "Species": "Pan-human", + "Ship": "Limiting Factor", + "Drone": "Flere-Imsaho", + }, + { + "CharacterID": 2, + "FirstName": "Cheradenine", + "LastName": "Zakalwe", + "Role": "Protagonist", + "Book": "Use of Weapons", + "Affiliation": "Culture", + "Species": "Human", + "Ship": "Xenophobe", + "Drone": "Skaffen-Amtiskaw", + }, + { + "CharacterID": 3, + "FirstName": "Byr", + "LastName": "Genar-Hofoen", + "Role": "Protagonist", + "Book": "Excession", + "Affiliation": "Culture", + "Species": "Human", + "Ship": "Grey Area", + "Drone": "Eccentric", + }, + { + "CharacterID": 4, + "FirstName": "Diziet", + "LastName": "Sma", + "Role": None, + "Book": None, + "Affiliation": "Shellworld", + "Species": "Oct", + "Ship": None, + "Drone": None, + }, + ] + + # misc errors + with pytest.raises(SQLSyntaxError, match=r"did you mean COLUMNS\(\*\)\?"): + pl.sql("SELECT * FROM df1 NATURAL JOIN df2 WHERE COLUMNS('*') >= 5") + + with pytest.raises(SQLSyntaxError, match=r"COLUMNS expects a regex"): + pl.sql("SELECT COLUMNS(1234) FROM df1 NATURAL JOIN df2") + + +@pytest.mark.parametrize( + ("cols_constraint", "expected"), + [ + (">= 5", [(8, 8, 6)]), + ("< 7", [(5, 4, 4)]), + ("< 8", [(5, 4, 4), (7, 4, 4), (0, 7, 2)]), + ("!= 4", [(8, 8, 6), (2, 8, 6), (0, 7, 2)]), + ], +) +def test_natural_joins_02(cols_constraint: str, expected: list[tuple[int]]) -> None: + df1 = pl.DataFrame( # noqa: F841 + { + "x": [1, 5, 3, 8, 6, 7, 4, 0, 2], + "y": [3, 4, 6, 8, 3, 4, 1, 7, 8], + } + ) + df2 = pl.DataFrame( # noqa: F841 + { + "y": [0, 4, 0, 8, 0, 4, 0, 7, None], + "z": [9, 8, 7, 6, 5, 4, 3, 2, 1], + }, + ) + actual = pl.sql( + f""" + SELECT * EXCLUDE "y:df2" + FROM df1 NATURAL JOIN df2 + WHERE COLUMNS(*) {cols_constraint} + """ + ).collect() + + assert actual.rows() == expected