diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 364c412fb02d..49d5012c6fa5 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.57" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" +checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c" dependencies = [ "proc-macro2", "quote", @@ -271,9 +271,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "byteorder" @@ -398,9 +398,9 @@ dependencies = [ [[package]] name = "comfy-table" -version = "6.1.0" +version = "6.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85914173c2f558d61613bfbbf1911f14e630895087a7ed2fafc0f5319e1536e7" +checksum = "7b3d16bb3da60be2f7c7acfc438f2ae6f3496897ce68c291d0509bb67b4e248e" dependencies = [ "strum", "strum_macros", @@ -499,9 +499,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f39818dcfc97d45b03953c1292efc4e80954e1583c4aa770bac1383e2310a4" +checksum = "3f83d0ebf42c6eafb8d7c52f7e5f2d3003b89c7aa4fd2b79229209459a849af8" dependencies = [ "cc", "cxxbridge-flags", @@ -511,9 +511,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e580d70777c116df50c390d1211993f62d40302881e54d4b79727acb83d0199" +checksum = "07d050484b55975889284352b0ffc2ecbda25c0c55978017c132b29ba0818a86" dependencies = [ "cc", "codespan-reporting", @@ -526,15 +526,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56a46460b88d1cec95112c8c363f0e2c39afdb237f60583b0b36343bf627ea9c" +checksum = "99d2199b00553eda8012dfec8d3b1c75fce747cf27c169a270b3b99e3448ab78" [[package]] name = "cxxbridge-macro" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747b608fecf06b0d72d440f27acc99288207324b793be2c17991839f3d4995ea" +checksum = "dcb67a6de1f602736dd7eaead0080cf3435df806c61b24b13328db128c58868f" dependencies = [ "proc-macro2", "quote", @@ -567,7 +567,7 @@ dependencies = [ "log", "num_cpus", "object_store", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "parking_lot", "parquet", "paste", @@ -605,8 +605,9 @@ name = "datafusion-common" version = "13.0.0" dependencies = [ "arrow", + "chrono", "object_store", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "parquet", "sqlparser", ] @@ -651,7 +652,7 @@ dependencies = [ "hashbrown", "lazy_static", "md-5", - "ordered-float 3.2.0", + "ordered-float 3.3.0", "paste", "rand", "regex", @@ -819,7 +820,7 @@ checksum = "e11dcc7e4d79a8c89b9ab4c6f5c30b1fc4a83c420792da3542fd31179ed5f517" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1120,9 +1121,9 @@ dependencies = [ [[package]] name = "iana-time-zone-haiku" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde6edd6cef363e9359ed3c98ba64590ba9eecba2293eb5a723ab32aee8926aa" +checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca" dependencies = [ "cxx", "cxx-build", @@ -1298,9 +1299,9 @@ checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" [[package]] name = "libmimalloc-sys" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ca136052550448f55df7898c6dbe651c6b574fe38a0d9ea687a9f8088a2e2c" +checksum = "8fc093ab289b0bfda3aa1bdfab9c9542be29c7ef385cfcbe77f8c9813588eb48" dependencies = [ "cc", ] @@ -1376,9 +1377,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mimalloc" -version = "0.1.29" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f64ad83c969af2e732e907564deb0d0ed393cec4af80776f77dd77a1a427698" +checksum = "76ce6a4b40d3bff9eb3ce9881ca0737a85072f9f975886082640cd46a75cdb35" dependencies = [ "libmimalloc-sys", ] @@ -1407,7 +1408,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1582,9 +1583,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "129d36517b53c461acc6e1580aeb919c8ae6708a4b1eae61c4463a615d4f0411" +checksum = "1f74e330193f90ec45e2b257fa3ef6df087784157ac1ad2c1e71c62837b03aa7" dependencies = [ "num-traits", ] @@ -1607,15 +1608,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" +checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1712,9 +1713,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" dependencies = [ "unicode-ident", ] @@ -1896,14 +1897,14 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" dependencies = [ "log", "ring", @@ -2014,9 +2015,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" +checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" dependencies = [ "itoa 1.0.4", "ryu", @@ -2107,9 +2108,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "sqlparser" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0781f2b6bd03e5adf065c8e772b49eaea9f640d06a1b9130330fe8bd2563f4fd" +checksum = "86be66ea0b2b22749cfa157d16e2e84bf793e626a3375f4d378dc289fa03affb" dependencies = [ "log", ] @@ -2598,43 +2599,100 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", +] + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + [[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + [[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + [[package]] name = "winreg" version = "0.10.1" diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index d69bc197c35b..9aedeac9e116 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -41,9 +41,10 @@ pyarrow = ["pyo3", "arrow/pyarrow"] [dependencies] apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true } arrow = { version = "25.0.0", default-features = false } +chrono = { version = "0.4", default-features = false } cranelift-module = { version = "0.89.0", optional = true } object_store = { version = "0.5.0", default-features = false, optional = true } ordered-float = "3.0" parquet = { version = "25.0.0", default-features = false, optional = true } pyo3 = { version = "0.17.1", optional = true } -sqlparser = "0.25" +sqlparser = "0.26" diff --git a/datafusion/physical-expr/src/expressions/delta.rs b/datafusion/common/src/delta.rs similarity index 98% rename from datafusion/physical-expr/src/expressions/delta.rs rename to datafusion/common/src/delta.rs index b7efdab0a48d..1de0836fc3ec 100644 --- a/datafusion/physical-expr/src/expressions/delta.rs +++ b/datafusion/common/src/delta.rs @@ -27,7 +27,7 @@ use chrono::Datelike; /// Returns true if the year is a leap-year, as naively defined in the Gregorian calendar. #[inline] -pub(crate) fn is_leap_year(year: i32) -> bool { +fn is_leap_year(year: i32) -> bool { year % 4 == 0 && (year % 100 != 0 || year % 400 == 0) } @@ -49,7 +49,7 @@ fn normalise_day(year: i32, month: u32, day: u32) -> u32 { /// Shift a date by the given number of months. /// Ambiguous month-ends are shifted backwards as necessary. -pub(crate) fn shift_months(date: D, months: i32) -> D { +pub fn shift_months(date: D, months: i32) -> D { let mut year = date.year() + (date.month() as i32 + months) / 12; let mut month = (date.month() as i32 + months) % 12; let mut day = date.day(); diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 0b6c67f88430..8330a360004a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -17,17 +17,21 @@ pub mod bisect; mod column; +pub mod delta; mod dfschema; mod error; pub mod from_slice; +pub mod parsers; #[cfg(feature = "pyarrow")] mod pyarrow; pub mod scalar; pub mod stats; +pub mod test_util; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{field_not_found, DataFusionError, Result, SchemaError}; +pub use parsers::parse_interval; pub use scalar::{ScalarType, ScalarValue}; pub use stats::{ColumnStatistics, Statistics}; diff --git a/datafusion/sql/src/interval.rs b/datafusion/common/src/parsers.rs similarity index 95% rename from datafusion/sql/src/interval.rs rename to datafusion/common/src/parsers.rs index dbdd038aec71..7b78c92bcbdf 100644 --- a/datafusion/sql/src/interval.rs +++ b/datafusion/common/src/parsers.rs @@ -16,7 +16,7 @@ // under the License. //! Interval parsing logic -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use crate::{DataFusionError, Result, ScalarValue}; use std::str::FromStr; const SECONDS_PER_HOUR: f32 = 3_600_f32; @@ -24,7 +24,7 @@ const MILLIS_PER_SECOND: f32 = 1_000_f32; /// Parses a string with an interval like `'0.5 MONTH'` to an /// appropriately typed [`ScalarValue`] -pub(crate) fn parse_interval(leading_field: &str, value: &str) -> Result { +pub fn parse_interval(leading_field: &str, value: &str) -> Result { // We are storing parts as integers, it's why we need to align parts fractional // INTERVAL '0.5 MONTH' = 15 days, INTERVAL '1.5 MONTH' = 1 month 15 days // INTERVAL '0.5 DAY' = 12 hours, INTERVAL '1.5 DAY' = 1 day 12 hours @@ -144,9 +144,9 @@ pub(crate) fn parse_interval(leading_field: &str, value: &str) -> Result Result { match ($LHS, $RHS) { + // Binary operations on arguments with the same type: ( ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2), @@ -482,35 +486,144 @@ macro_rules! impl_op { (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { primitive_op!(lhs, rhs, Int8, $OPERATION) } - _ => { - impl_distinct_cases_op!($LHS, $RHS, $OPERATION) + // Binary operations on arguments with different types: + (ScalarValue::Date32(Some(days)), _) => { + let value = date32_add(*days, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::Date32(Some(value))) + } + (ScalarValue::Date64(Some(ms)), _) => { + let value = date64_add(*ms, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::Date64(Some(value))) + } + (ScalarValue::TimestampSecond(Some(ts_s), zone), _) => { + let value = seconds_add(*ts_s, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampSecond(Some(value), zone.clone())) + } + (ScalarValue::TimestampMillisecond(Some(ts_ms), zone), _) => { + let value = milliseconds_add(*ts_ms, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampMillisecond(Some(value), zone.clone())) } + (ScalarValue::TimestampMicrosecond(Some(ts_us), zone), _) => { + let value = microseconds_add(*ts_us, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampMicrosecond(Some(value), zone.clone())) + } + (ScalarValue::TimestampNanosecond(Some(ts_ns), zone), _) => { + let value = nanoseconds_add(*ts_ns, $RHS, get_sign!($OPERATION))?; + Ok(ScalarValue::TimestampNanosecond(Some(value), zone.clone())) + } + _ => Err(DataFusionError::Internal(format!( + "Operator {} is not implemented for types {:?} and {:?}", + stringify!($OPERATION), + $LHS, + $RHS + ))), } }; } -// If we want a special implementation for an operation this is the place to implement it. -// For instance, in the future we may want to implement subtraction for dates but not addition. -// We can implement such special cases here. -macro_rules! impl_distinct_cases_op { - ($LHS:expr, $RHS:expr, +) => { - match ($LHS, $RHS) { - e => Err(DataFusionError::Internal(format!( - "Addition is not implemented for {:?}", - e - ))), - } +macro_rules! get_sign { + (+) => { + 1 }; - ($LHS:expr, $RHS:expr, -) => { - match ($LHS, $RHS) { - e => Err(DataFusionError::Internal(format!( - "Subtraction is not implemented for {:?}", - e - ))), - } + (-) => { + -1 }; } +#[inline] +pub fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::days(days as i64)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_days() as i32) +} + +#[inline] +pub fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = epoch.add(Duration::milliseconds(ms)); + let posterior = do_date_math(prior, scalar, sign)?; + Ok(posterior.sub(epoch).num_milliseconds()) +} + +#[inline] +pub fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { + Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) +} + +#[inline] +pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ms / 1000; + let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) +} + +#[inline] +pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_us / 1_000_000; + let nsecs = ((ts_us % 1_000_000) * 1000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) +} + +#[inline] +pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { + let secs = ts_ns / 1_000_000_000; + let nsecs = (ts_ns % 1_000_000_000) as u32; + Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) +} + +#[inline] +fn do_date_time_math( + secs: i64, + nsecs: u32, + scalar: &ScalarValue, + sign: i32, +) -> Result { + let prior = NaiveDateTime::from_timestamp(secs, nsecs); + do_date_math(prior, scalar, sign) +} + +fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result +where + D: Datelike + Add, +{ + Ok(match scalar { + ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), + ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), + ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }) +} + +// Can remove once chrono:0.4.23 is released +fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D +where + D: Datelike + Add, +{ + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(interval); + let months = months * sign; + let days = days * sign; + let nanos = nanos * sign as i64; + let a = shift_months(prior, months); + let b = a.add(Duration::days(days as i64)); + b.add(Duration::nanoseconds(nanos)) +} + +// Can remove once chrono:0.4.23 is released +fn add_day_time(prior: D, interval: i64, sign: i32) -> D +where + D: Datelike + Add, +{ + let (days, ms) = IntervalDayTimeType::to_parts(interval); + let days = days * sign; + let ms = ms * sign; + let intermediate = prior.add(Duration::days(days as i64)); + intermediate.add(Duration::milliseconds(ms as i64)) +} + // manual implementation of `Hash` that uses OrderedFloat to // get defined behavior for floating point impl std::hash::Hash for ScalarValue { @@ -2296,44 +2409,6 @@ impl TryFrom<&DataType> for ScalarValue { } } -// TODO: Remove these coercions once the hardcoded "u64" offset is changed to a -// ScalarValue in WindowFrameBound. -pub trait TryFromValue { - fn try_from_value(datatype: &DataType, value: T) -> Result; -} - -macro_rules! impl_try_from_value { - ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:ty]),+]) => { - impl TryFromValue<$NATIVE> for ScalarValue { - fn try_from_value(datatype: &DataType, value: $NATIVE) -> Result { - match datatype { - $(DataType::$SCALAR => Ok(ScalarValue::$SCALAR(Some(value as $PRIMITIVE))),)+ - _ => { - let msg = format!("Can't create a scalar from data_type \"{:?}\"", datatype); - Err(DataFusionError::NotImplemented(msg)) - } - } - } - } - }; -} - -impl_try_from_value!( - u64, - [ - [Float64, f64], - [Float32, f32], - [UInt64, u64], - [UInt32, u32], - [UInt16, u16], - [UInt8, u8], - [Int64, i64], - [Int32, i32], - [Int16, i16], - [Int8, i8] - ] -); - macro_rules! format_option { ($F:expr, $EXPR:expr) => {{ match $EXPR { @@ -3831,7 +3906,7 @@ mod tests { match lhs.$FUNCTION(&rhs) { Ok(_result) => { panic!( - "Expected summation error between lhs: '{:?}', rhs: {:?}", + "Expected binary operation error between lhs: '{:?}', rhs: {:?}", lhs, rhs ); } @@ -3849,8 +3924,8 @@ mod tests { }; } - expect_operation_error!(expect_add_error, add, "Addition is not implemented"); - expect_operation_error!(expect_sub_error, sub, "Subtraction is not implemented"); + expect_operation_error!(expect_add_error, add, "Operator + is not implemented"); + expect_operation_error!(expect_sub_error, sub, "Operator - is not implemented"); macro_rules! decimal_op_test_cases { ($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => { diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs new file mode 100644 index 000000000000..3545fd270a76 --- /dev/null +++ b/datafusion/common/src/test_util.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility functions to make testing DataFusion based crates easier + +/// A macro to assert that one string is contained within another with +/// a nice error message if they are not. +/// +/// Usage: `assert_contains!(actual, expected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; +} + +/// A macro to assert that one string is NOT contained within another with +/// a nice error message if they are are. +/// +/// Usage: `assert_not_contains!(actual, unexpected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_not_contains { + ($ACTUAL: expr, $UNEXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let unexpected_value: String = $UNEXPECTED.into(); + assert!( + !actual_value.contains(&unexpected_value), + "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", + unexpected_value, + actual_value + ); + }; +} diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 6512db0ae922..616c16b4f02e 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -89,7 +89,7 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -sqlparser = "0.25" +sqlparser = "0.26" tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index f5bd890591fd..0dda94322619 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -1170,7 +1170,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use crate::{ - assert_batches_sorted_eq, assert_contains, + assert_batches_sorted_eq, datasource::file_format::{parquet::ParquetFormat, FileFormat}, physical_plan::collect, }; @@ -1182,6 +1182,7 @@ mod tests { datatypes::{DataType, Field}, }; use chrono::{TimeZone, Utc}; + use datafusion_common::assert_contains; use datafusion_expr::{cast, col, lit}; use futures::StreamExt; use object_store::local::LocalFileSystem; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 1995a6196eed..83ae71d66a18 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -62,7 +62,7 @@ use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::{Between, BinaryExpr, GetIndexedField, GroupingSet, Like}; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::utils::{expand_wildcard, expr_to_columns}; -use datafusion_expr::WindowFrameUnits; +use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use datafusion_optimizer::utils::unalias; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -550,6 +550,16 @@ impl DefaultPhysicalPlanner { ref order_by, .. } => generate_sort_key(partition_by, order_by), + Expr::Alias(expr, _) => { + // Convert &Box to &T + match &**expr { + Expr::WindowFunction { + ref partition_by, + ref order_by, + ..} => generate_sort_key(partition_by, order_by), + _ => unreachable!(), + } + } _ => unreachable!(), }; let sort_keys = get_sort_keys(&window_expr[0]); @@ -1368,6 +1378,26 @@ fn get_physical_expr_pair( Ok((physical_expr, physical_name)) } +/// Check if window bounds are valid after schema information is available, and +/// window_frame bounds are casted to the corresponding column type. +/// queries like: +/// OVER (ORDER BY a RANGES BETWEEN 3 PRECEDING AND 5 PRECEDING) +/// OVER (ORDER BY a RANGES BETWEEN INTERVAL '3 DAY' PRECEDING AND '5 DAY' PRECEDING) are rejected +pub fn is_window_valid(window_frame: &WindowFrame) -> bool { + match (&window_frame.start_bound, &window_frame.end_bound) { + (WindowFrameBound::Following(_), WindowFrameBound::Preceding(_)) + | (WindowFrameBound::Following(_), WindowFrameBound::CurrentRow) + | (WindowFrameBound::CurrentRow, WindowFrameBound::Preceding(_)) => false, + (WindowFrameBound::Preceding(lhs), WindowFrameBound::Preceding(rhs)) => { + !rhs.is_null() && (lhs.is_null() || (lhs >= rhs)) + } + (WindowFrameBound::Following(lhs), WindowFrameBound::Following(rhs)) => { + !lhs.is_null() && (rhs.is_null() || (lhs <= rhs)) + } + _ => true, + } +} + /// Create a window expression with a name from a logical expression pub fn create_window_expr_with_name( e: &Expr, @@ -1429,21 +1459,28 @@ pub fn create_window_expr_with_name( )), }) .collect::>>()?; - if window_frame.is_some() - && window_frame.unwrap().units == WindowFrameUnits::Groups - { - return Err(DataFusionError::NotImplemented( - "Window frame definitions involving GROUPS are not supported yet" - .to_string(), - )); + if let Some(ref window_frame) = window_frame { + if window_frame.units == WindowFrameUnits::Groups { + return Err(DataFusionError::NotImplemented( + "Window frame definitions involving GROUPS are not supported yet" + .to_string(), + )); + } + if !is_window_valid(window_frame) { + return Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + window_frame.start_bound, window_frame.end_bound + ))); + } } + let window_frame = window_frame.clone().map(Arc::new); windows::create_window_expr( fun, name, &args, &partition_by, &order_by, - *window_frame, + window_frame, physical_input_schema, ) } @@ -1675,7 +1712,6 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; use crate::datasource::MemTable; use crate::execution::context::TaskContext; use crate::execution::options::CsvReadOptions; @@ -1690,6 +1726,7 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type, SchemaRef}; use arrow::record_batch::RecordBatch; + use datafusion_common::assert_contains; use datafusion_common::{DFField, DFSchema, DFSchemaRef}; use datafusion_expr::{col, lit, sum, Extension, GroupingSet, LogicalPlanBuilder}; use fmt::Debug; diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 26cb14fe33a9..be9421a9de84 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -51,7 +51,7 @@ pub fn create_window_expr( args: &[Arc], partition_by: &[Arc], order_by: &[PhysicalSortExpr], - window_frame: Option, + window_frame: Option>, input_schema: &Schema, ) -> Result> { Ok(match fun { @@ -186,7 +186,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, create_window_expr( @@ -195,7 +195,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, create_window_expr( @@ -204,7 +204,7 @@ mod tests { &[col("c3", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?, ], @@ -250,7 +250,7 @@ mod tests { &[col("a", &schema)?], &[], &[], - Some(WindowFrame::default()), + Some(Arc::new(WindowFrame::default())), schema.as_ref(), )?], blocking_exec, diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index d92b9db6082c..96ac67bcbef1 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -98,52 +98,6 @@ macro_rules! assert_batches_sorted_eq { }; } -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - -/// A macro to assert that one string is NOT contained within another with -/// a nice error message if they are are. -/// -/// Usage: `assert_not_contains!(actual, unexpected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -#[macro_export] -macro_rules! assert_not_contains { - ($ACTUAL: expr, $UNEXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let unexpected_value: String = $UNEXPECTED.into(); - assert!( - !actual_value.contains(&unexpected_value), - "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", - unexpected_value, - actual_value - ); - }; -} - /// Returns the arrow test data directory, which is by default stored /// in a git submodule rooted at `testing/data`. /// diff --git a/datafusion/core/tests/sql/idenfifers.rs b/datafusion/core/tests/sql/idenfifers.rs index d50e5989cc98..e2fde56e959e 100644 --- a/datafusion/core/tests/sql/idenfifers.rs +++ b/datafusion/core/tests/sql/idenfifers.rs @@ -18,7 +18,8 @@ use std::sync::Arc; use arrow::{array::StringArray, record_batch::RecordBatch}; -use datafusion::{assert_batches_sorted_eq, assert_contains, prelude::*}; +use datafusion::{assert_batches_sorted_eq, prelude::*}; +use datafusion_common::assert_contains; use crate::sql::plan_and_collect; diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 9d1c440a15df..23ea2a2ddcac 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -25,10 +25,6 @@ use arrow::{ use chrono::prelude::*; use chrono::Duration; -use datafusion::assert_batches_eq; -use datafusion::assert_batches_sorted_eq; -use datafusion::assert_contains; -use datafusion::assert_not_contains; use datafusion::datasource::TableProvider; use datafusion::from_slice::FromSlice; use datafusion::logical_expr::{Aggregate, LogicalPlan, Projection, TableScan}; @@ -37,12 +33,14 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::ExecutionPlanVisitor; use datafusion::prelude::*; use datafusion::test_util; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion::{datasource::MemTable, physical_plan::collect}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::ColumnarValue, }; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; +use datafusion_common::{assert_contains, assert_not_contains}; use datafusion_expr::Volatility; use object_store::path::Path; use std::fs::File; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 2708d91e0d75..d9ede9771858 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -523,6 +523,7 @@ async fn window_frame_rows_preceding() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + #[tokio::test] async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> { let ctx = SessionContext::new(); @@ -977,22 +978,22 @@ async fn window_frame_ranges_unbounded_preceding_following() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT \ - SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING), \ - COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) \ + SUM(c2) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as sum1, \ + COUNT(*) OVER (ORDER BY c2 RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as cnt1 \ FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------+-----------------+", - "| SUM(aggregate_test_100.c2) | COUNT(UInt8(1)) |", - "+----------------------------+-----------------+", - "| 285 | 100 |", - "| 123 | 63 |", - "| 285 | 100 |", - "| 123 | 63 |", - "| 123 | 63 |", - "+----------------------------+-----------------+", + "+------+------+", + "| sum1 | cnt1 |", + "+------+------+", + "| 285 | 100 |", + "| 123 | 63 |", + "| 285 | 100 |", + "| 123 | 63 |", + "| 123 | 63 |", + "+------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -1075,6 +1076,95 @@ async fn window_frame_partition_by_order_by_desc() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_frame_range_float() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + SUM(c12) OVER (ORDER BY C12 RANGE BETWEEN 0.2 PRECEDING AND 0.2 FOLLOWING) + FROM aggregate_test_100 + ORDER BY C9 + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------------------+", + "| SUM(aggregate_test_100.c12) |", + "+-----------------------------+", + "| 2.5476701803634296 |", + "| 10.6299412548214 |", + "| 2.5476701803634296 |", + "| 20.349518503437288 |", + "| 21.408674363507753 |", + "+-----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn window_frame_ranges_timestamp() -> Result<()> { + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )])); + + // define data in two partitions + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(TimestampNanosecondArray::from_slice(&[ + 1664264591000000000, + 1664264592000000000, + 1664264592000000000, + 1664264593000000000, + 1664264594000000000, + 1664364594000000000, + 1664464594000000000, + 1664564594000000000, + ]))], + ) + .unwrap(); + + let ctx = SessionContext::new(); + // declare a new context. In spark API, this corresponds to a new spark SQLsession + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + // Register table + ctx.register_table("t", Arc::new(provider)).unwrap(); + + // execute the query + let df = ctx + .sql( + "SELECT + ts, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) AS cnt1, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '0 DAY' PRECEDING AND '0' DAY FOLLOWING) as cnt2, + COUNT(*) OVER (ORDER BY ts RANGE BETWEEN '5' SECOND PRECEDING AND CURRENT ROW) as cnt3 + FROM t + ORDER BY ts" + ) + .await?; + + let actual = df.collect().await?; + let expected = vec![ + "+---------------------+------+------+------+", + "| ts | cnt1 | cnt2 | cnt3 |", + "+---------------------+------+------+------+", + "| 2022-09-27 07:43:11 | 6 | 1 | 1 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:12 | 6 | 2 | 3 |", + "| 2022-09-27 07:43:13 | 6 | 1 | 4 |", + "| 2022-09-27 07:43:14 | 6 | 1 | 5 |", + "| 2022-09-28 11:29:54 | 2 | 1 | 1 |", + "| 2022-09-29 15:16:34 | 2 | 1 | 1 |", + "| 2022-09-30 19:03:14 | 1 | 1 | 1 |", + "+---------------------+------+------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn window_frame_ranges_unbounded_preceding_err() -> Result<()> { let ctx = SessionContext::new(); @@ -1119,3 +1209,50 @@ async fn window_frame_groups_query() -> Result<()> { .contains("Window frame definitions involving GROUPS are not supported yet")); Ok(()) } + +#[tokio::test] +async fn window_frame_creation() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + // execute the query + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN 1 PRECEDING AND 2 PRECEDING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_eq!( + results.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)" + ); + + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN 2 FOLLOWING AND 1 FOLLOWING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_eq!( + results.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (2 FOLLOWING) cannot be larger than end bound (1 FOLLOWING)" + ); + + let df = ctx + .sql( + "SELECT + COUNT(c1) OVER (ORDER BY c2 RANGE BETWEEN '1 DAY' PRECEDING AND '2 DAY' FOLLOWING) + FROM aggregate_test_100;", + ) + .await?; + let results = df.collect().await; + assert_contains!( + results.err().unwrap().to_string(), + "Arrow error: External error: Internal error: Operator - is not implemented for types UInt32(1) and Utf8(\"1 DAY\")" + ); + + Ok(()) +} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index d3b348c4bcb3..afa2be239ffe 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -39,4 +39,4 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { version = "25.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } log = "^0.4" -sqlparser = "0.25" +sqlparser = "0.26" diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs index 4a006ad87a27..02702b113ba3 100644 --- a/datafusion/expr/src/type_coercion.rs +++ b/datafusion/expr/src/type_coercion.rs @@ -62,6 +62,16 @@ pub fn is_numeric(dt: &DataType) -> bool { ) } +/// Determine if a DataType is Timestamp or not +pub fn is_timestamp(dt: &DataType) -> bool { + matches!(dt, DataType::Timestamp(_, _)) +} + +/// Determine if a DataType is Date or not +pub fn is_date(dt: &DataType) -> bool { + matches!(dt, DataType::Date32 | DataType::Date64) +} + pub mod aggregates; pub mod binary; pub mod functions; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 7f9afd0b51a8..5bf81d165db5 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,19 +23,19 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use sqlparser::ast; -use std::cmp::Ordering; +use sqlparser::parser::ParserError::ParserError; use std::convert::{From, TryFrom}; use std::fmt; -use std::hash::{Hash, Hasher}; +use std::hash::Hash; /// The frame-spec determines which output rows are read by an aggregate window function. /// /// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the /// starting frame boundary are also omitted), in which case the ending frame boundary defaults to /// CURRENT ROW. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct WindowFrame { /// A frame type - either ROWS, RANGE or GROUPS pub units: WindowFrameUnits, @@ -60,27 +60,22 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.into(); - let end_bound = value - .end_bound - .map(WindowFrameBound::from) - .unwrap_or(WindowFrameBound::CurrentRow); + let start_bound = value.start_bound.try_into()?; + let end_bound = match value.end_bound { + Some(value) => value.try_into()?, + None => WindowFrameBound::CurrentRow, + }; - if let WindowFrameBound::Following(None) = start_bound { + if let WindowFrameBound::Following(ScalarValue::Utf8(None)) = start_bound { Err(DataFusionError::Execution( "Invalid window frame: start bound cannot be unbounded following" .to_owned(), )) - } else if let WindowFrameBound::Preceding(None) = end_bound { + } else if let WindowFrameBound::Preceding(ScalarValue::Utf8(None)) = end_bound { Err(DataFusionError::Execution( "Invalid window frame: end bound cannot be unbounded preceding" .to_owned(), )) - } else if start_bound > end_bound { - Err(DataFusionError::Execution(format!( - "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", - start_bound, end_bound - ))) } else { let units = value.units.into(); Ok(Self { @@ -96,7 +91,7 @@ impl Default for WindowFrame { fn default() -> Self { WindowFrame { units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), + start_bound: WindowFrameBound::Preceding(ScalarValue::Utf8(None)), end_bound: WindowFrameBound::CurrentRow, } } @@ -110,8 +105,7 @@ impl Default for WindowFrame { /// 4. FOLLOWING /// 5. UNBOUNDED FOLLOWING /// -/// in this implementation we'll only allow to be u64 (i.e. no dynamic boundary) -#[derive(Debug, Clone, Copy, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WindowFrameBound { /// 1. UNBOUNDED PRECEDING /// The frame boundary is the first row in the partition. @@ -119,7 +113,7 @@ pub enum WindowFrameBound { /// 2. PRECEDING /// must be a non-negative constant numeric expression. The boundary is a row that /// is "units" prior to the current row. - Preceding(Option), + Preceding(ScalarValue), /// 3. The current row. /// /// For RANGE and GROUPS frame types, peers of the current row are also @@ -132,70 +126,72 @@ pub enum WindowFrameBound { /// /// 5. UNBOUNDED FOLLOWING /// The frame boundary is the last row in the partition. - Following(Option), + Following(ScalarValue), } -impl From for WindowFrameBound { - fn from(value: ast::WindowFrameBound) -> Self { - match value { - ast::WindowFrameBound::Preceding(v) => Self::Preceding(v), - ast::WindowFrameBound::Following(v) => Self::Following(v), +impl TryFrom for WindowFrameBound { + type Error = DataFusionError; + + fn try_from(value: ast::WindowFrameBound) -> Result { + Ok(match value { + ast::WindowFrameBound::Preceding(Some(v)) => { + Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + } + ast::WindowFrameBound::Preceding(None) => { + Self::Preceding(ScalarValue::Utf8(None)) + } + ast::WindowFrameBound::Following(Some(v)) => { + Self::Following(convert_frame_bound_to_scalar_value(*v)?) + } + ast::WindowFrameBound::Following(None) => { + Self::Following(ScalarValue::Utf8(None)) + } ast::WindowFrameBound::CurrentRow => Self::CurrentRow, - } + }) } } -impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), - WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"), - WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"), - WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n), - WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n), +pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { + Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) + | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, + ast::Expr::Interval { + value, + leading_field, + .. + } => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + let msg = format!("INTERVAL expression cannot be {:?}", e); + return Err(DataFusionError::SQL(ParserError(msg))); + } + }; + if let Some(leading_field) = leading_field { + format!("{} {}", result, leading_field) + } else { + result + } } - } -} - -impl PartialEq for WindowFrameBound { - fn eq(&self, other: &Self) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl PartialOrd for WindowFrameBound { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for WindowFrameBound { - fn cmp(&self, other: &Self) -> Ordering { - self.get_rank().cmp(&other.get_rank()) - } -} - -impl Hash for WindowFrameBound { - fn hash(&self, state: &mut H) { - self.get_rank().hash(state) - } + e => { + let msg = format!("Window frame bound cannot be {:?}", e); + return Err(DataFusionError::Internal(msg)); + } + }))) } -impl WindowFrameBound { - /// get the rank of this window frame bound. - /// - /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value - /// which requires special handling e.g. with preceding the larger the value the smaller the - /// rank and also for 0 preceding / following it is the same as current row - fn get_rank(&self) -> (u8, u64) { +impl fmt::Display for WindowFrameBound { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - WindowFrameBound::Preceding(None) => (0, 0), - WindowFrameBound::Following(None) => (4, 0), - WindowFrameBound::Preceding(Some(0)) - | WindowFrameBound::CurrentRow - | WindowFrameBound::Following(Some(0)) => (2, 0), - WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), - WindowFrameBound::Following(Some(v)) => (3, *v), + WindowFrameBound::Preceding(ScalarValue::Utf8(None)) => { + f.write_str("UNBOUNDED PRECEDING") + } + WindowFrameBound::Preceding(n) => write!(f, "{} PRECEDING", n), + WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"), + WindowFrameBound::Following(ScalarValue::Utf8(None)) => { + f.write_str("UNBOUNDED FOLLOWING") + } + WindowFrameBound::Following(n) => write!(f, "{} FOLLOWING", n), } } } @@ -250,105 +246,34 @@ mod tests { start_bound: ast::WindowFrameBound::Following(None), end_bound: None, }; - let result = WindowFrame::try_from(window_frame); + let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound cannot be unbounded following" - .to_owned() - ); + err.to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Range, start_bound: ast::WindowFrameBound::Preceding(None), end_bound: Some(ast::WindowFrameBound::Preceding(None)), }; - let result = WindowFrame::try_from(window_frame); + let err = WindowFrame::try_from(window_frame).unwrap_err(); assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: end bound cannot be unbounded preceding" - .to_owned() - ); - - let window_frame = ast::WindowFrame { - units: ast::WindowFrameUnits::Range, - start_bound: ast::WindowFrameBound::Preceding(Some(1)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))), - }; - let result = WindowFrame::try_from(window_frame); - assert_eq!( - result.err().unwrap().to_string(), - "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + err.to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() ); let window_frame = ast::WindowFrame { units: ast::WindowFrameUnits::Rows, - start_bound: ast::WindowFrameBound::Preceding(Some(2)), - end_bound: Some(ast::WindowFrameBound::Preceding(Some(1))), + start_bound: ast::WindowFrameBound::Preceding(Some(Box::new( + ast::Expr::Value(ast::Value::Number("2".to_string(), false)), + ))), + end_bound: Some(ast::WindowFrameBound::Preceding(Some(Box::new( + ast::Expr::Value(ast::Value::Number("1".to_string(), false)), + )))), }; let result = WindowFrame::try_from(window_frame); assert!(result.is_ok()); Ok(()) } - - #[test] - fn test_eq() { - assert_eq!( - WindowFrameBound::Preceding(Some(0)), - WindowFrameBound::CurrentRow - ); - assert_eq!( - WindowFrameBound::CurrentRow, - WindowFrameBound::Following(Some(0)) - ); - assert_eq!( - WindowFrameBound::Following(Some(2)), - WindowFrameBound::Following(Some(2)) - ); - assert_eq!( - WindowFrameBound::Following(None), - WindowFrameBound::Following(None) - ); - assert_eq!( - WindowFrameBound::Preceding(Some(2)), - WindowFrameBound::Preceding(Some(2)) - ); - assert_eq!( - WindowFrameBound::Preceding(None), - WindowFrameBound::Preceding(None) - ); - } - - #[test] - fn test_ord() { - assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow); - // ! yes this is correct! - assert!( - WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1)) - ); - assert!( - WindowFrameBound::Preceding(Some(u64::MAX)) - < WindowFrameBound::Preceding(Some(u64::MAX - 1)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(1000000)) - ); - assert!( - WindowFrameBound::Preceding(None) - < WindowFrameBound::Preceding(Some(u64::MAX)) - ); - assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0))); - assert!( - WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1)) - ); - assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1))); - assert!( - WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2)) - ); - assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None)); - assert!( - WindowFrameBound::Following(Some(u64::MAX)) - < WindowFrameBound::Following(None) - ); - } } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 2833eee048b4..ae1327ed1888 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -19,8 +19,10 @@ use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; +use arrow::datatypes::{DataType, IntervalUnit}; +use datafusion_common::{ + parse_interval, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{Between, BinaryExpr, Case, Like}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; @@ -29,10 +31,12 @@ use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_when, get_coerce_type_for_list, }; +use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, + WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; use std::sync::Arc; @@ -72,7 +76,6 @@ fn optimize_internal( .iter() .map(|p| optimize_internal(external_schema, p, optimizer_config)) .collect::>>()?; - // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here let mut schema = new_inputs.iter().map(|input| input.schema()).fold( @@ -410,11 +413,121 @@ impl ExprRewriter for TypeCoercionRewriter { }; Ok(expr) } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let window_frame = + get_coerced_window_frame(window_frame, &self.schema, &order_by)?; + let expr = Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + }; + Ok(expr) + } expr => Ok(expr), } } } +/// Casts the ScalarValue `value` to coerced type. +// When coerced type is `Interval` we use `parse_interval` since `try_from_string` not +// supports conversion from string to Interval +fn convert_to_coerced_type( + coerced_type: &DataType, + value: &ScalarValue, +) -> Result { + match value { + // In here we do casting either for ScalarValue::Utf8(None) or + // ScalarValue::Utf8(Some(val)). The other types are already casted. + // The reason is that we convert the sqlparser result + // to the Utf8 for all possible cases. Hence the types other than Utf8 + // are already casted to appropriate type. Therefore they can be returned directly. + ScalarValue::Utf8(None) => ScalarValue::try_from(coerced_type), + ScalarValue::Utf8(Some(val)) => { + // we need special handling for Interval types + if let DataType::Interval(..) = coerced_type { + parse_interval("millisecond", val) + } else { + ScalarValue::try_from_string(val.clone(), coerced_type) + } + } + s => Ok(s.clone()), + } +} + +fn coerce_frame_bound( + coerced_type: &DataType, + bound: &WindowFrameBound, +) -> Result { + Ok(match bound { + WindowFrameBound::Preceding(val) => { + WindowFrameBound::Preceding(convert_to_coerced_type(coerced_type, val)?) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + WindowFrameBound::Following(val) => { + WindowFrameBound::Following(convert_to_coerced_type(coerced_type, val)?) + } + }) +} + +fn get_coerced_window_frame( + window_frame: Option, + schema: &DFSchemaRef, + expressions: &[Expr], +) -> Result> { + fn get_coerced_type(column_type: &DataType) -> Result { + if is_numeric(column_type) { + Ok(column_type.clone()) + } else if is_timestamp(column_type) || is_date(column_type) { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } else { + Err(DataFusionError::Internal(format!( + "Cannot run range queries on datatype: {:?}", + column_type + ))) + } + } + + if let Some(window_frame) = window_frame { + let mut window_frame = window_frame; + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + match &mut window_frame.units { + WindowFrameUnits::Range => { + let col_type = current_types.first().ok_or_else(|| { + DataFusionError::Internal( + "ORDER BY column cannot be empty".to_string(), + ) + })?; + let coerced_type = get_coerced_type(col_type)?; + window_frame.start_bound = + coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; + window_frame.end_bound = + coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; + } + WindowFrameUnits::Rows | WindowFrameUnits::Groups => { + let coerced_type = DataType::UInt64; + window_frame.start_bound = + coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; + window_frame.end_bound = + coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; + } + } + + Ok(Some(window_frame)) + } else { + Ok(None) + } +} // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result { diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index fa021f61a940..97f038068794 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::delta::shift_months; use crate::PhysicalExpr; use arrow::array::{ Array, ArrayRef, Date32Array, Date64Array, TimestampMicrosecondArray, @@ -27,13 +26,15 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::record_batch::RecordBatch; -use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime}; +use datafusion_common::scalar::{ + date32_add, date64_add, microseconds_add, milliseconds_add, nanoseconds_add, + seconds_add, +}; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; use std::any::Any; use std::fmt::{Display, Formatter}; -use std::ops::{Add, Sub}; use std::sync::Arc; /// Perform DATE/TIME/TIMESTAMP +/ INTERVAL math @@ -117,26 +118,29 @@ impl PhysicalExpr for DateTimeIntervalExpr { // Unwrap interval to add let intervals = match &intervals { ColumnarValue::Scalar(interval) => interval, - _ => Err(DataFusionError::Execution( - "Columnar execution is not yet supported for DateIntervalExpr" - .to_string(), - ))?, + _ => { + let msg = "Columnar execution is not yet supported for DateIntervalExpr"; + return Err(DataFusionError::Execution(msg.to_string())); + } }; // Invert sign for subtraction - let sign = match &self.op { + let sign = match self.op { Operator::Plus => 1, Operator::Minus => -1, _ => { // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - ))? + let msg = "Invalid operator for DateIntervalExpr"; + return Err(DataFusionError::Internal(msg.to_string())); } }; match dates { - ColumnarValue::Scalar(operand) => evaluate_scalar(operand, sign, intervals), + ColumnarValue::Scalar(operand) => Ok(ColumnarValue::Scalar(if sign > 0 { + operand.add(intervals)? + } else { + operand.sub(intervals)? + })), ColumnarValue::Array(array) => evaluate_array(array, sign, intervals), } } @@ -214,138 +218,6 @@ pub fn evaluate_array( Ok(ColumnarValue::Array(ret)) } -fn evaluate_scalar( - operand: ScalarValue, - sign: i32, - scalar: &ScalarValue, -) -> Result { - let res = match operand { - ScalarValue::Date32(Some(days)) => { - let value = date32_add(days, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::Date32(Some(value))) - } - ScalarValue::Date64(Some(ms)) => { - let value = date64_add(ms, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::Date64(Some(value))) - } - ScalarValue::TimestampSecond(Some(ts_s), zone) => { - let value = seconds_add(ts_s, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(value), zone)) - } - ScalarValue::TimestampMillisecond(Some(ts_ms), zone) => { - let value = milliseconds_add(ts_ms, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(Some(value), zone)) - } - ScalarValue::TimestampMicrosecond(Some(ts_us), zone) => { - let value = microseconds_add(ts_us, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(value), zone)) - } - ScalarValue::TimestampNanosecond(Some(ts_ns), zone) => { - let value = nanoseconds_add(ts_ns, scalar, sign)?; - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(value), zone)) - } - _ => Err(DataFusionError::Execution(format!( - "Invalid lhs type {} for DateIntervalExpr", - operand.get_datatype() - )))?, - }; - Ok(res) -} - -#[inline] -fn date32_add(days: i32, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::days(days as i64)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_days() as i32) -} - -#[inline] -fn date64_add(ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let epoch = NaiveDate::from_ymd(1970, 1, 1); - let prior = epoch.add(Duration::milliseconds(ms)); - let posterior = do_date_math(prior, scalar, sign)?; - Ok(posterior.sub(epoch).num_milliseconds()) -} - -#[inline] -fn seconds_add(ts_s: i64, scalar: &ScalarValue, sign: i32) -> Result { - Ok(do_date_time_math(ts_s, 0, scalar, sign)?.timestamp()) -} - -#[inline] -fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ms / 1000; - let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_millis()) -} - -#[inline] -fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_us / 1_000_000; - let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos() / 1000) -} - -#[inline] -fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result { - let secs = ts_ns / 1_000_000_000; - let nsecs = (ts_ns % 1_000_000_000) as u32; - Ok(do_date_time_math(secs, nsecs, scalar, sign)?.timestamp_nanos()) -} - -#[inline] -fn do_date_time_math( - secs: i64, - nsecs: u32, - scalar: &ScalarValue, - sign: i32, -) -> Result { - let prior = NaiveDateTime::from_timestamp(secs, nsecs); - do_date_math(prior, scalar, sign) -} - -fn do_date_math(prior: D, scalar: &ScalarValue, sign: i32) -> Result -where - D: Datelike + Add, -{ - Ok(match scalar { - ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), - ScalarValue::IntervalYearMonth(Some(i)) => shift_months(prior, *i * sign), - ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), - other => Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - )))?, - }) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_m_d_nano(prior: D, interval: i128, sign: i32) -> D -where - D: Datelike + Add, -{ - let interval = interval as u128; - let nanos = (interval >> 64) as i64 * sign as i64; - let days = (interval >> 32) as i32 * sign; - let months = interval as i32 * sign; - let a = shift_months(prior, months); - let b = a.add(Duration::days(days as i64)); - b.add(Duration::nanoseconds(nanos)) -} - -// Can remove once https://github.com/apache/arrow-rs/pull/2031 is released -fn add_day_time(prior: D, interval: i64, sign: i32) -> D -where - D: Datelike + Add, -{ - let interval = interval as u64; - let days = (interval >> 32) as i32 * sign; - let ms = interval as i32 * sign; - let intermediate = prior.add(Duration::days(days as i64)); - intermediate.add(Duration::milliseconds(ms as i64)) -} - #[cfg(test)] mod tests { use super::*; @@ -353,8 +225,11 @@ mod tests { use crate::execution_props::ExecutionProps; use arrow::array::{ArrayRef, Date32Builder}; use arrow::datatypes::*; + use chrono::{Duration, NaiveDate}; + use datafusion_common::delta::shift_months; use datafusion_common::{Column, Result, ToDFSchema}; use datafusion_expr::Expr; + use std::ops::Add; #[test] fn add_11_months() { @@ -403,8 +278,7 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date32(Some(0))); let op = Operator::Plus; - let interval = create_day_time(1, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(1, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -454,8 +328,8 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date64(Some(0))); let op = Operator::Plus; - let interval = create_day_time(-15, -24 * 60 * 60 * 1000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = + Expr::Literal(ScalarValue::new_interval_dt(-15, -24 * 60 * 60 * 1000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -505,10 +379,7 @@ mod tests { // setup let dt = Expr::Literal(ScalarValue::Date32(Some(0))); let op = Operator::Plus; - - let interval = create_month_day_nano(-12, -15, -42); - - let interval = Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_mdn(-12, -15, -42)); // exercise let res = exercise(&dt, op, &interval)?; @@ -534,8 +405,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Plus; - let interval = create_day_time(0, 1); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 1)); // exercise let res = exercise(&dt, op, &interval)?; @@ -558,8 +428,7 @@ mod tests { let now_ts_s = chrono::Utc::now().timestamp(); let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); let op = Operator::Plus; - let interval = create_day_time(0, 2 * 3600 * 1_000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 2 * 3600 * 1_000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -582,8 +451,7 @@ mod tests { let now_ts_s = chrono::Utc::now().timestamp(); let dt = Expr::Literal(ScalarValue::TimestampSecond(Some(now_ts_s), None)); let op = Operator::Minus; - let interval = create_day_time(0, 4 * 3600 * 1_000); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(0, 4 * 3600 * 1_000)); // exercise let res = exercise(&dt, op, &interval)?; @@ -606,8 +474,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Plus; - let interval = create_day_time(8, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(8, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -630,8 +497,7 @@ mod tests { let now_ts_ns = chrono::Utc::now().timestamp_nanos(); let dt = Expr::Literal(ScalarValue::TimestampNanosecond(Some(now_ts_ns), None)); let op = Operator::Minus; - let interval = create_day_time(16, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(16, 0)); // exercise let res = exercise(&dt, op, &interval)?; @@ -660,8 +526,7 @@ mod tests { let props = ExecutionProps::new(); let dt = Expr::Column(Column::from_name("a")); - let interval = create_day_time(26, 0); - let interval = Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))); + let interval = Expr::Literal(ScalarValue::new_interval_dt(26, 0)); let op = Operator::Plus; let lhs = create_physical_expr(&dt, &dfs, &schema, &props)?; @@ -754,29 +619,4 @@ mod tests { let res = cut.evaluate(&batch)?; Ok(res) } - - // Can remove once https://github.com/apache/arrow-rs/pull/2031 is released - - /// Creates an IntervalDayTime given its constituent components - /// - /// https://github.com/apache/arrow-rs/blob/e59b023480437f67e84ba2f827b58f78fd44c3a1/integration-testing/src/lib.rs#L222 - fn create_day_time(days: i32, millis: i32) -> i64 { - let m = millis as u64 & u32::MAX as u64; - let d = (days as u64 & u32::MAX as u64) << 32; - (m | d) as i64 - } - - // Can remove once https://github.com/apache/arrow-rs/pull/2031 is released - /// Creates an IntervalMonthDayNano given its constituent components - /// - /// Source: https://github.com/apache/arrow-rs/blob/e59b023480437f67e84ba2f827b58f78fd44c3a1/integration-testing/src/lib.rs#L340 - /// ((nanoseconds as i128) & 0xFFFFFFFFFFFFFFFF) << 64 - /// | ((days as i128) & 0xFFFFFFFF) << 32 - /// | ((months as i128) & 0xFFFFFFFF); - fn create_month_day_nano(months: i32, days: i32, nanos: i64) -> i128 { - let m = months as u128 & u32::MAX as u128; - let d = (days as u128 & u32::MAX as u128) << 32; - let n = (nanos as u128) << 64; - (m | d | n) as i128 - } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 208e6d0b51fb..ffbefbd3fd52 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,6 @@ mod case; mod cast; mod column; mod datetime; -mod delta; mod get_indexed_field; mod in_list; mod is_not_null; diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index b81c0eaf243b..80cb4d10ce1a 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -29,7 +29,6 @@ use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::bisect::bisect; -use datafusion_common::scalar::TryFromValue; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrameBound}; @@ -44,7 +43,7 @@ pub struct AggregateWindowExpr { aggregate: Arc, partition_by: Vec>, order_by: Vec, - window_frame: Option, + window_frame: Option>, } impl AggregateWindowExpr { @@ -53,7 +52,7 @@ impl AggregateWindowExpr { aggregate: Arc, partition_by: &[Arc], order_by: &[PhysicalSortExpr], - window_frame: Option, + window_frame: Option>, ) -> Self { Self { aggregate, @@ -66,7 +65,7 @@ impl AggregateWindowExpr { /// create a new accumulator based on the underlying aggregation function fn create_accumulator(&self) -> Result { let accumulator = self.aggregate.create_accumulator()?; - let window_frame = self.window_frame; + let window_frame = self.window_frame.clone(); let partition_by = self.partition_by().to_vec(); let order_by = self.order_by.to_vec(); let field = self.aggregate.field()?; @@ -144,15 +143,13 @@ fn calculate_index_of_row( range_columns: &[ArrayRef], sort_options: &[SortOptions], idx: usize, - delta: u64, + delta: Option<&ScalarValue>, ) -> Result { let current_row_values = range_columns .iter() .map(|col| ScalarValue::try_from_array(col, idx)) .collect::>>()?; - let end_range = if delta == 0 { - current_row_values - } else { + let end_range = if let Some(delta) = delta { let is_descending: bool = sort_options .first() .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))? @@ -163,19 +160,23 @@ fn calculate_index_of_row( .map(|value| { if value.is_null() { return Ok(value.clone()); - }; - let offset = ScalarValue::try_from_value(&value.get_datatype(), delta)?; + } if SEARCH_SIDE == is_descending { // TODO: Handle positive overflows - value.add(&offset) - } else if value.is_unsigned() && value < &offset { - ScalarValue::try_from_value(&value.get_datatype(), 0) + value.add(delta) + } else if value.is_unsigned() && value < delta { + // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. + // If we decide to implement a "default" construction mechanism for ScalarValue, + // change the following statement to use that. + value.sub(value) } else { // TODO: Handle negative overflows - value.sub(&offset) + value.sub(delta) } }) .collect::>>()? + } else { + current_row_values }; // `BISECT_SIDE` true means bisect_left, false means bisect_right bisect::(range_columns, &end_range, sort_options) @@ -192,116 +193,118 @@ fn calculate_current_window( ) -> Result<(usize, usize)> { match window_frame.units { WindowFrameUnits::Range => { - let start = match window_frame.start_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => Ok(0), - WindowFrameBound::Preceding(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) + let start = match &window_frame.start_bound { + WindowFrameBound::Preceding(n) => { + if n.is_null() { + // UNBOUNDED PRECEDING + Ok(0) + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ) + } } WindowFrameBound::CurrentRow => calculate_index_of_row::( range_columns, sort_options, idx, - 0, + None, + ), + WindowFrameBound::Following(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), ), - WindowFrameBound::Following(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => { - Err(DataFusionError::Internal(format!( - "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", - window_frame - ))) - } }; - let end = match window_frame.end_bound { - // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => { - Err(DataFusionError::Internal(format!( - "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", - window_frame - ))) - } - WindowFrameBound::Preceding(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) - } + let end = match &window_frame.end_bound { + WindowFrameBound::Preceding(n) => calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ), WindowFrameBound::CurrentRow => calculate_index_of_row::( range_columns, sort_options, idx, - 0, + None, ), - WindowFrameBound::Following(Some(n)) => { - calculate_index_of_row::( - range_columns, - sort_options, - idx, - n, - ) + WindowFrameBound::Following(n) => { + if n.is_null() { + // UNBOUNDED FOLLOWING + Ok(length) + } else { + calculate_index_of_row::( + range_columns, + sort_options, + idx, + Some(n), + ) + } } - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => Ok(length), }; Ok((start?, end?)) } WindowFrameUnits::Rows => { let start = match window_frame.start_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => Ok(0), - WindowFrameBound::Preceding(Some(n)) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => Ok(0), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= n as usize { Ok(idx - n as usize) } else { Ok(0) } } + WindowFrameBound::Preceding(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } WindowFrameBound::CurrentRow => Ok(idx), - WindowFrameBound::Following(Some(n)) => Ok(min(idx + n as usize, length)), // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => { + WindowFrameBound::Following(ScalarValue::UInt64(None)) => { Err(DataFusionError::Internal(format!( "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'", window_frame ))) } + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { + Ok(min(idx + n as usize, length)) + } + WindowFrameBound::Following(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } }; let end = match window_frame.end_bound { // UNBOUNDED PRECEDING - WindowFrameBound::Preceding(None) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { Err(DataFusionError::Internal(format!( "Frame end cannot be UNBOUNDED PRECEDING '{:?}'", window_frame ))) } - WindowFrameBound::Preceding(Some(n)) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= n as usize { Ok(idx - n as usize + 1) } else { Ok(0) } } + WindowFrameBound::Preceding(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } WindowFrameBound::CurrentRow => Ok(idx + 1), - WindowFrameBound::Following(Some(n)) => { + // UNBOUNDED FOLLOWING + WindowFrameBound::Following(ScalarValue::UInt64(None)) => Ok(length), + WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { Ok(min(idx + n as usize + 1, length)) } - // UNBOUNDED FOLLOWING - WindowFrameBound::Following(None) => Ok(length), + WindowFrameBound::Following(_) => { + Err(DataFusionError::Internal("Rows should be Uint".to_string())) + } }; Ok((start?, end?)) } @@ -317,23 +320,15 @@ fn calculate_current_window( #[derive(Debug)] struct AggregateWindowAccumulator { accumulator: Box, - window_frame: Option, + window_frame: Option>, partition_by: Vec>, order_by: Vec, field: Field, } impl AggregateWindowAccumulator { - /// This function constructs a simple window frame with a single ORDER BY. - fn implicit_order_by_window() -> WindowFrame { - WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(None), - end_bound: WindowFrameBound::Following(Some(0)), - } - } /// This function calculates the aggregation on all rows in `value_slice`. - /// Returns an array of size `len`. + /// Returns an array of size `length`. fn calculate_whole_table( &mut self, value_slice: &[ArrayRef], @@ -433,20 +428,23 @@ impl AggregateWindowAccumulator { .map(|v| v.slice(value_range.start, length)) .collect::>(); let order_columns = &order_bys[self.partition_by.len()..order_bys.len()].to_vec(); - match (order_columns.len(), self.window_frame) { - (0, None) => { + match (&order_columns[..], &self.window_frame) { + ([], None) => { // OVER () case self.calculate_whole_table(&value_slice, length) } - (_n, None) => { + ([column, ..], None) => { // OVER (ORDER BY a) case // We create an implicit window for ORDER BY. - self.window_frame = - Some(AggregateWindowAccumulator::implicit_order_by_window()); - + let empty_bound = ScalarValue::try_from(column.data_type())?; + self.window_frame = Some(Arc::new(WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(empty_bound), + end_bound: WindowFrameBound::CurrentRow, + })); self.calculate_running_window(&value_slice, order_columns, value_range) } - (0, Some(frame)) => { + ([], Some(frame)) => { match frame.units { WindowFrameUnits::Range => { // OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) case @@ -466,9 +464,7 @@ impl AggregateWindowAccumulator { } } // OVER (ORDER BY a ROWS/RANGE BETWEEN X PRECEDING AND Y FOLLOWING) case - (_n, _) => { - self.calculate_running_window(&value_slice, order_columns, value_range) - } + _ => self.calculate_running_window(&value_slice, order_columns, value_range), } } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d61f52ee7bb2..ca547b251cf5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -638,11 +638,7 @@ enum WindowFrameBoundType { message WindowFrameBound { WindowFrameBoundType window_frame_bound_type = 1; - // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/tokio-rs/prost/issues/430 and https://github.com/tokio-rs/prost/pull/455) - // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) - oneof bound_value { - uint64 value = 2; - } + ScalarValue bound_value = 2; } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 79b477b3e1ef..b7bb5ec8209c 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -1293,12 +1293,12 @@ impl TryFrom for WindowFrameBound { protobuf::WindowFrameBoundType::Preceding => { // FIXME implement bound value parsing // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Preceding(Some(1))) + Ok(Self::Preceding(ScalarValue::UInt64(Some(1)))) } protobuf::WindowFrameBoundType::Following => { // FIXME implement bound value parsing // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Following(Some(1))) + Ok(Self::Following(ScalarValue::UInt64(Some(1)))) } } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index f8dab779b405..6c4fb69e2f7d 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -389,9 +389,11 @@ impl From for protobuf::WindowFrameUnits { } } -impl From for protobuf::WindowFrameBound { - fn from(bound: WindowFrameBound) -> Self { - match bound { +impl TryFrom<&WindowFrameBound> for protobuf::WindowFrameBound { + type Error = Error; + + fn try_from(bound: &WindowFrameBound) -> Result { + Ok(match bound { WindowFrameBound::CurrentRow => Self { window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow .into(), @@ -399,25 +401,27 @@ impl From for protobuf::WindowFrameBound { }, WindowFrameBound::Preceding(v) => Self { window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), - bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + bound_value: Some(v.try_into()?), }, WindowFrameBound::Following(v) => Self { window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), - bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + bound_value: Some(v.try_into()?), }, - } + }) } } -impl From for protobuf::WindowFrame { - fn from(window: WindowFrame) -> Self { - Self { +impl TryFrom<&WindowFrame> for protobuf::WindowFrame { + type Error = Error; + + fn try_from(window: &WindowFrame) -> Result { + Ok(Self { window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), - start_bound: Some(window.start_bound.into()), + start_bound: Some((&window.start_bound).try_into()?), end_bound: Some(protobuf::window_frame::EndBound::Bound( - window.end_bound.into(), + (&window.end_bound).try_into()?, )), - } + }) } } @@ -528,9 +532,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .iter() .map(|e| e.try_into()) .collect::, _>>()?; - let window_frame = window_frame.map(|window_frame| { - protobuf::window_expr_node::WindowFrame::Frame(window_frame.into()) - }); + + let window_frame = match window_frame { + Some(frame) => Some( + protobuf::window_expr_node::WindowFrame::Frame(frame.try_into()?) + ), + None => None + }; let window_expr = Box::new(protobuf::WindowExprNode { expr: arg_expr, window_function: Some(window_function), diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 4a88a01e31b0..44f3860f676a 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -40,4 +40,4 @@ unicode_expressions = [] arrow = { version = "25.0.0", default-features = false } datafusion-common = { path = "../common", version = "13.0.0" } datafusion-expr = { path = "../expr", version = "13.0.0" } -sqlparser = "0.25" +sqlparser = "0.26" diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 20d3ed2f72be..19404419b26c 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -18,7 +18,6 @@ //! This module provides a SQL parser that translates SQL queries into an abstract syntax //! tree (AST), and a SQL query planner that creates a logical plan from the AST. -mod interval; pub mod parser; pub mod planner; mod table_reference; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 92264f06023f..df1a20339b15 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -17,9 +17,9 @@ //! SQL Query Planner (produces logical plan from SQL AST) -use crate::interval::parse_interval; use crate::parser::{CreateExternalTable, DescribeTable, Statement as DFStatement}; use arrow::datatypes::*; +use datafusion_common::parsers::parse_interval; use datafusion_common::{context, ToDFSchema}; use datafusion_expr::expr_rewriter::normalize_col; use datafusion_expr::expr_rewriter::normalize_col_with_schemas; @@ -55,17 +55,20 @@ use datafusion_expr::expr::{Between, BinaryExpr, Case, GroupingSet, Like}; use datafusion_expr::logical_plan::builder::project_with_alias; use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::Expr::Alias; +use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator, ShowCreateObject, ShowStatementFilter, TableAlias, TableFactor, TableWithJoins, - TimezoneInfo, TrimWhereField, UnaryOperator, Value, Values as SQLValues, + TrimWhereField, UnaryOperator, Value, Values as SQLValues, }; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; +use sqlparser::ast::ExactNumberInfo; + use super::{ parser::DFParser, utils::{ @@ -154,6 +157,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { analyze, format: _, describe_alias: _, + .. } => self.explain_statement_to_plan(verbose, analyze, *statement), Statement::Query(query) => self.query_to_plan(*query, &mut HashMap::new()), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), @@ -248,6 +252,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_exists, names, cascade: _, + restrict: _, purge: _, // We don't support cascade and purge for now. // nor do we support multiple object names @@ -1749,7 +1754,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }), SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), - SQLExpr::Interval { value, leading_field, @@ -1763,7 +1767,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { last_field, fractional_seconds_precision, ), - SQLExpr::Identifier(id) => { if id.value.starts_with('@') { // TODO: figure out if ScalarVariables should be insensitive. @@ -2245,6 +2248,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + SQLExpr::Floor{expr, field: _field} => { + let fun = BuiltinScalarFunction::Floor; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + Ok(Expr::ScalarFunction { fun, args }) + } + + SQLExpr::Ceil{expr, field: _field} => { + let fun = BuiltinScalarFunction::Ceil; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + Ok(Expr::ScalarFunction { fun, args }) + } + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, ctes), SQLExpr::Exists{ subquery, negated } => self.parse_exists_subquery(&subquery, negated, schema, ctes), @@ -2756,7 +2771,16 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { ))) } } - SQLDataType::Decimal(precision, scale) => make_decimal_type(*precision, *scale), + SQLDataType::Decimal(exact_number_info) => { + let (precision, scale) = match *exact_number_info { + ExactNumberInfo::None => (None, None), + ExactNumberInfo::Precision(precision) => (Some(precision), None), + ExactNumberInfo::PrecisionAndScale(precision, scale) => { + (Some(precision), Some(scale)) + } + }; + make_decimal_type(precision, scale) + } SQLDataType::Bytea => Ok(DataType::Binary), // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade @@ -2775,6 +2799,11 @@ pub fn convert_simple_data_type(sql_type: &SQLDataType) -> Result { | SQLDataType::Set(_) | SQLDataType::MediumInt(_) | SQLDataType::UnsignedMediumInt(_) + | SQLDataType::Character(_) + | SQLDataType::CharacterVarying(_) + | SQLDataType::CharVarying(_) + | SQLDataType::CharacterLargeObject(_) + | SQLDataType::CharLargeObject(_) | SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", sql_type @@ -2814,7 +2843,7 @@ fn parse_sql_number(n: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; + use datafusion_common::assert_contains; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use std::any::Any; @@ -5384,29 +5413,6 @@ mod tests { } } - /// A macro to assert that one string is contained within another with - /// a nice error message if they are not. - /// - /// Usage: `assert_contains!(actual, expected)` - /// - /// Is a macro so test error - /// messages are on the same line as the failure; - /// - /// Both arguments must be convertable into Strings (Into) - #[macro_export] - macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; - } - struct EmptyTable { table_schema: SchemaRef, } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index d138897a84ed..550c5ee42b2d 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -195,7 +195,7 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, - window_frame: *window_frame, + window_frame: window_frame.clone(), }), Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF { fun: fun.clone(), @@ -480,6 +480,16 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr .iter() .map(|expr| match expr { Expr::WindowFunction { partition_by, .. } => Ok(partition_by), + Expr::Alias(expr, _) => { + // convert &Box to &T + match &**expr { + Expr::WindowFunction { partition_by, .. } => Ok(partition_by), + expr => Err(DataFusionError::Execution(format!( + "Impossibly got non-window expr {:?}", + expr + ))), + } + } expr => Err(DataFusionError::Execution(format!( "Impossibly got non-window expr {:?}", expr