Skip to content

Commit

Permalink
[SPARK-22187][SS] Update unsaferow format for saved state in flatMapG…
Browse files Browse the repository at this point in the history
…roupsWithState to allow timeouts with deleted state

## What changes were proposed in this pull request?

Currently, the group state of user-defined-type is encoded as top-level columns in the UnsafeRows stores in the state store. The timeout timestamp is also saved as (when needed) as the last top-level column. Since the group state is serialized to top-level columns, you cannot save "null" as a value of state (setting null in all the top-level columns is not equivalent). So we don't let the user set the timeout without initializing the state for a key. Based on user experience, this leads to confusion.

This PR is to change the row format such that the state is saved as nested columns. This would allow the state to be set to null, and avoid these confusing corner cases. However, queries recovering from existing checkpoint will use the previous format to maintain compatibility with existing production queries.

## How was this patch tested?
Refactored existing end-to-end tests and added new tests for explicitly testing obj-to-row conversion for both state formats.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #21739 from tdas/SPARK-22187-1.
  • Loading branch information
tdas committed Jul 19, 2018
1 parent 8d707b0 commit b3d88ac
Show file tree
Hide file tree
Showing 23 changed files with 708 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,8 @@ trait ComplexTypeMergingExpression extends Expression {
"The collection of input data types must not be empty.")
require(
TypeCoercion.haveSameType(inputTypesForMerging),
"All input types must be the same except nullable, containsNull, valueContainsNull flags.")
"All input types must be the same except nullable, containsNull, valueContainsNull flags." +
s" The input types found are\n\t${inputTypesForMerging.mkString("\n\t")}")
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,14 @@ object SQLConf {
.intConf
.createWithDefault(10)

val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
.internal()
.doc("State format version used by flatMapGroupsWithState operation in a streaming query")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
.doc("The default location for storing checkpoint data for streaming queries.")
.stringConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case FlatMapGroupsWithState(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode,
timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
execPlan :: Nil
case _ =>
Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.CompletionIterator

/**
Expand All @@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec(
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
stateEncoder: ExpressionEncoder[Any],
stateFormatVersion: Int,
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
Expand All @@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec(
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {

import GroupStateImpl._
import FlatMapGroupsWithStateExecHelper._

private val isTimeoutEnabled = timeoutConf != NoTimeout
private val timestampTimeoutAttribute =
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
private val stateAttributes: Seq[Attribute] = {
val encSchemaAttribs = stateEncoder.schema.toAttributes
if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
}
// Get the serializer for the state, taking into account whether we need to save timestamps
private val stateSerializer = {
val encoderSerializer = stateEncoder.namedExpressions
if (isTimeoutEnabled) {
encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
} else {
encoderSerializer
}
}
// Get the deserializer for the state. Note that this must be done in the driver, as
// resolving and binding of deserializer expressions to the encoded type can be safely done
// only in the driver.
private val stateDeserializer = stateEncoder.resolveAndBind().deserializer

private val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
private[sql] val stateManager =
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)

/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
Expand Down Expand Up @@ -125,11 +107,11 @@ case class FlatMapGroupsWithStateExec(
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
stateAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val updater = new StateStoreUpdater(store)
val processor = new InputProcessor(store)

// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
Expand All @@ -143,7 +125,7 @@ case class FlatMapGroupsWithStateExec(
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys()
processor.processNewData(filteredIter) ++ processor.processTimedOutState()

// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
Expand All @@ -158,7 +140,7 @@ case class FlatMapGroupsWithStateExec(
}

/** Helper class to update the state store */
class StateStoreUpdater(store: StateStore) {
class InputProcessor(store: StateStore) {

// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
Expand All @@ -167,14 +149,6 @@ case class FlatMapGroupsWithStateExec(
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)

// Converters for translating state between rows and Java objects
private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
stateDeserializer, stateAttributes)
private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)

// Index of the additional metadata fields in the state row
private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)

// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
Expand All @@ -183,20 +157,19 @@ case class FlatMapGroupsWithStateExec(
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
keyUnsafeRow,
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
store.get(keyUnsafeRow),
hasTimedOut = false)
}
}

