Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
Signed-off-by: chloe-zh <chloezh1102@gmail.com>
  • Loading branch information
chloe-zh committed Jun 9, 2021
1 parent 6f5350d commit e30b685
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,7 @@ public interface AggregationState {
*/
ExprValue result();

Set<ExprValue> distinctSet();
default Set<ExprValue> distinctValues() {
return Set.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public boolean conditionValue(BindingTuple tuple) {
}

private Boolean duplicated(ExprValue value, S state) {
for (ExprValue exprValue : state.distinctSet()) {
for (ExprValue exprValue : state.distinctValues()) {
if (exprValue.compareTo(value) == 0) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand Down Expand Up @@ -81,10 +80,5 @@ protected static class AvgState implements AggregationState {
public ExprValue result() {
return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count);
}

@Override
public Set<ExprValue> distinctSet() {
return Set.of();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;


import java.util.HashSet;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -84,7 +83,7 @@ public ExprValue result() {
}

@Override
public Set<ExprValue> distinctSet() {
public Set<ExprValue> distinctValues() {
return set;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.List;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
Expand Down Expand Up @@ -75,10 +74,5 @@ public void max(ExprValue value) {
public ExprValue result() {
return maxResult;
}

@Override
public Set<ExprValue> distinctSet() {
return Set.of();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,5 @@ public void min(ExprValue value) {
public ExprValue result() {
return minResult;
}

@Override
public Set<ExprValue> distinctSet() {
return Set.of();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,5 @@ public void add(ExprValue value) {
public ExprValue result() {
return isEmptyCollection ? ExprNullValue.of() : sumResult;
}

@Override
public Set<ExprValue> distinctSet() {
return Set.of();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*
*/

package org.opensearch.sql.expression.aggregation;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import org.junit.jupiter.api.Test;
import org.opensearch.sql.data.model.ExprIntegerValue;

public class AggregatorStateTest extends AggregationTest {

@Test
void count_distinct_values() {
CountAggregator.CountState state = new CountAggregator.CountState();
state.count(new ExprIntegerValue(1));
assertFalse(state.distinctValues().isEmpty());
}

@Test
void default_distinct_values() {
AvgAggregator.AvgState state = new AvgAggregator.AvgState();
assertTrue(state.distinctValues().isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionNodeVisitor;
import org.opensearch.sql.expression.LiteralExpression;
Expand Down Expand Up @@ -92,7 +93,7 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node,
case "count":
return make(AggregationBuilders.cardinality(name), expression);
default:
throw new IllegalStateException(String.format(
throw new ExpressionEvaluationException(String.format(
"unsupported distinct aggregator %s", node.getFunctionName().getFunctionName()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.aggregation.AvgAggregator;
import org.opensearch.sql.expression.aggregation.CountAggregator;
import org.opensearch.sql.expression.aggregation.MaxAggregator;
Expand Down Expand Up @@ -202,6 +203,14 @@ void should_build_cardinality_aggregation() {
Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))));
}

@Test
void should_throw_exception_for_unsupported_distinct_aggregator() {
assertThrows(ExpressionEvaluationException.class,
() -> buildQuery(Collections.singletonList(named("avg(distinct age)", new AvgAggregator(
Collections.singletonList(ref("name", STRING)), STRING).distinct(true)))),
"unsupported distinct aggregator avg");
}

@Test
void should_throw_exception_for_unsupported_aggregator() {
when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg"));
Expand Down
1 change: 1 addition & 0 deletions ppl/src/main/antlr/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ statsAggTerm
statsFunction
: statsFunctionName LT_PRTHS valueExpression RT_PRTHS #statsFunctionCall
| COUNT LT_PRTHS RT_PRTHS #countAllFunctionCall
| (DISTINCT_COUNT | DC) LT_PRTHS valueExpression? RT_PRTHS #distinctCountFunctionCall
| percentileAggFunction #percentileAggFunctionCall
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext;
import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext;
Expand Down Expand Up @@ -203,6 +204,12 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex
return new AggregateFunction("count", AllFields.of());
}

@Override
public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) {
return new AggregateFunction("count",
ctx.valueExpression() != null ? visit(ctx.valueExpression()) : AllFields.of(), true);
}

@Override
public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) {
return new AggregateFunction(ctx.PERCENTILE().getText(), visit(ctx.aggField),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs;
import static org.opensearch.sql.ast.dsl.AstDSL.defaultSortFieldArgs;
import static org.opensearch.sql.ast.dsl.AstDSL.defaultStatsArgs;
import static org.opensearch.sql.ast.dsl.AstDSL.distinctAggregate;
import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral;
import static org.opensearch.sql.ast.dsl.AstDSL.equalTo;
import static org.opensearch.sql.ast.dsl.AstDSL.eval;
Expand Down Expand Up @@ -376,6 +377,35 @@ public void testCountFuncCallExpr() {
));
}

@Test
public void testDistinctCount() {
assertEqual("source=t | stats distinct_count(a)",
agg(
relation("t"),
exprList(
alias("distinct_count(a)",
distinctAggregate("count", field("a")))),
emptyList(),
emptyList(),
defaultStatsArgs()));

assertEqual("source=t | stats dc() by b",
agg(
relation("t"),
exprList(
alias(
"dc()",
distinctAggregate("count", AllFields.of())
)
),
emptyList(),
exprList(
alias("b", field("b"))
),
defaultStatsArgs()
));
}

@Test
public void testEvalFuncCallExpr() {
assertEqual("source=t | eval f=abs(a)",
Expand Down

0 comments on commit e30b685

Please sign in to comment.