Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26154][SS] Streaming left/right outer join should not return outer nulls for already matched rows #23634

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,16 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val STREAMING_JOIN_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.join.stateFormatVersion")
.internal()
.doc("State format version used by streaming join operations in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running.")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
if left.isStreaming && right.isStreaming =>

new StreamingSymmetricHashJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
val stateVersion = conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)

new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, condition,
stateVersion, planLater(left), planLater(right)) :: Nil

case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.json4s.jackson.Serialization
import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}


/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
Expand Down Expand Up @@ -91,7 +91,8 @@ object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
STREAMING_JOIN_STATE_FORMAT_VERSION)

/**
* Default values of relevant configurations that are used for backward compatibility.
Expand All @@ -108,7 +109,9 @@ object OffsetSeqMetadata extends Logging {
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
StreamingAggregationStateManager.legacyVersion.toString
StreamingAggregationStateManager.legacyVersion.toString,
STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
StreamingJoinStateManager.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager._
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}

Expand Down Expand Up @@ -131,6 +133,7 @@ case class StreamingSymmetricHashJoinExec(
stateInfo: Option[StatefulOperatorStateInfo],
eventTimeWatermark: Option[Long],
stateWatermarkPredicates: JoinStateWatermarkPredicates,
stateFormatVersion: Int,
left: SparkPlan,
right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter {

Expand All @@ -139,13 +142,14 @@ case class StreamingSymmetricHashJoinExec(
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
stateFormatVersion: Int,
left: SparkPlan,
right: SparkPlan) = {

this(
leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right),
stateInfo = None, eventTimeWatermark = None,
stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right)
stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right)
}

private def throwBadJoinTypeException(): Nothing = {
Expand Down Expand Up @@ -200,7 +204,8 @@ case class StreamingSymmetricHashJoinExec(

protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
val stateStoreNames = StreamingJoinStateManager.allStateStoreNames(stateFormatVersion,
LeftSide, RightSide)
left.execute().stateStoreAwareZipPartitions(
right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)
}
Expand All @@ -223,7 +228,6 @@ case class StreamingSymmetricHashJoinExec(
val updateStartTimeNs = System.nanoTime
val joinedRow = new JoinedRow


val postJoinFilter =
newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _
val leftSideJoiner = new OneSideHashJoiner(
Expand Down Expand Up @@ -261,7 +265,6 @@ case class StreamingSymmetricHashJoinExec(
val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
(leftOutputIter ++ rightOutputIter), onInnerOutputCompletion)


val outputIter: Iterator[InternalRow] = joinType match {
case Inner =>
innerOutputIter
Expand All @@ -280,10 +283,17 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
}
}

val removedRowIter = leftSideJoiner.removeOldState()
val outerOutputIter = removedRowIter
.filterNot(pair => matchesWithRightSideState(pair))
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
case 2 => kvAndMatched.matched.get
case _ => throw new IllegalStateException("Incorrect state format version! " +
s"version $stateFormatVersion")
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))

innerOutputIter ++ outerOutputIter
case RightOuter =>
Expand All @@ -293,10 +303,17 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
}
}

val removedRowIter = rightSideJoiner.removeOldState()
val outerOutputIter = removedRowIter
.filterNot(pair => matchesWithLeftSideState(pair))
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
case 2 => kvAndMatched.matched.get
case _ => throw new IllegalStateException("Incorrect state format version! " +
s"version $stateFormatVersion")
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

innerOutputIter ++ outerOutputIter
case _ => throwBadJoinTypeException()
Expand Down Expand Up @@ -394,8 +411,10 @@ case class StreamingSymmetricHashJoinExec(
val preJoinFilter =
newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _

private val joinStateManager = new SymmetricHashJoinStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value)
private val joinStateManager = StreamingJoinStateManager.createStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value,
stateFormatVersion)

private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)

private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
Expand Down Expand Up @@ -445,16 +464,11 @@ case class StreamingSymmetricHashJoinExec(
// the case of inner join).
if (preJoinFilter(thisRow)) {
val key = keyGenerator(thisRow)
val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
generateJoinedRow(thisRow, thatRow)
}.filter(postJoinFilter)
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
if (shouldAddToState) {
joinStateManager.append(key, thisRow)
updatedStateRowsCount += 1
}
outputIter

val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager
.getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter)

new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
} else {
joinSide match {
case LeftSide if joinType == LeftOuter =>
Expand All @@ -467,6 +481,31 @@ case class StreamingSymmetricHashJoinExec(
}
}

private class AddingProcessedRowToStateCompletionIterator(
key: UnsafeRow,
thisRow: UnsafeRow,
subIter: Iterator[JoinedRow])
extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) {
private var iteratorNotEmpty: Boolean = false

override def hasNext: Boolean = {
val ret = super.hasNext
if (ret && !iteratorNotEmpty) {
iteratorNotEmpty = true
}
ret
}

override def completion(): Unit = {
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
if (shouldAddToState) {
joinStateManager.append(key, thisRow, matched = iteratorNotEmpty)
updatedStateRowsCount += 1
}
}
}

/**
* Get an iterator over the values stored in this joiner's state manager for the given key.
*
Expand All @@ -486,7 +525,7 @@ case class StreamingSymmetricHashJoinExec(
* We do this to avoid requiring either two passes or full materialization when
* processing the rows for outer join.
*/
def removeOldState(): Iterator[UnsafeRowPair] = {
def removeOldState(): Iterator[KeyToValueAndMatched] = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
Expand Down
Loading