Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support ORDER BY in new SQL engine #782

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ public LogicalPlan visitEval(Eval node, AnalysisContext context) {
@Override
public LogicalPlan visitSort(Sort node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);

// the first options is {"count": "integer"}
Integer count = (Integer) node.getOptions().get(0).getValue().getValue();
List<Pair<SortOption, Expression>> sortList =
Expand All @@ -298,7 +301,8 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) {
sortField -> {
// the first options is {"asc": "true/false"}
Boolean asc = (Boolean) sortField.getFieldArgs().get(0).getValue().getValue();
Expression expression = expressionAnalyzer.analyze(sortField, context);
Expression expression = optimizer.optimize(
expressionAnalyzer.analyze(sortField.getField(), context), context);
return ImmutablePair.of(
asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC, expression);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.In;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Interval;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.IntervalUnit;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Let;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Map;
Expand Down Expand Up @@ -228,27 +227,27 @@ public static UnresolvedArgument unresolvedArg(String argName, UnresolvedExpress
return new UnresolvedArgument(argName, argValue);
}

public static UnresolvedExpression field(UnresolvedExpression field) {
public Field field(UnresolvedExpression field) {
return new Field((QualifiedName) field);
}

public static Field field(String field) {
public Field field(String field) {
return new Field(field);
}

public static UnresolvedExpression field(UnresolvedExpression field, Argument... fieldArgs) {
return new Field((QualifiedName) field, Arrays.asList(fieldArgs));
public Field field(UnresolvedExpression field, Argument... fieldArgs) {
return new Field(field, Arrays.asList(fieldArgs));
}

public static Field field(String field, Argument... fieldArgs) {
public Field field(String field, Argument... fieldArgs) {
return new Field(field, Arrays.asList(fieldArgs));
}

public static UnresolvedExpression field(UnresolvedExpression field, List<Argument> fieldArgs) {
return new Field((QualifiedName) field, fieldArgs);
public Field field(UnresolvedExpression field, List<Argument> fieldArgs) {
return new Field(field, fieldArgs);
}

public static Field field(String field, List<Argument> fieldArgs) {
public Field field(String field, List<Argument> fieldArgs) {
return new Field(field, fieldArgs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
@EqualsAndHashCode(callSuper = false)
@AllArgsConstructor
public class Field extends UnresolvedExpression {
private QualifiedName field;
private UnresolvedExpression field;
private List<Argument> fieldArgs = Collections.emptyList();

public Field(QualifiedName field) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.RareTopN.CommandType;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort;
import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException;
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig;
Expand All @@ -49,6 +50,7 @@
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -255,6 +257,42 @@ public void project_values() {
);
}

@SuppressWarnings("unchecked")
@Test
public void sort_with_aggregator() {
assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.sort(
LogicalPlanDSL.aggregation(
LogicalPlanDSL.relation("test"),
ImmutableList.of(
DSL.named(
"avg(integer_value)",
dsl.avg(DSL.ref("integer_value", INTEGER)))),
ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))),
0,
// Aggregator in Sort AST node is replaced with reference by expression optimizer
Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("avg(integer_value)", DOUBLE))),
DSL.named("string_value", DSL.ref("string_value", STRING))),
AstDSL.project(
AstDSL.sort(
AstDSL.agg(
AstDSL.relation("test"),
ImmutableList.of(
AstDSL.alias(
"avg(integer_value)",
function("avg", qualifiedName("integer_value")))),
emptyList(),
ImmutableList.of(AstDSL.alias("string_value", qualifiedName("string_value"))),
emptyList()
),
ImmutableList.of(argument("count", intLiteral(0))),
field(
function("avg", qualifiedName("integer_value")),
argument("asc", booleanLiteral(true)))),
AstDSL.alias("string_value", qualifiedName("string_value"))));
}

@SuppressWarnings("unchecked")
@Test
public void window_function() {
Expand Down
1 change: 1 addition & 0 deletions docs/category.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"user/dql/expressions.rst",
"user/general/identifiers.rst",
"user/general/values.rst",
"user/dql/basics.rst",
"user/dql/functions.rst",
"user/dql/window.rst",
"user/beyond/partiql.rst",
Expand Down
24 changes: 24 additions & 0 deletions docs/user/dql/basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,30 @@ Result set:
| Quility|
+--------+

Example 3: Ordering by Aggregate Functions
------------------------------------------

Aggregate functions are allowed to be used in ``ORDER BY`` clause. You can reference it by same function call or its alias or ordinal in select list::

od> SELECT gender, MAX(age) FROM accounts GROUP BY gender ORDER BY MAX(age) DESC;
fetched rows / total rows = 2/2
+----------+------------+
| gender | MAX(age) |
|----------+------------|
| M | 36 |
| F | 28 |
+----------+------------+

Even if it's not present in ``SELECT`` clause, it can be also used as follows::

od> SELECT gender, MIN(age) FROM accounts GROUP BY gender ORDER BY MAX(age) DESC;
fetched rows / total rows = 2/2
+----------+------------+
| gender | MIN(age) |
|----------+------------|
| M | 32 |
| F | 28 |
+----------+------------+

LIMIT
=====
Expand Down
3 changes: 3 additions & 0 deletions integ-test/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ task integTestWithNewEngine(type: RestIntegTestTask) {

// Skip this IT to avoid breaking tests due to inconsistency in JDBC schema
exclude 'com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.class'

// Skip this IT because all assertions are against explain output
exclude 'com/amazon/opendistroforelasticsearch/sql/legacy/OrderIT.class'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ public void insert(String tableName, String[] columnNames, List<Object[]> batch)
public DBResult select(String query) {
try (Statement stmt = connection.createStatement()) {
ResultSet resultSet = stmt.executeQuery(query);
DBResult result = new DBResult(databaseName);
DBResult result = isOrderByQuery(query)
? DBResult.resultInOrder(databaseName) : DBResult.result(databaseName);
populateMetaData(resultSet, result);
populateData(resultSet, result);
return result;
Expand Down Expand Up @@ -200,6 +201,10 @@ private String mapToJDBCType(String esType) {
}
}

private boolean isOrderByQuery(String query) {
return query.trim().toUpperCase().contains("ORDER BY");
}

/**
* Setter for unit test mock
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
package com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset;

import com.amazon.opendistroforelasticsearch.sql.legacy.utils.StringUtils;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
Expand Down Expand Up @@ -60,11 +59,19 @@ public class DBResult {
private final Collection<Row> dataRows;

/**
* By default treat both columns and data rows in order. This makes sense for typical query
* with specific column names in SELECT but without ORDER BY.
* In theory, a result set is a multi-set (bag) that allows duplicate and doesn't
* have order.
*/
public DBResult(String databaseName) {
this(databaseName, new ArrayList<>(), new HashSet<>());
public static DBResult result(String databaseName) {
return new DBResult(databaseName, new ArrayList<>(), HashMultiset.create());
}

/**
* But for queries with ORDER BY clause, we want to preserve the original order of data rows
* so we can check if the order is correct.
*/
public static DBResult resultInOrder(String databaseName) {
return new DBResult(databaseName, new ArrayList<>(), new ArrayList<>());
}

public DBResult(String databaseName, Collection<Type> schema, Collection<Row> rows) {
Expand Down Expand Up @@ -97,10 +104,13 @@ public String getDatabaseName() {
}

/**
* Flatten for simplifying json generated
* Flatten for simplifying json generated.
*/
public Collection<Collection<Object>> getDataRows() {
return dataRows.stream().map(Row::getValues).collect(Collectors.toSet());
Collection<Collection<Object>> values = isDataRowOrdered()
? new ArrayList<>() : HashMultiset.create();
dataRows.stream().map(Row::getValues).forEach(values::add);
return values;
}

/**
Expand All @@ -124,6 +134,9 @@ private String diffSchema(DBResult other) {
}

private String diffDataRows(DBResult other) {
if (isDataRowOrdered()) {
return diff("Data row", (List<?>) dataRows, (List<?>) other.dataRows);
}
List<Row> thisRows = sort(dataRows);
List<Row> otherRows = sort(other.dataRows);
return diff("Data row", thisRows, otherRows);
Expand Down Expand Up @@ -160,6 +173,14 @@ private static int findFirstDifference(List<?> list1, List<?> list2) {
return -1;
}

/**
* Is data row a list that represent original order of data set
* which doesn't/shouldn't sort again.
*/
private boolean isDataRowOrdered() {
return (dataRows instanceof List);
}

/**
* Convert a collection to list and sort and return this new list.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.DBResult;
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Row;
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Type;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Arrays;
Expand Down Expand Up @@ -49,6 +50,36 @@ public void dbResultWithDifferentColumnShouldNotEqual() {
assertNotEquals(result1, result2);
}

@Test
public void dbResultWithSameRowsInDifferentOrderShouldEqual() {
DBResult result1 = DBResult.result("DB 1");
result1.addColumn("name", "VARCHAR");
result1.addRow(new Row(ImmutableList.of("test-1")));
result1.addRow(new Row(ImmutableList.of("test-2")));

DBResult result2 = DBResult.result("DB 2");
result2.addColumn("name", "VARCHAR");
result2.addRow(new Row(ImmutableList.of("test-2")));
result2.addRow(new Row(ImmutableList.of("test-1")));

assertEquals(result1, result2);
}

@Test
public void dbResultInOrderWithSameRowsInDifferentOrderShouldNotEqual() {
DBResult result1 = DBResult.resultInOrder("DB 1");
result1.addColumn("name", "VARCHAR");
result1.addRow(new Row(ImmutableList.of("test-1")));
result1.addRow(new Row(ImmutableList.of("test-2")));

DBResult result2 = DBResult.resultInOrder("DB 2");
result2.addColumn("name", "VARCHAR");
result2.addRow(new Row(ImmutableList.of("test-2")));
result2.addRow(new Row(ImmutableList.of("test-1")));

assertNotEquals(result1, result2);
}

@Test
public void dbResultWithDifferentColumnTypeShouldNotEqual() {
DBResult result1 = new DBResult("DB 1", Arrays.asList(new Type("age", "FLOAT")), emptyList());
Expand Down Expand Up @@ -89,4 +120,22 @@ public void shouldExplainDataRowsDifference() {
);
}

@Test
public void shouldExplainDataRowsOrderDifference() {
DBResult result1 = DBResult.resultInOrder("DB 1");
result1.addColumn("name", "VARCHAR");
result1.addRow(new Row(ImmutableList.of("hello")));
result1.addRow(new Row(ImmutableList.of("world")));

DBResult result2 = DBResult.resultInOrder("DB 2");
result2.addColumn("name", "VARCHAR");
result2.addRow(new Row(ImmutableList.of("world")));
result2.addRow(new Row(ImmutableList.of("hello")));

assertEquals(
"Data row at [0] is different: this=[Row(values=[hello])], other=[Row(values=[world])]",
result1.diff(result2)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.connection.JDBCConnection;
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.DBResult;
import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Type;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import java.sql.Connection;
Expand Down Expand Up @@ -123,10 +125,10 @@ public void testSelectQuery() throws SQLException {
result.getSchema()
);
assertEquals(
Sets.newHashSet(
HashMultiset.create(ImmutableList.of(
Arrays.asList("John", 25),
Arrays.asList("Hank", 30)
),
)),
result.getDataRows()
);
}
Expand Down Expand Up @@ -170,11 +172,11 @@ public void testSelectQueryWithFloatInResultSet() throws SQLException {
result.getSchema()
);
assertEquals(
Sets.newHashSet(
HashMultiset.create(ImmutableList.of(
Arrays.asList("John", 25.13),
Arrays.asList("Hank", 30.46),
Arrays.asList("Allen", 15.1)
),
)),
result.getDataRows()
);
}
Expand Down
Loading