Skip to content

Commit

Permalink
[SPARK-23406][SS] Enable stream-stream self-joins
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Solved two bugs to enable stream-stream self joins.

### Incorrect analysis due to missing MultiInstanceRelation trait
Streaming leaf nodes did not extend MultiInstanceRelation, which is necessary for the catalyst analyzer to convert the self-join logical plan DAG into a tree (by creating new instances of the leaf relations). This was causing the error `Failure when resolving conflicting references in Join:` (see JIRA for details).

### Incorrect attribute rewrite when splicing batch plans in MicroBatchExecution
When splicing the source's batch plan into the streaming plan (by replacing the StreamingExecutionPlan), we were rewriting the attribute reference in the streaming plan with the new attribute references from the batch plan. This was incorrectly handling the scenario when multiple StreamingExecutionRelation point to the same source, and therefore eventually point to the same batch plan returned by the source. Here is an example query, and its corresponding plan transformations.
```
val df = input.toDF
val join =
      df.select('value % 5 as "key", 'value).join(
        df.select('value % 5 as "key", 'value), "key")
```
Streaming logical plan before splicing the batch plan
```
Project [key#6, value#1, value#12]
+- Join Inner, (key#6 = key#9)
   :- Project [(value#1 % 5) AS key#6, value#1]
   :  +- StreamingExecutionRelation Memory[#1], value#1
   +- Project [(value#12 % 5) AS key#9, value#12]
      +- StreamingExecutionRelation Memory[#1], value#12  // two different leaves pointing to same source
```
Batch logical plan after splicing the batch plan and before rewriting
```
Project [key#6, value#1, value#12]
+- Join Inner, (key#6 = key#9)
   :- Project [(value#1 % 5) AS key#6, value#1]
   :  +- LocalRelation [value#66]           // replaces StreamingExecutionRelation Memory[#1], value#1
   +- Project [(value#12 % 5) AS key#9, value#12]
      +- LocalRelation [value#66]           // replaces StreamingExecutionRelation Memory[#1], value#12
```
Batch logical plan after rewriting the attributes. Specifically, for spliced, the new output attributes (value#66) replace the earlier output attributes (value#12, and value#1, one for each StreamingExecutionRelation).
```
Project [key#6, value#66, value#66]       // both value#1 and value#12 replaces by value#66
+- Join Inner, (key#6 = key#9)
   :- Project [(value#66 % 5) AS key#6, value#66]
   :  +- LocalRelation [value#66]
   +- Project [(value#66 % 5) AS key#9, value#66]
      +- LocalRelation [value#66]
```
This causes the optimizer to eliminate value#66 from one side of the join.
```
Project [key#6, value#66, value#66]
+- Join Inner, (key#6 = key#9)
   :- Project [(value#66 % 5) AS key#6, value#66]
   :  +- LocalRelation [value#66]
   +- Project [(value#66 % 5) AS key#9]   // this does not generate value, incorrect join results
      +- LocalRelation [value#66]
```

**Solution**: Instead of rewriting attributes, use a Project to introduce aliases between the output attribute references and the new reference generated by the spliced plans. The analyzer and optimizer will take care of the rest.
```
Project [key#6, value#1, value#12]
+- Join Inner, (key#6 = key#9)
   :- Project [(value#1 % 5) AS key#6, value#1]
   :  +- Project [value#66 AS value#1]   // solution: project with aliases
   :     +- LocalRelation [value#66]
   +- Project [(value#12 % 5) AS key#9, value#12]
      +- Project [value#66 AS value#12]    // solution: project with aliases
         +- LocalRelation [value#66]
```

## How was this patch tested?
New unit test

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes apache#20598 from tdas/SPARK-23406.
  • Loading branch information
tdas committed Feb 14, 2018
1 parent 400a1d9 commit 658d9d9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}

import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
Expand Down Expand Up @@ -415,27 +415,25 @@ class MicroBatchExecution(
}
}

// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val newBatchesPlan = logicalPlan transform {
case StreamingExecutionRelation(source, output) =>
newData.get(source).map { dataPlan =>
assert(output.size == dataPlan.output.size,
s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
s"${Utils.truncatedString(dataPlan.output, ",")}")
replacements ++= output.zip(dataPlan.output)
dataPlan

val aliases = output.zip(dataPlan.output).map { case (to, from) =>
Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata))
}
Project(aliases, dataPlan)
}.getOrElse {
LocalRelation(output, isStreaming = true)
}
}

// Rewire the plan to use the new attributes that were returned by the source.
val replacementMap = AttributeMap(replacements)
val newAttributePlan = newBatchesPlan transformAllExpressions {
case a: Attribute if replacementMap.contains(a) =>
replacementMap(a).withMetadata(a.metadata)
case ct: CurrentTimestamp =>
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
ct.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2}
Expand All @@ -42,7 +42,7 @@ object StreamingRelation {
* passing to [[StreamExecution]] to run a query.
*/
case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
extends LeafNode {
extends LeafNode with MultiInstanceRelation {
override def isStreaming: Boolean = true
override def toString: String = sourceName

Expand All @@ -53,6 +53,8 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes)
)

override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))
}

/**
Expand All @@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
case class StreamingExecutionRelation(
source: BaseStreamingSource,
output: Seq[Attribute])(session: SparkSession)
extends LeafNode {
extends LeafNode with MultiInstanceRelation {

override def isStreaming: Boolean = true
override def toString: String = source.toString
Expand All @@ -74,6 +76,8 @@ case class StreamingExecutionRelation(
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)

override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}

// We have to pack in the V1 data source as a shim, for the case when a source implements
Expand All @@ -92,13 +96,15 @@ case class StreamingRelationV2(
extraOptions: Map[String, String],
output: Seq[Attribute],
v1Relation: Option[StreamingRelation])(session: SparkSession)
extends LeafNode {
extends LeafNode with MultiInstanceRelation {
override def isStreaming: Boolean = true
override def toString: String = sourceName

override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)

override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}

/**
Expand All @@ -108,7 +114,7 @@ case class ContinuousExecutionRelation(
source: ContinuousReadSupport,
extraOptions: Map[String, String],
output: Seq[Attribute])(session: SparkSession)
extends LeafNode {
extends LeafNode with MultiInstanceRelation {

override def isStreaming: Boolean = true
override def toString: String = source.toString
Expand All @@ -120,6 +126,8 @@ case class ContinuousExecutionRelation(
override def computeStats(): Statistics = Statistics(
sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
)

override def newInstance(): LogicalPlan = this.copy(output = output.map(_.newInstance()))(session)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.execution.{FileSourceScanExec, LogicalRDD}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreProviderId}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -323,6 +325,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with
assert(e.toString.contains("Stream stream joins without equality predicate is not supported"))
}

test("stream stream self join") {
val input = MemoryStream[Int]
val df = input.toDF
val join =
df.select('value % 5 as "key", 'value).join(
df.select('value % 5 as "key", 'value), "key")

testStream(join)(
AddData(input, 1, 2),
CheckAnswer((1, 1, 1), (2, 2, 2)),
StopStream,
StartStream(),
AddData(input, 3, 6),
/*
(1, 1) (1, 1)
(2, 2) x (2, 2) = (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6)
(1, 6) (1, 6)
*/
CheckAnswer((3, 3, 3), (1, 1, 1), (1, 1, 6), (2, 2, 2), (1, 6, 1), (1, 6, 6)))
}

test("locality preferences of StateStoreAwareZippedRDD") {
import StreamingSymmetricHashJoinHelper._

Expand Down

0 comments on commit 658d9d9

Please sign in to comment.