Skip to content

Commit

Permalink
fix: merge operation with string predicates (delta-io#1705)
Browse files Browse the repository at this point in the history
# Description
Fixes an issue when users use string predicates with the merge
operation.

Parsing a string predicate did not properly handle table references and
would always assume a bare table with a table name of the empty string.
Now the qualifier is `None` however a `DFSchema` with qualifiers can be
supplied where it makes sense.

Now users must provide source and target aliases whenever both sides
share a column name otherwise the operation will error out.

Minor refactoring of the expression parser was also done and allowed
using of case expressions.


# Related Issue(s)
- closes delta-io#1699

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
  • Loading branch information
Blajda and wjones127 authored Oct 12, 2023
1 parent 3639ac7 commit 04576f4
Show file tree
Hide file tree
Showing 4 changed files with 562 additions and 214 deletions.
121 changes: 99 additions & 22 deletions rust/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,82 @@

//! Utility functions for Datafusion's Expressions

use std::fmt::{self, Display, Formatter, Write};
use std::{
fmt::{self, Display, Formatter, Write},
sync::Arc,
};

use datafusion_common::ScalarValue;
use arrow_schema::DataType;
use datafusion::execution::context::SessionState;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::{
expr::{InList, ScalarUDF},
Between, BinaryExpr, Expr, Like,
AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource,
};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use sqlparser::ast::escape_quoted_string;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use sqlparser::tokenizer::Tokenizer;

use crate::{DeltaResult, DeltaTableError};

pub(crate) struct DeltaContextProvider<'a> {
state: &'a SessionState,
}

impl<'a> ContextProvider for DeltaContextProvider<'a> {
fn get_table_provider(&self, _name: TableReference) -> DFResult<Arc<dyn TableSource>> {
unimplemented!()
}

fn get_function_meta(&self, name: &str) -> Option<Arc<datafusion_expr::ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}

fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state.aggregate_functions().get(name).cloned()
}

use crate::DeltaTableError;
fn get_variable_type(&self, _var: &[String]) -> Option<DataType> {
unimplemented!()
}

fn options(&self) -> &ConfigOptions {
self.state.config_options()
}

fn get_window_meta(&self, name: &str) -> Option<Arc<datafusion_expr::WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
}

/// Parse a string predicate into an `Expr`
pub(crate) fn parse_predicate_expression(
schema: &DFSchema,
expr: impl AsRef<str>,
df_state: &SessionState,
) -> DeltaResult<Expr> {
let dialect = &GenericDialect {};
let mut tokenizer = Tokenizer::new(dialect, expr.as_ref());
let tokens = tokenizer
.tokenize()
.map_err(|err| DeltaTableError::GenericError {
source: Box::new(err),
})?;
let sql = Parser::new(dialect)
.with_tokens(tokens)
.parse_expr()
.map_err(|err| DeltaTableError::GenericError {
source: Box::new(err),
})?;

let context_provider = DeltaContextProvider { state: df_state };
let sql_to_rel = SqlToRel::new(&context_provider);

Ok(sql_to_rel.sql_to_expr(sql, schema, &mut Default::default())?)
}

struct SqlFormat<'a> {
expr: &'a Expr,
Expand Down Expand Up @@ -115,6 +181,9 @@ impl<'a> Display for SqlFormat<'a> {
Expr::BinaryExpr(expr) => write!(f, "{}", BinaryExprFormat { expr }),
Expr::ScalarFunction(func) => fmt_function(f, &func.fun.to_string(), false, &func.args),
Expr::ScalarUDF(ScalarUDF { fun, args }) => fmt_function(f, &fun.name, false, args),
Expr::Cast(Cast { expr, data_type }) => {
write!(f, "arrow_cast({}, '{}')", SqlFormat { expr }, data_type)
}
Expr::Between(Between {
expr,
negated,
Expand Down Expand Up @@ -271,9 +340,10 @@ impl<'a> fmt::Display for ScalarValueFormat<'a> {
mod test {
use std::collections::HashMap;

use arrow_schema::DataType;
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::{col, decode, lit, substring, Expr, ExprSchemable};
use datafusion_expr::{col, decode, lit, substring, Cast, Expr, ExprSchemable};

use crate::{DeltaOps, DeltaTable, Schema, SchemaDataType, SchemaField};

Expand Down Expand Up @@ -368,6 +438,13 @@ mod test {

// String expression that we output must be parsable for conflict resolution.
let tests = vec![
simple!(
Expr::Cast(Cast {
expr: Box::new(lit(1_i64)),
data_type: DataType::Int32
}),
"arrow_cast(1, 'Int32')".to_string()
),
simple!(col("value").eq(lit(3_i64)), "value = 3".to_string()),
simple!(col("active").is_true(), "active IS TRUE".to_string()),
simple!(col("active"), "active".to_string()),
Expand Down Expand Up @@ -443,6 +520,23 @@ mod test {
substring(col("modified"), lit(0_i64), lit(4_i64)).eq(lit("2021")),
"substr(modified, 0, 4) = '2021'".to_string()
),
simple!(
col("value")
.cast_to::<DFSchema>(
&arrow_schema::DataType::Utf8,
&table
.state
.input_schema()
.unwrap()
.as_ref()
.to_owned()
.try_into()
.unwrap()
)
.unwrap()
.eq(lit("1")),
"arrow_cast(value, 'Utf8') = '1'".to_string()
),
];

let session = SessionContext::new();
Expand Down Expand Up @@ -479,23 +573,6 @@ mod test {
))),
"".to_string()
),
simple!(
col("value")
.cast_to::<DFSchema>(
&arrow_schema::DataType::Utf8,
&table
.state
.input_schema()
.unwrap()
.as_ref()
.to_owned()
.try_into()
.unwrap()
)
.unwrap()
.eq(lit("1")),
"CAST(value as STRING) = '1'".to_string()
),
];

for test in unsupported_types {
Expand Down
Loading

0 comments on commit 04576f4

Please sign in to comment.