diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fbb9a8cfae2e1..f70c5b9797716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -871,6 +871,16 @@ object SQLConf { .intConf .createWithDefault(2) + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") + .internal() + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c4ea857fd1d7..c5957a0726b4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + stateVersion, planLater(child)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..80e1f32b72226 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -256,6 +256,7 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -287,7 +288,8 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -311,6 +313,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion = stateFormatVersion, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ae7f2869b0f3..ab788c5b8bcc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -100,19 +100,21 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, None, + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, - StateStoreRestoreExec(_, None, child))) => + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( keys, Some(aggStateInfo), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 9847756f22d4f..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging { private val relevantSQLConfDefaultValues = Map[String, String]( STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> - FlatMapGroupsWithStateExecHelper.legacyVersion.toString + FlatMapGroupsWithStateExecHelper.legacyVersion.toString, + STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> + StreamingAggregationStateManager.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala new file mode 100644 index 0000000000000..9bfb9561b42a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.types.StructType + +/** + * Base trait for state manager purposed to be used from streaming aggregations. + */ +sealed trait StreamingAggregationStateManager extends Serializable { + + /** Extract columns consisting key from input row, and return the new row for key columns. */ + def getKey(row: UnsafeRow): UnsafeRow + + /** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */ + def getStateValueSchema: StructType + + /** Get the current value of a non-null key from the target state store. */ + def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + */ + def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + */ + def commit(store: StateStore): Long + + /** Remove a single non-null key from the target state store. */ + def remove(store: StateStore, key: UnsafeRow): Unit + + /** Return an iterator containing all the key-value pairs in target state store. */ + def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** Return an iterator containing all the keys in target state store. */ + def keys(store: StateStore): Iterator[UnsafeRow] + + /** Return an iterator containing all the values in target state store. */ + def values(store: StateStore): Iterator[UnsafeRow] +} + +object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + + // flag to check whether the row needs to be project into input row attributes after join + // e.g. if the fields in the joined row are not in the expected order + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + inputRowAttributes, keyValueJoinedExpressions) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + restoreOriginalRow(key, savedState) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginalRow(rowPair)) + } + + private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = { + restoreOriginalRow(rowPair.key, rowPair.value) + } + + private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = { + val joinedRow = joiner.join(key, value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6759fb42b4052..34e26d85ae2ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -167,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode { } } } + + protected def removeKeysOlderThanWatermark( + storeManager: StreamingAggregationStateManager, + store: StateStore): Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + storeManager.keys(store).foreach { keyRow => + if (watermarkPredicateForKeys.get.eval(keyRow)) { + storeManager.remove(store, keyRow) + } + } + } + } } object WatermarkSupport { @@ -201,20 +211,23 @@ object WatermarkSupport { case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case @@ -224,10 +237,10 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) + val restoredRow = stateManager.get(store, key) numOutputRows += 1 - Option(savedState).toSeq :+ row + Option(restoredRow).toSeq :+ row } } } @@ -254,9 +267,13 @@ case class StateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -265,11 +282,10 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -282,19 +298,18 @@ case class StateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { - store.commit() + stateManager.commit(store) } setStoreMetrics(store) - store.iterator().map { rowPair => + stateManager.values(store).map { valueRow => numOutputRows += 1 - rowPair.value + valueRow } // Update and output only rows being evicted from the StateStore @@ -304,14 +319,13 @@ case class StateStoreSaveExec( val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numUpdatedStateRows += 1 } } val removalStartTimeNs = System.nanoTime - val rangeIter = store.getRange(None, None) + val rangeIter = stateManager.iterator(store) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { @@ -319,7 +333,7 @@ case class StateStoreSaveExec( while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) + stateManager.remove(store, rowPair.key) removedValueRow = rowPair.value } } @@ -333,7 +347,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } @@ -352,8 +366,7 @@ case class StateStoreSaveExec( override protected def getNext(): InternalRow = { if (baseIterator.hasNext) { val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key, row) + stateManager.put(store, row) numOutputRows += 1 numUpdatedStateRows += 1 row @@ -367,8 +380,10 @@ case class StateStoreSaveExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } + allRemovalsTimeMs += timeTakenMs { + removeKeysOlderThanWatermark(stateManager, store) + } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata new file mode 100644 index 0000000000000..c160d737278e1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"2f32aca2-1b97-458f-a48f-109328724f09"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 new file mode 100644 index 0000000000000..acdc6e69e975a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 new file mode 100644 index 0000000000000..27353e8724507 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000..281b21e960909 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta new file mode 100644 index 0000000000000..b701841d71535 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000..6352978051846 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta differ diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000..f4fb2520a4ac4 Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala new file mode 100644 index 0000000000000..98586d6492c9e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + } + + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key.copy(), newValue.copy()) + + override def remove(key: UnsafeRow): Unit = map.remove(key) + + override def commit(): Long = version + 1 + + override def abort(): Unit = {} + + override def id: StateStoreId = null + + override def version: Long = 0 + + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) + + override def hasCommitted: Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala new file mode 100644 index 0000000000000..daacdfd58c7b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StreamingAggregationStateManagerSuite extends StreamTest { + // ============================ fields and method for test data ============================ + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testKeys.contains(p.name) + } + val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testValues.contains(p.name) + } + val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val expectedTestKeyRow: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, testOutputAttributes) + keyProjector(testRow) + } + + val expectedTestValueRowForV2: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(testValuesAttributes, + testOutputAttributes) + valueProjector(testRow) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + // ============================ StateManagerImplV1 ============================ + + test("StateManager v1 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 1) + + // in V1, input row is stored as value + testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, + expectedTestKeyRow, expectedStateValue = testRow) + } + + // ============================ StateManagerImplV2 ============================ + test("StateManager v2 - get, put, iter") { + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) + + // in V2, row for values itself (excluding keys from input row) is stored as value + // so that stored value doesn't have key part, but state manager V2 will provide same output + // as V1 when getting row for key + testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, testRow, + expectedTestKeyRow, expectedTestValueRowForV2) + } + + private def testGetPutIterOnStateManager( + stateManager: StreamingAggregationStateManager, + expectedValueSchema: StructType, + inputRow: UnsafeRow, + expectedStateKey: UnsafeRow, + expectedStateValue: UnsafeRow): Unit = { + + assert(stateManager.getStateValueSchema === expectedValueSchema) + + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, inputRow) + + assert(memoryStateStore.iterator().size === 1) + assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size) + + val keyRow = stateManager.getKey(inputRow) + assert(keyRow === expectedStateKey) + + // iterate state store and verify whether expected format of key and value are stored + val pair = memoryStateStore.iterator().next() + assert(pair.key === keyRow) + assert(pair.value === expectedStateValue) + + // iterate with state manager and see whether original rows are returned as values + val pairFromStateManager = stateManager.iterator(memoryStateStore).next() + assert(pairFromStateManager.key === keyRow) + assert(pairFromStateManager.value === inputRow) + + // following as keys and values + assert(stateManager.keys(memoryStateStore).next() === keyRow) + assert(stateManager.values(memoryStateStore).next() === inputRow) + + // verify the stored value once again via get + assert(memoryStateStore.get(keyRow) === expectedStateValue) + + // state manager should return row which is same as input row regardless of format version + assert(inputRow === stateManager.get(memoryStateStore, keyRow)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 82d7755aef5f0..76511ae2c8362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import java.io.File import java.sql.Date -import java.util.concurrent.ConcurrentHashMap import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite { var failInTask = true - class MemoryStateStore extends StateStore() { - import scala.collection.JavaConverters._ - private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - - override def iterator(): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } - } - - override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } - override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def commit(): Long = version + 1 - override def abort(): Unit = { } - override def id: StateStoreId = null - override def version: Long = 0 - override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) - override def hasCommitted: Boolean = true - } - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { if (!predicate) throw new TestFailedException("Could not get processing time", 20) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 382da13430781..1ae6ff3a90989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.Assertions -import org.scalatest.BeforeAndAfterAll +import org.apache.commons.io.FileUtils +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -31,13 +32,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.Utils object FailureSingleton { var firstTime = true @@ -53,7 +56,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - test("simple count, update mode") { + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } + } + + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + testQuietly(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + testWithAllStateVersions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -77,7 +108,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("count distinct") { + testWithAllStateVersions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -93,7 +124,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, complete mode") { + testWithAllStateVersions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -116,7 +147,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, append mode") { + testWithAllStateVersions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -133,7 +164,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("sort after aggregate in complete mode") { + testWithAllStateVersions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -158,7 +189,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("state metrics") { + testWithAllStateVersions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -211,7 +242,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("multiple keys") { + testWithAllStateVersions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -228,7 +259,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietly("midbatch failure") { + testQuietlyWithAllStateVersions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -254,7 +285,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("typed aggregators") { + testWithAllStateVersions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -264,7 +295,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_time, complete mode") { + testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -316,7 +347,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_date, complete mode") { + testWithAllStateVersions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -365,7 +396,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " + + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") @@ -429,7 +461,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default @@ -467,8 +500,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + - "has non-empty grouping keys") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " + + "repartitioned when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -520,7 +553,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-22230: last should change with new batches") { + testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -536,7 +569,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. // - A streaming query with a streaming aggregation. @@ -545,22 +579,72 @@ class StreamingAggregationSuite extends StateStoreMetricsTest // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a // micro-batch with 128 records that shuffle to a single partition. // This test throws the exact error reported in SPARK-23004 without the corresponding fix. - withSQLConf("spark.sql.shuffle.partitions" -> "1") { - val input = MemoryStream[Int] - val df = input.toDF().toDF("value") - .selectExpr("value as group", "value") - .groupBy("group") - .agg(collect_list("value")) - testStream(df, outputMode = OutputMode.Update)( - AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), - AssertOnQuery { q => - q.processAllAvailable() - true + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) + } + + + test("simple count, update mode - recovery from checkpoint uses state format version 1") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)) + */ + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + + Execute { query => + // Verify state format = 1 + val stateVersions = query.lastExecution.executedPlan.collect { + case f: StateStoreSaveExec => f.stateFormatVersion + case f: StateStoreRestoreExec => f.stateFormatVersion } - ) - } + assert(stateVersions.size == 2) + assert(stateVersions.forall(_ == 1)) + }, + + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = {