diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4267f182bda8..97e4fcc327c3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,7 @@ use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, DFSchema}; +use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; @@ -187,13 +187,20 @@ pub enum Expr { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Alias { pub expr: Box, + pub relation: Option, pub name: String, } impl Alias { - pub fn new(expr: Expr, name: impl Into) -> Self { + /// Create an alias with an optional schema/field qualifier. + pub fn new( + expr: Expr, + relation: Option>, + name: impl Into, + ) -> Self { Self { expr: Box::new(expr), + relation: relation.map(|r| r.into()), name: name.into(), } } @@ -844,7 +851,27 @@ impl Expr { asc, nulls_first, }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, name.into())), + _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), + } + } + + /// Return `self AS name` alias expression with a specific qualifier + pub fn alias_qualified( + self, + relation: Option>, + name: impl Into, + ) -> Expr { + match self { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => Expr::Sort(Sort::new( + Box::new(expr.alias_qualified(relation, name)), + asc, + nulls_first, + )), + _ => Expr::Alias(Alias::new(self, relation, name.into())), } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2631708fb780..5881feece1fc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -305,6 +305,13 @@ impl ExprSchemable for Expr { self.nullable(input_schema)?, ) .with_metadata(self.metadata(input_schema)?)), + Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( + relation.clone(), + name, + self.get_type(input_schema)?, + self.nullable(input_schema)?, + ) + .with_metadata(self.metadata(input_schema)?)), _ => Ok(DFField::new_unqualified( &self.display_name()?, self.get_type(input_schema)?, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4a30f4e223bf..c4ff9fe95435 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -32,8 +32,8 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -551,16 +551,29 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = self.plan; let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(union(left_plan, right_plan)?), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + union(left_plan, right_plan)?, + ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct { - input: Arc::new(self.plan), - }))) + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + self.plan, + ))))) + } + + /// Project first values of the specified expression list according to the provided + /// sorting expressions grouped by the `DISTINCT ON` clause expressions. + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + Ok(Self::from(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + )))) } /// Apply a join to `right` using explicitly specified columns and an diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 8316417138bd..51d78cd721b6 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -33,10 +33,10 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, EmptyRelation, Explain, - Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, - PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, + Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, + Partitioning, PlanType, Prepare, Projection, Repartition, Sort, StringifiedPlan, + Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d62ac8926328..b7537dc02e9d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,8 +25,8 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::dml::CopyOptions; -use crate::expr::{Alias, Exists, InSubquery, Placeholder}; -use crate::expr_rewriter::create_col_from_scalar_expr; +use crate::expr::{Alias, Exists, InSubquery, Placeholder, Sort as SortExpr}; +use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; @@ -163,7 +163,8 @@ impl LogicalPlan { }) => projected_schema, LogicalPlan::Projection(Projection { schema, .. }) => schema, LogicalPlan::Filter(Filter { input, .. }) => input.schema(), - LogicalPlan::Distinct(Distinct { input }) => input.schema(), + LogicalPlan::Distinct(Distinct::All(input)) => input.schema(), + LogicalPlan::Distinct(Distinct::On(DistinctOn { schema, .. })) => schema, LogicalPlan::Window(Window { schema, .. }) => schema, LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), @@ -367,6 +368,16 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.clone().unwrap_or(vec![]).iter()) + .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Subquery(_) @@ -377,7 +388,7 @@ impl LogicalPlan { | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) - | LogicalPlan::Distinct(_) + | LogicalPlan::Distinct(Distinct::All(_)) | LogicalPlan::Dml(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) @@ -405,7 +416,9 @@ impl LogicalPlan { LogicalPlan::Union(Union { inputs, .. }) => { inputs.iter().map(|arc| arc.as_ref()).collect() } - LogicalPlan::Distinct(Distinct { input }) => vec![input], + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => vec![input], LogicalPlan::Explain(explain) => vec![&explain.plan], LogicalPlan::Analyze(analyze) => vec![&analyze.input], LogicalPlan::Dml(write) => vec![&write.input], @@ -461,8 +474,11 @@ impl LogicalPlan { Ok(Some(agg.group_expr.as_slice()[0].clone())) } } + LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, .. })) => { + Ok(Some(select_expr[0].clone())) + } LogicalPlan::Filter(Filter { input, .. }) - | LogicalPlan::Distinct(Distinct { input, .. }) + | LogicalPlan::Distinct(Distinct::All(input)) | LogicalPlan::Sort(Sort { input, .. }) | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) @@ -823,10 +839,29 @@ impl LogicalPlan { inputs: inputs.iter().cloned().map(Arc::new).collect(), schema: schema.clone(), })), - LogicalPlan::Distinct(Distinct { .. }) => { - Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(inputs[0].clone()), - })) + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(_) => Distinct::All(Arc::new(inputs[0].clone())), + Distinct::On(DistinctOn { + on_expr, + select_expr, + .. + }) => { + let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); + let select_expr = expr.split_off(on_expr.len()); + Distinct::On(DistinctOn::try_new( + expr, + select_expr, + if !sort_expr.is_empty() { + Some(sort_expr) + } else { + None + }, + Arc::new(inputs[0].clone()), + )?) + } + }; + Ok(LogicalPlan::Distinct(distinct)) } LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); @@ -1064,7 +1099,9 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, - LogicalPlan::Distinct(Distinct { input }) => input.max_rows(), + LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, LogicalPlan::Ddl(_) @@ -1667,9 +1704,21 @@ impl LogicalPlan { LogicalPlan::Statement(statement) => { write!(f, "{}", statement.display()) } - LogicalPlan::Distinct(Distinct { .. }) => { - write!(f, "Distinct:") - } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => write!(f, "Distinct:"), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => write!( + f, + "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", + expr_vec_fmt!(on_expr), + expr_vec_fmt!(select_expr), + if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + ), + }, LogicalPlan::Explain { .. } => write!(f, "Explain"), LogicalPlan::Analyze { .. } => write!(f, "Analyze"), LogicalPlan::Union(_) => write!(f, "Union"), @@ -2132,9 +2181,93 @@ pub struct Limit { /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] -pub struct Distinct { +pub enum Distinct { + /// Plain `DISTINCT` referencing all selection expressions + All(Arc), + /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns + On(DistinctOn), +} + +/// Removes duplicate rows from the input +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct DistinctOn { + /// The `DISTINCT ON` clause expression list + pub on_expr: Vec, + /// The selected projection expression list + pub select_expr: Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, + /// The schema description of the DISTINCT ON output + pub schema: DFSchemaRef, +} + +impl DistinctOn { + /// Create a new `DistinctOn` struct. + pub fn try_new( + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + input: Arc, + ) -> Result { + if on_expr.is_empty() { + return plan_err!("No `ON` expressions provided"); + } + + let on_expr = normalize_cols(on_expr, input.as_ref())?; + + let schema = DFSchema::new_with_metadata( + exprlist_to_fields(&select_expr, &input)?, + input.schema().metadata().clone(), + )?; + + let mut distinct_on = DistinctOn { + on_expr, + select_expr, + sort_expr: None, + input, + schema: Arc::new(schema), + }; + + if let Some(sort_expr) = sort_expr { + distinct_on = distinct_on.with_sort_expr(sort_expr)?; + } + + Ok(distinct_on) + } + + /// Try to update `self` with a new sort expressions. + /// + /// Validates that the sort expressions are a super-set of the `ON` expressions. + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + + // Check that the left-most sort expressions are the same as the `ON` expressions. + let mut matched = true; + for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { + match sort { + Expr::Sort(SortExpr { expr, .. }) => { + if on != &**expr { + matched = false; + break; + } + } + _ => return plan_err!("Not a sort expression: {sort}"), + } + } + + if self.on_expr.len() > sort_expr.len() || !matched { + return plan_err!( + "SELECT DISTINCT ON expressions must match initial ORDER BY expressions" + ); + } + + self.sort_expr = Some(sort_expr); + Ok(self) + } } /// Aggregates its input based on a set of grouping and aggregate diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index d6c14b86227a..6b86de37ba44 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -157,9 +157,11 @@ impl TreeNode for Expr { let mut transform = transform; Ok(match self { - Expr::Alias(Alias { expr, name, .. }) => { - Expr::Alias(Alias::new(transform(*expr)?, name)) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, Expr::Exists { .. } => self, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a462cdb34631..8f13bf5f61be 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -800,9 +800,11 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { match e { Expr::Column(_) => e, Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(Alias { expr, name, .. }) => { - columnize_expr(*expr, input_schema).alias(name) - } + Expr::Alias(Alias { + expr, + relation, + name, + }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { expr: Box::new(columnize_expr(*expr, input_schema)), data_type, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8025402ccef5..f5ad767c5016 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -238,6 +238,14 @@ impl CommonSubexprEliminate { let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { + // Alias aggregation expressions if they have changed + let new_aggr_expr = new_aggr_expr + .iter() + .zip(aggr_expr.iter()) + .map(|(new_expr, old_expr)| { + new_expr.clone().alias_if_changed(old_expr.display_name()?) + }) + .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) @@ -367,7 +375,7 @@ impl OptimizerRule for CommonSubexprEliminate { Ok(Some(build_recover_project_plan( &original_schema, optimized_plan, - ))) + )?)) } plan => Ok(plan), } @@ -458,16 +466,19 @@ fn build_common_expr_project_plan( /// the "intermediate" projection plan built in [build_common_expr_project_plan]. /// /// This is for those plans who don't keep its own output schema like `Filter` or `Sort`. -fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> LogicalPlan { +fn build_recover_project_plan( + schema: &DFSchema, + input: LogicalPlan, +) -> Result { let col_exprs = schema .fields() .iter() .map(|field| Expr::Column(field.qualified_column())) .collect(); - LogicalPlan::Projection( - Projection::try_new(col_exprs, Arc::new(input)) - .expect("Cannot build projection plan from an invalid schema"), - ) + Ok(LogicalPlan::Projection(Projection::try_new( + col_exprs, + Arc::new(input), + )?)) } fn extract_expressions( diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 89bcc90bc075..5771ea2e19a2 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -52,7 +52,7 @@ impl OptimizerRule for EliminateNestedUnion { schema: schema.clone(), }))) } - LogicalPlan::Distinct(Distinct { input: plan }) => match plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs .iter() @@ -60,12 +60,12 @@ impl OptimizerRule for EliminateNestedUnion { .flat_map(extract_plans_from_union) .collect::>(); - Ok(Some(LogicalPlan::Distinct(Distinct { - input: Arc::new(LogicalPlan::Union(Union { + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + LogicalPlan::Union(Union { inputs, schema: schema.clone(), - })), - }))) + }), + ))))) } _ => Ok(None), }, @@ -94,7 +94,7 @@ fn extract_plans_from_union(plan: &Arc) -> Vec> { fn extract_plan_from_distinct(plan: &Arc) -> &Arc { match plan.as_ref() { - LogicalPlan::Distinct(Distinct { input: plan }) => plan, + LogicalPlan::Distinct(Distinct::All(plan)) => plan, _ => plan, } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5231dc869875..e93565fef0a0 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -427,7 +427,7 @@ impl Optimizer { /// Returns an error if plans have different schemas. /// /// It ignores metadata and nullability. -fn assert_schema_is_the_same( +pub(crate) fn assert_schema_is_the_same( rule_name: &str, prev_plan: &LogicalPlan, new_plan: &LogicalPlan, @@ -438,7 +438,7 @@ fn assert_schema_is_the_same( if !equivalent { let e = DataFusionError::Internal(format!( - "Failed due to generate a different schema, original schema: {:?}, new schema: {:?}", + "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", prev_plan.schema(), new_plan.schema() )); @@ -503,7 +503,7 @@ mod tests { let err = opt.optimize(&plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to generate a different schema, \ + Internal error: Failed due to a difference in schemas, \ original schema: DFSchema { fields: [], metadata: {}, functional_dependencies: FunctionalDependencies { deps: [] } }, \ new schema: DFSchema { fields: [\ DFField { qualifier: Some(Bare { table: \"test\" }), field: Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, \ diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index b05d811cb481..2c314bf7651c 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -228,7 +228,7 @@ impl OptimizerRule for PushDownProjection { // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = vec![]; for e in agg.aggr_expr.iter() { - let column = Column::from_name(e.display_name()?); + let column = Column::from(e.display_name()?); if required_columns.contains(&column) { new_aggr_expr.push(e.clone()); } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 540617b77084..187e510e557d 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -20,7 +20,11 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; +use datafusion_expr::{ + aggregate_function::AggregateFunction as AggregateFunctionFunc, col, + expr::AggregateFunction, LogicalPlanBuilder, +}; +use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -32,6 +36,22 @@ use datafusion_expr::{Aggregate, Distinct, LogicalPlan}; /// ```text /// SELECT a, b FROM tab GROUP BY a, b /// ``` +/// +/// On the other hand, for a `DISTINCT ON` query the replacement is +/// a bit more involved and effectively converts +/// ```text +/// SELECT DISTINCT ON (a) b FROM tab ORDER BY a DESC, c +/// ``` +/// +/// into +/// ```text +/// SELECT b FROM ( +/// SELECT a, FIRST_VALUE(b ORDER BY a DESC, c) AS b +/// FROM tab +/// GROUP BY a +/// ) +/// ORDER BY a DESC +/// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] #[derive(Default)] @@ -51,7 +71,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let group_expr = expand_wildcard(input.schema(), input, None)?; let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), @@ -60,6 +80,65 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { )?); Ok(Some(aggregate)) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + select_expr, + on_expr, + sort_expr, + input, + schema, + })) => { + // Construct the aggregation expression to be used to fetch the selected expressions. + let aggr_expr = select_expr + .iter() + .map(|e| { + Expr::AggregateFunction(AggregateFunction::new( + AggregateFunctionFunc::FirstValue, + vec![e.clone()], + false, + None, + sort_expr.clone(), + )) + }) + .collect::>(); + + // Build the aggregation plan + let plan = LogicalPlanBuilder::from(input.as_ref().clone()) + .aggregate(on_expr.clone(), aggr_expr.to_vec())? + .build()?; + + let plan = if let Some(sort_expr) = sort_expr { + // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, + // this on it's own isn't enough to guarantee the proper output order of the grouping + // (`ON`) expression, so we need to sort those as well. + LogicalPlanBuilder::from(plan) + .sort(sort_expr[..on_expr.len()].to_vec())? + .build()? + } else { + plan + }; + + // Whereas the aggregation plan by default outputs both the grouping and the aggregation + // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan + .schema() + .fields() + .iter() + .skip(on_expr.len()) + .zip(schema.fields().iter()) + .map(|(new_field, old_field)| { + Ok(col(new_field.qualified_column()).alias_qualified( + old_field.qualifier().cloned(), + old_field.name(), + )) + }) + .collect::>>()?; + + let plan = LogicalPlanBuilder::from(plan) + .project(project_exprs)? + .build()?; + + Ok(Some(plan)) + } _ => Ok(None), } } @@ -98,4 +177,27 @@ mod tests { expected, ) } + + #[test] + fn replace_distinct_on() -> datafusion_common::Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on( + vec![col("a")], + vec![col("b")], + Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), + )? + .build()?; + + let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ + \n Sort: test.a DESC NULLS FIRST\ + \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ + \n TableScan: test"; + + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + &plan, + expected, + ) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 3eac2317b849..917ddc565c9e 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::Optimizer; +use crate::optimizer::{assert_schema_is_the_same, Optimizer}; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -155,7 +155,7 @@ pub fn assert_optimized_plan_eq( plan: &LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); + let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( optimizer.rules.get(0).unwrap(), @@ -163,6 +163,9 @@ pub fn assert_optimized_plan_eq( &OptimizerContext::new(), )? .unwrap_or_else(|| plan.clone()); + + // Ensure schemas always match after an optimization + assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9dcd55e731bb..62b226e33339 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -73,6 +73,7 @@ message LogicalPlanNode { CustomTableScanNode custom_scan = 25; PrepareNode prepare = 26; DropViewNode drop_view = 27; + DistinctOnNode distinct_on = 28; } } @@ -308,6 +309,13 @@ message DistinctNode { LogicalPlanNode input = 1; } +message DistinctOnNode { + repeated LogicalExprNode on_expr = 1; + repeated LogicalExprNode select_expr = 2; + repeated LogicalExprNode sort_expr = 3; + LogicalPlanNode input = 4; +} + message UnionNode { repeated LogicalPlanNode inputs = 1; } @@ -485,6 +493,7 @@ message Not { message AliasNode { LogicalExprNode expr = 1; string alias = 2; + repeated OwnedTableReference relation = 3; } message BinaryExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 948ad0c4cedb..7602e1a36657 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -967,6 +967,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { len += 1; } + if !self.relation.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -974,6 +977,9 @@ impl serde::Serialize for AliasNode { if !self.alias.is_empty() { struct_ser.serialize_field("alias", &self.alias)?; } + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } struct_ser.end() } } @@ -986,12 +992,14 @@ impl<'de> serde::Deserialize<'de> for AliasNode { const FIELDS: &[&str] = &[ "expr", "alias", + "relation", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, Alias, + Relation, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1015,6 +1023,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { match value { "expr" => Ok(GeneratedField::Expr), "alias" => Ok(GeneratedField::Alias), + "relation" => Ok(GeneratedField::Relation), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1036,6 +1045,7 @@ impl<'de> serde::Deserialize<'de> for AliasNode { { let mut expr__ = None; let mut alias__ = None; + let mut relation__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1050,11 +1060,18 @@ impl<'de> serde::Deserialize<'de> for AliasNode { } alias__ = Some(map_.next_value()?); } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); + } } } Ok(AliasNode { expr: expr__, alias: alias__.unwrap_or_default(), + relation: relation__.unwrap_or_default(), }) } } @@ -6070,6 +6087,151 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for DistinctOnNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.on_expr.is_empty() { + len += 1; + } + if !self.select_expr.is_empty() { + len += 1; + } + if !self.sort_expr.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; + } + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; + } + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DistinctOnNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OnExpr, + SelectExpr, + SortExpr, + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DistinctOnNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.DistinctOnNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); + } + on_expr__ = Some(map_.next_value()?); + } + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); + } + select_expr__ = Some(map_.next_value()?); + } + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map_.next_value()?); + } + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -13146,6 +13308,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::DropView(v) => { struct_ser.serialize_field("dropView", v)?; } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } } } struct_ser.end() @@ -13195,6 +13360,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "prepare", "drop_view", "dropView", + "distinct_on", + "distinctOn", ]; #[allow(clippy::enum_variant_names)] @@ -13225,6 +13392,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { CustomScan, Prepare, DropView, + DistinctOn, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -13272,6 +13440,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), "prepare" => Ok(GeneratedField::Prepare), "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -13474,6 +13643,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dropView")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) +; + } + GeneratedField::DistinctOn => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctOn")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 93b0a05c314d..825481a18822 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -38,7 +38,7 @@ pub struct DfSchema { pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28" )] pub logical_plan_type: ::core::option::Option, } @@ -99,6 +99,8 @@ pub mod logical_plan_node { Prepare(::prost::alloc::boxed::Box), #[prost(message, tag = "27")] DropView(super::DropViewNode), + #[prost(message, tag = "28")] + DistinctOn(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -483,6 +485,18 @@ pub struct DistinctNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct DistinctOnNode { + #[prost(message, repeated, tag = "1")] + pub on_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub select_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub sort_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, boxed, tag = "4")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, @@ -754,6 +768,8 @@ pub struct AliasNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "3")] + pub relation: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b2b66693f78d..674492edef43 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1151,6 +1151,11 @@ pub fn parse_expr( } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr")?, + alias + .relation + .first() + .map(|r| OwnedTableReference::try_from(r.clone())) + .transpose()?, alias.alias.clone(), ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e426c598523e..851f062bd51f 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -55,7 +55,7 @@ use datafusion_expr::{ EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, - DropView, Expr, LogicalPlan, LogicalPlanBuilder, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, }; use prost::bytes::BufMut; @@ -734,6 +734,33 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(distinct.input, ctx, extension_codec)?; LogicalPlanBuilder::from(input).distinct()?.build() } + LogicalPlanType::DistinctOn(distinct_on) => { + let input: LogicalPlan = + into_logical_plan!(distinct_on.input, ctx, extension_codec)?; + let on_expr = distinct_on + .on_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let select_expr = distinct_on + .select_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?; + let sort_expr = match distinct_on.sort_expr.len() { + 0 => None, + _ => Some( + distinct_on + .sort_expr + .iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, _>>()?, + ), + }; + LogicalPlanBuilder::from(input) + .distinct_on(on_expr, select_expr, sort_expr)? + .build() + } LogicalPlanType::ViewScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; @@ -1005,7 +1032,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Distinct(Distinct { input }) => { + LogicalPlan::Distinct(Distinct::All(input)) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -1019,6 +1046,42 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + .. + })) => { + let input: protobuf::LogicalPlanNode = + protobuf::LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr = match sort_expr { + None => vec![], + Some(sort_expr) => sort_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + }; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( + protobuf::DistinctOnNode { + on_expr: on_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + select_expr: select_expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + sort_expr, + input: Some(Box::new(input)), + }, + ))), + }) + } LogicalPlan::Window(Window { input, window_expr, .. }) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e590731f5810..946f2c6964a5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -476,9 +476,17 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Column(c) => Self { expr_type: Some(ExprType::Column(c.into())), }, - Expr::Alias(Alias { expr, name, .. }) => { + Expr::Alias(Alias { + expr, + relation, + name, + }) => { let alias = Box::new(protobuf::AliasNode { expr: Some(Box::new(expr.as_ref().try_into()?)), + relation: relation + .to_owned() + .map(|r| vec![r.into()]) + .unwrap_or(vec![]), alias: name.to_owned(), }); Self { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 97c553dc04e6..cc76e8a19e98 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -300,6 +300,32 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_distinct_on() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT DISTINCT ON (a % 2) a, b * 2 FROM t1 ORDER BY a % 2 DESC, b"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + #[tokio::test] async fn roundtrip_single_count_distinct() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f32c81527925..5b890accd81f 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -128,6 +128,12 @@ fn exact_roundtrip_linearized_binary_expr() { } } +#[test] +fn roundtrip_qualified_alias() { + let qual_alias = col("c1").alias_qualified(Some("my_table"), "my_column"); + assert_eq!(qual_alias, roundtrip_expr(&qual_alias)); +} + #[test] fn roundtrip_deeply_nested_binary_expr() { // We need more stack space so this doesn't overflow in dev builds diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index fc2a3fb9a57b..832e2da9c6ec 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -23,7 +23,7 @@ use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Expr, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, @@ -161,6 +161,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context)?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + + if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { + // In case of `DISTINCT ON` we must capture the sort expressions since during the plan + // optimization we're effectively doing a `first_value` aggregation according to them. + let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) + } else { + LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + } } } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index e9a7941ab064..31333affe0af 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -76,7 +76,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // process `where` clause - let plan = self.plan_selection(select.selection, plan, planner_context)?; + let base_plan = self.plan_selection(select.selection, plan, planner_context)?; // handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; @@ -84,16 +84,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( - &plan, + &base_plan, select.projection, empty_from, planner_context, )?; // having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(plan.clone(), select_exprs.clone())?; + let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; let mut combined_schema = (**projected_plan.schema()).clone(); - combined_schema.merge(plan.schema()); + combined_schema.merge(base_plan.schema()); // this alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); @@ -148,7 +148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; // aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); - for f in plan.schema().fields() { + for f in base_plan.schema().fields() { alias_map.remove(f.name()); } let group_by_expr = @@ -158,7 +158,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(group_by_expr); let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( - plan.schema(), + base_plan.schema(), &[group_by_expr.clone()], )?; Ok(group_by_expr) @@ -171,7 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .iter() .filter(|select_expr| match select_expr { Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _ }) => !matches!( + Expr::Alias(Alias { expr, name: _, .. }) => !matches!( **expr, Expr::AggregateFunction(_) | Expr::AggregateUDF(_) ), @@ -187,16 +187,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { || !aggr_exprs.is_empty() { self.aggregate( - plan, + &base_plan, &select_exprs, having_expr_opt.as_ref(), - group_by_exprs, - aggr_exprs, + &group_by_exprs, + &aggr_exprs, )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (plan, select_exprs, having_expr_opt) + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) } }; @@ -229,19 +229,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = project(plan, select_exprs_post_aggr)?; // process distinct clause - let distinct = select - .distinct - .map(|distinct| match distinct { - Distinct::Distinct => Ok(true), - Distinct::On(_) => not_impl_err!("DISTINCT ON Exprs not supported"), - }) - .transpose()? - .unwrap_or(false); + let plan = match select.distinct { + None => Ok(plan), + Some(Distinct::Distinct) => { + LogicalPlanBuilder::from(plan).distinct()?.build() + } + Some(Distinct::On(on_expr)) => { + if !aggr_exprs.is_empty() + || !group_by_exprs.is_empty() + || !window_func_exprs.is_empty() + { + return not_impl_err!("DISTINCT ON expressions with GROUP BY, aggregation or window functions are not supported "); + } - let plan = if distinct { - LogicalPlanBuilder::from(plan).distinct()?.build() - } else { - Ok(plan) + let on_expr = on_expr + .into_iter() + .map(|e| { + self.sql_expr_to_logical_expr( + e.clone(), + plan.schema(), + planner_context, + ) + }) + .collect::>>()?; + + // Build the final plan + return LogicalPlanBuilder::from(base_plan) + .distinct_on(on_expr, select_exprs, None)? + .build(); + } }?; // DISTRIBUTE BY @@ -471,6 +487,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .clone(); *expr = Expr::Alias(Alias { expr: Box::new(new_expr), + relation: None, name: name.clone(), }); } @@ -511,18 +528,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// the aggregate fn aggregate( &self, - input: LogicalPlan, + input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, - group_by_exprs: Vec, - aggr_exprs: Vec, + group_by_exprs: &[Expr], + aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { let group_by_exprs = - get_updated_group_by_exprs(&group_by_exprs, select_exprs, input.schema())?; + get_updated_group_by_exprs(group_by_exprs, select_exprs, input.schema())?; // create the aggregate plan let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .aggregate(group_by_exprs.clone(), aggr_exprs.to_vec())? .build()?; // in this next section of code we are re-writing the projection to refer to columns @@ -549,25 +566,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { _ => aggr_projection_exprs.push(expr.clone()), } } - aggr_projection_exprs.extend_from_slice(&aggr_exprs); + aggr_projection_exprs.extend_from_slice(aggr_exprs); // now attempt to resolve columns and replace with fully-qualified columns let aggr_projection_exprs = aggr_projection_exprs .iter() - .map(|expr| resolve_columns(expr, &input)) + .map(|expr| resolve_columns(expr, input)) .collect::>>()?; // next we replace any expressions that are not a column with a column referencing // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() - .map(|expr| expr_as_column_expr(expr, &input)) + .map(|expr| expr_as_column_expr(expr, input)) .collect::>>()?; // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() - .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) + .map(|expr| rebase_expr(expr, &aggr_projection_exprs, input)) .collect::>>()?; // finally, we have some validation that the re-written projection can be resolved @@ -582,7 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // aggregation. let having_expr_post_aggr = if let Some(having_expr) = having_expr_opt { let having_expr_post_aggr = - rebase_expr(having_expr, &aggr_projection_exprs, &input)?; + rebase_expr(having_expr, &aggr_projection_exprs, input)?; check_columns_satisfy_exprs( &column_exprs_post_aggr, diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt new file mode 100644 index 000000000000..8a36b49b98c6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +LOCATION '../../testing/data/csv/aggregate_test_100.csv' + +# Basic example: distinct on the first column project the second one, and +# order by the third +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +a 5 +b 4 +c 2 +d 1 +e 3 + +# Basic example + reverse order of the selected column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1, c3 DESC; +---- +a 1 +b 5 +c 4 +d 1 +e 1 + +# Basic example + reverse order of the ON column +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3; +---- +e 3 +d 1 +c 2 +b 4 +a 4 + +# Basic example + reverse order of both columns + limit +query TI +SELECT DISTINCT ON (c1) c1, c2 FROM aggregate_test_100 ORDER BY c1 DESC, c3 DESC LIMIT 3; +---- +e 1 +d 1 +c 4 + +# Basic example + omit ON column from selection +query I +SELECT DISTINCT ON (c1) c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +5 +4 +2 +1 +3 + +# Test explain makes sense +query TT +EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; +---- +logical_plan +Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +--Sort: aggregate_test_100.c1 ASC NULLS LAST +----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +------TableScan: aggregate_test_100 projection=[c1, c2, c3] +physical_plan +ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] +----SortExec: expr=[c1@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3), FIRST_VALUE(aggregate_test_100.c2)], ordering_mode=Sorted +--------------SortExec: expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true + +# ON expressions are not a sub-set of the ORDER BY expressions +query error SELECT DISTINCT ON expressions must match initial ORDER BY expressions +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2, c3; + +# ON expressions are empty +query error DataFusion error: Error during planning: No `ON` expressions provided +SELECT DISTINCT ON () c1, c2 FROM aggregate_test_100 ORDER BY c1, c2; + +# Use expressions in the ON and ORDER BY clauses, as well as the selection +query II +SELECT DISTINCT ON (c2 % 2 = 0) c2, c3 - 100 FROM aggregate_test_100 ORDER BY c2 % 2 = 0, c3 DESC; +---- +1 25 +4 23 + +# Multiple complex expressions +query TIB +SELECT DISTINCT ON (chr(ascii(c1) + 3), c2 % 2) chr(ascii(upper(c1)) + 3), c2 % 2, c3 > 80 AND c2 % 2 = 1 +FROM aggregate_test_100 +WHERE c1 IN ('a', 'b') +ORDER BY chr(ascii(c1) + 3), c2 % 2, c3 DESC; +---- +D 0 false +D 1 true +E 0 false +E 1 false + +# Joins using CTEs +query II +WITH t1 AS (SELECT * FROM aggregate_test_100), +t2 AS (SELECT * FROM aggregate_test_100) +SELECT DISTINCT ON (t1.c1, t2.c2) t2.c3, t1.c4 +FROM t1 INNER JOIN t2 ON t1.c13 = t2.c13 +ORDER BY t1.c1, t2.c2, t2.c5 +LIMIT 3; +---- +-25 15295 +45 15673 +-72 -11122 diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6fe8eca33705..9356a7753427 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{Like, WindowFrameUnits}; +use datafusion::logical_expr::{Distinct, Like, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -244,11 +244,11 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Distinct(distinct) => { + LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(distinct.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; // Get grouping keys from the input relation's number of output fields - let grouping = (0..distinct.input.schema().fields().len()) + let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) .collect::>>()?;