Skip to content

Commit

Permalink
[SPARK-49829][SS] Revise the optimization on adding input to state st…
Browse files Browse the repository at this point in the history
…ore in stream-stream join (correctness fix)
  • Loading branch information
HeartSaVioR committed Sep 30, 2024
1 parent 0c19059 commit 12ba515
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,37 @@ case class StreamingSymmetricHashJoinExec(
private val iteratorNotEmpty: Boolean = super.hasNext

override def completion(): Unit = {
val isLeftSemiWithMatch =
joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty
// Add to state store only if both removal predicates do not match,
// and the row is not matched for left side of left semi join.
val shouldAddToState =
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) &&
!isLeftSemiWithMatch
// The criteria of whether the input has to be added into state store or not:
// - Left side: input can be skipped to be added to the state store if it's already matched
// and the join type is left semi.
// For other cases, the input should be added, including the case it's going to be evicted
// in this batch. It hasn't yet evaluated with inputs from right side "for this batch".
// Learn about how the each side figures out the matches from other side.
// - Right side: for this side, the evaluation with inputs from left side "for this batch"
// is done at this point. That said, input can be skipped to be added to the state store
// if input is going to be evicted in this batch. Though, input should be added to the
// state store if it's right outer join or full outer join, as unmatched output is
// handled during state eviction.
val isLeftSemiWithMatch = joinType == LeftSemi && joinSide == LeftSide && iteratorNotEmpty
val shouldAddToState = if (isLeftSemiWithMatch) {
false
} else if (joinSide == LeftSide) {
true
} else {
// joinSide == RightSide

// if the input is not evicted in this batch (hence need to be persisted)
val isNotEvictingInThisBatch =
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)

isNotEvictingInThisBatch ||
// if the input is producing "unmatched row" in this batch
(
(joinType == RightOuter && !iteratorNotEmpty) ||
(joinType == FullOuter && !iteratorNotEmpty)
)
}

if (shouldAddToState) {
joinStateManager.append(key, thisRow, matched = iteratorNotEmpty)
updatedStateRowsCount += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,60 @@ class MultiStatefulOperatorsSuite
testOutputWatermarkInJoin(join3, input1, -40L * 1000 - 1)
}

// NOTE: This is the revise of the reproducer in SPARK-45637. CREDIT goes to @andrezjzera.
test("SPARK-49829 time window agg per each source followed by stream-stream join") {
val inputStream1 = MemoryStream[Long]
val inputStream2 = MemoryStream[Long]

val df1 = inputStream1.toDF()
.selectExpr("value", "timestamp_seconds(value) AS ts")
.withWatermark("ts", "5 seconds")

val df2 = inputStream2.toDF()
.selectExpr("value", "timestamp_seconds(value) AS ts")
.withWatermark("ts", "5 seconds")

val df1Window = df1.groupBy(
window($"ts", "10 seconds")
).agg(sum("value").as("sum_df1"))

val df2Window = df2.groupBy(
window($"ts", "10 seconds")
).agg(sum("value").as("sum_df2"))

val joined = df1Window.join(df2Window, "window", "inner")
.selectExpr("CAST(window.end AS long) AS window_end", "sum_df1", "sum_df2")

// The test verifies the case where both sides produce input as time window (append mode)
// for stream-stream join having join condition for equality of time window.
// Inputs are produced into stream-stream join when the time windows are completed, meaning
// they will be evicted in this batch for stream-stream join as well. (NOTE: join condition
// does not delay the state watermark in stream-stream join).
// Before SPARK-49829, left side does not add the input to state store if it's going to evict
// in this batch, which breaks the match between input from left side and input from right
// side "for this batch".
testStream(joined)(
MultiAddData(
(inputStream1, Seq(1L, 2L, 3L, 4L, 5L)),
(inputStream2, Seq(5L, 6L, 7L, 8L, 9L))
),
// watermark: 5 - 5 = 0
CheckNewAnswer(),
MultiAddData(
(inputStream1, Seq(11L, 12L, 13L, 14L, 15L)),
(inputStream2, Seq(15L, 16L, 17L, 18L, 19L)),
),
// watermark: 15 - 5 = 10 (windows for [0, 10) are completed)
CheckNewAnswer((10L, 15L, 35L)),
MultiAddData(
(inputStream1, Seq(100L)),
(inputStream2, Seq(101L)),
),
// watermark: 100 - 5 = 95 (windows for [0, 20) are completed)
CheckNewAnswer((20L, 65L, 85L))
)
}

private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q =>
q.processAllAvailable()
val progressWithData = q.recentProgress.lastOption.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId}
Expand Down Expand Up @@ -223,6 +224,28 @@ abstract class StreamingJoinSuite

(inputStream, select)
}