/** Find the groups that have timeout set and are timing out right now, and call the function */
def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
Expand All @@ -205,12 +178,11 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutPairs = store.getRange(None, None).filter { rowPair =>
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
timingOutPairs.flatMap { rowPair =>
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
timingOutPairs.flatMap { stateData =>
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
}
} else Iterator.empty
}
Expand All @@ -220,73 +192,44 @@ case class FlatMapGroupsWithStateExec(
* iterator. Note that the store updating is lazy, that is, the store will be updated only
* after the returned iterator is fully consumed.
*
* @param keyRow Row representing the key, cannot be null
* @param stateData All the data related to the state to be updated
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
* @param prevStateRow Row representing the previous state, can be null
* @param hasTimedOut Whether this function is being called for a key timeout
*/
private def callFunctionAndUpdateState(
keyRow: UnsafeRow,
stateData: StateData,
valueRowIter: Iterator[InternalRow],
prevStateRow: UnsafeRow,
hasTimedOut: Boolean): Iterator[InternalRow] = {

val keyObj = getKeyObj(keyRow) // convert key to objects
val keyObj = getKeyObj(stateData.keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObj = getStateObj(prevStateRow)
val keyedState = GroupStateImpl.createForStreaming(
Option(stateObj),
val groupState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut,
watermarkPresent)

// Call function, get the returned objects and convert them to rows
val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
}

// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {

val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
// If the state has not yet been set but timeout has been set, then
// we have to generate a row to save the timeout. However, attempting serialize
// null using case class encoder throws -
// java.lang.NullPointerException: Null value appeared in non-nullable field:
// If the schema is inferred from a Scala tuple / case class, or a Java bean, please
// try to use scala.Option[_] or other nullable types.
if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
throw new IllegalStateException(
"Cannot set timeout when state is not defined, that is, state has not been" +
"initialized or has been removed")
}

if (keyedState.hasRemoved) {
store.remove(keyRow)
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1

} else {
val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
val stateRowToWrite = if (keyedState.hasUpdated) {
getStateRow(keyedState.get)
} else {
prevStateRow
}

val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged

if (shouldWriteState) {
if (stateRowToWrite == null) {
// This should never happen because checks in GroupStateImpl should avoid cases
// where empty state would need to be written
throw new IllegalStateException("Attempting to write empty state")
}
setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
store.put(keyRow, stateRowToWrite)
val updatedStateObj = if (groupState.exists) groupState.get else null
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
numUpdatedStateRows += 1
}
}
Expand All @@ -295,28 +238,5 @@ case class FlatMapGroupsWithStateExec(
// Return an iterator of rows such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
}

/** Returns the state as Java object if defined */
def getStateObj(stateRow: UnsafeRow): Any = {
if (stateRow != null) getStateObjFromRow(stateRow) else null
}

/** Returns the row for an updated state */
def getStateRow(obj: Any): UnsafeRow = {
assert(obj != null)
getStateRowFromObj(obj)
}

/** Returns the timeout timestamp of a state row is set */
def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
if (isTimeoutEnabled && stateRow != null) {
stateRow.getLong(timeoutTimestampIndex)
} else NO_TIMESTAMP
}

/** Set the timestamp in a state row */
def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}

/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
Expand Down Expand Up @@ -87,7 +88,8 @@ case class OffsetSeqMetadata(
object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY)
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)

/**
* Default values of relevant configurations that are used for backward compatibility.
Expand All @@ -100,7 +102,9 @@ object OffsetSeqMetadata extends Logging {
* with a specific default value for ensuring same behavior of the query as before.
*/
private val relevantSQLConfDefaultValues = Map[String, String](
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down
Loading

0 comments on commit b3d88ac

Please sign in to comment.