From 92594cdcabfe294d1720d0c80fb6c013467ea2d0 Mon Sep 17 00:00:00 2001 From: xudong963 Date: Wed, 11 May 2022 21:25:04 +0800 Subject: [PATCH 1/3] fix(planner): fix some cases in aggregator plan --- query/src/sql/exec/expression_builder.rs | 17 ++++++- query/src/sql/exec/mod.rs | 47 ++++++++++--------- query/src/sql/planner/binder/aggregate.rs | 26 +++++++++- query/src/sql/planner/binder/bind_context.rs | 8 ++++ query/src/sql/planner/binder/project.rs | 14 +++++- query/src/sql/planner/binder/select.rs | 2 +- .../20+_others/20_0001_planner_v2.result | 27 +++++++++++ .../20+_others/20_0001_planner_v2.sql | 18 +++++++ 8 files changed, 130 insertions(+), 29 deletions(-) diff --git a/query/src/sql/exec/expression_builder.rs b/query/src/sql/exec/expression_builder.rs index 267646f0e5e3..58967d292796 100644 --- a/query/src/sql/exec/expression_builder.rs +++ b/query/src/sql/exec/expression_builder.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::DataSchemaRef; use common_datavalues::DataValue; use common_exception::ErrorCode; use common_exception::Result; @@ -39,9 +40,21 @@ impl<'a> ExpressionBuilder<'a> { ExpressionBuilder { metadata } } - pub fn build_and_rename(&self, scalar: &Scalar, index: IndexType) -> Result { + pub fn build_and_rename( + &self, + scalar: &Scalar, + index: IndexType, + input_schema: &DataSchemaRef, + ) -> Result { let mut expr = self.build(scalar)?; - expr = self.normalize_aggr_to_col(expr)?; + // If the input_schema contains the field of expr, + // it means that its sub-plan has already processed the expression + // and we can directly convert it to column expression + if input_schema.has_field(expr.column_name().as_str()) { + expr = Expression::Column(expr.column_name()); + } else { + expr = self.normalize_aggr_to_col(expr)?; + } let column = self.metadata.column(index); Ok(Expression::Alias( format_field_name(column.name.as_str(), index), diff --git a/query/src/sql/exec/mod.rs b/query/src/sql/exec/mod.rs index eb2d2b8f97ee..bf1891373a18 100644 --- a/query/src/sql/exec/mod.rs +++ b/query/src/sql/exec/mod.rs @@ -215,7 +215,7 @@ impl PipelineBuilder { let expr_builder = ExpressionBuilder::create(&self.metadata); for item in project.items.iter() { let scalar = &item.expr; - let expression = expr_builder.build_and_rename(scalar, item.index)?; + let expression = expr_builder.build_and_rename(scalar, item.index, &input_schema)?; expressions.push(expression); } pipeline.add_transform(|transform_input_port, transform_output_port| { @@ -341,32 +341,36 @@ impl PipelineBuilder { let pre_input_schema = input_schema.clone(); let input_schema = schema_builder.build_group_by(input_schema, group_expressions.as_slice())?; - pipeline.add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - group_expressions.clone(), - self.ctx.clone(), - ) - })?; + if !input_schema.eq(&pre_input_schema) { + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + group_expressions.clone(), + self.ctx.clone(), + ) + })?; + } // Process aggregation function with non-column expression, such as sum(3) let pre_input_schema = input_schema.clone(); let res = schema_builder.build_agg_func(pre_input_schema.clone(), agg_expressions.as_slice())?; let input_schema = res.0; - pipeline.add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - res.1.clone(), - self.ctx.clone(), - ) - })?; + if !input_schema.eq(&pre_input_schema) { + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + res.1.clone(), + self.ctx.clone(), + ) + })?; + } // Get partial schema from agg_expressions let partial_data_fields = @@ -419,7 +423,6 @@ impl PipelineBuilder { self.ctx.clone(), ) })?; - Ok(final_schema) } diff --git a/query/src/sql/planner/binder/aggregate.rs b/query/src/sql/planner/binder/aggregate.rs index ee0d20ae55b1..a35fe165250e 100644 --- a/query/src/sql/planner/binder/aggregate.rs +++ b/query/src/sql/planner/binder/aggregate.rs @@ -23,6 +23,7 @@ use crate::sql::optimizer::SExpr; use crate::sql::plans::AggregateFunction; use crate::sql::plans::AggregatePlan; use crate::sql::plans::Scalar; +use crate::sql::plans::Scalar::BoundColumnRef; use crate::sql::BindContext; impl<'a> Binder { @@ -69,11 +70,32 @@ impl<'a> Binder { let scalar_binder = ScalarBinder::new(input_context, self.ctx.clone()); let mut group_expr = Vec::with_capacity(group_by_expr.len()); for expr in group_by_expr.iter() { - group_expr.push(scalar_binder.bind_expr(expr).await?); + let scalar_expr = scalar_binder.bind_expr(expr).await?.0; + if let BoundColumnRef(bound_column) = scalar_expr { + let col_name = bound_column.column.column_name.as_str(); + if input_context.origin_group_by.is_some() { + let origin_group_by = input_context.origin_group_by.as_ref().unwrap().clone(); + if origin_group_by.contains_key(bound_column.column.column_name.as_str()) { + // Use the origin group by expression + group_expr.push( + origin_group_by + .get(col_name) + .ok_or_else(|| { + ErrorCode::SemanticError({ + format!("Not exist alias name {}", col_name) + }) + })? + .clone(), + ); + continue; + } + } + } + group_expr.push(scalar_binder.bind_expr(expr).await?.0); } let aggregate_plan = AggregatePlan { - group_expr: group_expr.into_iter().map(|(scalar, _)| scalar).collect(), + group_expr, agg_expr: input_context.agg_scalar_exprs.clone().unwrap(), }; let new_expr = SExpr::create_unary( diff --git a/query/src/sql/planner/binder/bind_context.rs b/query/src/sql/planner/binder/bind_context.rs index e7604ec6a9b5..627040d5c8b1 100644 --- a/query/src/sql/planner/binder/bind_context.rs +++ b/query/src/sql/planner/binder/bind_context.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use common_ast::ast::Identifier; use common_ast::ast::TableAlias; use common_ast::parser::error::DisplayError as _; @@ -53,6 +55,11 @@ pub struct BindContext { /// Aggregation scalar expression pub agg_scalar_exprs: Option>, + /// The origin scalar expression of Group by + /// For the sql: `SELECT a%3 as a1, count(1) as ct from t GROUP BY a1`; + /// The origin scalar expression is `a%3`. + pub origin_group_by: Option>, + /// Order by columnBinding, consider the sql: select sum(a) from t group by a,b order by b;, /// Order by requires not just the columns in the selection, /// but the columns of the entire table as well as the columns of the selection @@ -70,6 +77,7 @@ impl BindContext { columns: vec![], expression: None, agg_scalar_exprs: None, + origin_group_by: None, order_by_columns: Some(Vec::new()), } } diff --git a/query/src/sql/planner/binder/project.rs b/query/src/sql/planner/binder/project.rs index f740f04480a8..9b70e97545f4 100644 --- a/query/src/sql/planner/binder/project.rs +++ b/query/src/sql/planner/binder/project.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use common_ast::ast::Indirection; use common_ast::ast::SelectTarget; use common_exception::ErrorCode; @@ -73,13 +75,14 @@ impl<'a> Binder { &mut self, select_list: &[SelectTarget<'a>], has_order_by: bool, - input_context: &BindContext, + input_context: &mut BindContext, ) -> Result { let mut output_context = BindContext::new(); if has_order_by { output_context.order_by_columns = Some(input_context.columns.clone()); } output_context.expression = input_context.expression.clone(); + let mut origin_group_by = HashMap::new(); for select_target in select_list { match select_target { SelectTarget::QualifiedName(names) => { @@ -128,6 +131,11 @@ impl<'a> Binder { data_type, scalar: Some(Box::new(bound_expr.clone())), }; + if alias.is_some() { + input_context.columns.push(column_binding.clone()); + origin_group_by + .insert(alias.as_ref().unwrap().name.clone(), bound_expr.clone()); + } if has_order_by && !matches!(bound_expr, Scalar::BoundColumnRef(BoundColumnRef { .. })) { @@ -143,7 +151,9 @@ impl<'a> Binder { } } } - + if !origin_group_by.is_empty() { + input_context.origin_group_by = Some(origin_group_by); + } Ok(output_context) } } diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index b45517fcfe28..a0cc0a5c6a38 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -101,7 +101,7 @@ impl<'a> Binder { // Output of current `SELECT` statement. let mut output_context = self - .normalize_select_list(&stmt.select_list, has_order_by, &input_context) + .normalize_select_list(&stmt.select_list, has_order_by, &mut input_context) .await?; self.analyze_aggregate(&output_context, &mut input_context)?; diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index 4e83c8fd60b4..4dc55c8c9d24 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -52,6 +52,33 @@ 0 0 9 +0 0 +0 1 +1 0 +1 1 +2 0 +2 1 +0 +1 +2 +0 4 +2 3 +NULL 3 +0 0 2 +0 2 2 +1 0 2 +1 2 1 +NULL NULL 3 +0 0 2 +0 1 2 +1 0 1 +1 1 2 +NULL 2 3 +0 0 2 +1 0 2 +0 1 1 +1 1 2 +2 NULL 3 ====INNER_JOIN==== 1 1 2 2 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index e1254a69b962..d9078d580926 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -51,6 +51,24 @@ select count(null) from numbers(1000); SELECT max(number) FROM numbers_mt (10) where number > 99999999998; SELECT max(number) FROM numbers_mt (10) where number > 2; +SELECT number%3 as c1, number%2 as c2 FROM numbers_mt(10000) where number > 2 group by number%3, number%2 order by c1,c2; +SELECT number%3 as c1 FROM numbers_mt(10) where number > 2 group by number%3 order by c1; + +CREATE TABLE t(a UInt64 null, b UInt32 null, c UInt32) Engine = Fuse; +INSERT INTO t(a,b, c) SELECT if (number % 3 = 1, null, number) as a, number + 3 as b, number + 4 as c FROM numbers(10); +-- nullable(u8) +SELECT a%3 as a1, count(1) as ct from t GROUP BY a1 ORDER BY a1,ct; + +-- nullable(u8), nullable(u8) +SELECT a%2 as a1, a%3 as a2, count(0) as ct FROM t GROUP BY a1, a2 ORDER BY a1, a2; + +-- nullable(u8), u64 +SELECT a%2 as a1, to_uint64(c % 3) as c1, count(0) as ct FROM t GROUP BY a1, c1 ORDER BY a1, c1, ct; +-- u64, nullable(u8) +SELECT to_uint64(c % 3) as c1, a%2 as a1, count(0) as ct FROM t GROUP BY a1, c1 ORDER BY a1, c1, ct; + +drop table t; + -- Inner join select '====INNER_JOIN===='; create table t(a int); From 11f8c15f52c14c1fe660c92e252e4adb2fd17877 Mon Sep 17 00:00:00 2001 From: xudong963 Date: Wed, 11 May 2022 23:22:16 +0800 Subject: [PATCH 2/3] address comments --- query/src/sql/planner/binder/aggregate.rs | 35 ++++++++++++++++---- query/src/sql/planner/binder/bind_context.rs | 12 ------- query/src/sql/planner/binder/project.rs | 4 ++- query/src/sql/planner/binder/select.rs | 18 ++++++---- 4 files changed, 42 insertions(+), 27 deletions(-) diff --git a/query/src/sql/planner/binder/aggregate.rs b/query/src/sql/planner/binder/aggregate.rs index a35fe165250e..d41938b76e72 100644 --- a/query/src/sql/planner/binder/aggregate.rs +++ b/query/src/sql/planner/binder/aggregate.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use common_ast::ast::Expr; use common_exception::ErrorCode; use common_exception::Result; @@ -26,11 +28,31 @@ use crate::sql::plans::Scalar; use crate::sql::plans::Scalar::BoundColumnRef; use crate::sql::BindContext; +#[derive(Clone, PartialEq, Debug)] +pub struct AggregateInfo { + /// Aggregation scalar expression + pub agg_scalar_exprs: Option>, + + /// The origin scalar expression of Group by + /// For the sql: `SELECT a%3 as a1, count(1) as ct from t GROUP BY a1`; + /// The origin scalar expression is `a%3`. + pub origin_group_by: Option>, +} + +impl AggregateInfo { + pub fn new() -> Self { + AggregateInfo { + agg_scalar_exprs: None, + origin_group_by: None, + } + } +} + impl<'a> Binder { pub(crate) fn analyze_aggregate( &self, output_context: &BindContext, - input_context: &mut BindContext, + agg_info: &mut AggregateInfo, ) -> Result<()> { let mut agg_expr: Vec = Vec::new(); for agg_scalar in find_aggregate_scalars_from_bind_context(output_context)? { @@ -58,7 +80,7 @@ impl<'a> Binder { } } } - input_context.agg_scalar_exprs = Some(agg_expr); + agg_info.agg_scalar_exprs = Some(agg_expr); Ok(()) } @@ -66,15 +88,15 @@ impl<'a> Binder { &mut self, group_by_expr: &[Expr<'a>], input_context: &mut BindContext, + agg_info: &AggregateInfo, ) -> Result<()> { let scalar_binder = ScalarBinder::new(input_context, self.ctx.clone()); let mut group_expr = Vec::with_capacity(group_by_expr.len()); for expr in group_by_expr.iter() { - let scalar_expr = scalar_binder.bind_expr(expr).await?.0; + let (scalar_expr, _) = scalar_binder.bind_expr(expr).await?; if let BoundColumnRef(bound_column) = scalar_expr { let col_name = bound_column.column.column_name.as_str(); - if input_context.origin_group_by.is_some() { - let origin_group_by = input_context.origin_group_by.as_ref().unwrap().clone(); + if let Some(origin_group_by) = agg_info.origin_group_by.as_ref() { if origin_group_by.contains_key(bound_column.column.column_name.as_str()) { // Use the origin group by expression group_expr.push( @@ -93,10 +115,9 @@ impl<'a> Binder { } group_expr.push(scalar_binder.bind_expr(expr).await?.0); } - let aggregate_plan = AggregatePlan { group_expr, - agg_expr: input_context.agg_scalar_exprs.clone().unwrap(), + agg_expr: agg_info.agg_scalar_exprs.clone().unwrap(), }; let new_expr = SExpr::create_unary( aggregate_plan.into(), diff --git a/query/src/sql/planner/binder/bind_context.rs b/query/src/sql/planner/binder/bind_context.rs index 627040d5c8b1..09908f936a39 100644 --- a/query/src/sql/planner/binder/bind_context.rs +++ b/query/src/sql/planner/binder/bind_context.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - use common_ast::ast::Identifier; use common_ast::ast::TableAlias; use common_ast::parser::error::DisplayError as _; @@ -52,14 +50,6 @@ pub struct BindContext { /// The relational operator in current context pub expression: Option, - /// Aggregation scalar expression - pub agg_scalar_exprs: Option>, - - /// The origin scalar expression of Group by - /// For the sql: `SELECT a%3 as a1, count(1) as ct from t GROUP BY a1`; - /// The origin scalar expression is `a%3`. - pub origin_group_by: Option>, - /// Order by columnBinding, consider the sql: select sum(a) from t group by a,b order by b;, /// Order by requires not just the columns in the selection, /// but the columns of the entire table as well as the columns of the selection @@ -76,8 +66,6 @@ impl BindContext { _parent: Some(parent), columns: vec![], expression: None, - agg_scalar_exprs: None, - origin_group_by: None, order_by_columns: Some(Vec::new()), } } diff --git a/query/src/sql/planner/binder/project.rs b/query/src/sql/planner/binder/project.rs index 9b70e97545f4..ff5e8e0e49af 100644 --- a/query/src/sql/planner/binder/project.rs +++ b/query/src/sql/planner/binder/project.rs @@ -19,6 +19,7 @@ use common_ast::ast::SelectTarget; use common_exception::ErrorCode; use common_exception::Result; +use crate::sql::binder::aggregate::AggregateInfo; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; use crate::sql::planner::binder::BindContext; @@ -75,6 +76,7 @@ impl<'a> Binder { &mut self, select_list: &[SelectTarget<'a>], has_order_by: bool, + agg_info: &mut AggregateInfo, input_context: &mut BindContext, ) -> Result { let mut output_context = BindContext::new(); @@ -152,7 +154,7 @@ impl<'a> Binder { } } if !origin_group_by.is_empty() { - input_context.origin_group_by = Some(origin_group_by); + agg_info.origin_group_by = Some(origin_group_by); } Ok(output_context) } diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index a0cc0a5c6a38..77fbcb3517dc 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -28,6 +28,7 @@ use common_exception::Result; use common_planners::Expression; use crate::catalogs::CATALOG_DEFAULT; +use crate::sql::binder::aggregate::AggregateInfo; use crate::sql::binder::scalar_common::split_conjunctions; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; @@ -99,16 +100,19 @@ impl<'a> Binder { } // Output of current `SELECT` statement. - + let mut agg_info = AggregateInfo::new(); let mut output_context = self - .normalize_select_list(&stmt.select_list, has_order_by, &mut input_context) + .normalize_select_list( + &stmt.select_list, + has_order_by, + &mut agg_info, + &mut input_context, + ) .await?; - self.analyze_aggregate(&output_context, &mut input_context)?; - - if !input_context.agg_scalar_exprs.as_ref().unwrap().is_empty() || !stmt.group_by.is_empty() - { - self.bind_group_by(&stmt.group_by, &mut input_context) + self.analyze_aggregate(&output_context, &mut agg_info)?; + if !agg_info.agg_scalar_exprs.as_ref().unwrap().is_empty() || !stmt.group_by.is_empty() { + self.bind_group_by(&stmt.group_by, &mut input_context, &agg_info) .await?; output_context.expression = input_context.expression.clone(); } From 1af0054c914fe1d5888dfdb214d0760818187280 Mon Sep 17 00:00:00 2001 From: xudong963 Date: Thu, 12 May 2022 11:35:07 +0800 Subject: [PATCH 3/3] remove is_some --- query/src/sql/planner/binder/project.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/query/src/sql/planner/binder/project.rs b/query/src/sql/planner/binder/project.rs index ff5e8e0e49af..cc973715ba1c 100644 --- a/query/src/sql/planner/binder/project.rs +++ b/query/src/sql/planner/binder/project.rs @@ -133,10 +133,9 @@ impl<'a> Binder { data_type, scalar: Some(Box::new(bound_expr.clone())), }; - if alias.is_some() { + if let Some(alias) = alias { input_context.columns.push(column_binding.clone()); - origin_group_by - .insert(alias.as_ref().unwrap().name.clone(), bound_expr.clone()); + origin_group_by.insert(alias.name.clone(), bound_expr.clone()); } if has_order_by && !matches!(bound_expr, Scalar::BoundColumnRef(BoundColumnRef { .. }))