Skip to content

Commit 2e7fb76

Browse files
authored
feat: fix unsaferow for window (#1298)
* Support config of debugShowNodeDf for openmldb-batch * Add rowToString in spark row util and add unit tests * Update offset for string and non-string columns in row codec * Add unit tests for unsafe row opt * Add data util for openmldb-batch unit tests
1 parent 5cc52af commit 2e7fb76

File tree

11 files changed

+239
-64
lines changed

11 files changed

+239
-64
lines changed

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ configure: thirdparty-fast
118118

119119
openmldb-clean:
120120
rm -rf "$(OPENMLDB_BUILD_DIR)"
121+
@cd java && ./mvnw clean
121122

122123
THIRD_PARTY_BUILD_DIR ?= $(MAKEFILE_DIR)/.deps
123124
THIRD_PARTY_SRC_DIR ?= $(MAKEFILE_DIR)/thirdsrc

hybridse/src/codec/fe_row_codec.cc

+5
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,11 @@ RowFormat::RowFormat(const hybridse::codec::Schema* schema)
893893
next_str_pos_.insert(
894894
std::make_pair(string_field_cnt, string_field_cnt));
895895
string_field_cnt += 1;
896+
897+
if (FLAGS_enable_spark_unsaferow_format) {
898+
// For UnsafeRowOpt, the offset should be added for string and non-string columns
899+
offset += 8;
900+
}
896901
} else {
897902
auto TYPE_SIZE_MAP = codec::GetTypeSizeMap();
898903
auto it = TYPE_SIZE_MAP.find(column.type());

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/OpenmldbBatchConfig.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class OpenmldbBatchConfig extends Serializable {
2929
@ConfigOption(name = "spark.sql.session.timeZone")
3030
var timeZone = "Asia/Shanghai"
3131

32-
// test mode 用于测试的时候验证相关问题
32+
// test mode
3333
@ConfigOption(name = "openmldb.test.tiny", doc = "控制读取表的数据条数,默认读全量数据")
3434
var tinyData: Long = -1
3535

@@ -42,6 +42,10 @@ class OpenmldbBatchConfig extends Serializable {
4242
@ConfigOption(name = "openmldb.test.print", doc = "执行过程中允许打印数据")
4343
var print: Boolean = false
4444

45+
// Debug options
46+
@ConfigOption(name = "openmldb.debug.show_node_df", doc = "Use Spark DataFrame.show() for each physical nodes")
47+
var debugShowNodeDf: Boolean = false
48+
4549
// Window skew optimization
4650
@ConfigOption(name = "openmldb.window.skew.opt", doc = "Enable window skew optimization or not")
4751
var enableWindowSkewOpt: Boolean = false
@@ -66,7 +70,6 @@ class OpenmldbBatchConfig extends Serializable {
6670
@ConfigOption(name = "openmldb.window.skew.opt.config", doc = "The skew config for window skew optimization")
6771
var windowSkewOptConfig: String = ""
6872

69-
// 慢速执行模式
7073
@ConfigOption(name = "openmldb.slowRunCacheDir", doc =
7174
"""
7275
| Slow run mode cache directory path. If specified, run OpenMLDB plan with slow mode.

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkPlanner.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,13 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
273273
// Set the output to context cache
274274
ctx.putPlanResult(root.GetNodeId(), outputSpatkInstance)
275275

276+
if (config.debugShowNodeDf) {
277+
logger.warn(s"Debug and print DataFrame of nodeId: ${root.GetNodeId()}, nodeType: ${root.GetTypeName()}")
278+
outputSpatkInstance.getDf().show()
279+
}
276280
outputSpatkInstance
277281
}
278282

279-
280283
/**
281284
* Run plan slowly by storing and loading each intermediate result from external data path.
282285
*/

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/WindowAggPlan.scala

-6
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ object WindowAggPlan {
279279
val skewGroups = config.groupIdxs :+ config.partIdIdx
280280
computer.resetGroupKeyComparator(skewGroups)
281281
}
282-
if (sqlConfig.print) {
283-
logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}")
284-
}
285282

286283
val resIter = if (sqlConfig.enableWindowSkewOpt) {
287284
limitInputIter.flatMap(zippedRow => {
@@ -341,9 +338,6 @@ object WindowAggPlan {
341338
val skewGroups = config.groupIdxs :+ config.partIdIdx
342339
computer.resetGroupKeyComparator(skewGroups)
343340
}
344-
if (sqlConfig.print) {
345-
logger.info(s"windowAggIter mode: ${sqlConfig.enableWindowSkewOpt}")
346-
}
347341

348342
val resIter = if (sqlConfig.enableWindowSkewOpt) {
349343
limitInputIter.flatMap(row => {

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/SparkRowUtil.scala

+45-8
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,65 @@
1616

1717
package com._4paradigm.openmldb.batch.utils
1818

19-
import com._4paradigm.hybridse.sdk.{HybridSeException, UnsupportedHybridSeException}
19+
import com._4paradigm.hybridse.sdk.HybridSeException
2020
import com._4paradigm.openmldb.proto.Type
2121
import org.apache.spark.sql.Row
22-
import org.apache.spark.sql.types.{
23-
BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType,
24-
ShortType, StringType, TimestampType
25-
}
22+
import org.apache.spark.sql.types.{BooleanType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType,
23+
ShortType, StringType, StructType, TimestampType}
2624

2725
object SparkRowUtil {
2826

29-
def getLongFromIndex(keyIdx: Int, sparkType: DataType, row: Row): java.lang.Long = {
27+
def rowToString(schema: StructType, row: Row): String = {
28+
val rowStr = new StringBuilder("Spark row: ")
29+
(0 until schema.size).foreach(i => {
30+
if (i == 0) {
31+
rowStr ++= s"${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}"
32+
} else {
33+
rowStr ++= s", ${schema(i).dataType}: ${getColumnStringValue(schema, row, i)}"
34+
}
35+
})
36+
rowStr.toString()
37+
}
38+
39+
/**
40+
* Get the string value of the specified column.
41+
*
42+
* @param row
43+
* @param index
44+
* @return
45+
*/
46+
def getColumnStringValue(schema: StructType, row: Row, index: Int): String = {
47+
if (row.isNullAt(index)) {
48+
"null"
49+
} else {
50+
val colType = schema(index).dataType
51+
colType match {
52+
case BooleanType => String.valueOf(row.getBoolean(index))
53+
case ShortType => String.valueOf(row.getShort(index))
54+
case DoubleType => String.valueOf(row.getDouble(index))
55+
case IntegerType => String.valueOf(row.getInt(index))
56+
case LongType => String.valueOf(row.getLong(index))
57+
case TimestampType => String.valueOf(row.getTimestamp(index))
58+
case DateType => String.valueOf(row.getDate(index))
59+
case StringType => row.getString(index)
60+
case _ =>
61+
throw new HybridSeException(s"Unsupported data type: $colType")
62+
}
63+
}
64+
}
65+
66+
def getLongFromIndex(keyIdx: Int, colType: DataType, row: Row): java.lang.Long = {
3067
if (row.isNullAt(keyIdx)) {
3168
null
3269
} else {
33-
sparkType match {
70+
colType match {
3471
case ShortType => row.getShort(keyIdx).toLong
3572
case IntegerType => row.getInt(keyIdx).toLong
3673
case LongType => row.getLong(keyIdx)
3774
case TimestampType => row.getTimestamp(keyIdx).getTime
3875
case DateType => row.getDate(keyIdx).getTime
3976
case _ =>
40-
throw new HybridSeException(s"Illegal window key type: $sparkType")
77+
throw new HybridSeException(s"Illegal window key type: $colType")
4178
}
4279
}
4380
}
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,9 @@
1-
### set log levels ###
2-
log4j.rootLogger=stdout,warn,error
31

4-
# console log
2+
log4j.rootLogger=WARN, stdout
3+
4+
# Console log
55
log4j.appender.stdout = org.apache.log4j.ConsoleAppender
66
log4j.appender.stdout.Target = System.out
7-
log4j.appender.stdout.Threshold = INFO
87
log4j.appender.stdout.layout = org.apache.log4j.PatternLayout
98
log4j.appender.stdout.Encoding=UTF-8
10-
log4j.appender.stdout.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
11-
12-
#info log
13-
log4j.logger.info=info
14-
log4j.appender.info=org.apache.log4j.DailyRollingFileAppender
15-
log4j.appender.info.DatePattern='_'yyyy-MM-dd'.log'
16-
log4j.appender.info.File=logs/info.log
17-
log4j.appender.info.Append=true
18-
log4j.appender.info.Threshold=INFO
19-
log4j.appender.info.Encoding=UTF-8
20-
log4j.appender.info.layout=org.apache.log4j.PatternLayout
21-
log4j.appender.info.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
22-
#debugs log
23-
log4j.logger.debug=debug
24-
log4j.appender.debug=org.apache.log4j.DailyRollingFileAppender
25-
log4j.appender.debug.DatePattern='_'yyyy-MM-dd'.log'
26-
log4j.appender.debug.File=logs/debug.log
27-
log4j.appender.debug.Append=true
28-
log4j.appender.debug.Threshold=DEBUG
29-
log4j.appender.debug.Encoding=UTF-8
30-
log4j.appender.debug.layout=org.apache.log4j.PatternLayout
31-
log4j.appender.debug.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
32-
#warn log
33-
log4j.logger.warn=warn
34-
log4j.appender.warn=org.apache.log4j.DailyRollingFileAppender
35-
log4j.appender.warn.DatePattern='_'yyyy-MM-dd'.log'
36-
log4j.appender.warn.File=logs/warn.log
37-
log4j.appender.warn.Append=true
38-
log4j.appender.warn.Threshold=WARN
39-
log4j.appender.warn.Encoding=UTF-8
40-
log4j.appender.warn.layout=org.apache.log4j.PatternLayout
41-
log4j.appender.warn.layout.ConversionPattern= %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
42-
#error
43-
log4j.logger.error=error
44-
log4j.appender.error = org.apache.log4j.DailyRollingFileAppender
45-
log4j.appender.error.DatePattern='_'yyyy-MM-dd'.log'
46-
log4j.appender.error.File = logs/error.log
47-
log4j.appender.error.Append = true
48-
log4j.appender.error.Threshold = ERROR
49-
log4j.appender.error.Encoding=UTF-8
50-
log4j.appender.error.layout = org.apache.log4j.PatternLayout
51-
log4j.appender.error.layout.ConversionPattern = %d{yyyy-MM-dd HH:mm:ss} [ %c.%M(%F:%L) ] - [ %p ] %m%n
9+
log4j.appender.stdout.layout.ConversionPattern = %c.%M(%F:%L) - %p: %m%n
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch.end2end
18+
19+
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
20+
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
21+
22+
object DataUtil {
23+
24+
def getStringDf(spark: SparkSession): DataFrame = {
25+
val data = Seq(
26+
Row(1, "abc", 100)
27+
)
28+
val schema = StructType(List(
29+
StructField("int_col", IntegerType),
30+
StructField("str_col", StringType),
31+
StructField("int_col2", IntegerType)
32+
))
33+
spark.createDataFrame(spark.sparkContext.makeRDD(data), schema)
34+
}
35+
36+
def getTestDf(spark: SparkSession): DataFrame = {
37+
val data = Seq(
38+
Row(1, "tom", 100L, 1),
39+
Row(2, "tom", 200L, 2),
40+
Row(3, "tom", 300L, 3),
41+
Row(4, "amy", 400L, 4),
42+
Row(5, "amy", 500L, 5))
43+
val schema = StructType(List(
44+
StructField("id", IntegerType),
45+
StructField("name", StringType),
46+
StructField("trans_amount", LongType),
47+
StructField("trans_time", IntegerType)))
48+
spark.createDataFrame(spark.sparkContext.makeRDD(data), schema)
49+
}
50+
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch.end2end.unsafe
18+
19+
import com._4paradigm.openmldb.batch.SparkTestSuite
20+
import com._4paradigm.openmldb.batch.api.OpenmldbSession
21+
import com._4paradigm.openmldb.batch.end2end.DataUtil
22+
import com._4paradigm.openmldb.batch.utils.SparkUtil
23+
24+
class TestUnsafeProject extends SparkTestSuite {
25+
26+
override def customizedBefore(): Unit = {
27+
val spark = getSparkSession
28+
spark.conf.set("spark.openmldb.unsaferow.opt", true)
29+
}
30+
31+
test("Test unsafe project") {
32+
val spark = getSparkSession
33+
val sess = new OpenmldbSession(spark)
34+
35+
val df = DataUtil.getStringDf(spark)
36+
sess.registerTable("t1", df)
37+
df.createOrReplaceTempView("t1")
38+
39+
val sqlText = "SELECT int_col, int_col2 + 1000 FROM t1"
40+
41+
val outputDf = sess.sql(sqlText)
42+
val sparksqlOutputDf = sess.sparksql(sqlText)
43+
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
44+
45+
}
46+
47+
override def customizedAfter(): Unit = {
48+
val spark = getSparkSession
49+
spark.conf.set("spark.openmldb.unsaferow.opt", false)
50+
}
51+
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batch.end2end.unsafe
18+
19+
import com._4paradigm.openmldb.batch.SparkTestSuite
20+
import com._4paradigm.openmldb.batch.api.OpenmldbSession
21+
import com._4paradigm.openmldb.batch.end2end.DataUtil
22+
import com._4paradigm.openmldb.batch.utils.SparkUtil
23+
24+
class TestUnsafeWindow extends SparkTestSuite {
25+
26+
override def customizedBefore(): Unit = {
27+
val spark = getSparkSession
28+
spark.conf.set("spark.openmldb.unsaferow.opt", true)
29+
}
30+
31+
test("Test unsafe window") {
32+
val spark = getSparkSession
33+
val sess = new OpenmldbSession(spark)
34+
35+
val df = DataUtil.getTestDf(spark)
36+
sess.registerTable("t1", df)
37+
df.createOrReplaceTempView("t1")
38+
39+
val sqlText ="""
40+
| SELECT id, sum(trans_amount) OVER w AS w_sum_amount FROM t1
41+
| WINDOW w AS (
42+
| PARTITION BY id
43+
| ORDER BY trans_time
44+
| ROWS BETWEEN 10 PRECEDING AND CURRENT ROW);
45+
""".stripMargin
46+
47+
val outputDf = sess.sql(sqlText)
48+
val sparksqlOutputDf = sess.sparksql(sqlText)
49+
assert(SparkUtil.approximateDfEqual(outputDf.getSparkDf(), sparksqlOutputDf, false))
50+
}
51+
52+
override def customizedAfter(): Unit = {
53+
val spark = getSparkSession
54+
spark.conf.set("spark.openmldb.unsaferow.opt", false)
55+
}
56+
57+
}

0 commit comments

Comments
 (0)