diff --git a/query/src/sql/exec/expression_builder.rs b/query/src/sql/exec/expression_builder.rs index 267646f0e5e3..58967d292796 100644 --- a/query/src/sql/exec/expression_builder.rs +++ b/query/src/sql/exec/expression_builder.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use common_datavalues::DataSchemaRef; use common_datavalues::DataValue; use common_exception::ErrorCode; use common_exception::Result; @@ -39,9 +40,21 @@ impl<'a> ExpressionBuilder<'a> { ExpressionBuilder { metadata } } - pub fn build_and_rename(&self, scalar: &Scalar, index: IndexType) -> Result { + pub fn build_and_rename( + &self, + scalar: &Scalar, + index: IndexType, + input_schema: &DataSchemaRef, + ) -> Result { let mut expr = self.build(scalar)?; - expr = self.normalize_aggr_to_col(expr)?; + // If the input_schema contains the field of expr, + // it means that its sub-plan has already processed the expression + // and we can directly convert it to column expression + if input_schema.has_field(expr.column_name().as_str()) { + expr = Expression::Column(expr.column_name()); + } else { + expr = self.normalize_aggr_to_col(expr)?; + } let column = self.metadata.column(index); Ok(Expression::Alias( format_field_name(column.name.as_str(), index), diff --git a/query/src/sql/exec/mod.rs b/query/src/sql/exec/mod.rs index eb2d2b8f97ee..bf1891373a18 100644 --- a/query/src/sql/exec/mod.rs +++ b/query/src/sql/exec/mod.rs @@ -215,7 +215,7 @@ impl PipelineBuilder { let expr_builder = ExpressionBuilder::create(&self.metadata); for item in project.items.iter() { let scalar = &item.expr; - let expression = expr_builder.build_and_rename(scalar, item.index)?; + let expression = expr_builder.build_and_rename(scalar, item.index, &input_schema)?; expressions.push(expression); } pipeline.add_transform(|transform_input_port, transform_output_port| { @@ -341,32 +341,36 @@ impl PipelineBuilder { let pre_input_schema = input_schema.clone(); let input_schema = schema_builder.build_group_by(input_schema, group_expressions.as_slice())?; - pipeline.add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - group_expressions.clone(), - self.ctx.clone(), - ) - })?; + if !input_schema.eq(&pre_input_schema) { + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + group_expressions.clone(), + self.ctx.clone(), + ) + })?; + } // Process aggregation function with non-column expression, such as sum(3) let pre_input_schema = input_schema.clone(); let res = schema_builder.build_agg_func(pre_input_schema.clone(), agg_expressions.as_slice())?; let input_schema = res.0; - pipeline.add_transform(|transform_input_port, transform_output_port| { - ExpressionTransform::try_create( - transform_input_port, - transform_output_port, - pre_input_schema.clone(), - input_schema.clone(), - res.1.clone(), - self.ctx.clone(), - ) - })?; + if !input_schema.eq(&pre_input_schema) { + pipeline.add_transform(|transform_input_port, transform_output_port| { + ExpressionTransform::try_create( + transform_input_port, + transform_output_port, + pre_input_schema.clone(), + input_schema.clone(), + res.1.clone(), + self.ctx.clone(), + ) + })?; + } // Get partial schema from agg_expressions let partial_data_fields = @@ -419,7 +423,6 @@ impl PipelineBuilder { self.ctx.clone(), ) })?; - Ok(final_schema) } diff --git a/query/src/sql/planner/binder/aggregate.rs b/query/src/sql/planner/binder/aggregate.rs index ee0d20ae55b1..d41938b76e72 100644 --- a/query/src/sql/planner/binder/aggregate.rs +++ b/query/src/sql/planner/binder/aggregate.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use common_ast::ast::Expr; use common_exception::ErrorCode; use common_exception::Result; @@ -23,13 +25,34 @@ use crate::sql::optimizer::SExpr; use crate::sql::plans::AggregateFunction; use crate::sql::plans::AggregatePlan; use crate::sql::plans::Scalar; +use crate::sql::plans::Scalar::BoundColumnRef; use crate::sql::BindContext; +#[derive(Clone, PartialEq, Debug)] +pub struct AggregateInfo { + /// Aggregation scalar expression + pub agg_scalar_exprs: Option>, + + /// The origin scalar expression of Group by + /// For the sql: `SELECT a%3 as a1, count(1) as ct from t GROUP BY a1`; + /// The origin scalar expression is `a%3`. + pub origin_group_by: Option>, +} + +impl AggregateInfo { + pub fn new() -> Self { + AggregateInfo { + agg_scalar_exprs: None, + origin_group_by: None, + } + } +} + impl<'a> Binder { pub(crate) fn analyze_aggregate( &self, output_context: &BindContext, - input_context: &mut BindContext, + agg_info: &mut AggregateInfo, ) -> Result<()> { let mut agg_expr: Vec = Vec::new(); for agg_scalar in find_aggregate_scalars_from_bind_context(output_context)? { @@ -57,7 +80,7 @@ impl<'a> Binder { } } } - input_context.agg_scalar_exprs = Some(agg_expr); + agg_info.agg_scalar_exprs = Some(agg_expr); Ok(()) } @@ -65,16 +88,36 @@ impl<'a> Binder { &mut self, group_by_expr: &[Expr<'a>], input_context: &mut BindContext, + agg_info: &AggregateInfo, ) -> Result<()> { let scalar_binder = ScalarBinder::new(input_context, self.ctx.clone()); let mut group_expr = Vec::with_capacity(group_by_expr.len()); for expr in group_by_expr.iter() { - group_expr.push(scalar_binder.bind_expr(expr).await?); + let (scalar_expr, _) = scalar_binder.bind_expr(expr).await?; + if let BoundColumnRef(bound_column) = scalar_expr { + let col_name = bound_column.column.column_name.as_str(); + if let Some(origin_group_by) = agg_info.origin_group_by.as_ref() { + if origin_group_by.contains_key(bound_column.column.column_name.as_str()) { + // Use the origin group by expression + group_expr.push( + origin_group_by + .get(col_name) + .ok_or_else(|| { + ErrorCode::SemanticError({ + format!("Not exist alias name {}", col_name) + }) + })? + .clone(), + ); + continue; + } + } + } + group_expr.push(scalar_binder.bind_expr(expr).await?.0); } - let aggregate_plan = AggregatePlan { - group_expr: group_expr.into_iter().map(|(scalar, _)| scalar).collect(), - agg_expr: input_context.agg_scalar_exprs.clone().unwrap(), + group_expr, + agg_expr: agg_info.agg_scalar_exprs.clone().unwrap(), }; let new_expr = SExpr::create_unary( aggregate_plan.into(), diff --git a/query/src/sql/planner/binder/bind_context.rs b/query/src/sql/planner/binder/bind_context.rs index e7604ec6a9b5..09908f936a39 100644 --- a/query/src/sql/planner/binder/bind_context.rs +++ b/query/src/sql/planner/binder/bind_context.rs @@ -50,9 +50,6 @@ pub struct BindContext { /// The relational operator in current context pub expression: Option, - /// Aggregation scalar expression - pub agg_scalar_exprs: Option>, - /// Order by columnBinding, consider the sql: select sum(a) from t group by a,b order by b;, /// Order by requires not just the columns in the selection, /// but the columns of the entire table as well as the columns of the selection @@ -69,7 +66,6 @@ impl BindContext { _parent: Some(parent), columns: vec![], expression: None, - agg_scalar_exprs: None, order_by_columns: Some(Vec::new()), } } diff --git a/query/src/sql/planner/binder/project.rs b/query/src/sql/planner/binder/project.rs index f740f04480a8..cc973715ba1c 100644 --- a/query/src/sql/planner/binder/project.rs +++ b/query/src/sql/planner/binder/project.rs @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use common_ast::ast::Indirection; use common_ast::ast::SelectTarget; use common_exception::ErrorCode; use common_exception::Result; +use crate::sql::binder::aggregate::AggregateInfo; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; use crate::sql::planner::binder::BindContext; @@ -73,13 +76,15 @@ impl<'a> Binder { &mut self, select_list: &[SelectTarget<'a>], has_order_by: bool, - input_context: &BindContext, + agg_info: &mut AggregateInfo, + input_context: &mut BindContext, ) -> Result { let mut output_context = BindContext::new(); if has_order_by { output_context.order_by_columns = Some(input_context.columns.clone()); } output_context.expression = input_context.expression.clone(); + let mut origin_group_by = HashMap::new(); for select_target in select_list { match select_target { SelectTarget::QualifiedName(names) => { @@ -128,6 +133,10 @@ impl<'a> Binder { data_type, scalar: Some(Box::new(bound_expr.clone())), }; + if let Some(alias) = alias { + input_context.columns.push(column_binding.clone()); + origin_group_by.insert(alias.name.clone(), bound_expr.clone()); + } if has_order_by && !matches!(bound_expr, Scalar::BoundColumnRef(BoundColumnRef { .. })) { @@ -143,7 +152,9 @@ impl<'a> Binder { } } } - + if !origin_group_by.is_empty() { + agg_info.origin_group_by = Some(origin_group_by); + } Ok(output_context) } } diff --git a/query/src/sql/planner/binder/select.rs b/query/src/sql/planner/binder/select.rs index b45517fcfe28..77fbcb3517dc 100644 --- a/query/src/sql/planner/binder/select.rs +++ b/query/src/sql/planner/binder/select.rs @@ -28,6 +28,7 @@ use common_exception::Result; use common_planners::Expression; use crate::catalogs::CATALOG_DEFAULT; +use crate::sql::binder::aggregate::AggregateInfo; use crate::sql::binder::scalar_common::split_conjunctions; use crate::sql::optimizer::SExpr; use crate::sql::planner::binder::scalar::ScalarBinder; @@ -99,16 +100,19 @@ impl<'a> Binder { } // Output of current `SELECT` statement. - + let mut agg_info = AggregateInfo::new(); let mut output_context = self - .normalize_select_list(&stmt.select_list, has_order_by, &input_context) + .normalize_select_list( + &stmt.select_list, + has_order_by, + &mut agg_info, + &mut input_context, + ) .await?; - self.analyze_aggregate(&output_context, &mut input_context)?; - - if !input_context.agg_scalar_exprs.as_ref().unwrap().is_empty() || !stmt.group_by.is_empty() - { - self.bind_group_by(&stmt.group_by, &mut input_context) + self.analyze_aggregate(&output_context, &mut agg_info)?; + if !agg_info.agg_scalar_exprs.as_ref().unwrap().is_empty() || !stmt.group_by.is_empty() { + self.bind_group_by(&stmt.group_by, &mut input_context, &agg_info) .await?; output_context.expression = input_context.expression.clone(); } diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result index 4e83c8fd60b4..4dc55c8c9d24 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.result @@ -52,6 +52,33 @@ 0 0 9 +0 0 +0 1 +1 0 +1 1 +2 0 +2 1 +0 +1 +2 +0 4 +2 3 +NULL 3 +0 0 2 +0 2 2 +1 0 2 +1 2 1 +NULL NULL 3 +0 0 2 +0 1 2 +1 0 1 +1 1 2 +NULL 2 3 +0 0 2 +1 0 2 +0 1 1 +1 1 2 +2 NULL 3 ====INNER_JOIN==== 1 1 2 2 diff --git a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql index e1254a69b962..d9078d580926 100644 --- a/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql +++ b/tests/suites/0_stateless/20+_others/20_0001_planner_v2.sql @@ -51,6 +51,24 @@ select count(null) from numbers(1000); SELECT max(number) FROM numbers_mt (10) where number > 99999999998; SELECT max(number) FROM numbers_mt (10) where number > 2; +SELECT number%3 as c1, number%2 as c2 FROM numbers_mt(10000) where number > 2 group by number%3, number%2 order by c1,c2; +SELECT number%3 as c1 FROM numbers_mt(10) where number > 2 group by number%3 order by c1; + +CREATE TABLE t(a UInt64 null, b UInt32 null, c UInt32) Engine = Fuse; +INSERT INTO t(a,b, c) SELECT if (number % 3 = 1, null, number) as a, number + 3 as b, number + 4 as c FROM numbers(10); +-- nullable(u8) +SELECT a%3 as a1, count(1) as ct from t GROUP BY a1 ORDER BY a1,ct; + +-- nullable(u8), nullable(u8) +SELECT a%2 as a1, a%3 as a2, count(0) as ct FROM t GROUP BY a1, a2 ORDER BY a1, a2; + +-- nullable(u8), u64 +SELECT a%2 as a1, to_uint64(c % 3) as c1, count(0) as ct FROM t GROUP BY a1, c1 ORDER BY a1, c1, ct; +-- u64, nullable(u8) +SELECT to_uint64(c % 3) as c1, a%2 as a1, count(0) as ct FROM t GROUP BY a1, c1 ORDER BY a1, c1, ct; + +drop table t; + -- Inner join select '====INNER_JOIN===='; create table t(a int);