Skip to content

Commit

Permalink
feat(kdl): kdl_match and kdl_select functions (#2145) (#2147)
Browse files Browse the repository at this point in the history
  • Loading branch information
tychoish authored Dec 4, 2023
1 parent 777dfd9 commit 6cfa754
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 36 deletions.
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/sqlexec/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tokio-postgres = "0.7.8"
once_cell = "1.18.0"
url.workspace = true
parking_lot = "0.12.1"
kdl = "5.0.0-alpha.1"
serde = { workspace = true }
itertools = "0.12.0"
reqwest = { version = "0.11.22", default-features = false, features = ["json"] }
Expand Down
192 changes: 158 additions & 34 deletions crates/sqlexec/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use datafusion::common::ScalarValue;
use datafusion::error::Result;
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Signature, TypeSignature, Volatility};
use datafusion::prelude::Expr;
use kdl::{KdlDocument, KdlNode, KdlQuery};
use sqlbuiltins::builtins::POSTGRES_SCHEMA;
use std::str::FromStr;
use std::sync::Arc;
Expand All @@ -30,20 +31,57 @@ pub enum BuiltinScalarFunction {
/// postgres functions
/// All of these functions are in the `pg_catalog` schema.
Pg(BuiltinPostgresFunctions),

// KdlMatches and KdlSelect (kdl_matches and kdl_select) allow for
// accessing KDL documents using the KQL (a CSS-inspired selector
// langauge and an analog to XPath) language. Matches is a
// predicate and can be used in `WHERE` statements while Select is
// a projection operator.
KdlMatches,
KdlSelect,
}

impl BuiltinScalarFunction {
pub fn find_function(name: &str) -> Option<Self> {
Self::from_str(name).ok()
}
pub fn into_expr(self, args: Vec<Expr>) -> Expr {
match self {
Self::ConnectionId => string_var("connection_id"),
Self::Version => string_var("version"),
Self::Pg(pg) => pg.into_expr(args),
Self::KdlMatches => udf_to_expr(kdl_matches(), args),
Self::KdlSelect => udf_to_expr(kdl_select(), args),
}
}
}

impl FromStr for BuiltinScalarFunction {
type Err = datafusion::common::DataFusionError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"connection_id" => Ok(Self::ConnectionId),
"version" => Ok(Self::Version),
"kdl_matches" => Ok(Self::KdlMatches),
"kdl_select" => Ok(Self::KdlSelect),
s => BuiltinPostgresFunctions::from_str(s).map(Self::Pg),
}
}
}

#[derive(Debug, Copy, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum BuiltinPostgresFunctions {
/// SQL function `pg_get_userbyid`
/// SQL function `pg_userbyid`
///
/// `pg_get_userbyid(userid int)` -> `String`
/// ```sql
/// select pg_get_userbyid(1);
/// ```
GetUserById,
/// SQL function `pg_table_is_visible`
///
///
/// `pg_table_is_visible(table_oid int)` -> `Boolean`
/// ```sql
/// select pg_table_is_visible(1);
Expand All @@ -69,7 +107,7 @@ pub enum BuiltinPostgresFunctions {
/// ```sql
HasSchemaPrivilege,
/// SQL function `has_database_privilege`
///
///
/// `has_database_privilege(user_name text, database_name text, privilege text) -> Boolean`
/// ```sql
/// select has_database_privilege('foo', 'bar', 'baz');
Expand Down Expand Up @@ -133,6 +171,12 @@ pub enum BuiltinPostgresFunctions {
CurrentCatalog,
}

impl From<BuiltinPostgresFunctions> for BuiltinScalarFunction {
fn from(f: BuiltinPostgresFunctions) -> Self {
Self::Pg(f)
}
}

impl BuiltinPostgresFunctions {
fn into_expr(self, args: Vec<Expr>) -> Expr {
match self {
Expand Down Expand Up @@ -170,25 +214,6 @@ impl BuiltinPostgresFunctions {
}
}

impl BuiltinScalarFunction {
pub fn find_function(name: &str) -> Option<Self> {
Self::from_str(name).ok()
}
pub fn into_expr(self, args: Vec<Expr>) -> Expr {
match self {
Self::ConnectionId => string_var("connection_id"),
Self::Version => string_var("version"),
Self::Pg(pg) => pg.into_expr(args),
}
}
}

impl From<BuiltinPostgresFunctions> for BuiltinScalarFunction {
fn from(f: BuiltinPostgresFunctions) -> Self {
Self::Pg(f)
}
}

impl FromStr for BuiltinPostgresFunctions {
type Err = datafusion::common::DataFusionError;

Expand Down Expand Up @@ -230,25 +255,123 @@ impl FromStr for BuiltinPostgresFunctions {
}
}

impl FromStr for BuiltinScalarFunction {
type Err = datafusion::common::DataFusionError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"connection_id" => Ok(Self::ConnectionId),
"version" => Ok(Self::Version),
s => BuiltinPostgresFunctions::from_str(s).map(Self::Pg),
}
}
}

fn udf_to_expr(udf: ScalarUDF, args: Vec<Expr>) -> Expr {
Expr::ScalarUDF(datafusion::logical_expr::expr::ScalarUDF::new(
udf.into(),
args,
))
}

fn kdl_matches() -> ScalarUDF {
ScalarUDF {
name: "kdl_matches".to_string(),
signature: Signature::new(
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
]),
Volatility::Immutable,
),
return_type: Arc::new(|_| Ok(Arc::new(DataType::Boolean))),
fun: Arc::new(move |input| {
let (doc, filter) = kdl_parse_udf_args(input)?;

Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
doc.query(filter)
.map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))
.map(|val| val.is_some())?,
))))
}),
}
}

