Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for Substrait ExtendedExpression #12728

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 151 additions & 52 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion::logical_expr::{
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
};
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;

use crate::extensions::Extensions;
Expand Down Expand Up @@ -96,7 +97,7 @@ use substrait::proto::{
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};
use substrait::proto::{ExtendedExpression, FunctionArgument, SortField};

// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone
Expand Down Expand Up @@ -251,6 +252,81 @@ pub async fn from_substrait_plan(
}
}

/// An ExprContainer is a container for a collection of expressions with a common input schema
///
/// In addition, each expression is associated with a field, which defines the
/// expression's output. The data type and nullability of the field are calculated from the
/// expression and the input schema. However the names of the field (and its nested fields) are
/// derived from the Substrait message.
pub struct ExprContainer {
/// The input schema for the expressions
pub input_schema: DFSchemaRef,
/// The expressions
///
/// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
pub exprs: Vec<(Expr, Field)>,
}

/// Convert Substrait ExtendedExpression to ExprContainer
///
/// A Substrait ExtendedExpression message contains one or more expressions,
/// with names for the outputs, and an input schema. These pieces are all included
/// in the ExprContainer.
///
/// This is a top-level message and can be used to send expressions (not plans)
/// between systems. This is often useful for scenarios like pushdown where filter
/// expressions need to be sent to remote systems.
pub async fn from_substrait_extended_expr(
ctx: &SessionContext,
extended_expr: &ExtendedExpression,
) -> Result<ExprContainer> {
// Register function extension
let extensions = Extensions::try_from(&extended_expr.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}

let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
Some(base_schema) => from_substrait_named_struct(base_schema, &extensions),
None => {
plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message")
}
}?);

// Parse expressions
let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
let scalar_expr = match &substrait_expr.expr_type {
Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
Some(ExprType::Measure(_)) => {
not_impl_err!("Measure expressions are not yet supported")
}
None => {
plan_err!("required property `expr_type` missing from Substrait ExpressionReference message")
}
}?;
let expr =
from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?;
let (output_type, expected_nullability) =
expr.data_type_and_nullable(&input_schema)?;
let output_field = Field::new("", output_type, expected_nullability);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have just picked the first name here, then started names_idx from 1 instead 😅 I'm overall not a super fan of the "rename_self" option but I see why you did it, I guess it's alright.

let mut names_idx = 0;
let output_field = rename_field(
&output_field,
&substrait_expr.output_names,
expr_idx,
&mut names_idx,
/*rename_self=*/ true,
)?;
exprs.push((expr, output_field));
}

Ok(ExprContainer {
input_schema,
exprs,
})
}

/// parse projection
pub fn extract_projection(
t: LogicalPlan,
Expand Down Expand Up @@ -334,6 +410,68 @@ fn rename_expressions(
.collect()
}

fn rename_field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, and thanks for adding the comments!

field: &Field,
dfs_names: &Vec<String>,
unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}"
name_idx: &mut usize, // Index into dfs_names
rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name
) -> Result<Field> {
let name = if rename_self {
next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?
} else {
field.name().to_string()
};
match field.data_type() {
DataType::Struct(children) => {
let children = children
.iter()
.enumerate()
.map(|(child_idx, f)| {
rename_field(
f.as_ref(),
dfs_names,
child_idx,
name_idx,
/*rename_self=*/ true,
)
})
.collect::<Result<_>>()?;
Ok(field
.to_owned()
.with_name(name)
.with_data_type(DataType::Struct(children)))
}
DataType::List(inner) => {
let renamed_inner = rename_field(
inner.as_ref(),
dfs_names,
0,
name_idx,
/*rename_self=*/ false,
)?;
Ok(field
.to_owned()
.with_data_type(DataType::List(FieldRef::new(renamed_inner)))
.with_name(name))
}
DataType::LargeList(inner) => {
let renamed_inner = rename_field(
inner.as_ref(),
dfs_names,
0,
name_idx,
/*rename_self= */ false,
)?;
Ok(field
.to_owned()
.with_data_type(DataType::LargeList(FieldRef::new(renamed_inner)))
.with_name(name))
}
_ => Ok(field.to_owned().with_name(name)),
}
}

/// Produce a version of the given schema with names matching the given list of names.
/// Substrait doesn't deal with column (incl. nested struct field) names within the schema,
/// but it does give us the list of expected names at the end of the plan, so we use this
Expand All @@ -342,59 +480,20 @@ fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchema> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<DataType> {
match dtype {
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| {
let name = next_struct_field_name(0, dfs_names, name_idx)?;
Ok((**f).to_owned().with_name(name).with_data_type(
rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
))
})
.collect::<Result<_>>()?;
Ok(DataType::Struct(fields))
}
DataType::List(inner) => Ok(DataType::List(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
Ok((
q.cloned(),
(**f)
.to_owned()
.with_name(name)
.with_data_type(rename_inner_fields(
f.data_type(),
dfs_names,
&mut name_idx,
)?),
))
.enumerate()
.map(|(field_idx, (q, f))| {
let renamed_f = rename_field(
f.as_ref(),
dfs_names,
field_idx,
&mut name_idx,
/*rename_self=*/ true,
)?;
Ok((q.cloned(), renamed_f))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
Expand Down Expand Up @@ -1681,14 +1780,14 @@ fn from_substrait_struct_type(
}

fn next_struct_field_name(
i: usize,
column_idx: usize,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<String> {
if dfs_names.is_empty() {
// If names are not given, create dummy names
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
Ok(format!("c{i}"))
Ok(format!("c{column_idx}"))
} else {
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
substrait_datafusion_err!("Named schema must contain names for all fields")
Expand Down
Loading