diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index 4f168d1080d1..fee7913dede7 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -634,7 +634,7 @@ fn find_flags(expr: &Expr) -> PolarsResult { /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns -pub(crate) fn rewrite_projections( +pub fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index d0d6a41e9fb9..afdac2d300fc 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -1,6 +1,6 @@ mod convert_utils; mod dsl_to_ir; -mod expr_expansion; +pub(crate) mod expr_expansion; mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index ca9acc44cf53..9255c811e489 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -16,7 +16,7 @@ pub(crate) mod ir; mod apply; mod builder_dsl; mod builder_ir; -pub(crate) mod conversion; +pub mod conversion; #[cfg(feature = "debugging")] pub(crate) mod debug; pub mod expr_ir; diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index d90e032cc925..34c38cefbdab 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -11,6 +11,7 @@ pub(crate) use polars_time::prelude::*; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; +pub use crate::plans::conversion::expr_expansion::rewrite_projections; #[cfg(feature = "debugging")] pub use crate::plans::debug::*; pub use crate::plans::options::*; diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 5e44f1956332..cce84fcc9596 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -6,7 +6,7 @@ use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; use polars_plan::prelude::*; use sqlparser::ast::{ - Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint, + Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, @@ -600,36 +600,44 @@ impl SQLContext { lf = self.process_where(lf, &select_stmt.selection)?; // Column projections. - let projections: Vec<_> = select_stmt + let projections: Vec = select_stmt .projection .iter() .map(|select_item| { Ok(match select_item { - SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, self, schema.as_deref())?, + SelectItem::UnnamedExpr(expr) => { + vec![parse_sql_expr(expr, self, schema.as_deref())?] + }, SelectItem::ExprWithAlias { expr, alias } => { let expr = parse_sql_expr(expr, self, schema.as_deref())?; - expr.alias(&alias.value) + vec![expr.alias(&alias.value)] }, - SelectItem::QualifiedWildcard(oname, wildcard_options) => self - .process_qualified_wildcard( - oname, + SelectItem::QualifiedWildcard(obj_name, wildcard_options) => { + let expanded = self.process_qualified_wildcard( + obj_name, wildcard_options, &mut contains_wildcard_exclude, - )?, + schema.as_deref(), + )?; + rewrite_projections(vec![expanded], &(schema.clone().unwrap()), &[])? + }, SelectItem::Wildcard(wildcard_options) => { contains_wildcard = true; let e = col("*"); - self.process_wildcard_additional_options( + vec![self.process_wildcard_additional_options( e, wildcard_options, &mut contains_wildcard_exclude, - )? + )?] }, }) }) - .collect::>()?; + .collect::>>>()? + .into_iter() + .flatten() + .collect(); - // Check for "GROUP BY ..." (after projections, as there may be ordinal/position ints). + // Check for "GROUP BY ..." (after determining projections) let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) @@ -1152,25 +1160,13 @@ impl SQLContext { ObjectName(idents): &ObjectName, options: &WildcardAdditionalOptions, contains_wildcard_exclude: &mut bool, + schema: Option<&Schema>, ) -> PolarsResult { - let idents = idents.as_slice(); - let e = match idents { - [tbl_name] => { - let lf = self.table_map.get_mut(&tbl_name.value).ok_or_else(|| { - polars_err!( - SQLInterface: "no table named '{}' found", - tbl_name - ) - })?; - let schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; - cols(schema.iter_names()) - }, - e => polars_bail!( - SQLSyntax: "invalid wildcard expression ({:?})", - e - ), - }; - self.process_wildcard_additional_options(e, options, contains_wildcard_exclude) + let mut new_idents = idents.clone(); + new_idents.push(Ident::new("*")); + let identifier = SQLExpr::CompoundIdentifier(new_idents); + let expr = parse_sql_expr(&identifier, self, schema)?; + self.process_wildcard_additional_options(expr, options, contains_wildcard_exclude) } fn process_wildcard_additional_options( diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 39a9eca5c818..a4ae10997715 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -390,23 +390,28 @@ impl SQLExprVisitor<'_> { } else { Schema::new() })) - }; + }?; - let mut column: PolarsResult = if lf.is_none() { + let mut column: PolarsResult = if lf.is_none() && schema.is_empty() { Ok(col(&ident_root.value)) } else { - let col_name = &remaining_idents.next().unwrap().value; - if let Some((_, name, _)) = schema?.get_full(col_name) { - let resolved = &self.ctx.resolve_name(&ident_root.value, col_name); + let name = &remaining_idents.next().unwrap().value; + if lf.is_some() && name == "*" { + Ok(cols(schema.iter_names())) + } else if let Some((_, name, _)) = schema.get_full(name) { + let resolved = &self.ctx.resolve_name(&ident_root.value, name); Ok(if name != resolved { col(resolved).alias(name) } else { col(name) }) + } else if lf.is_none() { + remaining_idents = idents.iter().skip(1); + Ok(col(&ident_root.value)) } else { polars_bail!( SQLInterface: "no column named '{}' found in table '{}'", - col_name, + name, ident_root ) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 0a60b9dc7aca..77980bf54f77 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -326,21 +326,21 @@ fn test_binary_functions() { SELECT a, b, - a + b as add, - a - b as sub, - a * b as mul, - a / b as div, - a % b as rem, - a <> b as neq, - a = b as eq, - a > b as gt, - a < b as lt, - a >= b as gte, - a <= b as lte, - a and b as and, - a or b as or, - a xor b as xor, - a || b as concat + a + b AS add, + a - b AS sub, + a * b AS mul, + a / b AS div, + a % b AS rem, + a <> b AS neq, + a = b AS eq, + a > b AS gt, + a < b AS lt, + a >= b AS gte, + a <= b AS lte, + a and b AS and, + a or b AS or, + a xor b AS xor, + a || b AS concat FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df.lazy().select(&[ @@ -374,18 +374,18 @@ fn test_agg_functions() { context.register("df", df.clone().lazy()); let sql = r#" SELECT - sum(a) as sum_a, - first(a) as first_a, - last(a) as last_a, - avg(a) as avg_a, - max(a) as max_a, - min(a) as min_a, - atan(a) as atan_a, - stddev(a) as stddev_a, - variance(a) as variance_a, - count(a) as count_a, - count(distinct a) as count_distinct_a, - count(*) as count_all + sum(a) AS sum_a, + first(a) AS first_a, + last(a) AS last_a, + avg(a) AS avg_a, + max(a) AS max_a, + min(a) AS min_a, + atan(a) AS atan_a, + stddev(a) AS stddev_a, + variance(a) AS variance_a, + count(a) AS count_a, + count(distinct a) AS count_distinct_a, + count(*) AS count_all FROM df"#; let df_sql = context.execute(sql).unwrap().collect().unwrap(); let df_pl = df @@ -414,6 +414,7 @@ fn test_create_table() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" CREATE TABLE df2 AS SELECT a @@ -423,14 +424,15 @@ fn test_create_table() { "Response" => ["CREATE TABLE"] } .unwrap(); + assert!(df_sql.equals(&create_tbl_res)); let df_2 = context .execute(r#"SELECT a FROM df2"#) .unwrap() .collect() .unwrap(); - let expected = df.lazy().select(&[col("a")]).collect().unwrap(); + let expected = df.lazy().select(&[col("a")]).collect().unwrap(); assert!(df_2.equals(&expected)); } @@ -450,6 +452,7 @@ fn test_unary_minus_0() { .filter(col("value").lt(lit(-1))) .collect() .unwrap(); + assert!(df_sql.equals(&df_pl)); } @@ -478,7 +481,7 @@ fn test_arr_agg() { vec![col("a").implode().alias("a")], ), ( - "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) as b FROM df", + "SELECT ARRAY_AGG(a) AS a, ARRAY_AGG(b) AS b FROM df", vec![col("a").implode().alias("a"), col("b").implode().alias("b")], ), ( @@ -530,6 +533,23 @@ fn test_ctes() -> PolarsResult<()> { Ok(()) } +#[test] +fn test_cte_values() -> PolarsResult<()> { + let sql = r#" + WITH + x AS (SELECT w.* FROM (VALUES(1,2), (3,4)) AS w(a, b)), + y (m, n) AS ( + WITH z(c, d) AS (SELECT a, b FROM x) + SELECT d*2 AS d2, c*3 AS c3 FROM z + ) + SELECT n, m FROM y + "#; + let mut context = SQLContext::new(); + assert!(context.execute(sql).is_ok()); + + Ok(()) +} + #[test] #[cfg(feature = "ipc")] fn test_group_by_2() -> PolarsResult<()> { @@ -543,7 +563,7 @@ fn test_group_by_2() -> PolarsResult<()> { let sql = r#" SELECT category, - count(category) as count, + count(category) AS count, max(calories), min(fats_g) FROM foods @@ -566,6 +586,7 @@ fn test_group_by_2() -> PolarsResult<()> { SortMultipleOptions::default().with_order_descending_multi([false, true]), ) .limit(2); + let expected = expected.collect()?; assert!(df_sql.equals(&expected)); Ok(()) @@ -591,6 +612,7 @@ fn test_case_expr() { .then(lit("lteq_5")) .otherwise(lit("no match")) .alias("sign"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -600,6 +622,7 @@ fn test_case_expr_with_expression() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); + let sql = r#" SELECT CASE b%2 @@ -615,6 +638,7 @@ fn test_case_expr_with_expression() { .then(lit("odd")) .otherwise(lit("No?")) .alias("parity"); + let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.equals(&df_pl)); } @@ -630,17 +654,72 @@ fn test_sql_expr() { #[test] fn test_iss_9471() { - let sql = r#" - SELECT - ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") as "abs", - FROM df"#; let df = df! { "a" => [-4, -3, -2, -1, 0, 1, 2, 3, 4], } .unwrap() .lazy(); + let mut context = SQLContext::new(); context.register("df", df); + + let sql = r#" + SELECT + ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") AS "abs", + FROM df"#; let res = context.execute(sql); + assert!(res.is_err()) } + +#[test] +fn test_order_by_excluded_column() { + let df = df! { + "x" => [0, 1, 2, 3], + "y" => [4, 2, 0, 8], + } + .unwrap() + .lazy(); + + let mut context = SQLContext::new(); + context.register("df", df); + + for sql in [ + "SELECT * EXCLUDE y FROM df ORDER BY y", + "SELECT df.* EXCLUDE y FROM df ORDER BY y", + ] { + let df_sorted = context.execute(sql).unwrap().collect().unwrap(); + + let expected = df! {"x" => [2, 1, 0, 3],}.unwrap(); + assert!(df_sorted.equals(&expected)); + } +} + +#[test] +fn test_struct_wildcards() { + let struct_cols = vec![col("num"), col("str"), col("val")]; + let df_original = df! { + "num" => [100, 200, 300, 400], + "str" => ["d", "c", "b", "a"], + "val" => [0.0, 5.0, 3.0, 4.0], + } + .unwrap(); + + let df_struct = df_original + .clone() + .lazy() + .select([as_struct(struct_cols).alias("json_msg")]); + + let mut context = SQLContext::new(); + context.register("df", df_struct.clone().lazy()); + + for sql in [ + r#"SELECT json_msg.* FROM df"#, + r#"SELECT df.json_msg.* FROM df"#, + r#"SELECT json_msg.* FROM df ORDER BY json_msg.num"#, + r#"SELECT df.json_msg.* FROM df ORDER BY json_msg.str DESC"#, + ] { + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + assert!(df_sql.equals(&df_original)); + } +} diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py index 6f1cad494ac6..9ed6bd1a2cb0 100644 --- a/py-polars/tests/unit/sql/test_structs.py +++ b/py-polars/tests/unit/sql/test_structs.py @@ -8,7 +8,7 @@ @pytest.fixture() -def struct_df() -> pl.DataFrame: +def df_struct() -> pl.DataFrame: return pl.DataFrame( { "id": [100, 200, 300, 400], @@ -19,8 +19,8 @@ def struct_df() -> pl.DataFrame: ).select(pl.struct(pl.all()).alias("json_msg")) -def test_struct_field_selection(struct_df: pl.DataFrame) -> None: - res = struct_df.sql( +def test_struct_field_selection(df_struct: pl.DataFrame) -> None: + res = df_struct.sql( """ SELECT -- validate table alias resolution @@ -36,17 +36,39 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: json_msg.name DESC """ ) - expected = pl.DataFrame( - { - "ID": [400, 100], - "NAME": ["Zoe", "Alice"], - "AGE": [45, 32], - } + {"ID": [400, 100], "NAME": ["Zoe", "Alice"], "AGE": [45, 32]} ) assert_frame_equal(expected, res) +@pytest.mark.parametrize( + ("fields", "excluding"), + [ + ("json_msg.*", ""), + ("self.json_msg.*", ""), + ("json_msg.other.*", ""), + ("self.json_msg.other.*", ""), + ], +) +def test_struct_field_wildcard_selection( + fields: str, + excluding: str, + df_struct: pl.DataFrame, +) -> None: + query = f"SELECT {fields} {excluding} FROM df_struct ORDER BY json_msg.id" + print(query) + res = pl.sql(query).collect() + + expected = df_struct.unnest("json_msg") + if fields.endswith(".other.*"): + expected = expected["other"].struct.unnest() + if excluding: + expected = expected.drop(excluding.split(",")) + + assert_frame_equal(expected, res) + + @pytest.mark.parametrize( "invalid_column", [ @@ -55,6 +77,6 @@ def test_struct_field_selection(struct_df: pl.DataFrame) -> None: "self.json_msg.other.invalid_column", ], ) -def test_struct_indexing_errors(invalid_column: str, struct_df: pl.DataFrame) -> None: +def test_struct_indexing_errors(invalid_column: str, df_struct: pl.DataFrame) -> None: with pytest.raises(StructFieldNotFoundError, match="invalid_column"): - struct_df.sql(f"SELECT {invalid_column} FROM self") + df_struct.sql(f"SELECT {invalid_column} FROM self")