diff --git a/README.md b/README.md index 622b943281..a1e0ba02ae 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Here is a documentation list with features only available in this improved SQL q * [Aggregations](./docs/user/dql/aggregations.rst): aggregation over expression and more other features * [Complex queries](./docs/user/dql/complex.rst) * Improvement on Subqueries in FROM clause -* [Window functions](./docs/user/dql/window.rst): ranking window function support +* [Window functions](./docs/user/dql/window.rst): ranking and aggregate window function support To avoid impact on your side, normally you won't see any difference in query response. If you want to check if and why your query falls back to be handled by old SQL engine, please explain your query and check Elasticsearch log for "Request is falling back to old SQL engine due to ...". diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java index a7ab7e9702..4738d25b74 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java @@ -44,12 +44,15 @@ import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.AggregationState; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator; import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause; import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.WhenClause; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName; +import com.amazon.opendistroforelasticsearch.sql.expression.window.aggregation.AggregateWindowFunction; +import com.amazon.opendistroforelasticsearch.sql.expression.window.ranking.RankingWindowFunction; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.Arrays; @@ -166,9 +169,15 @@ public Expression visitFunction(Function node, AnalysisContext context) { return (Expression) repository.compile(functionName, arguments); } + @SuppressWarnings("unchecked") @Override public Expression visitWindowFunction(WindowFunction node, AnalysisContext context) { - return visitFunction(node.getFunction(), context); + Expression expr = node.getFunction().accept(this, context); + // Wrap regular aggregator by aggregate window function to adapt window operator use + if (expr instanceof Aggregator) { + return new AggregateWindowFunction((Aggregator) expr); + } + return expr; } @Override diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java index b98c7be53e..eb837dbd26 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -20,6 +20,7 @@ import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor; import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.NamedExpression; import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator; import com.amazon.opendistroforelasticsearch.sql.expression.conditional.cases.CaseClause; @@ -89,6 +90,14 @@ public Expression visitAggregator(Aggregator node, AnalysisContext context) { return expressionMap.getOrDefault(node, node); } + @Override + public Expression visitNamed(NamedExpression node, AnalysisContext context) { + if (expressionMap.containsKey(node)) { + return expressionMap.get(node); + } + return node.getDelegated().accept(this, context); + } + /** * Implement this because Case/When is not registered in function repository. */ @@ -145,7 +154,7 @@ public Void visitAggregation(LogicalAggregation plan, Void context) { public Void visitWindow(LogicalWindow plan, Void context) { Expression windowFunc = plan.getWindowFunction(); expressionMap.put(windowFunc, - new ReferenceExpression(windowFunc.toString(), windowFunc.type())); + new ReferenceExpression(((NamedExpression) windowFunc).getName(), windowFunc.type())); return visitNode(plan, context); } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzer.java index 8a07d847b8..a949e6fcf3 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzer.java @@ -91,7 +91,15 @@ public List visitAlias(Alias node, AnalysisContext context) { private Expression referenceIfSymbolDefined(Alias expr, AnalysisContext context) { UnresolvedExpression delegatedExpr = expr.getDelegated(); - return optimizer.optimize(delegatedExpr.accept(expressionAnalyzer, context), context); + + // Pass named expression because expression like window function loses full name + // (OVER clause) and thus depends on name in alias to be replaced correctly + return optimizer.optimize( + DSL.named( + expr.getName(), + delegatedExpr.accept(expressionAnalyzer, context), + expr.getAlias()), + context); } @Override diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzer.java index d5fbe1b19b..acec2adfd7 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzer.java @@ -27,6 +27,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.WindowFunction; import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.NamedExpression; import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalPlan; import com.amazon.opendistroforelasticsearch.sql.planner.logical.LogicalSort; @@ -68,19 +69,26 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con @Override public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - return node.getDelegated().accept(this, context); - } + if (!(node.getDelegated() instanceof WindowFunction)) { + return null; + } + + WindowFunction unresolved = (WindowFunction) node.getDelegated(); + Expression windowFunction = expressionAnalyzer.analyze(unresolved, context); + List partitionByList = analyzePartitionList(unresolved, context); + List> sortList = analyzeSortList(unresolved, context); - @Override - public LogicalPlan visitWindowFunction(WindowFunction node, AnalysisContext context) { - Expression windowFunction = expressionAnalyzer.analyze(node, context); - List partitionByList = analyzePartitionList(node, context); - List> sortList = analyzeSortList(node, context); WindowDefinition windowDefinition = new WindowDefinition(partitionByList, sortList); + NamedExpression namedWindowFunction = + new NamedExpression(node.getName(), windowFunction, node.getAlias()); + List> allSortItems = windowDefinition.getAllSortItems(); + if (allSortItems.isEmpty()) { + return new LogicalWindow(child, namedWindowFunction, windowDefinition); + } return new LogicalWindow( - new LogicalSort(child, windowDefinition.getAllSortItems()), - windowFunction, + new LogicalSort(child, allSortItems), + namedWindowFunction, windowDefinition); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java index 2c53b5aa0c..dffcd4c0f1 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java @@ -236,7 +236,7 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) { return new When(condition, result); } - public UnresolvedExpression window(Function function, + public UnresolvedExpression window(UnresolvedExpression function, List partitionByList, List> sortList) { return new WindowFunction(function, partitionByList, sortList); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/WindowFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/WindowFunction.java index 976be0c48f..c886ebe929 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/WindowFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/WindowFunction.java @@ -19,7 +19,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor; import com.amazon.opendistroforelasticsearch.sql.ast.Node; import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption; -import java.util.Collections; +import com.google.common.collect.ImmutableList; import java.util.List; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; @@ -35,13 +35,17 @@ @ToString public class WindowFunction extends UnresolvedExpression { - private final Function function; + private final UnresolvedExpression function; private List partitionByList; private List> sortList; @Override public List getChild() { - return Collections.singletonList(function); + ImmutableList.Builder children = ImmutableList.builder(); + children.add(function); + children.addAll(partitionByList); + sortList.forEach(pair -> children.add(pair.getRight())); + return children.build(); } @Override diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/WindowFunctionExpression.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/WindowFunctionExpression.java new file mode 100644 index 0000000000..f22dcd9ba5 --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/WindowFunctionExpression.java @@ -0,0 +1,39 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.window; + +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; + +/** + * Window function abstraction. + */ +public interface WindowFunctionExpression extends Expression { + + /** + * Create specific window frame based on window definition and what's current window function. + * For now two types of cumulative window frame is returned: + * 1. Ranking window functions: ignore frame definition and always operates on + * previous and current row. + * 2. Aggregate window functions: frame partition into peers and sliding window is not supported. + * + * @param definition window definition + * @return window frame + */ + WindowFrame createWindowFrame(WindowDefinition definition); + +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunction.java new file mode 100644 index 0000000000..8d04bf6039 --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunction.java @@ -0,0 +1,78 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.window.aggregation; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType; +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.AggregationState; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator; +import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowFunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.PeerRowsWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +/** + * Aggregate function adapter that adapts Aggregator for window operator use. + */ +@EqualsAndHashCode +@RequiredArgsConstructor +public class AggregateWindowFunction implements WindowFunctionExpression { + + private final Aggregator aggregator; + private AggregationState state; + + @Override + public WindowFrame createWindowFrame(WindowDefinition definition) { + return new PeerRowsWindowFrame(definition); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + PeerRowsWindowFrame frame = (PeerRowsWindowFrame) valueEnv; + if (frame.isNewPartition()) { + state = aggregator.create(); + } + + List peers = frame.next(); + for (ExprValue peer : peers) { + state = aggregator.iterate(peer.bindingTuples(), state); + } + return state.result(); + } + + @Override + public ExprType type() { + return aggregator.type(); + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return aggregator.accept(visitor, context); + } + + @Override + public String toString() { + return aggregator.toString(); + } + +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrame.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/CurrentRowWindowFrame.java similarity index 72% rename from core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrame.java rename to core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/CurrentRowWindowFrame.java index 75a5b3605f..4a4d15e826 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrame.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/CurrentRowWindowFrame.java @@ -14,13 +14,14 @@ * */ -package com.amazon.opendistroforelasticsearch.sql.expression.window; +package com.amazon.opendistroforelasticsearch.sql.expression.window.frame; -import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; -import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.google.common.collect.PeekingIterator; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -30,22 +31,21 @@ import lombok.ToString; /** - * Cumulative window frame that accumulates data row incrementally as window operator iterates - * input rows. Conceptually, cumulative window frame should hold all seen rows till next partition. + * Conceptually, cumulative window frame should hold all seen rows till next partition. * This class is actually an optimized version that only hold previous and current row. This is * efficient and sufficient for ranking and aggregate window function support for now, though need * to add "real" cumulative frame implementation in future as needed. */ @EqualsAndHashCode -@Getter @RequiredArgsConstructor @ToString -public class CumulativeWindowFrame implements WindowFrame { +public class CurrentRowWindowFrame implements WindowFrame { + @Getter private final WindowDefinition windowDefinition; - private ExprTupleValue previous; - private ExprTupleValue current; + private ExprValue previous; + private ExprValue current; @Override public boolean isNewPartition() { @@ -61,30 +61,39 @@ public boolean isNewPartition() { } @Override - public int currentIndex() { - // Current row index is always 1 since only 2 rows maintained - return 1; + public void load(PeekingIterator it) { + previous = current; + current = it.next(); } @Override - public void add(ExprTupleValue row) { - previous = current; - current = row; + public ExprValue current() { + return current; } - @Override - public ExprTupleValue get(int index) { - if (index != 0 && index != 1) { - throw new IndexOutOfBoundsException("Index is out of boundary of window frame: " + index); - } - return (index == 0) ? previous : current; + public ExprValue previous() { + return previous; } - private List resolve(List expressions, ExprTupleValue row) { + private List resolve(List expressions, ExprValue row) { Environment valueEnv = row.bindingTuples(); return expressions.stream() .map(expr -> expr.valueOf(valueEnv)) .collect(Collectors.toList()); } + /** + * Current row window frame won't pre-fetch any row ahead. + * So always return false as nothing "cached" in frame. + */ + @Override + public boolean hasNext() { + return false; + } + + @Override + public List next() { + return Collections.emptyList(); + } + } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrame.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrame.java new file mode 100644 index 0000000000..7ba29ca014 --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrame.java @@ -0,0 +1,157 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.window.frame; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.google.common.collect.PeekingIterator; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; + +/** + * Window frame that only keep peers (tuples with same value of fields specified in sort list + * in window definition). See PeerWindowFrameTest for details about how this window frame + * interacts with window operator and window function. + */ +@RequiredArgsConstructor +public class PeerRowsWindowFrame implements WindowFrame { + + private final WindowDefinition windowDefinition; + + /** + * All peer rows (peer means rows in a partition that share same sort key + * based on sort list in window definition. + */ + private final List peers = new ArrayList<>(); + + /** + * Which row in the peer is currently being enriched by window function. + */ + private int position; + + /** + * Does row at current position represents a new partition. + */ + private boolean isNewPartition = true; + + /** + * If any more pre-fetched rows not returned to window operator yet. + */ + @Override + public boolean hasNext() { + return position < peers.size(); + } + + /** + * Move position and clear new partition flag. + * Note that because all peer rows have same result from window function, + * this is only returned at first time to change window function state. + * Afterwards, empty list is returned to avoid changes until next peer loaded. + * + * @return all rows for the peer + */ + @Override + public List next() { + isNewPartition = false; + if (position++ == 0) { + return peers; + } + return Collections.emptyList(); + } + + /** + * Current row at the position. Because rows are pre-fetched here, + * window operator needs to get them from here too. + * @return row at current position that being enriched by window function + */ + @Override + public ExprValue current() { + return peers.get(position); + } + + /** + * Preload all peer rows if last peer rows done. Note that when no more data in peeking iterator, + * there must be rows in frame (hasNext()=true), so no need to check it.hasNext() in this method. + * Load until: + * 1. Different peer found (row with different sort key) + * 2. Or new partition (row with different partition key) + * 3. Or no more rows + * @param it rows iterator + */ + @Override + public void load(PeekingIterator it) { + if (hasNext()) { + return; + } + + // Reset state: reset new partition before clearing peers + isNewPartition = !isSamePartition(it.peek()); + position = 0; + peers.clear(); + + while (it.hasNext()) { + ExprValue next = it.peek(); + if (peers.isEmpty()) { + peers.add(it.next()); + } else if (isSamePartition(next) && isPeer(next)) { + peers.add(it.next()); + } else { + break; + } + } + } + + @Override + public boolean isNewPartition() { + return isNewPartition; + } + + private boolean isPeer(ExprValue next) { + List sortFields = + windowDefinition.getSortList() + .stream() + .map(Pair::getRight) + .collect(Collectors.toList()); + + ExprValue last = peers.get(peers.size() - 1); + return resolve(sortFields, last).equals(resolve(sortFields, next)); + } + + private boolean isSamePartition(ExprValue next) { + if (peers.isEmpty()) { + return false; + } + + List partitionByList = windowDefinition.getPartitionByList(); + ExprValue last = peers.get(peers.size() - 1); + return resolve(partitionByList, last).equals(resolve(partitionByList, next)); + } + + private List resolve(List expressions, ExprValue row) { + Environment valueEnv = row.bindingTuples(); + return expressions.stream() + .map(expr -> expr.valueOf(valueEnv)) + .collect(Collectors.toList()); + } + +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/WindowFrame.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/WindowFrame.java index 4920598f69..fcc36e15fc 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/WindowFrame.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/WindowFrame.java @@ -16,10 +16,12 @@ package com.amazon.opendistroforelasticsearch.sql.expression.window.frame; -import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; +import com.google.common.collect.PeekingIterator; +import java.util.Iterator; +import java.util.List; /** * Window frame that represents a subset of a window which is all data accessible to @@ -30,11 +32,11 @@ * Note that which type of window frame is used is determined by both window function itself * and frame definition in a window definition. */ -public interface WindowFrame extends Environment { +public interface WindowFrame extends Environment, Iterator> { @Override default ExprValue resolve(Expression var) { - return var.valueOf(get(currentIndex()).bindingTuples()); + return var.valueOf(current().bindingTuples()); } /** @@ -44,22 +46,15 @@ default ExprValue resolve(Expression var) { boolean isNewPartition(); /** - * Get current row index in the frame. - * @return index + * Load one or more rows as window function calculation needed. + * @param iterator peeking iterator that can peek next element without moving iterator */ - int currentIndex(); + void load(PeekingIterator iterator); /** - * Add a row to the window frame. - * @param row data row - */ - void add(ExprTupleValue row); - - /** - * Get a data rows within the frame by offset. - * @param index index starting from 0 to upper boundary + * Get current data row for giving window operator chance to get rows preloaded into frame. * @return data row */ - ExprTupleValue get(int index); + ExprValue current(); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/DenseRankFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/DenseRankFunction.java index 0eb0941fa7..bea3fa3a4e 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/DenseRankFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/DenseRankFunction.java @@ -17,7 +17,7 @@ package com.amazon.opendistroforelasticsearch.sql.expression.window.ranking; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; /** * Dense rank window function that assigns a rank number to each row similarly as @@ -30,7 +30,7 @@ public DenseRankFunction() { } @Override - protected int rank(CumulativeWindowFrame frame) { + protected int rank(CurrentRowWindowFrame frame) { if (frame.isNewPartition()) { rank = 1; } else { diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankFunction.java index 2569c2ca16..eb2c45299f 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankFunction.java @@ -17,7 +17,7 @@ package com.amazon.opendistroforelasticsearch.sql.expression.window.ranking; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; /** * Rank window function that assigns a rank number to each row based on sort items @@ -36,7 +36,7 @@ public RankFunction() { } @Override - protected int rank(CumulativeWindowFrame frame) { + protected int rank(CurrentRowWindowFrame frame) { if (frame.isNewPartition()) { total = 1; rank = 1; diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunction.java index bb5419c105..0be473b7e3 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunction.java @@ -26,7 +26,9 @@ import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowFunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple; import java.util.List; @@ -37,7 +39,8 @@ * Ranking window function base class that captures same info across different ranking functions, * such as same return type (integer), same argument list (no arg). */ -public abstract class RankingWindowFunction extends FunctionExpression { +public abstract class RankingWindowFunction extends FunctionExpression + implements WindowFunctionExpression { /** * Current rank number assigned. @@ -53,9 +56,14 @@ public ExprType type() { return ExprCoreType.INTEGER; } + @Override + public WindowFrame createWindowFrame(WindowDefinition definition) { + return new CurrentRowWindowFrame(definition); + } + @Override public ExprValue valueOf(Environment valueEnv) { - return new ExprIntegerValue(rank((CumulativeWindowFrame) valueEnv)); + return new ExprIntegerValue(rank((CurrentRowWindowFrame) valueEnv)); } /** @@ -63,14 +71,14 @@ public ExprValue valueOf(Environment valueEnv) { * @param frame window frame * @return rank number */ - protected abstract int rank(CumulativeWindowFrame frame); + protected abstract int rank(CurrentRowWindowFrame frame); /** * Check sort field to see if current value is different from previous. * @param frame window frame * @return true if different, false if same or no sort list defined */ - protected boolean isSortFieldValueDifferent(CumulativeWindowFrame frame) { + protected boolean isSortFieldValueDifferent(CurrentRowWindowFrame frame) { if (isSortItemsNotDefined(frame)) { return false; } @@ -81,17 +89,17 @@ protected boolean isSortFieldValueDifferent(CumulativeWindowFrame frame) { .map(Pair::getRight) .collect(Collectors.toList()); - List previous = resolve(frame, sortItems, frame.currentIndex() - 1); - List current = resolve(frame, sortItems, frame.currentIndex()); + List previous = resolve(frame, sortItems, frame.previous()); + List current = resolve(frame, sortItems, frame.current()); return !current.equals(previous); } - private boolean isSortItemsNotDefined(CumulativeWindowFrame frame) { + private boolean isSortItemsNotDefined(CurrentRowWindowFrame frame) { return frame.getWindowDefinition().getSortList().isEmpty(); } - private List resolve(WindowFrame frame, List expressions, int index) { - BindingTuple valueEnv = frame.get(index).bindingTuples(); + private List resolve(WindowFrame frame, List expressions, ExprValue row) { + BindingTuple valueEnv = row.bindingTuples(); return expressions.stream() .map(expr -> expr.valueOf(valueEnv)) .collect(Collectors.toList()); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RowNumberFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RowNumberFunction.java index e11d071ffc..bb5abaa525 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RowNumberFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RowNumberFunction.java @@ -17,7 +17,7 @@ package com.amazon.opendistroforelasticsearch.sql.expression.window.ranking; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; /** * Row number window function that assigns row number starting from 1 to each row in a partition. @@ -29,7 +29,7 @@ public RowNumberFunction() { } @Override - protected int rank(CumulativeWindowFrame frame) { + protected int rank(CurrentRowWindowFrame frame) { if (frame.isNewPartition()) { rank = 1; } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanDSL.java index 9f2cd274f2..f3be1955b8 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanDSL.java @@ -59,7 +59,7 @@ public static LogicalPlan project(LogicalPlan input, NamedExpression... fields) } public LogicalPlan window(LogicalPlan input, - Expression windowFunction, + NamedExpression windowFunction, WindowDefinition windowDefinition) { return new LogicalWindow(input, windowFunction, windowDefinition); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalWindow.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalWindow.java index 664f12686d..aa7a04c7c4 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalWindow.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalWindow.java @@ -16,7 +16,7 @@ package com.amazon.opendistroforelasticsearch.sql.planner.logical; -import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.NamedExpression; import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; import java.util.Collections; import lombok.EqualsAndHashCode; @@ -32,7 +32,7 @@ @Getter @ToString public class LogicalWindow extends LogicalPlan { - private final Expression windowFunction; + private final NamedExpression windowFunction; private final WindowDefinition windowDefinition; /** @@ -40,7 +40,7 @@ public class LogicalWindow extends LogicalPlan { */ public LogicalWindow( LogicalPlan child, - Expression windowFunction, + NamedExpression windowFunction, WindowDefinition windowDefinition) { super(Collections.singletonList(child)); this.windowFunction = windowFunction; diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanDSL.java index eb5442b5f9..f0a0a5be8b 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanDSL.java @@ -83,7 +83,7 @@ public static DedupeOperator dedupe( } public WindowOperator window(PhysicalPlan input, - Expression windowFunction, + NamedExpression windowFunction, WindowDefinition windowDefinition) { return new WindowOperator(input, windowFunction, windowDefinition); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperator.java index 92730d95b3..1286307564 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperator.java @@ -18,11 +18,13 @@ import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; -import com.amazon.opendistroforelasticsearch.sql.expression.Expression; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.NamedExpression; import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowFunctionExpression; import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; import java.util.Collections; import java.util.List; import lombok.EqualsAndHashCode; @@ -39,7 +41,7 @@ public class WindowOperator extends PhysicalPlan { private final PhysicalPlan input; @Getter - private final Expression windowFunction; + private final NamedExpression windowFunction; @Getter private final WindowDefinition windowDefinition; @@ -48,6 +50,15 @@ public class WindowOperator extends PhysicalPlan { @ToString.Exclude private final WindowFrame windowFrame; + /** + * Peeking iterator that can peek next element which is required + * by window frame such as peer frame to prefetch all rows related + * to same peer (of same sorting key). + */ + @EqualsAndHashCode.Exclude + @ToString.Exclude + private final PeekingIterator peekingIterator; + /** * Initialize window operator. * @param input child operator @@ -55,12 +66,13 @@ public class WindowOperator extends PhysicalPlan { * @param windowDefinition window definition */ public WindowOperator(PhysicalPlan input, - Expression windowFunction, + NamedExpression windowFunction, WindowDefinition windowDefinition) { this.input = input; this.windowFunction = windowFunction; this.windowDefinition = windowDefinition; this.windowFrame = createWindowFrame(); + this.peekingIterator = Iterators.peekingIterator(input); } @Override @@ -75,30 +87,18 @@ public List getChild() { @Override public boolean hasNext() { - return input.hasNext(); + return peekingIterator.hasNext() || windowFrame.hasNext(); } @Override public ExprValue next() { - loadRowsIntoWindowFrame(); + windowFrame.load(peekingIterator); return enrichCurrentRowByWindowFunctionResult(); } - /** - * For now cumulative window frame is returned always. When frame definition is supported: - * 1. Ranking window functions: ignore frame definition and always operates on entire window. - * 2. Aggregate window functions: operates on cumulative or sliding window based on definition. - */ private WindowFrame createWindowFrame() { - return new CumulativeWindowFrame(windowDefinition); - } - - /** - * For now always load next row into window frame. In future, how/how many rows loaded - * should be based on window frame type. - */ - private void loadRowsIntoWindowFrame() { - windowFrame.add((ExprTupleValue) input.next()); + return ((WindowFunctionExpression) windowFunction.getDelegated()) + .createWindowFrame(windowDefinition); } private ExprValue enrichCurrentRowByWindowFunctionResult() { @@ -109,13 +109,13 @@ private ExprValue enrichCurrentRowByWindowFunctionResult() { } private void preserveAllOriginalColumns(ImmutableMap.Builder mapBuilder) { - ExprTupleValue inputValue = windowFrame.get(windowFrame.currentIndex()); + ExprValue inputValue = windowFrame.current(); inputValue.tupleValue().forEach(mapBuilder::put); } private void addWindowFunctionResultColumn(ImmutableMap.Builder mapBuilder) { ExprValue exprValue = windowFunction.valueOf(windowFrame); - mapBuilder.put(windowFunction.toString(), exprValue); + mapBuilder.put(windowFunction.getName(), exprValue); } } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java index b8cd8ede2f..4c97c83bdd 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java @@ -370,13 +370,15 @@ public void window_function() { LogicalPlanDSL.relation("test"), ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)), ImmutablePair.of(DEFAULT_ASC, DSL.ref("integer_value", INTEGER))), - dsl.rowNumber(), + DSL.named("window_function", dsl.rowNumber()), new WindowDefinition( ImmutableList.of(DSL.ref("string_value", STRING)), ImmutableList.of( ImmutablePair.of(DEFAULT_ASC, DSL.ref("integer_value", INTEGER))))), DSL.named("string_value", DSL.ref("string_value", STRING)), - DSL.named("window_function", DSL.ref("row_number()", INTEGER))), + // Alias name "window_function" is used as internal symbol name to connect + // project item and window operator output + DSL.named("window_function", DSL.ref("window_function", INTEGER))), AstDSL.project( AstDSL.relation("test"), AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")), diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java index 87cfd118e6..c3ecc4c8be 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java @@ -16,7 +16,6 @@ package com.amazon.opendistroforelasticsearch.sql.analysis; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field; -import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.filteredAggregate; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -25,6 +24,7 @@ import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRUCT; +import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -39,11 +39,8 @@ import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; -import com.amazon.opendistroforelasticsearch.sql.expression.LiteralExpression; import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; +import com.amazon.opendistroforelasticsearch.sql.expression.window.aggregation.AggregateWindowFunction; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.context.annotation.Configuration; @@ -99,7 +96,7 @@ public void not() { public void qualified_name() { assertAnalyzeEqual( DSL.ref("integer_value", INTEGER), - AstDSL.qualifiedName("integer_value") + qualifiedName("integer_value") ); } @@ -115,7 +112,7 @@ public void case_value() { dsl.equal(DSL.ref("integer_value", INTEGER), DSL.literal(50)), DSL.literal("Fifty"))), AstDSL.caseWhen( - AstDSL.qualifiedName("integer_value"), + qualifiedName("integer_value"), AstDSL.stringLiteral("Default value"), AstDSL.when(AstDSL.intLiteral(30), AstDSL.stringLiteral("Thirty")), AstDSL.when(AstDSL.intLiteral(50), AstDSL.stringLiteral("Fifty")))); @@ -136,11 +133,11 @@ public void case_conditions() { null, AstDSL.when( AstDSL.function(">", - AstDSL.qualifiedName("integer_value"), + qualifiedName("integer_value"), AstDSL.intLiteral(50)), AstDSL.stringLiteral("Fifty")), AstDSL.when( AstDSL.function(">", - AstDSL.qualifiedName("integer_value"), + qualifiedName("integer_value"), AstDSL.intLiteral(30)), AstDSL.stringLiteral("Thirty")))); } @@ -158,7 +155,7 @@ public void castAnalyzer() { @Test public void case_with_default_result_type_different() { UnresolvedExpression caseWhen = AstDSL.caseWhen( - AstDSL.qualifiedName("integer_value"), + qualifiedName("integer_value"), AstDSL.intLiteral(60), AstDSL.when(AstDSL.intLiteral(30), AstDSL.stringLiteral("Thirty")), AstDSL.when(AstDSL.intLiteral(50), AstDSL.stringLiteral("Fifty"))); @@ -170,19 +167,37 @@ public void case_with_default_result_type_different() { exception.getMessage()); } + @Test + public void scalar_window_function() { + assertAnalyzeEqual( + dsl.rank(), + AstDSL.window(AstDSL.function("rank"), emptyList(), emptyList())); + } + + @SuppressWarnings("unchecked") + @Test + public void aggregate_window_function() { + assertAnalyzeEqual( + new AggregateWindowFunction(dsl.avg(DSL.ref("integer_value", INTEGER))), + AstDSL.window( + AstDSL.aggregate("avg", qualifiedName("integer_value")), + emptyList(), + emptyList())); + } + @Test public void qualified_name_with_qualifier() { analysisContext.push(); analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); assertAnalyzeEqual( DSL.ref("integer_value", INTEGER), - AstDSL.qualifiedName("index_alias", "integer_value") + qualifiedName("index_alias", "integer_value") ); analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "nested_field"), STRUCT); SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, - () -> analyze(AstDSL.qualifiedName("nested_field", "integer_value"))); + () -> analyze(qualifiedName("nested_field", "integer_value"))); assertEquals( "The qualifier [nested_field] of qualified name [nested_field.integer_value] " + "must be an index name or its alias", @@ -217,7 +232,7 @@ public void case_clause() { AstDSL.nullLiteral(), AstDSL.when( AstDSL.function("=", - AstDSL.qualifiedName("integer_value"), + qualifiedName("integer_value"), AstDSL.intLiteral(30)), AstDSL.stringLiteral("test")))); } @@ -226,7 +241,7 @@ public void case_clause() { public void skip_struct_data_type() { SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, - () -> analyze(AstDSL.qualifiedName("struct_value"))); + () -> analyze(qualifiedName("struct_value"))); assertEquals( "Identifier [struct_value] of type [STRUCT] is not supported yet", exception.getMessage() @@ -237,7 +252,7 @@ public void skip_struct_data_type() { public void skip_array_data_type() { SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, - () -> analyze(AstDSL.qualifiedName("array_value"))); + () -> analyze(qualifiedName("array_value"))); assertEquals( "Identifier [array_value] of type [ARRAY] is not supported yet", exception.getMessage() diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizerTest.java index 7b39cf82f6..902408043d 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizerTest.java @@ -149,9 +149,9 @@ void window_expression_should_be_replaced() { LogicalPlanDSL.window( LogicalPlanDSL.window( LogicalPlanDSL.relation("test"), - dsl.rank(), + DSL.named(dsl.rank()), new WindowDefinition(emptyList(), emptyList())), - dsl.denseRank(), + DSL.named(dsl.denseRank()), new WindowDefinition(emptyList(), emptyList())); assertEquals( @@ -169,7 +169,7 @@ Expression optimize(Expression expression) { Expression optimize(Expression expression, LogicalPlan logicalPlan) { final ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(functionRepository, logicalPlan); - return optimizer.optimize(expression, new AnalysisContext()); + return optimizer.optimize(DSL.named(expression), new AnalysisContext()); } LogicalPlan logicalPlan() { diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzerTest.java index f0fe2db2a5..08558520f0 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/SelectExpressionAnalyzerTest.java @@ -112,7 +112,8 @@ public void field_name_in_expression_with_qualifier() { } protected List analyze(UnresolvedExpression unresolvedExpression) { - doAnswer(returnsFirstArg()).when(optimizer).optimize(any(), any()); + doAnswer(invocation -> ((NamedExpression) invocation.getArgument(0)) + .getDelegated()).when(optimizer).optimize(any(), any()); return new SelectExpressionAnalyzer(expressionAnalyzer) .analyze(Arrays.asList(unresolvedExpression), analysisContext, optimizer); diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzerTest.java index 3292690b9a..ddc324cb9a 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/WindowExpressionAnalyzerTest.java @@ -73,7 +73,7 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed() LogicalPlanDSL.relation("test"), ImmutablePair.of(DEFAULT_ASC, DSL.ref("string_value", STRING)), ImmutablePair.of(DEFAULT_DESC, DSL.ref("integer_value", INTEGER))), - dsl.rowNumber(), + DSL.named("row_number", dsl.rowNumber()), new WindowDefinition( ImmutableList.of(DSL.ref("string_value", STRING)), ImmutableList.of( @@ -89,6 +89,25 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed() analysisContext)); } + @Test + void should_not_generate_sort_operator_if_no_partition_by_and_order_by_list() { + assertEquals( + LogicalPlanDSL.window( + LogicalPlanDSL.relation("test"), + DSL.named("row_number", dsl.rowNumber()), + new WindowDefinition( + ImmutableList.of(), + ImmutableList.of())), + analyzer.analyze( + AstDSL.alias( + "row_number", + AstDSL.window( + AstDSL.function("row_number"), + ImmutableList.of(), + ImmutableList.of())), + analysisContext)); + } + @Test void should_return_original_child_if_project_item_not_windowed() { assertEquals( diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/ExplainTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/ExplainTest.java index dd63d53e69..3314feae18 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/ExplainTest.java @@ -168,7 +168,7 @@ void can_explain_window() { List> sortList = ImmutableList.of( ImmutablePair.of(DEFAULT_ASC, ref("age", INTEGER))); - PhysicalPlan plan = window(tableScan, dsl.rank(), + PhysicalPlan plan = window(tableScan, named(dsl.rank()), new WindowDefinition(partitionByList, sortList)); assertEquals( diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrameTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CurrentRowWindowFrameTest.java similarity index 51% rename from core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrameTest.java rename to core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CurrentRowWindowFrameTest.java index 62659710ae..64d271dec8 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CumulativeWindowFrameTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/CurrentRowWindowFrameTest.java @@ -21,61 +21,88 @@ import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; -import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.WindowFrame; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; -class CumulativeWindowFrameTest { +class CurrentRowWindowFrameTest { - private final WindowDefinition windowDefinition = new WindowDefinition( - ImmutableList.of(DSL.ref("state", STRING)), - ImmutableList.of(ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)))); + private final CurrentRowWindowFrame windowFrame = new CurrentRowWindowFrame( + new WindowDefinition( + ImmutableList.of(DSL.ref("state", STRING)), + ImmutableList.of(ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER))))); - private final WindowFrame windowFrame = new CumulativeWindowFrame(windowDefinition); + @Test + void test_iterator_methods() { + assertFalse(windowFrame.hasNext()); + assertTrue(windowFrame.next().isEmpty()); + } @Test void should_return_new_partition_if_partition_by_field_value_changed() { - ExprTupleValue tuple1 = ExprTupleValue.fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), - "age", new ExprIntegerValue(20))); - windowFrame.add(tuple1); + PeekingIterator iterator = Iterators.peekingIterator( + Iterators.forArray( + ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), + "age", new ExprIntegerValue(20))), + ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), + "age", new ExprIntegerValue(30))), + ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("CA"), + "age", new ExprIntegerValue(18))))); + + windowFrame.load(iterator); assertTrue(windowFrame.isNewPartition()); - ExprTupleValue tuple2 = ExprTupleValue.fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), - "age", new ExprIntegerValue(30))); - windowFrame.add(tuple2); + windowFrame.load(iterator); assertFalse(windowFrame.isNewPartition()); - ExprTupleValue tuple3 = ExprTupleValue.fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), - "age", new ExprIntegerValue(18))); - windowFrame.add(tuple3); + windowFrame.load(iterator); assertTrue(windowFrame.isNewPartition()); } @Test void can_resolve_single_expression_value() { - windowFrame.add(ExprTupleValue.fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), - "age", new ExprIntegerValue(20)))); + windowFrame.load(Iterators.peekingIterator( + Iterators.singletonIterator( + ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), + "age", new ExprIntegerValue(20)))))); assertEquals( new ExprIntegerValue(20), windowFrame.resolve(DSL.ref("age", INTEGER))); } @Test - void should_throw_exception_if_access_row_out_of_boundary() { - assertThrows(IndexOutOfBoundsException.class, () -> windowFrame.get(2)); + void can_return_previous_and_current_row() { + ExprValue row1 = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), + "age", new ExprIntegerValue(20))); + ExprValue row2 = ExprTupleValue.fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), + "age", new ExprIntegerValue(30))); + PeekingIterator iterator = Iterators.peekingIterator(Iterators.forArray(row1, row2)); + + windowFrame.load(iterator); + assertNull(windowFrame.previous()); + assertEquals(row1, windowFrame.current()); + + windowFrame.load(iterator); + assertEquals(row1, windowFrame.previous()); + assertEquals(row2, windowFrame.current()); } } \ No newline at end of file diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunctionTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunctionTest.java new file mode 100644 index 0000000000..df1eb7c25e --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/aggregation/AggregateWindowFunctionTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.window.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue.fromExprValueMap; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.Aggregator; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.PeerRowsWindowFrame; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +/** + * Aggregate window function test collection. + */ +@SuppressWarnings("unchecked") +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +class AggregateWindowFunctionTest extends ExpressionTestBase { + + @SuppressWarnings("rawtypes") + @Test + void test_delegated_methods() { + Aggregator aggregator = mock(Aggregator.class); + when(aggregator.type()).thenReturn(LONG); + when(aggregator.accept(any(), any())).thenReturn(123); + when(aggregator.toString()).thenReturn("avg(age)"); + + AggregateWindowFunction windowFunction = new AggregateWindowFunction(aggregator); + assertEquals(LONG, windowFunction.type()); + assertEquals(123, (Integer) windowFunction.accept(null, null)); + assertEquals("avg(age)", windowFunction.toString()); + } + + @Test + void should_accumulate_all_peer_values_and_not_reset_state_if_same_partition() { + PeerRowsWindowFrame windowFrame = mock(PeerRowsWindowFrame.class); + AggregateWindowFunction windowFunction = + new AggregateWindowFunction(dsl.sum(DSL.ref("age", INTEGER))); + + when(windowFrame.isNewPartition()).thenReturn(true); + when(windowFrame.next()).thenReturn(ImmutableList.of( + fromExprValueMap(ImmutableMap.of("age", new ExprIntegerValue(10))), + fromExprValueMap(ImmutableMap.of("age", new ExprIntegerValue(20))))); + assertEquals(new ExprIntegerValue(30), windowFunction.valueOf(windowFrame)); + + when(windowFrame.isNewPartition()).thenReturn(false); + when(windowFrame.next()).thenReturn(ImmutableList.of( + fromExprValueMap(ImmutableMap.of("age", new ExprIntegerValue(30))))); + assertEquals(new ExprIntegerValue(60), windowFunction.valueOf(windowFrame)); + } + +} diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrameTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrameTest.java new file mode 100644 index 0000000000..a95ba5f029 --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/frame/PeerRowsWindowFrameTest.java @@ -0,0 +1,264 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.window.frame; + +import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue.fromExprValueMap; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +class PeerRowsWindowFrameTest { + + private final PeerRowsWindowFrame windowFrame = new PeerRowsWindowFrame( + new WindowDefinition( + ImmutableList.of(DSL.ref("state", STRING)), + ImmutableList.of(Pair.of(DEFAULT_ASC, DSL.ref("age", INTEGER))))); + + @Test + void test_single_row() { + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.singletonIterator(tuple("WA", 10, 100))); + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals(ImmutableList.of(tuple("WA", 10, 100)), windowFrame.next()); + } + + @Test + void test_single_partition_with_no_more_rows_after_peers() { + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("WA", 20, 200), + tuple("WA", 20, 50))); + + // Here we simulate how WindowFrame interacts with WindowOperator which calls load() + // and WindowFunction which calls isNewPartition() and move() + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals(ImmutableList.of(tuple("WA", 10, 100)), windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of(tuple("WA", 20, 200), tuple("WA", 20, 50)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals(ImmutableList.of(), windowFrame.next()); + } + + @Test + void test_single_partition_with_more_rows_after_peers() { + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("WA", 20, 200), + tuple("WA", 20, 50), + tuple("WA", 35, 150))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 20, 200), + tuple("WA", 20, 50)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of(), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 35, 150)), + windowFrame.next()); + } + + @Test + void test_two_partitions_with_all_same_peers_in_second_partition() { + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("CA", 18, 150), + tuple("CA", 18, 100))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("CA", 18, 150), + tuple("CA", 18, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of(), + windowFrame.next()); + } + + @Test + void test_two_partitions_with_single_row_in_each_partition() { + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("CA", 30, 200))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("CA", 30, 200)), + windowFrame.next()); + } + + @Test + void test_window_definition_with_no_partition_by() { + PeerRowsWindowFrame windowFrame = new PeerRowsWindowFrame( + new WindowDefinition( + ImmutableList.of(), + ImmutableList.of(Pair.of(DEFAULT_ASC, DSL.ref("age", INTEGER))))); + + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("CA", 30, 200))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("CA", 30, 200)), + windowFrame.next()); + } + + @Test + void test_window_definition_with_no_order_by() { + PeerRowsWindowFrame windowFrame = new PeerRowsWindowFrame( + new WindowDefinition( + ImmutableList.of(DSL.ref("state", STRING)), + ImmutableList.of())); + + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("CA", 30, 200))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100)), + windowFrame.next()); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("CA", 30, 200)), + windowFrame.next()); + } + + @Test + void test_window_definition_with_no_partition_by_and_order_by() { + PeerRowsWindowFrame windowFrame = new PeerRowsWindowFrame( + new WindowDefinition( + ImmutableList.of(), + ImmutableList.of())); + + PeekingIterator tuples = Iterators.peekingIterator( + Iterators.forArray( + tuple("WA", 10, 100), + tuple("CA", 30, 200))); + + windowFrame.load(tuples); + assertTrue(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of( + tuple("WA", 10, 100), + tuple("CA", 30, 200)), + windowFrame.next()); + + windowFrame.load(tuples); + assertFalse(windowFrame.isNewPartition()); + assertEquals( + ImmutableList.of(), + windowFrame.next()); + } + + private ExprValue tuple(String state, int age, int balance) { + return fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue(state), + "age", new ExprIntegerValue(age), + "balance", new ExprIntegerValue(balance))); + } + +} diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunctionTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunctionTest.java index ada077ca09..83c79c3dc5 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunctionTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/window/ranking/RankingWindowFunctionTest.java @@ -24,13 +24,17 @@ import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase; -import com.amazon.opendistroforelasticsearch.sql.expression.window.CumulativeWindowFrame; import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.amazon.opendistroforelasticsearch.sql.expression.window.frame.CurrentRowWindowFrame; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; +import com.google.common.collect.PeekingIterator; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -44,22 +48,54 @@ @ExtendWith(MockitoExtension.class) class RankingWindowFunctionTest extends ExpressionTestBase { - private final CumulativeWindowFrame windowFrame1 = new CumulativeWindowFrame( + private final CurrentRowWindowFrame windowFrame1 = new CurrentRowWindowFrame( new WindowDefinition( ImmutableList.of(DSL.ref("state", STRING)), ImmutableList.of(Pair.of(DEFAULT_ASC, DSL.ref("age", INTEGER))))); - private final CumulativeWindowFrame windowFrame2 = new CumulativeWindowFrame( + private final CurrentRowWindowFrame windowFrame2 = new CurrentRowWindowFrame( new WindowDefinition( ImmutableList.of(DSL.ref("state", STRING)), ImmutableList.of())); // No sort items defined + private PeekingIterator iterator1; + private PeekingIterator iterator2; + + @BeforeEach + void set_up() { + iterator1 = Iterators.peekingIterator(Iterators.forArray( + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(40))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(20))))); + + iterator2 = Iterators.peekingIterator(Iterators.forArray( + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15))))); + } + @Test void test_value_of() { - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + PeekingIterator iterator = Iterators.peekingIterator( + Iterators.singletonIterator( + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))))); RankingWindowFunction rowNumber = dsl.rowNumber(); + + windowFrame1.load(iterator); assertEquals(new ExprIntegerValue(1), rowNumber.valueOf(windowFrame1)); } @@ -67,20 +103,16 @@ void test_value_of() { void test_row_number() { RankingWindowFunction rowNumber = dsl.rowNumber(); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator1); assertEquals(1, rowNumber.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator1); assertEquals(2, rowNumber.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(40)))); + windowFrame1.load(iterator1); assertEquals(3, rowNumber.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(20)))); + windowFrame1.load(iterator1); assertEquals(1, rowNumber.rank(windowFrame1)); } @@ -88,24 +120,19 @@ void test_row_number() { void test_rank() { RankingWindowFunction rank = dsl.rank(); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator2); assertEquals(1, rank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator2); assertEquals(1, rank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50)))); + windowFrame1.load(iterator2); assertEquals(3, rank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55)))); + windowFrame1.load(iterator2); assertEquals(4, rank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15)))); + windowFrame1.load(iterator2); assertEquals(1, rank.rank(windowFrame1)); } @@ -113,24 +140,19 @@ void test_rank() { void test_dense_rank() { RankingWindowFunction denseRank = dsl.denseRank(); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator2); assertEquals(1, denseRank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame1.load(iterator2); assertEquals(1, denseRank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50)))); + windowFrame1.load(iterator2); assertEquals(2, denseRank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55)))); + windowFrame1.load(iterator2); assertEquals(3, denseRank.rank(windowFrame1)); - windowFrame1.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15)))); + windowFrame1.load(iterator2); assertEquals(1, denseRank.rank(windowFrame1)); } @@ -138,45 +160,49 @@ void test_dense_rank() { void row_number_should_work_if_no_sort_items_defined() { RankingWindowFunction rowNumber = dsl.rowNumber(); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator1); assertEquals(1, rowNumber.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator1); assertEquals(2, rowNumber.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(40)))); + windowFrame2.load(iterator1); assertEquals(3, rowNumber.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(20)))); + windowFrame2.load(iterator1); assertEquals(1, rowNumber.rank(windowFrame2)); } @Test void rank_should_always_return_1_if_no_sort_items_defined() { + PeekingIterator iterator = Iterators.peekingIterator( + Iterators.forArray( + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55))), + fromExprValueMap(ImmutableMap.of( + "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15))))); + RankingWindowFunction rank = dsl.rank(); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator); assertEquals(1, rank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator); assertEquals(1, rank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50)))); + windowFrame2.load(iterator); assertEquals(1, rank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55)))); + windowFrame2.load(iterator); assertEquals(1, rank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15)))); + windowFrame2.load(iterator); assertEquals(1, rank.rank(windowFrame2)); } @@ -184,24 +210,19 @@ void rank_should_always_return_1_if_no_sort_items_defined() { void dense_rank_should_always_return_1_if_no_sort_items_defined() { RankingWindowFunction denseRank = dsl.denseRank(); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator2); assertEquals(1, denseRank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(30)))); + windowFrame2.load(iterator2); assertEquals(1, denseRank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(50)))); + windowFrame2.load(iterator2); assertEquals(1, denseRank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("WA"), "age", new ExprIntegerValue(55)))); + windowFrame2.load(iterator2); assertEquals(1, denseRank.rank(windowFrame2)); - windowFrame2.add(fromExprValueMap(ImmutableMap.of( - "state", new ExprStringValue("CA"), "age", new ExprIntegerValue(15)))); + windowFrame2.load(iterator2); assertEquals(1, denseRank.rank(windowFrame2)); } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/DefaultImplementorTest.java index 7fdcc6f3e4..04ee791325 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/DefaultImplementorTest.java @@ -178,7 +178,7 @@ public void visitRelationShouldThrowException() { @SuppressWarnings({"rawtypes", "unchecked"}) @Test public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { - Expression windowFunction = new RowNumberFunction(); + NamedExpression windowFunction = named(new RowNumberFunction()); WindowDefinition windowDefinition = new WindowDefinition( Collections.singletonList(ref("state", STRING)), Collections.singletonList( diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index e7a7ed590e..9b9863a5cc 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -114,7 +114,7 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(dedup.accept(new LogicalPlanNodeVisitor() { }, null)); - LogicalPlan window = LogicalPlanDSL.window(relation, expression, new WindowDefinition( + LogicalPlan window = LogicalPlanDSL.window(relation, named(expression), new WindowDefinition( ImmutableList.of(ref), ImmutableList.of(Pair.of(SortOption.DEFAULT_ASC, expression)))); assertNull(window.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 34399abaf2..f97c551afe 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -118,7 +118,7 @@ public void test_PhysicalPlanVisitor_should_return_null() { assertNull(project.accept(new PhysicalPlanNodeVisitor() { }, null)); - PhysicalPlan window = PhysicalPlanDSL.window(plan, dsl.rowNumber(), + PhysicalPlan window = PhysicalPlanDSL.window(plan, named(dsl.rowNumber()), new WindowDefinition(emptyList(), emptyList())); assertNull(window.accept(new PhysicalPlanNodeVisitor() { }, null)); diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperatorTest.java index 9063b0d1e2..61f0f0a9ae 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/WindowOperatorTest.java @@ -26,9 +26,11 @@ import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.Expression; -import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression; +import com.amazon.opendistroforelasticsearch.sql.expression.NamedExpression; import com.amazon.opendistroforelasticsearch.sql.expression.window.WindowDefinition; +import com.amazon.opendistroforelasticsearch.sql.expression.window.aggregation.AggregateWindowFunction; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; @@ -47,7 +49,7 @@ class WindowOperatorTest extends PhysicalPlanTestBase { @Test - void test() { + void test_ranking_window_function() { window(dsl.rank()) .partitionBy(ref("action", STRING)) .sortBy(DEFAULT_ASC, ref("response", INTEGER)) @@ -69,18 +71,68 @@ void test() { .done(); } - private WindowOperatorAssertion window(FunctionExpression windowFunction) { + @SuppressWarnings("unchecked") + @Test + void test_aggregate_window_function() { + window(new AggregateWindowFunction(dsl.sum(ref("response", INTEGER)))) + .partitionBy(ref("action", STRING)) + .sortBy(DEFAULT_ASC, ref("response", INTEGER)) + .expectNext(ImmutableMap.of( + "ip", "209.160.24.63", "action", "GET", "response", 200, "referer", "www.amazon.com", + "sum(response)", 400)) + .expectNext(ImmutableMap.of( + "ip", "112.111.162.4", "action", "GET", "response", 200, "referer", "www.amazon.com", + "sum(response)", 400)) + .expectNext(ImmutableMap.of( + "ip", "209.160.24.63", "action", "GET", "response", 404, "referer", "www.amazon.com", + "sum(response)", 804)) + .expectNext(ImmutableMap.of( + "ip", "74.125.19.106", "action", "POST", "response", 200, "referer", "www.google.com", + "sum(response)", 200)) + .expectNext(ImmutableMap.of( + "ip", "74.125.19.106", "action", "POST", "response", 500, + "sum(response)", 700)) + .done(); + } + + @SuppressWarnings("unchecked") + @Test + void test_aggregate_window_function_without_sort_key() { + window(new AggregateWindowFunction(dsl.sum(ref("response", INTEGER)))) + .expectNext(ImmutableMap.of( + "ip", "209.160.24.63", "action", "GET", "response", 200, "referer", "www.amazon.com", + "sum(response)", 1504)) + .expectNext(ImmutableMap.of( + "ip", "74.125.19.106", "action", "POST", "response", 500, + "sum(response)", 1504)) + .expectNext(ImmutableMap.of( + "ip", "74.125.19.106", "action", "POST", "response", 200, "referer", "www.google.com", + "sum(response)", 1504)) + .expectNext(ImmutableMap.of( + "ip", "112.111.162.4", "action", "GET", "response", 200, "referer", "www.amazon.com", + "sum(response)", 1504)) + .expectNext(ImmutableMap.of( + "ip", "209.160.24.63", "action", "GET", "response", 404, "referer", "www.amazon.com", + "sum(response)", 1504)) + .done(); + } + + private WindowOperatorAssertion window(Expression windowFunction) { return new WindowOperatorAssertion(windowFunction); } @RequiredArgsConstructor private static class WindowOperatorAssertion { - private final Expression windowFunction; + private final NamedExpression windowFunction; private final List partitionByList = new ArrayList<>(); private final List> sortList = new ArrayList<>(); private WindowOperator windowOperator; + private WindowOperatorAssertion(Expression windowFunction) { + this.windowFunction = DSL.named(windowFunction); + } + WindowOperatorAssertion partitionBy(Expression expr) { partitionByList.add(expr); return this; diff --git a/docs/dev/AggregateWindowFunction.md b/docs/dev/AggregateWindowFunction.md new file mode 100644 index 0000000000..53c01c27cb --- /dev/null +++ b/docs/dev/AggregateWindowFunction.md @@ -0,0 +1,94 @@ +# SQL Aggregate Window Functions + +## 1.Overview + +To support aggregate window functions, the following two problems need to be addressed: + +1. How to make existing aggregate functions work as window function +2. How to handle duplicate sort key (field values in ORDER BY, will elaborate shortly) + +For the first problem, a wrapper class AggregateWindowFunction is created. In particular, it extends Expression interface and reuse existing aggregate functions to calculate result based on window frame. **Now let’s examine in details how to address the second problem**. + + +## 2.Problem Statement + +First let’s check why it’s a problem to the window frame and ranking window function framework introduced earlier. In the following example, `age` as sort key in `ORDER BY age` is unique, so the running total is accumulated on each row incrementally. + +``` +mysql> SELECT + -> ROW_NUMBER() OVER () AS "no.", + -> state, age, balance, + -> SUM(balance) OVER (PARTITION BY state ORDER BY age) AS "running total" + -> FROM accounts + -> ORDER BY state DESC, age; ++-----+-------+------+---------+---------------+ +| no. | state | age | balance | running total | ++-----+-------+------+---------+---------------+ +| 1 | WA | 10 | 100 | 100 | +| 2 | WA | 20 | 200 | 300 | +| 3 | WA | 25 | 50 | 350 | +| 4 | WA | 35 | 150 | 500 | +| 5 | CA | 18 | 150 | 150 | +| 6 | CA | 25 | 100 | 250 | +| 7 | CA | 30 | 200 | 450 | ++-----+-------+------+---------+---------------+ +``` + +However, problem arises when the sort key has duplicate values. For example, the 2nd and 3rd row (called peers) in ‘WA’ partition has same value. Same for the 5th and 6th row in the ‘CA’ partition. In this case, the running total would be the same for peer rows. This looks strange at first sight, though the reason is **the fact that which row is current is defined by sort key. That’s why the existing window frame and function implementation only based on and access to current row won’t work.** + +``` +mysql> SELECT + -> ROW_NUMBER() OVER () AS "no.", + -> state, age, balance, + -> SUM(balance) OVER (PARTITION BY state ORDER BY age) AS "running total" + -> FROM accounts + -> ORDER BY state DESC, age; ++-----+-------+------+---------+---------------+ +| no. | state | age | balance | running total | ++-----+-------+------+---------+---------------+ +| 1 | WA | 10 | 100 | 100 | +| 2 | WA | 20 | 200 | 350 | +| 3 | WA | 20 | 50 | 350 | +| 4 | WA | 35 | 150 | 500 | +| 5 | CA | 18 | 150 | 250 | +| 6 | CA | 18 | 100 | 250 | +| 7 | CA | 30 | 200 | 450 | ++-----+-------+------+---------+---------------+ +``` + +## 3.Solution + +### 3.1 How It Works + +By the examples above, we should be able to understand what an aggregate window function does conceptually. To implement, first we need to figure out how aggregate window functions work from iterative thinking. Let’s review the previous example and imagine a query engine behind it doing the calculation. + +``` ++-----+-------+------+---------+---------------+ +| no. | state | age | balance | running total | ++-----+-------+------+---------+---------------+ <- initial state +| 1 | WA | 10 | 100 | 100 | <- load 100, return 100 as sum +| 2 | WA | 20 | 200 | 350 | <- load 200 and 50, return 350 +| 3 | WA | 20 | 50 | 350 | <- load nothing, return 350 again +| 4 | WA | 35 | 150 | 500 | <- load 150, return 500 +| 5 | CA | 18 | 150 | 250 | <- new partition, reset and load 100 and 150 +| 6 | CA | 18 | 100 | 250 | <- load nothing, return 250 again +| 7 | CA | 30 | 200 | 450 | <- load 200, return 450 ++-----+-------+------+---------+---------------+ <- no more data, done +``` + +### 3.2 Design + +To explain the design in more intuitive way, formal sequence diagram in UML is not present here. Instead the following informal diagram illustrates how `WindowOperator`, `PeerRowsWindowFrame` and `AggregateWindowFunction` component work together as a whole to implement the same logic in last section. + +![High Level Design](img/aggregate-window-functions.png) + +### 3.3 Performance + +For time complexity, aggregate window functions are same as ranking functions which only scan input linearly. However, as for space complexity, there seems no way to avoid this memory consumption. Because more rows needs to be pre-fetched for calculation and meanwhile window operator need access to each previous row, one by one, as output. + +In the worst case, all input data will be pulled out into window frame if: + +1. Single partition due to no PARTITION BY clause +2. (and) All values of ORDER BY fields are exactly the same + +In this case, circuit breaker needs to be enabled to protect window operator from consuming large memory. diff --git a/docs/dev/img/aggregate-window-functions.png b/docs/dev/img/aggregate-window-functions.png new file mode 100644 index 0000000000..9132280e60 Binary files /dev/null and b/docs/dev/img/aggregate-window-functions.png differ diff --git a/docs/user/dql/window.rst b/docs/user/dql/window.rst index 004e3a42df..ba105903e6 100644 --- a/docs/user/dql/window.rst +++ b/docs/user/dql/window.rst @@ -36,6 +36,117 @@ The syntax of a window function is as follows in which both ``PARTITION BY`` and ) +Aggregate Functions +=================== + +Aggregate functions are window functions that operates on a cumulative window frame to calculate an aggregated result. How cumulative data in the window frame being aggregated is exactly same as how regular aggregate functions work. So aggregate window functions can be used to perform running calculation easily, for example running average or running sum. Note that if ``PARTITION BY`` clause present and specified column value(s) changed, the state of aggregate function will be reset. + +COUNT +----- + +Here is an example for ``COUNT`` function:: + + od> SELECT + ... gender, balance, + ... COUNT(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS cnt + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------+ + | gender | balance | cnt | + |----------+-----------+-------| + | F | 32838 | 1 | + | M | 4180 | 1 | + | M | 5686 | 2 | + | M | 39225 | 3 | + +----------+-----------+-------+ + +MIN +--- + +Here is an example for ``MIN`` function:: + + od> SELECT + ... gender, balance, + ... MIN(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS cnt + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------+ + | gender | balance | cnt | + |----------+-----------+-------| + | F | 32838 | 32838 | + | M | 4180 | 4180 | + | M | 5686 | 4180 | + | M | 39225 | 4180 | + +----------+-----------+-------+ + +MAX +--- + +Here is an example for ``MAX`` function:: + + od> SELECT + ... gender, balance, + ... MAX(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS cnt + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------+ + | gender | balance | cnt | + |----------+-----------+-------| + | F | 32838 | 32838 | + | M | 4180 | 4180 | + | M | 5686 | 5686 | + | M | 39225 | 39225 | + +----------+-----------+-------+ + +AVG +--- + +Here is an example for ``AVG`` function:: + + od> SELECT + ... gender, balance, + ... AVG(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS cnt + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | cnt | + |----------+-----------+--------------------| + | F | 32838 | 32838.0 | + | M | 4180 | 4180.0 | + | M | 5686 | 4933.0 | + | M | 39225 | 16363.666666666666 | + +----------+-----------+--------------------+ + +SUM +--- + +Here is an example for ``SUM`` function:: + + od> SELECT + ... gender, balance, + ... SUM(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS cnt + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------+ + | gender | balance | cnt | + |----------+-----------+-------| + | F | 32838 | 32838 | + | M | 4180 | 4180 | + | M | 5686 | 9866 | + | M | 39225 | 49091 | + +----------+-----------+-------+ + + Ranking Functions ================= diff --git a/docs/user/limitations/limitations.rst b/docs/user/limitations/limitations.rst index cd39a1472a..bffa9a0c97 100644 --- a/docs/user/limitations/limitations.rst +++ b/docs/user/limitations/limitations.rst @@ -63,13 +63,24 @@ Here's a link to the Github issue - `Issue 110 sortItem = ImmutablePair.of(DEFAULT_ASC, DSL.ref("age", INTEGER)); WindowDefinition windowDefinition = diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index 53e682b0f0..8a1191d938 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -4,10 +4,29 @@ SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles) AS rnk FROM kib SELECT DistanceMiles, ROW_NUMBER() OVER (ORDER BY DistanceMiles DESC) AS num FROM kibana_sample_data_flights SELECT DistanceMiles, RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM kibana_sample_data_flights SELECT DistanceMiles, DENSE_RANK() OVER (ORDER BY DistanceMiles DESC) AS rnk FROM kibana_sample_data_flights +SELECT DistanceMiles, COUNT(DistanceMiles) OVER () AS num FROM kibana_sample_data_flights +SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM kibana_sample_data_flights +SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM kibana_sample_data_flights +SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM kibana_sample_data_flights +SELECT DistanceMiles, MIN(DistanceMiles) OVER () AS num FROM kibana_sample_data_flights +SELECT FlightDelayMin, DistanceMiles, SUM(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM kibana_sample_data_flights +SELECT FlightDelayMin, DistanceMiles, AVG(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM kibana_sample_data_flights +SELECT FlightDelayMin, DistanceMiles, MAX(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM kibana_sample_data_flights +SELECT FlightDelayMin, DistanceMiles, MIN(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM kibana_sample_data_flights SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM kibana_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM kibana_sample_data_ecommerce +SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM kibana_sample_data_ecommerce +SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM kibana_sample_data_ecommerce +SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM kibana_sample_data_ecommerce +SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM kibana_sample_data_ecommerce +SELECT user, MIN(day_of_week_i) OVER (ORDER BY user) AS num FROM kibana_sample_data_ecommerce SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM kibana_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM kibana_sample_data_ecommerce +SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM kibana_sample_data_ecommerce +SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM kibana_sample_data_ecommerce +SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM kibana_sample_data_ecommerce +SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM kibana_sample_data_ecommerce +SELECT user, MIN(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM kibana_sample_data_ecommerce SELECT customer_gender, user, ROW_NUMBER() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM kibana_sample_data_ecommerce SELECT customer_gender, user, RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM kibana_sample_data_ecommerce SELECT customer_gender, user, DENSE_RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM kibana_sample_data_ecommerce diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index f645174397..84ab26279f 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -155,12 +155,14 @@ limitClause ; // Window Function's Details -windowFunction - : function=rankingWindowFunction overClause +windowFunctionClause + : function=windowFunction overClause ; -rankingWindowFunction - : functionName=(ROW_NUMBER | RANK | DENSE_RANK) LR_BRACKET RR_BRACKET +windowFunction + : functionName=(ROW_NUMBER | RANK | DENSE_RANK) + LR_BRACKET functionArgs? RR_BRACKET #scalarWindowFunction + | aggregateFunction #aggregateWindowFunction ; overClause @@ -283,7 +285,7 @@ nullNotnull functionCall : scalarFunctionName LR_BRACKET functionArgs? RR_BRACKET #scalarFunctionCall | specificFunction #specificFunctionCall - | windowFunction #windowFunctionCall + | windowFunctionClause #windowFunctionCall | aggregateFunction #aggregateFunctionCall | aggregateFunction (orderByClause)? filterClause #filteredAggregationFunctionCall ; diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java index 6942aab238..84e58d9535 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilder.java @@ -27,7 +27,10 @@ import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BooleanContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CaseFuncAlternativeContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CaseFunctionCallContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ColumnFilterContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ConvertedDataTypeContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CountStarFunctionCallContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.DataTypeFunctionCallContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.DateLiteralContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.IsNullPredicateContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.LikePredicateContext; @@ -36,16 +39,19 @@ import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.NullLiteralContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OverClauseContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.QualifiedNameContext; -import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RankingWindowFunctionContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RegexpPredicateContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RegularAggregateFunctionCallContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ScalarFunctionCallContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ScalarWindowFunctionContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ShowDescribePatternContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SignedDecimalContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SignedRealContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.StringContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.StringLiteralContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.TableFilterContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.TimeLiteralContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.TimestampLiteralContext; -import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionClauseContext; import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.createSortOption; import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL; @@ -68,6 +74,7 @@ import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.AndExpressionContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ColumnNameContext; +import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.FunctionArgsContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.IdentContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.IntervalLiteralContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.NestedExpressionAtomContext; @@ -122,28 +129,18 @@ public UnresolvedExpression visitNestedExpressionAtom(NestedExpressionAtomContex @Override public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ctx) { - if (ctx.functionArgs() == null) { - return new Function(ctx.scalarFunctionName().getText(), Collections.emptyList()); - } - return new Function( - ctx.scalarFunctionName().getText(), - ctx.functionArgs() - .functionArg() - .stream() - .map(this::visitFunctionArg) - .collect(Collectors.toList()) - ); + return visitFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs()); } @Override - public UnresolvedExpression visitTableFilter(OpenDistroSQLParser.TableFilterContext ctx) { + public UnresolvedExpression visitTableFilter(TableFilterContext ctx) { return new Function( LIKE.getName().getFunctionName(), Arrays.asList(qualifiedName("TABLE_NAME"), visit(ctx.showDescribePattern()))); } @Override - public UnresolvedExpression visitColumnFilter(OpenDistroSQLParser.ColumnFilterContext ctx) { + public UnresolvedExpression visitColumnFilter(ColumnFilterContext ctx) { return new Function( LIKE.getName().getFunctionName(), Arrays.asList(qualifiedName("COLUMN_NAME"), visit(ctx.showDescribePattern()))); @@ -151,7 +148,7 @@ public UnresolvedExpression visitColumnFilter(OpenDistroSQLParser.ColumnFilterCo @Override public UnresolvedExpression visitShowDescribePattern( - OpenDistroSQLParser.ShowDescribePatternContext ctx) { + ShowDescribePatternContext ctx) { if (ctx.compatibleID() != null) { return stringLiteral(ctx.compatibleID().getText()); } else { @@ -167,7 +164,7 @@ public UnresolvedExpression visitFilteredAggregationFunctionCall( } @Override - public UnresolvedExpression visitWindowFunction(WindowFunctionContext ctx) { + public UnresolvedExpression visitWindowFunctionClause(WindowFunctionClauseContext ctx) { OverClauseContext overClause = ctx.overClause(); List partitionByList = Collections.emptyList(); @@ -188,12 +185,12 @@ public UnresolvedExpression visitWindowFunction(WindowFunctionContext ctx) { createSortOption(item), visit(item.expression()))) .collect(Collectors.toList()); } - return new WindowFunction((Function) visit(ctx.function), partitionByList, sortList); + return new WindowFunction(visit(ctx.function), partitionByList, sortList); } @Override - public UnresolvedExpression visitRankingWindowFunction(RankingWindowFunctionContext ctx) { - return new Function(ctx.functionName.getText(), Collections.emptyList()); + public UnresolvedExpression visitScalarWindowFunction(ScalarWindowFunctionContext ctx) { + return visitFunction(ctx.functionName.getText(), ctx.functionArgs()); } @Override @@ -272,7 +269,7 @@ public UnresolvedExpression visitBoolean(BooleanContext ctx) { } @Override - public UnresolvedExpression visitStringLiteral(OpenDistroSQLParser.StringLiteralContext ctx) { + public UnresolvedExpression visitStringLiteral(StringLiteralContext ctx) { return AstDSL.stringLiteral(StringUtils.unquoteText(ctx.getText())); } @@ -332,16 +329,29 @@ public UnresolvedExpression visitCaseFuncAlternative(CaseFuncAlternativeContext @Override public UnresolvedExpression visitDataTypeFunctionCall( - OpenDistroSQLParser.DataTypeFunctionCallContext ctx) { + DataTypeFunctionCallContext ctx) { return new Cast(visit(ctx.expression()), visit(ctx.convertedDataType())); } @Override public UnresolvedExpression visitConvertedDataType( - OpenDistroSQLParser.ConvertedDataTypeContext ctx) { + ConvertedDataTypeContext ctx) { return AstDSL.stringLiteral(ctx.getText()); } + private Function visitFunction(String functionName, FunctionArgsContext args) { + if (args == null) { + return new Function(functionName, Collections.emptyList()); + } + return new Function( + functionName, + args.functionArg() + .stream() + .map(this::visitFunctionArg) + .collect(Collectors.toList()) + ); + } + private QualifiedName visitIdentifiers(List identifiers) { return new QualifiedName( identifiers.stream() diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java index 16b518db3d..0349b2fa51 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java @@ -22,7 +22,7 @@ import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SelectClauseContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SelectElementContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SubqueryAsRelationContext; -import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionClauseContext; import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.createSortOption; import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.getTextInQuery; @@ -183,7 +183,7 @@ public Void visitSubqueryAsRelation(SubqueryAsRelationContext ctx) { } @Override - public Void visitWindowFunction(WindowFunctionContext ctx) { + public Void visitWindowFunctionClause(WindowFunctionClauseContext ctx) { // skip collecting sort items in window functions return null; } diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java index 7ff33f8603..901a62c5b8 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -16,6 +16,7 @@ package com.amazon.opendistroforelasticsearch.sql.sql.parser; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.aggregate; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.and; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.booleanLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.caseWhen; @@ -294,6 +295,17 @@ public void canBuildWindowFunctionWithoutOrderBy() { buildExprAst("RANK() OVER (PARTITION BY state)")); } + @Test + public void canBuildAggregateWindowFunction() { + assertEquals( + window( + aggregate("AVG", qualifiedName("age")), + ImmutableList.of(qualifiedName("state")), + ImmutableList.of(ImmutablePair.of( + new SortOption(null, null), qualifiedName("age")))), + buildExprAst("AVG(age) OVER (PARTITION BY state ORDER BY age)")); + } + @Test public void canBuildCaseConditionStatement() { assertEquals(