From b3df480e54947589c473c9c1acbfd4a7b7c496bd Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 27 Apr 2020 14:36:52 -0700 Subject: [PATCH] Bug fix, count(distinct field) should transalte to cardinality aggregation (#442) --- .../sql/query/maker/AggMaker.java | 3 ++- .../planner/converter/SQLAggregationParser.java | 9 +++++++-- .../sql/esintgtest/AggregationIT.java | 16 ++++++++++++++++ .../converter/SQLAggregationParserTest.java | 14 ++++++++++++++ 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/maker/AggMaker.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/maker/AggMaker.java index eee1c69baf..8c918728c0 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/maker/AggMaker.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/maker/AggMaker.java @@ -15,6 +15,7 @@ package com.amazon.opendistroforelasticsearch.sql.query.maker; +import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; import com.amazon.opendistroforelasticsearch.sql.domain.Condition; import com.amazon.opendistroforelasticsearch.sql.domain.Field; import com.amazon.opendistroforelasticsearch.sql.domain.KVValue; @@ -697,7 +698,7 @@ private RangeAggregationBuilder rangeBuilder(MethodField field) { private ValuesSourceAggregationBuilder makeCountAgg(MethodField field) { // Cardinality is approximate DISTINCT. - if ("DISTINCT".equals(field.getOption())) { + if (SQLAggregateOption.DISTINCT.equals(field.getOption())) { if (field.getParams().size() == 1) { return AggregationBuilders.cardinality(field.getAlias()).field(field.getParams().get(0).value diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java index e847fca83a..c4c5fa44f6 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java @@ -17,6 +17,7 @@ import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; +import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; import com.alibaba.druid.sql.ast.expr.SQLCastExpr; import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; @@ -259,8 +260,12 @@ public AggregationExpr(SQLAggregateExpr expr) { public static String nameOfExpr(SQLExpr expr) { String exprName = expr.toString().toLowerCase(); if (expr instanceof SQLAggregateExpr) { - exprName = String.format("%s(%s)", ((SQLAggregateExpr) expr).getMethodName(), - ((SQLAggregateExpr) expr).getArguments().get(0)); + SQLAggregateExpr aggExpr = (SQLAggregateExpr) expr; + SQLAggregateOption option = aggExpr.getOption(); + exprName = option == null + ? String.format("%s(%s)", aggExpr.getMethodName(), aggExpr.getArguments().get(0)) + : String.format("%s(%s %s)", aggExpr.getMethodName(), option.name(), + aggExpr.getArguments().get(0)); } else if (expr instanceof SQLMethodInvokeExpr) { exprName = String.format("%s(%s)", ((SQLMethodInvokeExpr) expr).getMethodName(), nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0))); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/AggregationIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/AggregationIT.java index 7ab11ae7df..70f75d3027 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/AggregationIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/AggregationIT.java @@ -38,6 +38,10 @@ import static com.amazon.opendistroforelasticsearch.sql.esintgtest.TestsConstants.TEST_INDEX_GAME_OF_THRONES; import static com.amazon.opendistroforelasticsearch.sql.esintgtest.TestsConstants.TEST_INDEX_NESTED_TYPE; import static com.amazon.opendistroforelasticsearch.sql.esintgtest.TestsConstants.TEST_INDEX_ONLINE; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.rows; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.schema; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.verifyDataRows; +import static com.amazon.opendistroforelasticsearch.sql.util.MatcherUtils.verifySchema; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -64,6 +68,14 @@ public void countTest() throws IOException { Assert.assertThat(getIntAggregationValue(result, "COUNT(*)", "value"), equalTo(1000)); } + @Test + public void countDistinctTest() { + JSONObject response = executeJdbcRequest(String.format("SELECT COUNT(distinct gender) FROM %s", TEST_INDEX_ACCOUNT)); + + verifySchema(response, schema("COUNT(DISTINCT gender)", null, "integer")); + verifyDataRows(response, rows(2)); + } + @Test public void countWithDocsHintTest() throws Exception { @@ -1238,4 +1250,8 @@ private double getDoubleAggregationValue(final JSONObject queryResult, final Str return targetField.getDouble(subFieldName); } + + private JSONObject executeJdbcRequest(String query) { + return new JSONObject(executeQuery(query, "jdbc")); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java index 0c825b5476..2d089012f0 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java @@ -263,6 +263,20 @@ public void noGroupKeyInSelectShouldPass() { columnNode("avg(age)", null, ExpressionFactory.ref("avg_0")))); } + @Test + public void aggWithDistinctShouldPass() { + String sql = "SELECT count(distinct gender) FROM t GROUP BY age"; + SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); + parser.parse(mYSqlSelectQueryBlock(sql)); + List sqlSelectItems = parser.selectItemList(); + List columnNodes = parser.getColumnNodes(); + + assertThat(sqlSelectItems, containsInAnyOrder( + agg("count", "gender", "count_0"))); + assertThat(columnNodes, containsInAnyOrder( + columnNode("count(distinct gender)", null, ExpressionFactory.ref("count_0")))); + } + /** * TermQueryExplainIT.testNestedSingleGroupBy */