diff --git a/src/query/ast/src/lib.rs b/src/query/ast/src/lib.rs index 0d6962f7b8e3..ea7564e8ff3d 100644 --- a/src/query/ast/src/lib.rs +++ b/src/query/ast/src/lib.rs @@ -28,6 +28,7 @@ pub use visitors::walk_expr; pub use visitors::walk_expr_mut; pub use visitors::walk_query; pub use visitors::walk_query_mut; +pub use visitors::walk_statement_mut; pub use visitors::Visitor; pub use visitors::VisitorMut; diff --git a/src/query/ast/src/visitors/visitor_mut.rs b/src/query/ast/src/visitors/visitor_mut.rs index 6678d756fb69..f0ac1220a466 100644 --- a/src/query/ast/src/visitors/visitor_mut.rs +++ b/src/query/ast/src/visitors/visitor_mut.rs @@ -344,7 +344,9 @@ pub trait VisitorMut: Sized { walk_query_mut(self, query); } - fn visit_explain(&mut self, _kind: &mut ExplainKind, _query: &mut Statement<'_>) {} + fn visit_explain(&mut self, _kind: &mut ExplainKind, stmt: &mut Statement<'_>) { + walk_statement_mut(self, stmt); + } fn visit_copy(&mut self, _copy: &mut CopyStmt<'_>) {} diff --git a/src/query/sql/src/planner/planner.rs b/src/query/sql/src/planner/planner.rs index 1c9080ffc260..14b69a79dd9e 100644 --- a/src/query/sql/src/planner/planner.rs +++ b/src/query/sql/src/planner/planner.rs @@ -14,16 +14,19 @@ use std::sync::Arc; +use common_ast::ast::Statement; use common_ast::parser::parse_sql; use common_ast::parser::token::Token; use common_ast::parser::token::TokenKind; use common_ast::parser::token::Tokenizer; +use common_ast::walk_statement_mut; use common_ast::Backtrace; use common_catalog::catalog::CatalogManager; use common_catalog::table_context::TableContext; use common_exception::Result; use parking_lot::RwLock; +use super::semantic::DistinctToGroupBy; use crate::optimizer::optimize; use crate::optimizer::OptimizerConfig; use crate::optimizer::OptimizerContext; @@ -76,8 +79,8 @@ impl Planner { let res = async { // Step 2: Parse the SQL. let backtrace = Backtrace::new(); - let (stmt, format) = parse_sql(&tokens, sql_dialect, &backtrace)?; - + let (mut stmt, format) = parse_sql(&tokens, sql_dialect, &backtrace)?; + replace_stmt(&mut stmt); // Step 3: Bind AST with catalog, and generate a pure logical SExpr let metadata = Arc::new(RwLock::new(Metadata::default())); let name_resolution_ctx = NameResolutionContext::try_from(settings.as_ref())?; @@ -126,3 +129,10 @@ impl Planner { } } } + +fn replace_stmt(stmt: &mut Statement) { + let mut visitors = vec![DistinctToGroupBy::default()]; + for v in visitors.iter_mut() { + walk_statement_mut(v, stmt) + } +} diff --git a/src/query/sql/src/planner/semantic/distinct_to_groupby.rs b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs new file mode 100644 index 000000000000..59ecdf001d27 --- /dev/null +++ b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs @@ -0,0 +1,105 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_ast::ast::Expr; +use common_ast::ast::Identifier; +use common_ast::ast::Query; +use common_ast::ast::SelectStmt; +use common_ast::ast::SetExpr; +use common_ast::ast::TableReference; +use common_ast::VisitorMut; + +#[derive(Debug, Clone, Default)] +pub struct DistinctToGroupBy {} + +impl VisitorMut for DistinctToGroupBy { + fn visit_select_stmt(&mut self, stmt: &mut SelectStmt<'_>) { + let SelectStmt { + select_list, + from, + selection, + group_by, + having, + .. + } = stmt; + + if group_by.is_empty() && select_list.len() == 1 && from.len() == 1 { + if let common_ast::ast::SelectTarget::AliasedExpr { + expr: + box Expr::FunctionCall { + span, + distinct, + name, + args, + .. + }, + alias, + } = &select_list[0] + { + if ((name.name.to_ascii_lowercase() == "count" && *distinct) + || name.name.to_ascii_lowercase() == "count_distinct") + && args.iter().all(|arg| !matches!(arg, Expr::Literal { .. })) + { + let tmp_token = span[0].clone(); + let subquery = Query { + span: &[], + with: None, + body: SetExpr::Select(Box::new(SelectStmt { + span: &[], + distinct: false, + select_list: vec![], + from: from.clone(), + selection: None, + group_by: args.clone(), + having: None, + })), + order_by: vec![], + limit: vec![], + offset: None, + ignore_result: false, + }; + + let new_stmt = SelectStmt { + span: &[], + distinct: false, + select_list: vec![common_ast::ast::SelectTarget::AliasedExpr { + expr: Box::new(Expr::FunctionCall { + span: &[], + distinct: false, + name: Identifier { + name: "count".to_string(), + quote: None, + span: tmp_token.clone(), + }, + args: vec![], + params: vec![], + }), + alias: alias.clone(), + }], + from: vec![TableReference::Subquery { + span: &[], + subquery: Box::new(subquery), + alias: None, + }], + selection: selection.clone(), + group_by: vec![], + having: having.clone(), + }; + + *stmt = new_stmt; + } + } + } + } +} diff --git a/src/query/sql/src/planner/semantic/mod.rs b/src/query/sql/src/planner/semantic/mod.rs index 536338dfb984..a6558d17f280 100644 --- a/src/query/sql/src/planner/semantic/mod.rs +++ b/src/query/sql/src/planner/semantic/mod.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod distinct_to_groupby; mod grouping_check; mod lowering; mod name_resolution; mod type_check; +pub use distinct_to_groupby::DistinctToGroupBy; pub use grouping_check::GroupingChecker; pub use name_resolution::compare_table_name; pub use name_resolution::normalize_identifier; diff --git a/tests/sqllogictests/suites/mode/standalone/explain/explain.test b/tests/sqllogictests/suites/mode/standalone/explain/explain.test index e4a1c9517551..af66a77b82cc 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain/explain.test +++ b/tests/sqllogictests/suites/mode/standalone/explain/explain.test @@ -768,6 +768,70 @@ HashJoin ├── push downs: [filters: [], limit: NONE] └── estimated rows: 5.00 +query T +explain select count(distinct a) from t1; +---- +EvalScalar +├── expressions: [count() (#3)] +├── estimated rows: 1.00 +└── AggregateFinal + ├── group by: [] + ├── aggregate functions: [count()] + ├── estimated rows: 1.00 + └── AggregatePartial + ├── group by: [] + ├── aggregate functions: [count()] + ├── estimated rows: 1.00 + └── AggregateFinal + ├── group by: [a] + ├── aggregate functions: [] + ├── estimated rows: 1.00 + └── AggregatePartial + ├── group by: [a] + ├── aggregate functions: [] + ├── estimated rows: 1.00 + └── TableScan + ├── table: default.default.t1 + ├── read rows: 1 + ├── read bytes: 31 + ├── partitions total: 1 + ├── partitions scanned: 1 + ├── push downs: [filters: [], limit: NONE] + ├── output columns: [0] + └── estimated rows: 1.00 + +query T +explain select count_distinct(a) from t1; +---- +EvalScalar +├── expressions: [count() (#3)] +├── estimated rows: 1.00 +└── AggregateFinal + ├── group by: [] + ├── aggregate functions: [count()] + ├── estimated rows: 1.00 + └── AggregatePartial + ├── group by: [] + ├── aggregate functions: [count()] + ├── estimated rows: 1.00 + └── AggregateFinal + ├── group by: [a] + ├── aggregate functions: [] + ├── estimated rows: 1.00 + └── AggregatePartial + ├── group by: [a] + ├── aggregate functions: [] + ├── estimated rows: 1.00 + └── TableScan + ├── table: default.default.t1 + ├── read rows: 1 + ├── read bytes: 31 + ├── partitions total: 1 + ├── partitions scanned: 1 + ├── push downs: [filters: [], limit: NONE] + ├── output columns: [0] + └── estimated rows: 1.00 + statement ok drop table t1