Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade sqlparser-rs to 0.51.0, support new interval logic from sqlparse-rs #12222

Merged
merged 9 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ rand = "0.8"
regex = "1.8"
rstest = "0.22.0"
serde_json = "1"
sqlparser = { version = "0.50.0", features = ["visitor"] }
sqlparser = { version = "0.51.0", features = ["visitor"] }
tempfile = "3"
thiserror = "1.0.44"
tokio = { version = "1.36", features = ["macros", "rt", "sync"] }
Expand Down
4 changes: 2 additions & 2 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 28 additions & 16 deletions datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

use std::any::Any;
use std::str::FromStr;
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, Float64Array};
use arrow::compute::kernels::cast_utils::IntervalUnit;
use arrow::compute::{binary, cast, date_part, DatePart};
use arrow::datatypes::DataType::{
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View,
Expand Down Expand Up @@ -161,22 +163,32 @@ impl ScalarUDFImpl for DatePartFunc {
return exec_err!("Date part '{part}' not supported");
}

let arr = match part_trim.to_lowercase().as_str() {
"year" => date_part_f64(array.as_ref(), DatePart::Year)?,
"quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
"month" => date_part_f64(array.as_ref(), DatePart::Month)?,
"week" => date_part_f64(array.as_ref(), DatePart::Week)?,
"day" => date_part_f64(array.as_ref(), DatePart::Day)?,
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
"hour" => date_part_f64(array.as_ref(), DatePart::Hour)?,
"minute" => date_part_f64(array.as_ref(), DatePart::Minute)?,
"second" => seconds(array.as_ref(), Second)?,
"millisecond" => seconds(array.as_ref(), Millisecond)?,
"microsecond" => seconds(array.as_ref(), Microsecond)?,
"nanosecond" => seconds(array.as_ref(), Nanosecond)?,
"epoch" => epoch(array.as_ref())?,
_ => return exec_err!("Date part '{part}' not supported"),
// using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
// and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) {
match interval_unit {
IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?,
IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?,
IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?,
IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?,
IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?,
IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?,
IntervalUnit::Second => seconds(array.as_ref(), Second)?,
IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?,
IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?,
IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?,
// century and decade are not supported by `DatePart`, although they are supported in postgres
_ => return exec_err!("Date part '{part}' not supported"),
}
} else {
// special cases that can be extracted (in postgres) but are not interval units
match part_trim.to_lowercase().as_str() {
"qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
"epoch" => epoch(array.as_ref())?,
_ => return exec_err!("Date part '{part}' not supported"),
}
};

Ok(if is_scalar {
Expand Down
4 changes: 1 addition & 3 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(false, interval, schema, planner_context)
}
SQLExpr::Interval(interval) => self.sql_interval_to_expr(false, interval),
SQLExpr::Identifier(id) => {
self.sql_identifier_to_expr(id, schema, planner_context)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/expr/unary_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
self.parse_sql_number(&n, true)
}
SQLExpr::Interval(interval) => {
self.sql_interval_to_expr(true, interval, schema, planner_context)
self.sql_interval_to_expr(true, interval)
}
// not a literal, apply negative operator on expression
_ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(
Expand Down
195 changes: 72 additions & 123 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_expr::expr::{BinaryExpr, Placeholder};
use datafusion_expr::planner::PlannerResult;
use datafusion_expr::{lit, Expr, Operator};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value};
use sqlparser::parser::ParserError::ParserError;
use std::borrow::Cow;

Expand Down Expand Up @@ -168,12 +168,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

/// Convert a SQL interval expression to a DataFusion logical plan
/// expression
#[allow(clippy::only_used_in_recursion)]
pub(super) fn sql_interval_to_expr(
&self,
negative: bool,
interval: Interval,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
if interval.leading_precision.is_some() {
return not_impl_err!(
Expand All @@ -196,127 +195,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
);
}

// Only handle string exprs for now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is great to remove all this stuff from the sql planner (as it is now in the parser)

let value = match *interval.value {
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => {
if negative {
format!("-{s}")
} else {
s
}
}
// Support expressions like `interval '1 month' + date/timestamp`.
// Such expressions are parsed like this by sqlparser-rs
//
// Interval
// BinaryOp
// Value(StringLiteral)
// Cast
// Value(StringLiteral)
//
// This code rewrites them to the following:
//
// BinaryOp
// Interval
// Value(StringLiteral)
// Cast
// Value(StringLiteral)
SQLExpr::BinaryOp { left, op, right } => {
let df_op = match op {
BinaryOperator::Plus => Operator::Plus,
BinaryOperator::Minus => Operator::Minus,
BinaryOperator::Eq => Operator::Eq,
BinaryOperator::NotEq => Operator::NotEq,
BinaryOperator::Gt => Operator::Gt,
BinaryOperator::GtEq => Operator::GtEq,
BinaryOperator::Lt => Operator::Lt,
BinaryOperator::LtEq => Operator::LtEq,
_ => {
return not_impl_err!("Unsupported interval operator: {op:?}");
}
};
match (
interval.leading_field.as_ref(),
left.as_ref(),
right.as_ref(),
) {
(_, _, SQLExpr::Value(_)) => {
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: interval.leading_field.clone(),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
let right_expr = self.sql_interval_to_expr(
false,
Interval {
value: right,
leading_field: interval.leading_field,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}
// In this case, the left node is part of the interval
// expr and the right node is an independent expr.
//
// Leading field is not supported when the right operand
// is not a value.
(None, _, _) => {
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: None,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
schema,
planner_context,
)?;
let right_expr = self.sql_expr_to_logical_expr(
*right,
schema,
planner_context,
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}
_ => {
let value = SQLExpr::BinaryOp { left, op, right };
return not_impl_err!(
"Unsupported interval argument. Expected string literal, got: {value:?}"
);
}
if let SQLExpr::BinaryOp { left, op, right } = *interval.value {
let df_op = match op {
BinaryOperator::Plus => Operator::Plus,
BinaryOperator::Minus => Operator::Minus,
_ => {
return not_impl_err!("Unsupported interval operator: {op:?}");
}
}
_ => {
return not_impl_err!(
"Unsupported interval argument. Expected string literal, got: {:?}",
interval.value
);
}
};
};
let left_expr = self.sql_interval_to_expr(
negative,
Interval {
value: left,
leading_field: interval.leading_field.clone(),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
)?;
let right_expr = self.sql_interval_to_expr(
false,
Interval {
value: right,
leading_field: interval.leading_field,
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
},
)?;
return Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
df_op,
Box::new(right_expr),
)));
}

let value = interval_literal(*interval.value, negative)?;

let value = if has_units(&value) {
// If the interval already contains a unit
Expand All @@ -343,6 +257,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}

fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result<String> {
let s = match interval_value {
SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => s,
SQLExpr::Value(Value::Number(ref v, long)) => {
if long {
return not_impl_err!(
"Unsupported interval argument. Long number not supported: {interval_value:?}"
);
} else {
v.to_string()
}
}
SQLExpr::UnaryOp { op, expr } => {
let negative = match op {
UnaryOperator::Minus => !negative,
UnaryOperator::Plus => negative,
_ => {
return not_impl_err!(
"Unsupported SQL unary operator in interval {op:?}"
);
}
};
interval_literal(*expr, negative)?
}
_ => {
return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}");
}
};
if negative {
Ok(format!("-{s}"))
} else {
Ok(s)
}
}

// TODO make interval parsing better in arrow-rs / expose `IntervalType`
fn has_units(val: &str) -> bool {
let val = val.to_lowercase();
Expand Down
18 changes: 12 additions & 6 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,17 @@ fn test_table_references_in_plan_to_sql() {
assert_eq!(format!("{}", sql), expected_sql)
}

test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a driveby (unrelated) cleanup, right? (this is fine I am just verifying my understanding)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was necessary to get tests to pass, I assume it came from some other change in sqlparser, I didn't look too hard.

But it's not just cosmetic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this mostly changes catalog to "catalog"
a driveby is usage or raw literals (which is a nice thing on its own)

"catalog.schema.table",
r#"SELECT "catalog"."schema"."table".id, "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#,
);
test(
"schema.table",
r#"SELECT "schema"."table".id, "schema"."table"."value" FROM "schema"."table""#,
);
test(
"table",
"SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
r#"SELECT "table".id, "table"."value" FROM "table""#,
);
}

Expand All @@ -521,10 +527,10 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() {

test(
"catalog.schema.table",
"SELECT * FROM catalog.\"schema\".\"table\"",
r#"SELECT * FROM "catalog"."schema"."table""#,
);
test("schema.table", "SELECT * FROM \"schema\".\"table\"");
test("table", "SELECT * FROM \"table\"");
test("schema.table", r#"SELECT * FROM "schema"."table""#);
test("table", r#"SELECT * FROM "table""#);
}

#[test]
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,16 @@ SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanose
----
50.123456789

query R
select extract(second from '2024-08-09T12:13:14')
----
14

query R
select extract(seconds from '2024-08-09T12:13:14')
----
14

query R
SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
Expand All @@ -1381,6 +1391,11 @@ SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(N
----
50123456.789000005

query R
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
50123456.789000005

query R
SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
----
Expand Down
Loading
Loading