From 6cfa75429bc623a23c8fcee5ec5d68b721b8cbde Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Mon, 4 Dec 2023 16:45:03 -0500 Subject: [PATCH] feat(kdl): kdl_match and kdl_select functions (#2145) (#2147) --- Cargo.lock | 35 ++++ crates/sqlexec/Cargo.toml | 1 + crates/sqlexec/src/functions.rs | 192 ++++++++++++++---- .../sqllogictests/functions/json_scan.slt | 3 +- testdata/sqllogictests/functions/kdl.slt | 47 +++++ 5 files changed, 242 insertions(+), 36 deletions(-) create mode 100644 testdata/sqllogictests/functions/kdl.slt diff --git a/Cargo.lock b/Cargo.lock index 509900026..fe38edb07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3581,6 +3581,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "kdl" +version = "5.0.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11eecaf214881ac67e858089a815fa8cad941abcf21113970a6c5bc6e6e557c8" +dependencies = [ + "miette", + "nom", + "thiserror", +] + [[package]] name = "keyed_priority_queue" version = "0.4.1" @@ -4131,6 +4142,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "miette" +version = "5.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59bb584eaeeab6bd0226ccf3509a69d7936d148cf3d036ad350abe35e8c6856e" +dependencies = [ + "miette-derive", + "once_cell", + "thiserror", + "unicode-width", +] + +[[package]] +name = "miette-derive" +version = "5.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "mime" version = "0.3.17" @@ -6803,6 +6837,7 @@ dependencies = [ "datasources", "futures", "itertools 0.12.0", + "kdl", "logutil", "metastore", "num_cpus", diff --git a/crates/sqlexec/Cargo.toml b/crates/sqlexec/Cargo.toml index 926c4b76f..59b277ee8 100644 --- a/crates/sqlexec/Cargo.toml +++ b/crates/sqlexec/Cargo.toml @@ -36,6 +36,7 @@ tokio-postgres = "0.7.8" once_cell = "1.18.0" url.workspace = true parking_lot = "0.12.1" +kdl = "5.0.0-alpha.1" serde = { workspace = true } itertools = "0.12.0" reqwest = { version = "0.11.22", default-features = false, features = ["json"] } diff --git a/crates/sqlexec/src/functions.rs b/crates/sqlexec/src/functions.rs index ac1ee29ac..5e511ff36 100644 --- a/crates/sqlexec/src/functions.rs +++ b/crates/sqlexec/src/functions.rs @@ -4,6 +4,7 @@ use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Signature, TypeSignature, Volatility}; use datafusion::prelude::Expr; +use kdl::{KdlDocument, KdlNode, KdlQuery}; use sqlbuiltins::builtins::POSTGRES_SCHEMA; use std::str::FromStr; use std::sync::Arc; @@ -30,12 +31,49 @@ pub enum BuiltinScalarFunction { /// postgres functions /// All of these functions are in the `pg_catalog` schema. Pg(BuiltinPostgresFunctions), + + // KdlMatches and KdlSelect (kdl_matches and kdl_select) allow for + // accessing KDL documents using the KQL (a CSS-inspired selector + // langauge and an analog to XPath) language. Matches is a + // predicate and can be used in `WHERE` statements while Select is + // a projection operator. + KdlMatches, + KdlSelect, +} + +impl BuiltinScalarFunction { + pub fn find_function(name: &str) -> Option { + Self::from_str(name).ok() + } + pub fn into_expr(self, args: Vec) -> Expr { + match self { + Self::ConnectionId => string_var("connection_id"), + Self::Version => string_var("version"), + Self::Pg(pg) => pg.into_expr(args), + Self::KdlMatches => udf_to_expr(kdl_matches(), args), + Self::KdlSelect => udf_to_expr(kdl_select(), args), + } + } +} + +impl FromStr for BuiltinScalarFunction { + type Err = datafusion::common::DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "connection_id" => Ok(Self::ConnectionId), + "version" => Ok(Self::Version), + "kdl_matches" => Ok(Self::KdlMatches), + "kdl_select" => Ok(Self::KdlSelect), + s => BuiltinPostgresFunctions::from_str(s).map(Self::Pg), + } + } } #[derive(Debug, Copy, Clone)] #[cfg_attr(test, derive(PartialEq))] pub enum BuiltinPostgresFunctions { - /// SQL function `pg_get_userbyid` + /// SQL function `pg_userbyid` /// /// `pg_get_userbyid(userid int)` -> `String` /// ```sql @@ -43,7 +81,7 @@ pub enum BuiltinPostgresFunctions { /// ``` GetUserById, /// SQL function `pg_table_is_visible` - /// + /// /// `pg_table_is_visible(table_oid int)` -> `Boolean` /// ```sql /// select pg_table_is_visible(1); @@ -69,7 +107,7 @@ pub enum BuiltinPostgresFunctions { /// ```sql HasSchemaPrivilege, /// SQL function `has_database_privilege` - /// + /// /// `has_database_privilege(user_name text, database_name text, privilege text) -> Boolean` /// ```sql /// select has_database_privilege('foo', 'bar', 'baz'); @@ -133,6 +171,12 @@ pub enum BuiltinPostgresFunctions { CurrentCatalog, } +impl From for BuiltinScalarFunction { + fn from(f: BuiltinPostgresFunctions) -> Self { + Self::Pg(f) + } +} + impl BuiltinPostgresFunctions { fn into_expr(self, args: Vec) -> Expr { match self { @@ -170,25 +214,6 @@ impl BuiltinPostgresFunctions { } } -impl BuiltinScalarFunction { - pub fn find_function(name: &str) -> Option { - Self::from_str(name).ok() - } - pub fn into_expr(self, args: Vec) -> Expr { - match self { - Self::ConnectionId => string_var("connection_id"), - Self::Version => string_var("version"), - Self::Pg(pg) => pg.into_expr(args), - } - } -} - -impl From for BuiltinScalarFunction { - fn from(f: BuiltinPostgresFunctions) -> Self { - Self::Pg(f) - } -} - impl FromStr for BuiltinPostgresFunctions { type Err = datafusion::common::DataFusionError; @@ -230,18 +255,6 @@ impl FromStr for BuiltinPostgresFunctions { } } -impl FromStr for BuiltinScalarFunction { - type Err = datafusion::common::DataFusionError; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "connection_id" => Ok(Self::ConnectionId), - "version" => Ok(Self::Version), - s => BuiltinPostgresFunctions::from_str(s).map(Self::Pg), - } - } -} - fn udf_to_expr(udf: ScalarUDF, args: Vec) -> Expr { Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new( udf.into(), @@ -249,6 +262,116 @@ fn udf_to_expr(udf: ScalarUDF, args: Vec) -> Expr { )) } +fn kdl_matches() -> ScalarUDF { + ScalarUDF { + name: "kdl_matches".to_string(), + signature: Signature::new( + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + ]), + Volatility::Immutable, + ), + return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))), + fun: Arc::new(move |input| { + let (doc, filter) = kdl_parse_udf_args(input)?; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + doc.query(filter) + .map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string())) + .map(|val| val.is_some())?, + )))) + }), + } +} + +fn kdl_select() -> ScalarUDF { + ScalarUDF { + name: "kdl_select".to_string(), + signature: Signature::new( + // args: , + TypeSignature::OneOf(vec![ + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + ]), + Volatility::Immutable, + ), + return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))), + fun: Arc::new(move |input| { + let (sdoc, filter) = kdl_parse_udf_args(input)?; + + let out: Vec<&KdlNode> = sdoc + .query_all(filter) + .map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string())) + .map(|iter| iter.collect())?; + + let mut doc = sdoc.clone(); + let elems = doc.nodes_mut(); + elems.clear(); + for item in &out { + elems.push(item.to_owned().clone()) + } + + // TODO: consider if we should always return LargeUtf8? + // could end up with truncation (or an error) the document + // is too long and we write the data to a table that is + // established (and mostly) shorter values. + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + doc.to_string(), + )))) + }), + } +} + +fn kdl_parse_udf_args(args: &[ColumnarValue]) -> Result<(KdlDocument, KdlQuery)> { + // parse the filter first, because it's probably shorter and + // erroring earlier would be preferable to parsing a large that we + // don't need/want. + let filter: kdl::KdlQuery = match get_nth_scalar_value(args, 1) { + Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => { + val.parse().map_err(|err: kdl::KdlError| { + datafusion::common::DataFusionError::Execution(err.to_string()) + })? + } + Some(val) => { + return Err(datafusion::common::DataFusionError::Execution(format!( + "invalid type for KQL expression {}", + val.data_type(), + ))) + } + None => { + return Err(datafusion::common::DataFusionError::Execution( + "missing KQL query".to_string(), + )) + } + }; + + let doc: kdl::KdlDocument = match get_nth_scalar_value(args, 0) { + Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => { + val.parse().map_err(|err: kdl::KdlError| { + datafusion::common::DataFusionError::Execution(err.to_string()) + })? + } + Some(val) => { + return Err(datafusion::common::DataFusionError::Execution(format!( + "invalid type for KDL value {}", + val.data_type(), + ))) + } + None => { + return Err(datafusion::common::DataFusionError::Execution( + "invalid field for KDL".to_string(), + )) + } + }; + + Ok((doc, filter)) +} + fn pg_get_userbyid() -> ScalarUDF { ScalarUDF { name: "pg_get_userbyid".to_string(), @@ -391,6 +514,7 @@ mod tests { ("connection_id", ConnectionId), ("current_schemas", CurrentSchemas.into()), ("current_catalog", CurrentCatalog.into()), + ("kdl_matches", KdlMatches), ("pg_get_userbyid", GetUserById.into()), ("pg_table_is_visible", TableIsVisible.into()), ("pg_encoding_to_char", EncodingToChar.into()), diff --git a/testdata/sqllogictests/functions/json_scan.slt b/testdata/sqllogictests/functions/json_scan.slt index d5f9f4ead..104388f53 100644 --- a/testdata/sqllogictests/functions/json_scan.slt +++ b/testdata/sqllogictests/functions/json_scan.slt @@ -40,7 +40,7 @@ select * from ndjson_scan([]); # Glob patterns not supported on HTTP -statement error Unexpected status code '404 Not Found' +statement error Unexpected status code '404 Not Found' select * from ndjson_scan( 'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson' ); @@ -49,4 +49,3 @@ statement error Note that globbing is not supported for HTTP. select * from ndjson_scan( 'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson' ); - \ No newline at end of file diff --git a/testdata/sqllogictests/functions/kdl.slt b/testdata/sqllogictests/functions/kdl.slt new file mode 100644 index 000000000..38c4bf03b --- /dev/null +++ b/testdata/sqllogictests/functions/kdl.slt @@ -0,0 +1,47 @@ +statement ok +create table cows (id int, docs string); + +statement ok + +insert into cows values (0, 'name "betsy" status="amazing" age=89; details "classic" a=1 b=100;'); +insert into cows values (1, 'name "zabby" status="electric" age=120; details "zesty" a=1 b=400;'); + +statement error +select * from cows where kdl_matches('name == betsy;', cows.docs); + +statement error +select * from cows where kdl_matches(cows.docs, '+;'); + +statement error +select kdl_select('foo', cows.docs) from cows where kdl_matches(cows.docs, '[b=100]'); + +statement error +select kdl_select(cows.docs, '?$*;') from cows where kdl_matches(cows.docs, '[b=100]'); + +statement error +select count(*) from cows where kdl_matches(cows.docs, 'details[a] == 100;'); + +query I +select count(*) from cows where kdl_matches(cows.docs, '[a]'); +---- +2 + +query I +select count(*) from cows where kdl_matches(cows.docs, '[b=100]'); +---- +1 + +query I +select count(*) from cows where kdl_matches(cows.docs, '[status]'); +---- +2 + +query +select cows.docs from cows where kdl_matches(cows.docs, '[b=100]'); +---- +name "betsy" status="amazing" age=89; details "classic" a=1 b=100; + +query +select kdl_select(cows.docs, '[age=120]') from cows where kdl_matches(cows.docs, '[b=100]'); +---- +(empty) \ No newline at end of file