Skip to content

Commit 0a06919

Browse files
authored
feat: add orderby and unit test for window union to handle same timestamp (#1834)
* Add orderby and unit test for window union * Move window union order by in sortWithinPartition
1 parent 583fd0c commit 0a06919

File tree

3 files changed

+98
-13
lines changed

3 files changed

+98
-13
lines changed

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala

+30-11
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ import com._4paradigm.openmldb.batch.window.{WindowAggPlanUtil, WindowComputer}
2424
import com._4paradigm.openmldb.batch.{OpenmldbBatchConfig, PlanContext, SparkInstance}
2525
import com._4paradigm.openmldb.common.codec.CodecUtil
2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.expressions.JoinedRow
2827
import org.apache.spark.sql.types.{DateType, LongType, StructType, TimestampType}
29-
import org.apache.spark.sql.{DataFrame, Row, functions}
28+
import org.apache.spark.sql.{Column, DataFrame, Row, functions}
3029
import org.apache.spark.util.SerializableConfiguration
3130
import org.slf4j.LoggerFactory
31+
3232
import scala.collection.mutable
3333

3434
/** The planner which implements window agg physical node.
@@ -67,17 +67,26 @@ object WindowAggPlan {
6767
val dfWithIndex = inputTable.getDfConsideringIndex(ctx, physicalNode.GetNodeId())
6868

6969
// Do union if physical node has union flag
70+
val uniqueColName = "_WINDOW_UNION_FLAG_" + System.currentTimeMillis()
7071
val unionTable = if (isWindowWithUnion) {
71-
WindowAggPlanUtil.windowUnionTables(ctx, physicalNode, dfWithIndex)
72+
WindowAggPlanUtil.windowUnionTables(ctx, physicalNode, dfWithIndex, uniqueColName)
7273
} else {
7374
dfWithIndex
7475
}
7576

76-
// Do groupby and sort with window skew optimization or not
77+
// Use order by to make sure that rows with same timestamp from primary will be placed in last
78+
// TODO(tobe): support desc if we get config from physical plan
79+
val unionSparkCol: Option[Column] = if (isWindowWithUnion) {
80+
Some(unionTable.col(uniqueColName))
81+
} else {
82+
None
83+
}
84+
85+
// Do group by and sort with window skew optimization or not
7786
val repartitionDf = if (isWindowSkewOptimization) {
78-
windowPartitionWithSkewOpt(ctx, physicalNode, unionTable, windowAggConfig)
87+
windowPartitionWithSkewOpt(ctx, physicalNode, unionTable, windowAggConfig, unionSparkCol)
7988
} else {
80-
windowPartition(ctx, physicalNode, unionTable)
89+
windowPartition(ctx, physicalNode, unionTable, unionSparkCol)
8190
}
8291

8392
// Get the output schema which may add the index column
@@ -179,7 +188,8 @@ object WindowAggPlan {
179188
def windowPartitionWithSkewOpt(ctx: PlanContext,
180189
windowAggNode: PhysicalWindowAggrerationNode,
181190
inputDf: DataFrame,
182-
windowAggConfig: WindowAggConfig): DataFrame = {
191+
windowAggConfig: WindowAggConfig,
192+
unionSparkCol: Option[Column]): DataFrame = {
183193
val uniqueNamePostfix = ctx.getConf.windowSkewOptPostfix
184194

185195
// Cache the input table which may be used for multiple times
@@ -274,7 +284,12 @@ object WindowAggPlan {
274284
}
275285

276286
val sortedByCol = PhysicalNodeUtil.getOrderbyColumns(windowAggNode, addColumnsDf)
277-
val sortedByCols = repartitionCols ++ sortedByCol
287+
288+
val sortedByCols = if (unionSparkCol.isEmpty) {
289+
repartitionCols ++ sortedByCol
290+
} else {
291+
repartitionCols ++ sortedByCol ++ Array(unionSparkCol.get)
292+
}
278293

279294
// Notice that we should make sure the keys in the same partition are ordering as well
280295
val sortedDf = repartitionDf.sortWithinPartitions(sortedByCols: _*)
@@ -289,7 +304,8 @@ object WindowAggPlan {
289304
* 1. Repartition the table with the "partition by" keys.
290305
* 2. Sort the data within partitions with the "order by" keys.
291306
*/
292-
def windowPartition(ctx: PlanContext, windowAggNode: PhysicalWindowAggrerationNode, inputDf: DataFrame): DataFrame = {
307+
def windowPartition(ctx: PlanContext, windowAggNode: PhysicalWindowAggrerationNode, inputDf: DataFrame,
308+
unionSparkCol: Option[Column]): DataFrame = {
293309

294310
// Repartition the table with window keys
295311
val repartitionCols = PhysicalNodeUtil.getRepartitionColumns(windowAggNode, inputDf)
@@ -302,9 +318,12 @@ object WindowAggPlan {
302318
// Sort with the window orderby keys
303319
val orderbyCols = PhysicalNodeUtil.getOrderbyColumns(windowAggNode, inputDf)
304320

321+
val sortedDf = if (unionSparkCol.isEmpty) {
322+
repartitionDf.sortWithinPartitions(repartitionCols ++ orderbyCols: _*)
323+
} else {
324+
repartitionDf.sortWithinPartitions(repartitionCols ++ orderbyCols ++ Array(unionSparkCol.get): _*)
325+
}
305326
// Notice that we should make sure the keys in the same partition are ordering as well
306-
val sortedDf = repartitionDf.sortWithinPartitions(repartitionCols ++ orderbyCols: _*)
307-
308327
sortedDf
309328
}
310329

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/window/WindowAggPlanUtil.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import com._4paradigm.openmldb.batch.utils.{HybridseUtil, SparkColumnUtil, Spark
2424
import com._4paradigm.openmldb.batch.{OpenmldbBatchConfig, PlanContext, SparkInstance}
2525
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
2626
import org.apache.hadoop.fs.FileSystem
27+
import org.apache.spark.sql.functions.col
2728
import org.apache.spark.sql.{DataFrame, functions}
2829
import org.apache.spark.sql.types.{LongType, StructType}
2930
import org.apache.spark.util.SerializableConfiguration
@@ -48,10 +49,10 @@ object WindowAggPlanUtil {
4849
*/
4950
def windowUnionTables(ctx: PlanContext,
5051
physicalNode: PhysicalWindowAggrerationNode,
51-
inputDf: DataFrame): DataFrame = {
52+
inputDf: DataFrame,
53+
uniqueColName: String): DataFrame = {
5254

5355
val isKeepIndexColumn = SparkInstance.keepIndexColumn(ctx, physicalNode.GetNodeId())
54-
val uniqueColName = "_WINDOW_UNION_FLAG_" + System.currentTimeMillis()
5556
val unionNum = physicalNode.window_unions().GetSize().toInt
5657

5758
val rightTables = (0 until unionNum).map(i => {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch.end2end
18+
19+
import com._4paradigm.openmldb.batch.SparkTestSuite
20+
import com._4paradigm.openmldb.batch.api.OpenmldbSession
21+
import org.apache.spark.sql.Row
22+
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
23+
24+
class TestWindowUnionWithSameTimestamp extends SparkTestSuite {
25+
26+
test("Test window union with same timestamp") {
27+
28+
val spark = getSparkSession
29+
val sess = new OpenmldbSession(spark)
30+
31+
val data = Seq[Row](
32+
Row(1, 1L)
33+
)
34+
val schema = StructType(List(
35+
StructField("int_col", IntegerType),
36+
StructField("long_col", LongType)
37+
))
38+
val df = spark.createDataFrame(spark.sparkContext.makeRDD(data), schema)
39+
sess.registerTable("t1", df)
40+
41+
val data2 = Seq[Row](
42+
Row(1, 1L),
43+
Row(1, 1L)
44+
)
45+
val schema2 = StructType(List(
46+
StructField("int_col", IntegerType),
47+
StructField("long_col", LongType)
48+
))
49+
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema2)
50+
sess.registerTable("t2", df2)
51+
52+
val sqlText =
53+
"""
54+
| SELECT count(int_col) OVER w
55+
| FROM t1
56+
| WINDOW w AS (UNION t2 PARTITION BY int_col ORDER BY long_col ROWS BETWEEN 10 PRECEDING AND CURRENT ROW)
57+
|""".stripMargin
58+
59+
val outputDf = sess.sql(sqlText)
60+
val outputRow = outputDf.collect()(0)
61+
// The output of count(int_col) should contain current row from primary table and other rows from union tables
62+
assert(outputRow.getLong(0) == 3)
63+
}
64+
65+
}

0 commit comments

Comments
 (0)