Skip to content

Commit

Permalink
SPARK-26187 self left/right outer join should not return outer nulls …
Browse files Browse the repository at this point in the history
…for already matched rows

* Remove unused method
* Address backward compatibility with old state
* Do some refactoring
* address backward compatibility via introducing state format version
* Introduce helper object to deduplicate long-code
* Add 'matched' field in value type of state store instead of adding one more state store
* Normalize names as left & right for join tests
  • Loading branch information
HeartSaVioR committed May 13, 2019
1 parent 126310c commit 976e9ba
Show file tree
Hide file tree
Showing 33 changed files with 1,269 additions and 609 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,16 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val STREAMING_JOIN_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.join.stateFormatVersion")
.internal()
.doc("State format version used by streaming join operations in a streaming query. " +
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running.")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
if left.isStreaming && right.isStreaming =>

new StreamingSymmetricHashJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
val stateVersion = conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)

new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, condition,
stateVersion, planLater(left), planLater(right)) :: Nil

case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.json4s.jackson.Serialization
import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}


/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
Expand Down Expand Up @@ -91,7 +91,8 @@ object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
STREAMING_JOIN_STATE_FORMAT_VERSION)

/**
* Default values of relevant configurations that are used for backward compatibility.
Expand All @@ -108,7 +109,9 @@ object OffsetSeqMetadata extends Logging {
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
StreamingAggregationStateManager.legacyVersion.toString
StreamingAggregationStateManager.legacyVersion.toString,
STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
StreamingJoinStateManager.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager
import org.apache.spark.sql.execution.streaming.state.join.StreamingJoinStateManager._
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}

Expand Down Expand Up @@ -131,6 +133,7 @@ case class StreamingSymmetricHashJoinExec(
stateInfo: Option[StatefulOperatorStateInfo],
eventTimeWatermark: Option[Long],
stateWatermarkPredicates: JoinStateWatermarkPredicates,
stateFormatVersion: Int,
left: SparkPlan,
right: SparkPlan) extends SparkPlan with BinaryExecNode with StateStoreWriter {

Expand All @@ -139,13 +142,14 @@ case class StreamingSymmetricHashJoinExec(
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
stateFormatVersion: Int,
left: SparkPlan,
right: SparkPlan) = {

this(
leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right),
stateInfo = None, eventTimeWatermark = None,
stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right)
stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right)
}

private def throwBadJoinTypeException(): Nothing = {
Expand Down Expand Up @@ -200,7 +204,8 @@ case class StreamingSymmetricHashJoinExec(

protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
val stateStoreNames = StreamingJoinStateManager.allStateStoreNames(stateFormatVersion,
LeftSide, RightSide)
left.execute().stateStoreAwareZipPartitions(
right.execute(), stateInfo.get, stateStoreNames, stateStoreCoord)(processPartitions)
}
Expand All @@ -223,7 +228,6 @@ case class StreamingSymmetricHashJoinExec(
val updateStartTimeNs = System.nanoTime
val joinedRow = new JoinedRow


val postJoinFilter =
newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _
val leftSideJoiner = new OneSideHashJoiner(
Expand Down Expand Up @@ -261,7 +265,6 @@ case class StreamingSymmetricHashJoinExec(
val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
(leftOutputIter ++ rightOutputIter), onInnerOutputCompletion)


val outputIter: Iterator[InternalRow] = joinType match {
case Inner =>
innerOutputIter
Expand All @@ -280,10 +283,17 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue))
}
}

val removedRowIter = leftSideJoiner.removeOldState()
val outerOutputIter = removedRowIter
.filterNot(pair => matchesWithRightSideState(pair))
.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
case 2 => kvAndMatched.matched.get
case _ => throw new IllegalStateException("Incorrect state format version! " +
s"version $stateFormatVersion")
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))

innerOutputIter ++ outerOutputIter
case RightOuter =>
Expand All @@ -293,10 +303,17 @@ case class StreamingSymmetricHashJoinExec(
postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
}
}

val removedRowIter = rightSideJoiner.removeOldState()
val outerOutputIter = removedRowIter
.filterNot(pair => matchesWithLeftSideState(pair))
.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
val outerOutputIter = removedRowIter.filterNot { kvAndMatched =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(
new UnsafeRowPair(kvAndMatched.key, kvAndMatched.value))
case 2 => kvAndMatched.matched.get
case _ => throw new IllegalStateException("Incorrect state format version! " +
s"version $stateFormatVersion")
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

innerOutputIter ++ outerOutputIter
case _ => throwBadJoinTypeException()
Expand Down Expand Up @@ -394,8 +411,10 @@ case class StreamingSymmetricHashJoinExec(
val preJoinFilter =
newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _

private val joinStateManager = new SymmetricHashJoinStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value)
private val joinStateManager = StreamingJoinStateManager.createStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value,
stateFormatVersion)

private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)

private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
Expand Down Expand Up @@ -445,16 +464,11 @@ case class StreamingSymmetricHashJoinExec(
// the case of inner join).
if (preJoinFilter(thisRow)) {
val key = keyGenerator(thisRow)
val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
generateJoinedRow(thisRow, thatRow)
}.filter(postJoinFilter)
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
if (shouldAddToState) {
joinStateManager.append(key, thisRow)
updatedStateRowsCount += 1
}
outputIter

val outputIter: Iterator[JoinedRow] = otherSideJoiner.joinStateManager
.getJoinedRows(key, thatRow => generateJoinedRow(thisRow, thatRow), postJoinFilter)

new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
} else {
joinSide match {
case LeftSide if joinType == LeftOuter =>
Expand All @@ -467,6 +481,31 @@ case class StreamingSymmetricHashJoinExec(
}
}

private class AddingProcessedRowToStateCompletionIterator(
key: UnsafeRow,
thisRow: UnsafeRow,
subIter: Iterator[JoinedRow])
extends CompletionIterator[JoinedRow, Iterator[JoinedRow]](subIter) {
private var iteratorNotEmpty: Boolean = false

override def hasNext: Boolean = {
val ret = super.hasNext
if (ret && !iteratorNotEmpty) {
iteratorNotEmpty = true
}
ret
}

override def completion(): Unit = {
val shouldAddToState = // add only if both removal predicates do not match
!stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow)
if (shouldAddToState) {
joinStateManager.append(key, thisRow, matched = iteratorNotEmpty)
updatedStateRowsCount += 1
}
}
}

/**
* Get an iterator over the values stored in this joiner's state manager for the given key.
*
Expand All @@ -486,7 +525,7 @@ case class StreamingSymmetricHashJoinExec(
* We do this to avoid requiring either two passes or full materialization when
* processing the rows for outer join.
*/
def removeOldState(): Iterator[UnsafeRowPair] = {
def removeOldState(): Iterator[KeyToValueAndMatched] = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc)
Expand Down
Loading

0 comments on commit 976e9ba

Please sign in to comment.