Skip to content

Commit

Permalink
feat: add support for Substrait ExtendedExpression (#12728)
Browse files Browse the repository at this point in the history
* Add support for serializing and deserializing Substrait ExtendedExpr message

* Address clippy reviews

* Reuse existing rename method
  • Loading branch information
westonpace authored Oct 7, 2024
1 parent 9d8f77d commit 583bdc2
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 95 deletions.
203 changes: 151 additions & 52 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion::logical_expr::{
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
};
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;

use crate::extensions::Extensions;
Expand Down Expand Up @@ -96,7 +97,7 @@ use substrait::proto::{
sort_field::{SortDirection, SortKind::*},
AggregateFunction, Expression, NamedStruct, Plan, Rel, Type,
};
use substrait::proto::{FunctionArgument, SortField};
use substrait::proto::{ExtendedExpression, FunctionArgument, SortField};

// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone
Expand Down Expand Up @@ -251,6 +252,81 @@ pub async fn from_substrait_plan(
}
}

/// An ExprContainer is a container for a collection of expressions with a common input schema
///
/// In addition, each expression is associated with a field, which defines the
/// expression's output. The data type and nullability of the field are calculated from the
/// expression and the input schema. However the names of the field (and its nested fields) are
/// derived from the Substrait message.
pub struct ExprContainer {
/// The input schema for the expressions
pub input_schema: DFSchemaRef,
/// The expressions
///
/// Each item contains an expression and the field that defines the expected nullability and name of the expr's output
pub exprs: Vec<(Expr, Field)>,
}

/// Convert Substrait ExtendedExpression to ExprContainer
///
/// A Substrait ExtendedExpression message contains one or more expressions,
/// with names for the outputs, and an input schema. These pieces are all included
/// in the ExprContainer.
///
/// This is a top-level message and can be used to send expressions (not plans)
/// between systems. This is often useful for scenarios like pushdown where filter
/// expressions need to be sent to remote systems.
pub async fn from_substrait_extended_expr(
ctx: &SessionContext,
extended_expr: &ExtendedExpression,
) -> Result<ExprContainer> {
// Register function extension
let extensions = Extensions::try_from(&extended_expr.extensions)?;
if !extensions.type_variations.is_empty() {
return not_impl_err!("Type variation extensions are not supported");
}

let input_schema = DFSchemaRef::new(match &extended_expr.base_schema {
Some(base_schema) => from_substrait_named_struct(base_schema, &extensions),
None => {
plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message")
}
}?);

// Parse expressions
let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len());
for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() {
let scalar_expr = match &substrait_expr.expr_type {
Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr),
Some(ExprType::Measure(_)) => {
not_impl_err!("Measure expressions are not yet supported")
}
None => {
plan_err!("required property `expr_type` missing from Substrait ExpressionReference message")
}
}?;
let expr =
from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?;
let (output_type, expected_nullability) =
expr.data_type_and_nullable(&input_schema)?;
let output_field = Field::new("", output_type, expected_nullability);
let mut names_idx = 0;
let output_field = rename_field(
&output_field,
&substrait_expr.output_names,
expr_idx,
&mut names_idx,
/*rename_self=*/ true,
)?;
exprs.push((expr, output_field));
}

Ok(ExprContainer {
input_schema,
exprs,
})
}

/// parse projection
pub fn extract_projection(
t: LogicalPlan,
Expand Down Expand Up @@ -334,6 +410,68 @@ fn rename_expressions(
.collect()
}

fn rename_field(
field: &Field,
dfs_names: &Vec<String>,
unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}"
name_idx: &mut usize, // Index into dfs_names
rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name
) -> Result<Field> {
let name = if rename_self {
next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?
} else {
field.name().to_string()
};
match field.data_type() {
DataType::Struct(children) => {
let children = children
.iter()
.enumerate()
.map(|(child_idx, f)| {
rename_field(
f.as_ref(),
dfs_names,
child_idx,
name_idx,
/*rename_self=*/ true,
)
})
.collect::<Result<_>>()?;
Ok(field
.to_owned()
.with_name(name)
.with_data_type(DataType::Struct(children)))
}
DataType::List(inner) => {
let renamed_inner = rename_field(
inner.as_ref(),
dfs_names,
0,
name_idx,
/*rename_self=*/ false,
)?;
Ok(field
.to_owned()
.with_data_type(DataType::List(FieldRef::new(renamed_inner)))
.with_name(name))
}
DataType::LargeList(inner) => {
let renamed_inner = rename_field(
inner.as_ref(),
dfs_names,
0,
name_idx,
/*rename_self= */ false,
)?;
Ok(field
.to_owned()
.with_data_type(DataType::LargeList(FieldRef::new(renamed_inner)))
.with_name(name))
}
_ => Ok(field.to_owned().with_name(name)),
}
}

/// Produce a version of the given schema with names matching the given list of names.
/// Substrait doesn't deal with column (incl. nested struct field) names within the schema,
/// but it does give us the list of expected names at the end of the plan, so we use this
Expand All @@ -342,59 +480,20 @@ fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchema> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
name_idx: &mut usize,
) -> Result<DataType> {
match dtype {
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| {
let name = next_struct_field_name(0, dfs_names, name_idx)?;
Ok((**f).to_owned().with_name(name).with_data_type(
rename_inner_fields(f.data_type(), dfs_names, name_idx)?,
))
})
.collect::<Result<_>>()?;
Ok(DataType::Struct(fields))
}
DataType::List(inner) => Ok(DataType::List(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new(
(**inner).to_owned().with_data_type(rename_inner_fields(
inner.data_type(),
dfs_names,
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
Ok((
q.cloned(),
(**f)
.to_owned()
.with_name(name)
.with_data_type(rename_inner_fields(
f.data_type(),
dfs_names,
&mut name_idx,
)?),
))
.enumerate()
.map(|(field_idx, (q, f))| {
let renamed_f = rename_field(
f.as_ref(),
dfs_names,
field_idx,
&mut name_idx,
/*rename_self=*/ true,
)?;
Ok((q.cloned(), renamed_f))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
Expand Down Expand Up @@ -1681,14 +1780,14 @@ fn from_substrait_struct_type(
}

fn next_struct_field_name(
i: usize,
column_idx: usize,
dfs_names: &[String],
name_idx: &mut usize,
) -> Result<String> {
if dfs_names.is_empty() {
// If names are not given, create dummy names
// c0, c1, ... align with e.g. SqlToRel::create_named_struct
Ok(format!("c{i}"))
Ok(format!("c{column_idx}"))
} else {
let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| {
substrait_datafusion_err!("Named schema must contain names for all fields")
Expand Down
Loading

0 comments on commit 583bdc2

Please sign in to comment.