protected def assertStateStoreRows(
opId: Long,
joinSide: String,
expectedRows: Seq[Row])(projFn: DataFrame => DataFrame): AssertOnQuery = Execute { q =>
val checkpointLoc = q.resolvedCheckpointRoot

// just make sure the query have no leftover data
q.processAllAvailable()

// By default, it reads the state store from lastest committed batch.
// (Do we have a case where we have to look back on testing?)
val stateStoreDf = spark.read.format("statestore")
.option(StateSourceOptions.JOIN_SIDE, joinSide)
.option(StateSourceOptions.PATH, checkpointLoc)
.option(StateSourceOptions.OPERATOR_ID, opId)
.load()

val projectedDf = projFn(stateStoreDf)
checkAnswer(projectedDf, expectedRows)
}
}

@SlowSQLTest
Expand Down Expand Up @@ -1559,6 +1582,60 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite {
)
}
}

test("SPARK-49829 left-outer join, input being unmatched is between WM for late event and " +
"WM for eviction") {

withTempDir { checkpoint =>
// This config needs to be set, otherwise no-data batch will be triggered and after
// no-data batch, WM for late event and WM for eviction would be same.
withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") {
val memoryStream1 = MemoryStream[(String, Int)]
val memoryStream2 = MemoryStream[(String, Int)]

val data1 = memoryStream1.toDF()
.selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime")
.withWatermark("eventTime", "0 seconds")
val data2 = memoryStream2.toDF()
.selectExpr("_1 AS key", "timestamp_seconds(_2) AS eventTime")
.withWatermark("eventTime", "0 seconds")

val joinedDf = data1.join(data2, Seq("key", "eventTime"), "leftOuter")
.selectExpr("key", "CAST(eventTime AS long) AS eventTime")

testStream(joinedDf)(
StartStream(checkpointLocation = checkpoint.getCanonicalPath),
// batch 0
// WM: late record = 0, eviction = 0
MultiAddData(
(memoryStream1, Seq(("a", 1), ("b", 2))),
(memoryStream2, Seq(("b", 2), ("c", 1)))
),
CheckNewAnswer(("b", 2)),
assertStateStoreRows(0L, "left", Seq(Row("a", 1), Row("b", 2))) { df =>
df.selectExpr("value.key", "CAST(value.eventTime AS long)")
},
assertStateStoreRows(0L, "right", Seq(Row("b", 2), Row("c", 1))) { df =>
df.selectExpr("value.key", "CAST(value.eventTime AS long)")
},
// batch 1
// WM: late record = 0, eviction = 2
// Before Spark introduces multiple stateful operator, WM for late record was same as
// WM for eviction, hence ("d", 1) was treated as late record.
// With the multiple state operator, ("d", 1) is added in batch 1 but also evicted in
// batch 1. Before SPARK-49829, this wasn't producing unmatched row, and it is fixed.
AddData(memoryStream1, ("d", 1)),
CheckNewAnswer(("a", 1), ("d", 1)),
assertStateStoreRows(0L, "left", Seq()) { df =>
df.selectExpr("value.key", "CAST(value.eventTime AS long)")
},
assertStateStoreRows(0L, "right", Seq()) { df =>
df.selectExpr("value.key", "CAST(value.eventTime AS long)")
}
)
}
}
}
}

@SlowSQLTest
Expand Down Expand Up @@ -1966,4 +2043,130 @@ class StreamingLeftSemiJoinSuite extends StreamingJoinSuite {
assertNumStateRows(total = 9, updated = 4)
)
}