fn kdl_select() -> ScalarUDF {
ScalarUDF {
name: "kdl_select".to_string(),
signature: Signature::new(
// args: <FIELD>, <QUERY>
TypeSignature::OneOf(vec![
TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
]),
Volatility::Immutable,
),
return_type: Arc::new(|_| Ok(Arc::new(DataType::Utf8))),
fun: Arc::new(move |input| {
let (sdoc, filter) = kdl_parse_udf_args(input)?;

let out: Vec<&KdlNode> = sdoc
.query_all(filter)
.map_err(|e| datafusion::common::DataFusionError::Execution(e.to_string()))
.map(|iter| iter.collect())?;

let mut doc = sdoc.clone();
let elems = doc.nodes_mut();
elems.clear();
for item in &out {
elems.push(item.to_owned().clone())
}

// TODO: consider if we should always return LargeUtf8?
// could end up with truncation (or an error) the document
// is too long and we write the data to a table that is
// established (and mostly) shorter values.
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
doc.to_string(),
))))
}),
}
}

fn kdl_parse_udf_args(args: &[ColumnarValue]) -> Result<(KdlDocument, KdlQuery)> {
// parse the filter first, because it's probably shorter and
// erroring earlier would be preferable to parsing a large that we
// don't need/want.
let filter: kdl::KdlQuery = match get_nth_scalar_value(args, 1) {
Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => {
val.parse().map_err(|err: kdl::KdlError| {
datafusion::common::DataFusionError::Execution(err.to_string())
})?
}
Some(val) => {
return Err(datafusion::common::DataFusionError::Execution(format!(
"invalid type for KQL expression {}",
val.data_type(),
)))
}
None => {
return Err(datafusion::common::DataFusionError::Execution(
"missing KQL query".to_string(),
))
}
};

let doc: kdl::KdlDocument = match get_nth_scalar_value(args, 0) {
Some(ScalarValue::Utf8(Some(val))) | Some(ScalarValue::LargeUtf8(Some(val))) => {
val.parse().map_err(|err: kdl::KdlError| {
datafusion::common::DataFusionError::Execution(err.to_string())
})?
}
Some(val) => {
return Err(datafusion::common::DataFusionError::Execution(format!(
"invalid type for KDL value {}",
val.data_type(),
)))
}
None => {
return Err(datafusion::common::DataFusionError::Execution(
"invalid field for KDL".to_string(),
))
}
};

Ok((doc, filter))
}

fn pg_get_userbyid() -> ScalarUDF {
ScalarUDF {
name: "pg_get_userbyid".to_string(),
Expand Down Expand Up @@ -391,6 +514,7 @@ mod tests {
("connection_id", ConnectionId),
("current_schemas", CurrentSchemas.into()),
("current_catalog", CurrentCatalog.into()),
("kdl_matches", KdlMatches),
("pg_get_userbyid", GetUserById.into()),
("pg_table_is_visible", TableIsVisible.into()),
("pg_encoding_to_char", EncodingToChar.into()),
Expand Down
3 changes: 1 addition & 2 deletions testdata/sqllogictests/functions/json_scan.slt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ select * from ndjson_scan([]);

# Glob patterns not supported on HTTP

statement error Unexpected status code '404 Not Found'
statement error Unexpected status code '404 Not Found'
select * from ndjson_scan(
'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson'
);
Expand All @@ -49,4 +49,3 @@ statement error Note that globbing is not supported for HTTP.
select * from ndjson_scan(
'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson'
);

47 changes: 47 additions & 0 deletions testdata/sqllogictests/functions/kdl.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
statement ok
create table cows (id int, docs string);

statement ok

insert into cows values (0, 'name "betsy" status="amazing" age=89; details "classic" a=1 b=100;');
insert into cows values (1, 'name "zabby" status="electric" age=120; details "zesty" a=1 b=400;');

statement error
select * from cows where kdl_matches('name == betsy;', cows.docs);

statement error
select * from cows where kdl_matches(cows.docs, '+;');

statement error
select kdl_select('foo', cows.docs) from cows where kdl_matches(cows.docs, '[b=100]');

statement error
select kdl_select(cows.docs, '?$*;') from cows where kdl_matches(cows.docs, '[b=100]');

statement error
select count(*) from cows where kdl_matches(cows.docs, 'details[a] == 100;');

query I
select count(*) from cows where kdl_matches(cows.docs, '[a]');
----
2

query I
select count(*) from cows where kdl_matches(cows.docs, '[b=100]');
----
1

query I
select count(*) from cows where kdl_matches(cows.docs, '[status]');
----
2

query
select cows.docs from cows where kdl_matches(cows.docs, '[b=100]');
----
name "betsy" status="amazing" age=89; details "classic" a=1 b=100;

query
select kdl_select(cows.docs, '[age=120]') from cows where kdl_matches(cows.docs, '[b=100]');
----
(empty)

0 comments on commit 6cfa754

Please sign in to comment.