From 267a42fdd0fb9ac36e7eb416413c12ccb74e22f2 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 8 Jun 2024 14:11:10 +0400 Subject: [PATCH 1/2] fix: Ensure that SQL parser errors are consistently mapped to `SQLInterfaceError` --- crates/polars-sql/src/context.rs | 11 ++++++----- crates/polars-sql/src/functions.rs | 18 +++++++++--------- crates/polars-sql/src/sql_expr.rs | 19 ++++++++++++++----- py-polars/tests/unit/sql/test_temporal.py | 8 +++++++- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 7bb9c921ef8d..70e23e75ad09 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -1,7 +1,6 @@ use std::cell::RefCell; use polars_core::prelude::*; -use polars_error::to_compute_err; use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; use polars_plan::prelude::*; @@ -15,7 +14,9 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry}; -use crate::sql_expr::{parse_sql_array, parse_sql_expr, process_join_constraint}; +use crate::sql_expr::{ + parse_sql_array, parse_sql_expr, process_join_constraint, to_sql_interface_err, +}; use crate::table_functions::PolarsTableFunctions; /// The SQLContext is the main entry point for executing SQL queries. @@ -115,9 +116,9 @@ impl SQLContext { let ast = parser .try_with_sql(query) - .map_err(to_compute_err)? + .map_err(to_sql_interface_err)? .parse_statements() - .map_err(to_compute_err)?; + .map_err(to_sql_interface_err)?; polars_ensure!(ast.len() == 1, SQLInterface: "one (and only one) statement can be parsed at a time"); let res = self.execute_statement(ast.first().unwrap())?; @@ -913,7 +914,7 @@ impl SQLContext { ) -> PolarsResult { polars_ensure!( !contains_wildcard, - SQLSyntax: "GROUP BY error: can't process wildcard in group_by" + SQLSyntax: "GROUP BY error: cannot process wildcard in group_by" ); let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; let group_by_keys_schema = diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 37934dee065c..74ba07ac7830 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1343,10 +1343,10 @@ impl FromSQLExpr for f64 { SQLExpr::Value(v) => match v { SQLValue::Number(s, _) => s .parse() - .map_err(|_| polars_err!(SQLInterface: "can't parse literal {:?}", s)), - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", v), + .map_err(|_| polars_err!(SQLInterface: "cannot parse literal {:?}", s)), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v), }, - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", expr), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr), } } } @@ -1359,9 +1359,9 @@ impl FromSQLExpr for bool { match expr { SQLExpr::Value(v) => match v { SQLValue::Boolean(v) => Ok(*v), - _ => polars_bail!(SQLInterface: "can't parse boolean {:?}", v), + _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", v), }, - _ => polars_bail!(SQLInterface: "can't parse boolean {:?}", expr), + _ => polars_bail!(SQLInterface: "cannot parse boolean {:?}", expr), } } } @@ -1374,9 +1374,9 @@ impl FromSQLExpr for String { match expr { SQLExpr::Value(v) => match v { SQLValue::SingleQuotedString(s) => Ok(s.clone()), - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", v), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v), }, - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", expr), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr), } } } @@ -1392,9 +1392,9 @@ impl FromSQLExpr for StrptimeOptions { format: Some(s.clone()), ..StrptimeOptions::default() }), - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", v), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v), }, - _ => polars_bail!(SQLInterface: "can't parse literal {:?}", expr), + _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", expr), } } } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 0f59a87f56d3..227b1d7fa4fd 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -1,8 +1,8 @@ +use std::fmt::Display; use std::ops::Div; use polars_core::export::regex; use polars_core::prelude::*; -use polars_error::to_compute_err; use polars_lazy::prelude::*; use polars_plan::prelude::typed_lit; use polars_plan::prelude::LiteralValue::Null; @@ -29,6 +29,13 @@ use crate::SQLContext; static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +#[inline] +#[cold] +#[must_use] +pub fn to_sql_interface_err(err: impl Display) -> PolarsError { + PolarsError::SQLInterface(err.to_string().into()) +} + fn timeunit_from_precision(prec: &Option) -> PolarsResult { Ok(match prec { None => TimeUnit::Microseconds, @@ -804,7 +811,7 @@ impl SQLExprVisitor<'_> { .map(|n: i64| AnyValue::Int64(if negate { -n } else { n })) .map_err(|_| ()) } - .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {s:?}"))? + .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))? }, #[cfg(feature = "binary_encoding")] SQLValue::HexStringLiteral(x) => { @@ -1138,8 +1145,10 @@ pub fn sql_expr>(s: S) -> PolarsResult { ..Default::default() }); - let mut ast = parser.try_with_sql(s.as_ref()).map_err(to_compute_err)?; - let expr = ast.parse_select_item().map_err(to_compute_err)?; + let mut ast = parser + .try_with_sql(s.as_ref()) + .map_err(to_sql_interface_err)?; + let expr = ast.parse_select_item().map_err(to_sql_interface_err)?; Ok(match &expr { SelectItem::ExprWithAlias { expr, alias } => { @@ -1169,7 +1178,7 @@ pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsRes }; visitor.array_expr_to_series(arr.elem.as_slice()) }, - _ => polars_bail!(ComputeError: "Expected array expression, found {:?}", expr), + _ => polars_bail!(SQLSyntax: "Expected array expression, found {:?}", expr), } } diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 3233369a5ccd..a2496d7becad 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -6,7 +6,7 @@ import pytest import polars as pl -from polars.exceptions import ComputeError, SQLSyntaxError +from polars.exceptions import ComputeError, SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal @@ -275,3 +275,9 @@ def test_timestamp_time_unit_errors() -> None: match=f"invalid temporal type precision; expected 1-9, found {prec}", ): ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data") + + with pytest.raises( + SQLInterfaceError, + match="sql parser error: Expected literal int, found: - ", + ): + ctx.execute("SELECT ts::timestamp(-3) FROM frame_data") From c109a9292eb29f274272fa8c205e7eb4abe5be3a Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 9 Jun 2024 12:30:35 +0400 Subject: [PATCH 2/2] consolidate some tests --- crates/polars-sql/tests/iss_7436.rs | 40 ------- crates/polars-sql/tests/iss_7437.rs | 38 ------- crates/polars-sql/tests/iss_7440.rs | 27 ----- crates/polars-sql/tests/iss_8395.rs | 26 ----- crates/polars-sql/tests/iss_8419.rs | 45 -------- crates/polars-sql/tests/issues.rs | 161 ++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 176 deletions(-) delete mode 100644 crates/polars-sql/tests/iss_7436.rs delete mode 100644 crates/polars-sql/tests/iss_7437.rs delete mode 100644 crates/polars-sql/tests/iss_7440.rs delete mode 100644 crates/polars-sql/tests/iss_8395.rs delete mode 100644 crates/polars-sql/tests/iss_8419.rs create mode 100644 crates/polars-sql/tests/issues.rs diff --git a/crates/polars-sql/tests/iss_7436.rs b/crates/polars-sql/tests/iss_7436.rs deleted file mode 100644 index 6ea1c1c4b650..000000000000 --- a/crates/polars-sql/tests/iss_7436.rs +++ /dev/null @@ -1,40 +0,0 @@ -#[test] -#[cfg(feature = "csv")] -fn iss_7436() { - use polars_lazy::prelude::*; - use polars_sql::*; - - let mut context = SQLContext::new(); - let sql = r#" - CREATE TABLE foods AS - SELECT * - FROM read_csv('../../examples/datasets/foods1.csv')"#; - context.execute(sql).unwrap().collect().unwrap(); - let df_sql = context - .execute( - r#" - SELECT - "fats_g" AS fats, - AVG(calories) OVER (PARTITION BY "category") AS avg_calories_by_category - FROM foods - LIMIT 5 - "#, - ) - .unwrap() - .collect() - .unwrap(); - let expected = LazyCsvReader::new("../../examples/datasets/foods1.csv") - .finish() - .unwrap() - .select(&[ - col("fats_g").alias("fats"), - col("calories") - .mean() - .over(vec![col("category")]) - .alias("avg_calories_by_category"), - ]) - .limit(5) - .collect() - .unwrap(); - assert!(df_sql.equals(&expected)); -} diff --git a/crates/polars-sql/tests/iss_7437.rs b/crates/polars-sql/tests/iss_7437.rs deleted file mode 100644 index 1db92a33d992..000000000000 --- a/crates/polars-sql/tests/iss_7437.rs +++ /dev/null @@ -1,38 +0,0 @@ -#[cfg(feature = "csv")] -use polars_core::prelude::*; -#[cfg(feature = "csv")] -use polars_lazy::prelude::*; -#[cfg(feature = "csv")] -use polars_sql::*; - -#[test] -#[cfg(feature = "csv")] -fn iss_7437() -> PolarsResult<()> { - let mut context = SQLContext::new(); - let sql = r#" - CREATE TABLE foods AS - SELECT * - FROM read_csv('../../examples/datasets/foods1.csv')"#; - context.execute(sql)?.collect()?; - - let df_sql = context - .execute( - r#" - SELECT "category" as category - FROM foods - GROUP BY "category" - "#, - )? - .collect()? - .sort(["category"], SortMultipleOptions::default())?; - - let expected = LazyCsvReader::new("../../examples/datasets/foods1.csv") - .finish()? - .group_by(vec![col("category").alias("category")]) - .agg(vec![]) - .collect()? - .sort(["category"], Default::default())?; - - assert!(df_sql.equals(&expected)); - Ok(()) -} diff --git a/crates/polars-sql/tests/iss_7440.rs b/crates/polars-sql/tests/iss_7440.rs deleted file mode 100644 index b0649a86f3e5..000000000000 --- a/crates/polars-sql/tests/iss_7440.rs +++ /dev/null @@ -1,27 +0,0 @@ -use polars_core::prelude::*; -use polars_lazy::prelude::*; -use polars_sql::*; - -#[test] -fn iss_7440() { - let df = df! { - "a" => [2.0, -2.5] - } - .unwrap() - .lazy(); - let sql = r#"SELECT a, FLOOR(a) AS floor, CEIL(a) AS ceil FROM df"#; - let mut context = SQLContext::new(); - context.register("df", df.clone()); - - let df_sql = context.execute(sql).unwrap().collect().unwrap(); - - let df_pl = df - .select(&[ - col("a"), - col("a").floor().alias("floor"), - col("a").ceil().alias("ceil"), - ]) - .collect() - .unwrap(); - assert!(df_sql.equals_missing(&df_pl)); -} diff --git a/crates/polars-sql/tests/iss_8395.rs b/crates/polars-sql/tests/iss_8395.rs deleted file mode 100644 index 24fcaa7de3b3..000000000000 --- a/crates/polars-sql/tests/iss_8395.rs +++ /dev/null @@ -1,26 +0,0 @@ -#[cfg(feature = "csv")] -use polars_core::prelude::*; -#[cfg(feature = "csv")] -use polars_sql::*; - -#[test] -#[cfg(feature = "csv")] -fn iss_8395() -> PolarsResult<()> { - use polars_core::series::Series; - - let mut context = SQLContext::new(); - let sql = r#" - with foods as ( - SELECT * - FROM read_csv('../../examples/datasets/foods1.csv') - ) - select * from foods where category IN ('vegetables', 'seafood')"#; - let res = context.execute(sql)?; - let df = res.collect()?; - - // assert that the df only contains [vegetables, seafood] - let s = df.column("category")?.unique()?.sort(Default::default())?; - let expected = Series::new("category", &["seafood", "vegetables"]); - assert!(s.equals(&expected)); - Ok(()) -} diff --git a/crates/polars-sql/tests/iss_8419.rs b/crates/polars-sql/tests/iss_8419.rs deleted file mode 100644 index d967eefbe487..000000000000 --- a/crates/polars-sql/tests/iss_8419.rs +++ /dev/null @@ -1,45 +0,0 @@ -use polars_core::prelude::*; -use polars_lazy::prelude::*; -use polars_sql::*; - -#[test] -fn iss_8419() { - let df = df! { - "Year"=> [2018, 2018, 2019, 2019, 2020, 2020], - "Country"=> ["US", "UK", "US", "UK", "US", "UK"], - "Sales"=> [1000, 2000, 3000, 4000, 5000, 6000] - } - .unwrap() - .lazy(); - let expected = df - .clone() - .select(&[ - col("Year"), - col("Country"), - col("Sales"), - col("Sales") - .sort(SortOptions::default().with_order_descending(true)) - .cum_sum(false) - .alias("SalesCumulative"), - ]) - .sort(["SalesCumulative"], Default::default()) - .collect() - .unwrap(); - let mut ctx = SQLContext::new(); - ctx.register("df", df); - - let query = r#" - SELECT - Year, - Country, - Sales, - SUM(Sales) OVER (ORDER BY Sales DESC) as SalesCumulative - FROM - df - ORDER BY - SalesCumulative - "#; - let df = ctx.execute(query).unwrap().collect().unwrap(); - - assert!(df.equals(&expected)) -} diff --git a/crates/polars-sql/tests/issues.rs b/crates/polars-sql/tests/issues.rs new file mode 100644 index 000000000000..31c0a89e84ff --- /dev/null +++ b/crates/polars-sql/tests/issues.rs @@ -0,0 +1,161 @@ +use polars_core::prelude::*; +use polars_lazy::prelude::*; +use polars_sql::*; + +#[test] +#[cfg(feature = "csv")] +fn iss_7437() -> PolarsResult<()> { + let mut context = SQLContext::new(); + let sql = r#" + CREATE TABLE foods AS + SELECT * + FROM read_csv('../../examples/datasets/foods1.csv')"#; + context.execute(sql)?.collect()?; + + let df_sql = context + .execute( + r#" + SELECT "category" as category + FROM foods + GROUP BY "category" + "#, + )? + .collect()? + .sort(["category"], SortMultipleOptions::default())?; + + let expected = LazyCsvReader::new("../../examples/datasets/foods1.csv") + .finish()? + .group_by(vec![col("category").alias("category")]) + .agg(vec![]) + .collect()? + .sort(["category"], Default::default())?; + + assert!(df_sql.equals(&expected)); + Ok(()) +} + +#[test] +#[cfg(feature = "csv")] +fn iss_7436() { + let mut context = SQLContext::new(); + let sql = r#" + CREATE TABLE foods AS + SELECT * + FROM read_csv('../../examples/datasets/foods1.csv')"#; + context.execute(sql).unwrap().collect().unwrap(); + let df_sql = context + .execute( + r#" + SELECT + "fats_g" AS fats, + AVG(calories) OVER (PARTITION BY "category") AS avg_calories_by_category + FROM foods + LIMIT 5 + "#, + ) + .unwrap() + .collect() + .unwrap(); + let expected = LazyCsvReader::new("../../examples/datasets/foods1.csv") + .finish() + .unwrap() + .select(&[ + col("fats_g").alias("fats"), + col("calories") + .mean() + .over(vec![col("category")]) + .alias("avg_calories_by_category"), + ]) + .limit(5) + .collect() + .unwrap(); + assert!(df_sql.equals(&expected)); +} + +#[test] +fn iss_7440() { + let df = df! { + "a" => [2.0, -2.5] + } + .unwrap() + .lazy(); + let sql = r#"SELECT a, FLOOR(a) AS floor, CEIL(a) AS ceil FROM df"#; + let mut context = SQLContext::new(); + context.register("df", df.clone()); + + let df_sql = context.execute(sql).unwrap().collect().unwrap(); + + let df_pl = df + .select(&[ + col("a"), + col("a").floor().alias("floor"), + col("a").ceil().alias("ceil"), + ]) + .collect() + .unwrap(); + assert!(df_sql.equals_missing(&df_pl)); +} + +#[test] +#[cfg(feature = "csv")] +fn iss_8395() -> PolarsResult<()> { + use polars_core::series::Series; + + let mut context = SQLContext::new(); + let sql = r#" + with foods as ( + SELECT * + FROM read_csv('../../examples/datasets/foods1.csv') + ) + select * from foods where category IN ('vegetables', 'seafood')"#; + let res = context.execute(sql)?; + let df = res.collect()?; + + // assert that the df only contains [vegetables, seafood] + let s = df.column("category")?.unique()?.sort(Default::default())?; + let expected = Series::new("category", &["seafood", "vegetables"]); + assert!(s.equals(&expected)); + Ok(()) +} + +#[test] +fn iss_8419() { + let df = df! { + "Year"=> [2018, 2018, 2019, 2019, 2020, 2020], + "Country"=> ["US", "UK", "US", "UK", "US", "UK"], + "Sales"=> [1000, 2000, 3000, 4000, 5000, 6000] + } + .unwrap() + .lazy(); + let expected = df + .clone() + .select(&[ + col("Year"), + col("Country"), + col("Sales"), + col("Sales") + .sort(SortOptions::default().with_order_descending(true)) + .cum_sum(false) + .alias("SalesCumulative"), + ]) + .sort(["SalesCumulative"], Default::default()) + .collect() + .unwrap(); + let mut ctx = SQLContext::new(); + ctx.register("df", df); + + let query = r#" + SELECT + Year, + Country, + Sales, + SUM(Sales) OVER (ORDER BY Sales DESC) as SalesCumulative + FROM + df + ORDER BY + SalesCumulative + "#; + let df = ctx.execute(query).unwrap().collect().unwrap(); + + assert!(df.equals(&expected)) +}