// scalastyle:off line.size.limit
// DISCLAIM: This is a revision of below test, which was a part of report in the dev mailing
// list. CREDIT goes to @andrezjzera.
// https://github.com/andrzejzera/spark-bugs/blob/abae7a3839326a8eafc7516a51aca5e0c79282a6/spark-3.5/src/test/scala/OuterJoinTest.scala#L86C3-L167C4
// scalastyle:on
test("SPARK-49829 two chained stream-stream left outer joins among three input streams") {
withSQLConf(SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false") {
val memoryStream1 = MemoryStream[(Long, Int)]
val memoryStream2 = MemoryStream[(Long, Int)]
val memoryStream3 = MemoryStream[(Long, Int)]

val data1 = memoryStream1.toDF()
.selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v1")
.withWatermark("eventTime", "0 seconds")
val data2 = memoryStream2.toDF()
.selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v2")
.withWatermark("eventTime", "0 seconds")
val data3 = memoryStream3.toDF()
.selectExpr("timestamp_seconds(_1) AS eventTime", "_2 AS v3")
.withWatermark("eventTime", "0 seconds")

val join = data1
.join(data2, Seq("eventTime"), "leftOuter")
.join(data3, Seq("eventTime"), "leftOuter")
.selectExpr("CAST(eventTime AS long) AS eventTime", "v1", "v2", "v3")

def assertLeftRowsFor1stJoin(expected: Seq[Row]): AssertOnQuery = {
assertStateStoreRows(1L, "left", expected) { df =>
df.selectExpr("CAST(value.eventTime AS long)", "value.v1")
}
}

def assertRightRowsFor1stJoin(expected: Seq[Row]): AssertOnQuery = {
assertStateStoreRows(1L, "right", expected) { df =>
df.selectExpr("CAST(value.eventTime AS long)", "value.v2")
}
}

def assertLeftRowsFor2ndJoin(expected: Seq[Row]): AssertOnQuery = {
assertStateStoreRows(0L, "left", expected) { df =>
df.selectExpr("CAST(value.eventTime AS long)", "value.v1", "value.v2")
}
}

def assertRightRowsFor2ndJoin(expected: Seq[Row]): AssertOnQuery = {
assertStateStoreRows(0L, "right", expected) { df =>
df.selectExpr("CAST(value.eventTime AS long)", "value.v3")
}
}

testStream(join)(
// batch 0
// WM: late event = 0, eviction = 0
MultiAddData(
(memoryStream1, Seq((20L, 1))),
(memoryStream2, Seq((20L, 1))),
(memoryStream3, Seq((20L, 1)))
),
CheckNewAnswer((20, 1, 1, 1)),
assertLeftRowsFor1stJoin(Seq(Row(20, 1))),
assertRightRowsFor1stJoin(Seq(Row(20, 1))),
assertLeftRowsFor2ndJoin(Seq(Row(20, 1, 1))),
assertRightRowsFor2ndJoin(Seq(Row(20, 1))),
// batch 1
// WM: late event = 0, eviction = 20
MultiAddData(
(memoryStream1, Seq((21L, 2))),
(memoryStream2, Seq((21L, 2)))
),
CheckNewAnswer(),
assertLeftRowsFor1stJoin(Seq(Row(21, 2))),
assertRightRowsFor1stJoin(Seq(Row(21, 2))),
assertLeftRowsFor2ndJoin(Seq(Row(21, 2, 2))),
assertRightRowsFor2ndJoin(Seq()),
// batch 2
// WM: late event = 20, eviction = 20 (slowest: inputStream3)
MultiAddData(
(memoryStream1, Seq((22L, 3))),
(memoryStream3, Seq((22L, 3))),
),
CheckNewAnswer(),
assertLeftRowsFor1stJoin(Seq(Row(21, 2), Row(22, 3))),
assertRightRowsFor1stJoin(Seq(Row(21, 2))),
assertLeftRowsFor2ndJoin(Seq(Row(21, 2, 2))),
assertRightRowsFor2ndJoin(Seq(Row(22, 3))),
// batch 3
// WM: late event = 20, eviction = 21 (slowest: inputStream2)
AddData(memoryStream1, (23L, 4)),
CheckNewAnswer(Row(21, 2, 2, null)),
assertLeftRowsFor1stJoin(Seq(Row(22, 3), Row(23, 4))),
assertRightRowsFor1stJoin(Seq()),
assertLeftRowsFor2ndJoin(Seq()),
assertRightRowsFor2ndJoin(Seq(Row(22, 3))),
// batch 4
// WM: late event = 21, eviction = 21 (slowest: inputStream2)
MultiAddData(
(memoryStream1, Seq((24L, 5))),
(memoryStream2, Seq((24L, 5))),
(memoryStream3, Seq((24L, 5)))
),
CheckNewAnswer(Row(24, 5, 5, 5)),
assertLeftRowsFor1stJoin(Seq(Row(22, 3), Row(23, 4), Row(24, 5))),
assertRightRowsFor1stJoin(Seq(Row(24, 5))),
assertLeftRowsFor2ndJoin(Seq(Row(24, 5, 5))),
assertRightRowsFor2ndJoin(Seq(Row(22, 3), Row(24, 5))),
// batch 5
// WM: late event = 21, eviction = 24
// just trigger a new batch with arbitrary data as the original test relies on no-data
// batch, and we need to check with remaining unmatched outputs
AddData(memoryStream1, (100L, 6)),
CheckNewAnswer(Row(22, 3, null, 3), Row(23, 4, null, null))
)

/*
// The collection of the above new answers is the same with below in original test:
val expected = Array(
Row(Timestamp.valueOf("2024-02-10 10:20:00"), 1, 1, 1),
Row(Timestamp.valueOf("2024-02-10 10:21:00"), 2, 2, null),
Row(Timestamp.valueOf("2024-02-10 10:22:00"), 3, null, 3),
Row(Timestamp.valueOf("2024-02-10 10:23:00"), 4, null, null),
Row(Timestamp.valueOf("2024-02-10 10:24:00"), 5, 5, 5),
)
*/
}
}
}

0 comments on commit 12ba515

Please sign in to comment.