Skip to content

Commit

Permalink
[GLUTEN-4789] [VL] Wrong preceding/following convert in window functi…
Browse files Browse the repository at this point in the history
…on node (#4788)
  • Loading branch information
WangGuangxin authored Feb 27, 2024
1 parent 8bb4969 commit 285e6d3
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.execution

import org.apache.spark.SparkConf

class VeloxWindowExpressionSuite extends WholeStageTransformerSuite {

protected val rootPath: String = getClass.getResource("/").getPath
override protected val backend: String = "velox"
override protected val resourcePath: String = "/tpch-data-parquet-velox"
override protected val fileFormat: String = "parquet"

override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.sql.sources.useV1SourceList", "avro")
}

test("window row frame with mix preceding and following") {
runQueryAndCompare(
"select max(l_suppkey) over" +
" (partition by l_suppkey order by l_orderkey " +
"rows between 2 preceding and 1 preceding) from lineitem ") {
checkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_suppkey) over" +
" (partition by l_suppkey order by l_orderkey " +
"rows between 2 following and 3 following) from lineitem ") {
checkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_suppkey) over" +
" (partition by l_suppkey order by l_orderkey " +
"rows between -3 following and -2 following) from lineitem ") {
checkOperatorMatch[WindowExecTransformer]
}

runQueryAndCompare(
"select max(l_suppkey) over" +
" (partition by l_suppkey order by l_orderkey " +
"rows between unbounded preceding and 3 following) from lineitem ") {
checkOperatorMatch[WindowExecTransformer]
}
}
}
38 changes: 14 additions & 24 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -788,39 +788,29 @@ const core::WindowNode::Frame createWindowFrame(
VELOX_FAIL("the window type only support ROWS and RANGE, and the input type is ", std::to_string(type));
}

auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound boundType) -> core::WindowNode::BoundType {
auto boundTypeConversion = [](::substrait::Expression_WindowFunction_Bound boundType)
-> std::tuple<core::WindowNode::BoundType, core::TypedExprPtr> {
// TODO: support non-literal expression.
if (boundType.has_current_row()) {
return core::WindowNode::BoundType::kCurrentRow;
return std::make_tuple(core::WindowNode::BoundType::kCurrentRow, nullptr);
} else if (boundType.has_unbounded_following()) {
return core::WindowNode::BoundType::kUnboundedFollowing;
return std::make_tuple(core::WindowNode::BoundType::kUnboundedFollowing, nullptr);
} else if (boundType.has_unbounded_preceding()) {
return core::WindowNode::BoundType::kUnboundedPreceding;
return std::make_tuple(core::WindowNode::BoundType::kUnboundedPreceding, nullptr);
} else if (boundType.has_following()) {
return core::WindowNode::BoundType::kFollowing;
return std::make_tuple(
core::WindowNode::BoundType::kFollowing,
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(boundType.following().offset())));
} else if (boundType.has_preceding()) {
return core::WindowNode::BoundType::kPreceding;
return std::make_tuple(
core::WindowNode::BoundType::kPreceding,
std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(boundType.preceding().offset())));
} else {
VELOX_FAIL("The BoundType is not supported.");
}
};
frame.startType = boundTypeConversion(lower_bound);
switch (frame.startType) {
case core::WindowNode::BoundType::kPreceding:
// TODO: support non-literal expression.
frame.startValue = std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(lower_bound.preceding().offset()));
break;
default:
frame.startValue = nullptr;
}
frame.endType = boundTypeConversion(upper_bound);
switch (frame.endType) {
// TODO: support non-literal expression.
case core::WindowNode::BoundType::kFollowing:
frame.endValue = std::make_shared<core::ConstantTypedExpr>(BIGINT(), variant(upper_bound.following().offset()));
break;
default:
frame.endValue = nullptr;
}
std::tie(frame.startType, frame.startValue) = boundTypeConversion(lower_bound);
std::tie(frame.endType, frame.endValue) = boundTypeConversion(upper_bound);
return frame;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class WindowFunctionNode implements Serializable {
}

private Expression.WindowFunction.Bound.Builder setBound(
Expression.WindowFunction.Bound.Builder builder, String boundType, boolean isLowerBound) {
Expression.WindowFunction.Bound.Builder builder, String boundType) {
switch (boundType) {
case ("CURRENT ROW"):
Expression.WindowFunction.Bound.CurrentRow.Builder currentRowBuilder =
Expand All @@ -77,7 +77,7 @@ private Expression.WindowFunction.Bound.Builder setBound(
default:
try {
Long offset = Long.valueOf(boundType);
if (isLowerBound) {
if (offset < 0) {
Expression.WindowFunction.Bound.Preceding.Builder offsetPrecedingBuilder =
Expression.WindowFunction.Bound.Preceding.newBuilder();
offsetPrecedingBuilder.setOffset(0 - offset);
Expand Down Expand Up @@ -129,8 +129,8 @@ public Expression.WindowFunction toProtobuf() {

Expression.WindowFunction.Bound.Builder upperBoundBuilder =
Expression.WindowFunction.Bound.newBuilder();
windowBuilder.setLowerBound(setBound(lowerBoundBuilder, lowerBound, true).build());
windowBuilder.setUpperBound(setBound(upperBoundBuilder, upperBound, false).build());
windowBuilder.setLowerBound(setBound(lowerBoundBuilder, lowerBound).build());
windowBuilder.setUpperBound(setBound(upperBoundBuilder, upperBound).build());
windowBuilder.setWindowType(getWindowType(frameType));
return windowBuilder.build();
}
Expand Down

0 comments on commit 285e6d3

Please sign in to comment.