Skip to content

Commit

Permalink
Added conf and multi-version tests in FlatMapGroupsWithStateSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Jul 10, 2018
1 parent 9525484 commit c9f600b
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,14 @@ object SQLConf {
.intConf
.createWithDefault(10)

val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
.doc("State format version used by flatMapGroupsWithState operation in a streaming query")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
.doc("The default location for storing checkpoint data for streaming queries.")
.stringConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class ObjectType(cls: Class[_]) extends DataType {

def asNullable: DataType = this

override def simpleString: String = s"Object[${cls.getName}]"
override def simpleString: String = cls.getName

override def acceptsType(other: DataType): Boolean = other match {
case ObjectType(otherCls) => cls.isAssignableFrom(otherCls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case FlatMapGroupsWithState(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode,
timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
execPlan :: Nil
case _ =>
Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ case class FlatMapGroupsWithStateExec(
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
stateFormatVersion: Int,
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
Expand All @@ -65,7 +66,8 @@ case class FlatMapGroupsWithStateExec(
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled, 2)
private[sql] val stateManager =
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)

/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.streaming.state

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, Expression, GenericInternalRow, GetStructField, If, IsNull, Literal, SpecificInternalRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.ObjectOperator
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
Expand All @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._

object FlatMapGroupsWithStateExecHelper {

val DEFAULT_STATE_MANAGER_VERSION = 2
val supportedVersions = Seq(1, 2)

/**
* Class to capture deserialized state and timestamp return by the state manager.
Expand Down Expand Up @@ -58,24 +58,26 @@ object FlatMapGroupsWithStateExecHelper {
def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timeoutTimestamp: Long): Unit
def removeState(store: StateStore, keyRow: UnsafeRow): Unit
def getAllState(store: StateStore): Iterator[StateData]
def version: Int
}

def createStateManager(
stateEncoder: ExpressionEncoder[Any],
shouldStoreTimestamp: Boolean,
version: Int): StateManager = {
version match {
stateFormatVersion: Int): StateManager = {
stateFormatVersion match {
case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp)
case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp)
case _ => throw new IllegalArgumentException(s"Version $version")
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
}
}

// ===============================================================================================
// =========================== Private implementations of StateManager ===========================
// ===============================================================================================

private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean) extends StateManager {
private abstract class StateManagerImplBase(val version: Int, shouldStoreTimestamp: Boolean)
extends StateManager {

protected def stateSerializerExprs: Seq[Expression]
protected def stateDeserializerExpr: Expression
Expand Down Expand Up @@ -135,7 +137,7 @@ object FlatMapGroupsWithStateExecHelper {

private class StateManagerImplV1(
stateEncoder: ExpressionEncoder[Any],
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) {
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(1, shouldStoreTimestamp) {

private val timestampTimeoutAttribute =
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
Expand Down Expand Up @@ -175,7 +177,7 @@ object FlatMapGroupsWithStateExecHelper {

private class StateManagerImplV2(
stateEncoder: ExpressionEncoder[Any],
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(shouldStoreTimestamp) {
shouldStoreTimestamp: Boolean) extends StateManagerImplBase(2, shouldStoreTimestamp) {

/** Schema of the state rows saved in the state store */
override val stateSchema: StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.state

import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.sql.{Encoder, QueryTest}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._
import org.apache.spark.sql.streaming.StreamTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -601,7 +602,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
expectedState = Some(5), // state should change
expectedTimeoutTimestamp = 5000) // timestamp should change

test("flatMapGroupsWithState - streaming") {
testWithAllStateVersions("flatMapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count if state is defined, otherwise does not return anything
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
Expand Down Expand Up @@ -680,7 +681,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
)
}

test("flatMapGroupsWithState - streaming + aggregation") {
testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
Expand Down Expand Up @@ -739,7 +740,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
}

test("flatMapGroupsWithState - streaming with processing time timeout") {
testWithAllStateVersions("flatMapGroupsWithState - streaming with processing time timeout") {
// Function to maintain the count as state and set the proc. time timeout delay of 10 seconds.
// It returns the count if changed, or -1 if the state was removed by timeout.
val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {
Expand Down Expand Up @@ -803,7 +804,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
)
}

test("flatMapGroupsWithState - streaming with event time timeout + watermark") {
testWithAllStateVersions("flatMapGroupsWithState - streaming with event time timeout") {
// Function to maintain the max event time as state and set the timeout timestamp based on the
// current max event time seen. It returns the max event time in the state, or -1 if the state
// was removed by timeout.
Expand Down Expand Up @@ -1135,7 +1136,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
.logicalPlan.collectFirst {
case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
FlatMapGroupsWithStateExec(
f, k, v, g, d, o, None, s, m, t,
f, k, v, g, d, o, None, s, 2, m, t,
Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd"))
}.get
}
Expand Down Expand Up @@ -1168,6 +1169,16 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
}

def rowToInt(row: UnsafeRow): Int = row.getInt(0)

def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
test(s"$name - state format version $version") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> version.toString) {
func
}
}
}
}
}

object FlatMapGroupsWithStateSuite {
Expand Down

0 comments on commit c9f600b

Please sign in to comment.