From 4252f41d86cefb974ed4da5f26ea4805a0fee2e4 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Sun, 8 Jul 2018 18:37:12 +0900 Subject: [PATCH 01/13] [SPARK-24763][SS] Remove redundant key data from value in streaming aggregation * add option to configure enabling new feature: remove redundant key data from value * modify code to respect new option (turning on/off feature) * modify tests to run tests with both on/off * Add guard in OffsetSeqMetadata to prevent modifying option after executing query --- .../apache/spark/sql/internal/SQLConf.scala | 13 +++ .../sql/execution/streaming/OffsetSeq.scala | 2 +- .../streaming/statefulOperators.scala | 107 ++++++++++++++++-- .../streaming/StreamingAggregationSuite.scala | 94 +++++++++------ 4 files changed, 174 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fbb9a8cfae2e1..c2656f84f9b42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -871,6 +871,16 @@ object SQLConf { .intConf .createWithDefault(2) + val ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION = + buildConf("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation") + .internal() + .doc("ADVANCED: When true, stateful aggregation tries to remove redundant data " + + "between key and value in state. Enabling this option helps minimizing state size, " + + "but no longer be compatible with state with disabling this option." + + "You can't change this option after starting the query.") + .booleanConf + .createWithDefault(false) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -1618,6 +1628,9 @@ class SQLConf extends Serializable with Logging { def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def advancedRemoveRedundantInStatefulAggregation: Boolean = + getConf(ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 9847756f22d4f..e1d94945b8f94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION) /** * Default values of relevant configurations that are used for backward compatibility. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6759fb42b4052..2bbefad7e83a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ @@ -204,30 +204,64 @@ case class StateStoreRestoreExec( child: SparkPlan) extends UnaryExecNode with StateStoreReader { + val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation + if (removeRedundant) { + log.info("Advanced option removeRedundantInStatefulAggregation activated!") + } + + val valueExpressions: Seq[Attribute] = if (removeRedundant) { + child.output.diff(keyExpressions) + } else { + child.output + } + val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + valueExpressions.toStructType, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + val restoreValueProject = GenerateUnsafeProjection.generate( + keyValueJoinedExpressions, child.output) + val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case // the `HashAggregateExec` will output a 0 value for the partial merge. We need to // restore the value, so that we don't overwrite our state with a 0 value, but rather // merge the 0 with existing state. + // In this case the value should represent origin row, so no need to restore. store.iterator().map(_.value) } else { iter.flatMap { row => val key = getKey(row) val savedState = store.get(key) + val restoredRow = if (removeRedundant) { + if (savedState == null) { + savedState + } else { + val joinedRow = joiner.join(key, savedState) + if (needToProjectToRestoreValue) { + restoreValueProject(joinedRow) + } else { + joinedRow + } + } + } else { + savedState + } + numOutputRows += 1 - Option(savedState).toSeq :+ row + Option(restoredRow).toSeq :+ row } } } @@ -257,6 +291,19 @@ case class StateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation + if (removeRedundant) { + log.info("Advanced option removeRedundantInStatefulAggregation activated!") + } + + val valueExpressions: Seq[Attribute] = if (removeRedundant) { + child.output.diff(keyExpressions) + } else { + child.output + } + val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -265,11 +312,17 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - child.output.toStructType, + valueExpressions.toStructType, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output) + val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + val restoreValueProject = GenerateUnsafeProjection.generate( + keyValueJoinedExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -283,7 +336,12 @@ case class StateStoreSaveExec( while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key, row) + val value = if (removeRedundant) { + getValue(row) + } else { + row + } + store.put(key, value) numUpdatedStateRows += 1 } } @@ -294,7 +352,18 @@ case class StateStoreSaveExec( setStoreMetrics(store) store.iterator().map { rowPair => numOutputRows += 1 - rowPair.value + + if (removeRedundant) { + val joinedRow = joiner.join(rowPair.key, rowPair.value) + if (needToProjectToRestoreValue) { + restoreValueProject(joinedRow) + } else { + joinedRow + } + } else { + rowPair.value + } + } // Update and output only rows being evicted from the StateStore @@ -305,7 +374,12 @@ case class StateStoreSaveExec( while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key, row) + val value = if (removeRedundant) { + getValue(row) + } else { + row + } + store.put(key, value) numUpdatedStateRows += 1 } } @@ -320,7 +394,17 @@ case class StateStoreSaveExec( val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { store.remove(rowPair.key) - removedValueRow = rowPair.value + + if (removeRedundant) { + val joinedRow = joiner.join(rowPair.key, rowPair.value) + removedValueRow = if (needToProjectToRestoreValue) { + restoreValueProject(joinedRow) + } else { + joinedRow + } + } else { + removedValueRow = rowPair.value + } } } if (removedValueRow == null) { @@ -353,7 +437,12 @@ case class StateStoreSaveExec( if (baseIterator.hasNext) { val row = baseIterator.next().asInstanceOf[UnsafeRow] val key = getKey(row) - store.put(key, row) + val value = if (removeRedundant) { + getValue(row) + } else { + row + } + store.put(key, value) numOutputRows += 1 numUpdatedStateRows += 1 row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 382da13430781..d1d6a36ff219f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -53,7 +53,36 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - test("simple count, update mode") { + def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = { + val confAndTestNamePostfixMatrix = List( + (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "false"), ""), + (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "true"), + " : enable remove redundant in stateful aggregation") + ) + + confAndTestNamePostfixMatrix.foreach { + case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { + test(testName + testNamePostfix)(testFun) + } + } + } + + def testQuietlyWithAggrOptions(testName: String, pairs: (String, String)*) + (testFun: => Any): Unit = { + val confAndTestNamePostfixMatrix = List( + (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "false"), ""), + (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "true"), + " : enable remove redundant in stateful aggregation") + ) + + confAndTestNamePostfixMatrix.foreach { + case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { + testQuietly(testName + testNamePostfix)(testFun) + } + } + } + + testWithAggrOptions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -77,7 +106,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("count distinct") { + testWithAggrOptions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -93,7 +122,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, complete mode") { + testWithAggrOptions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -116,7 +145,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("simple count, append mode") { + testWithAggrOptions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -133,7 +162,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("sort after aggregate in complete mode") { + testWithAggrOptions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -158,7 +187,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("state metrics") { + testWithAggrOptions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -211,7 +240,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("multiple keys") { + testWithAggrOptions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -228,7 +257,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietly("midbatch failure") { + testQuietlyWithAggrOptions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -254,7 +283,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("typed aggregators") { + testWithAggrOptions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -264,7 +293,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_time, complete mode") { + testWithAggrOptions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -316,7 +345,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("prune results by current_date, complete mode") { + testWithAggrOptions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -365,7 +394,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") { + testWithAggrOptions("SPARK-19690: do not convert batch aggregation in streaming query " + + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") @@ -429,7 +459,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + testWithAggrOptions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default @@ -467,8 +498,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + - "has non-empty grouping keys") { + testWithAggrOptions("SPARK-21977: coalesce(1) with aggregation should still be repartitioned " + + "when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -520,7 +551,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - test("SPARK-22230: last should change with new batches") { + testWithAggrOptions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -536,7 +567,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") { + testWithAggrOptions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + "do not throw errors", "spark.sql.shuffle.partitions" -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. // - A streaming query with a streaming aggregation. @@ -545,20 +577,18 @@ class StreamingAggregationSuite extends StateStoreMetricsTest // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a // micro-batch with 128 records that shuffle to a single partition. // This test throws the exact error reported in SPARK-23004 without the corresponding fix. - withSQLConf("spark.sql.shuffle.partitions" -> "1") { - val input = MemoryStream[Int] - val df = input.toDF().toDF("value") - .selectExpr("value as group", "value") - .groupBy("group") - .agg(collect_list("value")) - testStream(df, outputMode = OutputMode.Update)( - AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), - AssertOnQuery { q => - q.processAllAvailable() - true - } - ) - } + val input = MemoryStream[Int] + val df = input.toDF().toDF("value") + .selectExpr("value as group", "value") + .groupBy("group") + .agg(collect_list("value")) + testStream(df, outputMode = OutputMode.Update)( + AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*), + AssertOnQuery { q => + q.processAllAvailable() + true + } + ) } /** Add blocks of data to the `BlockRDDBackedSource`. */ @@ -602,4 +632,4 @@ class StreamingAggregationSuite extends StateStoreMetricsTest blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) } } -} +} \ No newline at end of file From 941b88d8c96a368a8dfcc8dada3636abfaacfea9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 9 Jul 2018 15:31:58 +0900 Subject: [PATCH 02/13] Fix scala checkstyle issue --- .../apache/spark/sql/streaming/StreamingAggregationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index d1d6a36ff219f..03065383862e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -632,4 +632,4 @@ class StreamingAggregationSuite extends StateStoreMetricsTest blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) } } -} \ No newline at end of file +} From abec57f331bbdad6ef4689e2790d4a9fbb989715 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Mon, 9 Jul 2018 16:03:25 +0900 Subject: [PATCH 03/13] Remove duplicating code, use configuration key instead of string literal --- .../streaming/StreamingAggregationSuite.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 03065383862e8..f84ae1fef9518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -53,13 +54,13 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = { - val confAndTestNamePostfixMatrix = List( - (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "false"), ""), - (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "true"), - " : enable remove redundant in stateful aggregation") - ) + val confAndTestNamePostfixMatrix = List( + (Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "false"), ""), + (Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "true"), + " : enable remove redundant in stateful aggregation") + ) + def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = { confAndTestNamePostfixMatrix.foreach { case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { test(testName + testNamePostfix)(testFun) @@ -69,12 +70,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest def testQuietlyWithAggrOptions(testName: String, pairs: (String, String)*) (testFun: => Any): Unit = { - val confAndTestNamePostfixMatrix = List( - (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "false"), ""), - (Seq("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation" -> "true"), - " : enable remove redundant in stateful aggregation") - ) - confAndTestNamePostfixMatrix.foreach { case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { testQuietly(testName + testNamePostfix)(testFun) @@ -568,7 +563,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } testWithAggrOptions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + - "do not throw errors", "spark.sql.shuffle.partitions" -> "1") { + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. // - A streaming query with a streaming aggregation. From 977428cb35a6fc0a9fa7a0ca1a51e39a94447a01 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 18 Jul 2018 16:17:16 +0900 Subject: [PATCH 04/13] Refine code change: introduce trait and classes to group duplicate methods --- .../streaming/StatefulOperatorsHelper.scala | 136 ++++++++++++++++++ .../streaming/statefulOperators.scala | 121 +++------------- 2 files changed, 152 insertions(+), 105 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala new file mode 100644 index 0000000000000..d9b1e71e530c8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +object StatefulOperatorsHelper { + sealed trait StreamingAggregationStateManager extends Serializable { + def extractKey(row: InternalRow): UnsafeRow + def getValueExpressions: Seq[Attribute] + def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow + def get(store: StateStore, key: UnsafeRow): UnsafeRow + def put(store: StateStore, row: UnsafeRow): Unit + } + + object StreamingAggregationStateManager extends Logging { + def newImpl( + keyExpressions: Seq[Attribute], + childOutput: Seq[Attribute], + conf: SQLConf): StreamingAggregationStateManager = { + + if (conf.advancedRemoveRedundantInStatefulAggregation) { + log.info("Advanced option removeRedundantInStatefulAggregation activated!") + new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput) + } else { + new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput) + } + } + } + + abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val childOutput: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, childOutput) + + def extractKey(row: InternalRow): UnsafeRow = keyProjector(row) + } + + class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + childOutput: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { + + override def getValueExpressions: Seq[Attribute] = { + childOutput + } + + override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { + rowPair.value + } + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(extractKey(row), row) + } + } + + class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + childOutput: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { + + private val valueExpressions: Seq[Attribute] = childOutput.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + private val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != childOutput + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, childOutput) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + keyValueJoinedExpressions, childOutput) + + override def getValueExpressions: Seq[Attribute] = { + valueExpressions + } + + override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { + val joinedRow = joiner.join(rowPair.key, rowPair.value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + val joinedRow = joiner.join(key, savedState) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2bbefad7e83a5..847bb7e7e6386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ @@ -204,35 +203,18 @@ case class StateStoreRestoreExec( child: SparkPlan) extends UnaryExecNode with StateStoreReader { - val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation - if (removeRedundant) { - log.info("Advanced option removeRedundantInStatefulAggregation activated!") - } - - val valueExpressions: Seq[Attribute] = if (removeRedundant) { - child.output.diff(keyExpressions) - } else { - child.output - } - val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions - val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output - override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, + sqlContext.conf) child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - valueExpressions.toStructType, + stateManager.getValueExpressions.toStructType, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), - StructType.fromAttributes(valueExpressions)) - val restoreValueProject = GenerateUnsafeProjection.generate( - keyValueJoinedExpressions, child.output) - val hasInput = iter.hasNext if (!hasInput && keyExpressions.isEmpty) { // If our `keyExpressions` are empty, we're getting a global aggregation. In that case @@ -243,23 +225,8 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) - val restoredRow = if (removeRedundant) { - if (savedState == null) { - savedState - } else { - val joinedRow = joiner.join(key, savedState) - if (needToProjectToRestoreValue) { - restoreValueProject(joinedRow) - } else { - joinedRow - } - } - } else { - savedState - } - + val key = stateManager.extractKey(row) + val restoredRow = stateManager.get(store, key) numOutputRows += 1 Option(restoredRow).toSeq :+ row } @@ -291,38 +258,21 @@ case class StateStoreSaveExec( child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { - val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation - if (removeRedundant) { - log.info("Advanced option removeRedundantInStatefulAggregation activated!") - } - - val valueExpressions: Seq[Attribute] = if (removeRedundant) { - child.output.diff(keyExpressions) - } else { - child.output - } - val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions - val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output - override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, "Incorrect planning in IncrementalExecution, outputMode has not been set") + val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, + sqlContext.conf) + child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - valueExpressions.toStructType, + stateManager.getValueExpressions.toStructType, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output) - val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), - StructType.fromAttributes(valueExpressions)) - val restoreValueProject = GenerateUnsafeProjection.generate( - keyValueJoinedExpressions, child.output) - val numOutputRows = longMetric("numOutputRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") @@ -335,13 +285,7 @@ case class StateStoreSaveExec( allUpdatesTimeMs += timeTakenMs { while (iter.hasNext) { val row = iter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - val value = if (removeRedundant) { - getValue(row) - } else { - row - } - store.put(key, value) + stateManager.put(store, row) numUpdatedStateRows += 1 } } @@ -352,18 +296,7 @@ case class StateStoreSaveExec( setStoreMetrics(store) store.iterator().map { rowPair => numOutputRows += 1 - - if (removeRedundant) { - val joinedRow = joiner.join(rowPair.key, rowPair.value) - if (needToProjectToRestoreValue) { - restoreValueProject(joinedRow) - } else { - joinedRow - } - } else { - rowPair.value - } - + stateManager.restoreOriginRow(rowPair) } // Update and output only rows being evicted from the StateStore @@ -373,13 +306,7 @@ case class StateStoreSaveExec( val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - val value = if (removeRedundant) { - getValue(row) - } else { - row - } - store.put(key, value) + stateManager.put(store, row) numUpdatedStateRows += 1 } } @@ -394,17 +321,7 @@ case class StateStoreSaveExec( val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { store.remove(rowPair.key) - - if (removeRedundant) { - val joinedRow = joiner.join(rowPair.key, rowPair.value) - removedValueRow = if (needToProjectToRestoreValue) { - restoreValueProject(joinedRow) - } else { - joinedRow - } - } else { - removedValueRow = rowPair.value - } + removedValueRow = stateManager.restoreOriginRow(rowPair) } } if (removedValueRow == null) { @@ -436,13 +353,7 @@ case class StateStoreSaveExec( override protected def getNext(): InternalRow = { if (baseIterator.hasNext) { val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - val value = if (removeRedundant) { - getValue(row) - } else { - row - } - store.put(key, value) + stateManager.put(store, row) numOutputRows += 1 numUpdatedStateRows += 1 row From 63dfb5d2c82dfdf0a9e681fd5608f72a11dc04ed Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 20 Jul 2018 13:51:33 +0900 Subject: [PATCH 05/13] Change the strategy: "add new option" -> "apply by default, but keep backward compatible" --- .../apache/spark/sql/internal/SQLConf.scala | 19 +++-- .../spark/sql/execution/SparkStrategies.scala | 3 + .../sql/execution/aggregate/AggUtils.scala | 5 +- .../streaming/IncrementalExecution.scala | 6 +- .../sql/execution/streaming/OffsetSeq.scala | 6 +- .../streaming/StatefulOperatorsHelper.scala | 19 ++--- .../streaming/statefulOperators.scala | 13 ++-- .../streaming/StreamingAggregationSuite.scala | 70 ++++++++++--------- 8 files changed, 78 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c2656f84f9b42..8bd383040a901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -871,15 +871,15 @@ object SQLConf { .intConf .createWithDefault(2) - val ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION = - buildConf("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation") + val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.streamingAggregation.stateFormatVersion") .internal() - .doc("ADVANCED: When true, stateful aggregation tries to remove redundant data " + - "between key and value in state. Enabling this option helps minimizing state size, " + - "but no longer be compatible with state with disabling this option." + - "You can't change this option after starting the query.") - .booleanConf - .createWithDefault(false) + .doc("State format version used by streaming aggregation operations triggered " + + "explicitly or implicitly via agg() in a streaming query. State between versions are " + + "tend to be incompatible, so state format version shouldn't be modified after running.") + .intConf + .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") + .createWithDefault(2) val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") @@ -1628,9 +1628,6 @@ class SQLConf extends Serializable with Logging { def advancedPartitionPredicatePushdownEnabled: Boolean = getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) - def advancedRemoveRedundantInStatefulAggregation: Boolean = - getConf(ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION) - def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0c4ea857fd1d7..c5957a0726b4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Streaming aggregation doesn't support group aggregate pandas UDF") } + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, + stateVersion, planLater(child)) case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..80e1f32b72226 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -256,6 +256,7 @@ object AggUtils { groupingExpressions: Seq[NamedExpression], functionsWithoutDistinct: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, child: SparkPlan): Seq[SparkPlan] = { val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -287,7 +288,8 @@ object AggUtils { child = partialAggregate) } - val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1) + val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion, + partialMerged1) val partialMerged2: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) @@ -311,6 +313,7 @@ object AggUtils { stateInfo = None, outputMode = None, eventTimeWatermark = None, + stateFormatVersion = stateFormatVersion, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ae7f2869b0f3..ab788c5b8bcc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -100,19 +100,21 @@ class IncrementalExecution( val state = new Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, None, None, + case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, UnaryExecNode(agg, - StateStoreRestoreExec(_, None, child))) => + StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, Some(aggStateInfo), Some(outputMode), Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( keys, Some(aggStateInfo), + stateFormatVersion, child) :: Nil)) case StreamingDeduplicateExec(keys, child, None, None) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index e1d94945b8f94..816e388aceba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging { private implicit val format = Serialization.formats(NoTypeHints) private val relevantSQLConfs = Seq( SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, - FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION) + FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION) /** * Default values of relevant configurations that are used for backward compatibility. @@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging { private val relevantSQLConfDefaultValues = Map[String, String]( STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> - FlatMapGroupsWithStateExecHelper.legacyVersion.toString + FlatMapGroupsWithStateExecHelper.legacyVersion.toString, + STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> + StatefulOperatorsHelper.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala index d9b1e71e530c8..658f5b65531f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala @@ -22,10 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType object StatefulOperatorsHelper { + + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + sealed trait StreamingAggregationStateManager extends Serializable { def extractKey(row: InternalRow): UnsafeRow def getValueExpressions: Seq[Attribute] @@ -35,16 +38,14 @@ object StatefulOperatorsHelper { } object StreamingAggregationStateManager extends Logging { - def newImpl( + def createStateManager( keyExpressions: Seq[Attribute], childOutput: Seq[Attribute], - conf: SQLConf): StreamingAggregationStateManager = { - - if (conf.advancedRemoveRedundantInStatefulAggregation) { - log.info("Advanced option removeRedundantInStatefulAggregation activated!") - new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput) - } else { - new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput) + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 847bb7e7e6386..048b4a64a84fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -200,13 +200,15 @@ object WatermarkSupport { case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, - sqlContext.conf) child.execute().mapPartitionsWithStateStore( getStateInfo, @@ -255,17 +257,18 @@ case class StateStoreSaveExec( stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + private[sql] val stateManager = StreamingAggregationStateManager.createStateManager( + keyExpressions, child.output, stateFormatVersion) + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, "Incorrect planning in IncrementalExecution, outputMode has not been set") - val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, - sqlContext.conf) - child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f84ae1fef9518..68b20846db3a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.streaming import java.util.{Locale, TimeZone} -import org.scalatest.Assertions -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -54,30 +53,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest import testImplicits._ - val confAndTestNamePostfixMatrix = List( - (Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "false"), ""), - (Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "true"), - " : enable remove redundant in stateful aggregation") - ) + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } + } - def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = { - confAndTestNamePostfixMatrix.foreach { - case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { - test(testName + testNamePostfix)(testFun) + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StatefulOperatorsHelper.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) } } } - def testQuietlyWithAggrOptions(testName: String, pairs: (String, String)*) - (testFun: => Any): Unit = { - confAndTestNamePostfixMatrix.foreach { - case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) { - testQuietly(testName + testNamePostfix)(testFun) + def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StatefulOperatorsHelper.supportedVersions) { + testQuietly(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) } } } - testWithAggrOptions("simple count, update mode") { + testWithAllStateVersions("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -101,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("count distinct") { + testWithAllStateVersions("count distinct") { val inputData = MemoryStream[(Int, Seq[Int])] val aggregated = @@ -117,7 +121,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("simple count, complete mode") { + testWithAllStateVersions("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -140,7 +144,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("simple count, append mode") { + testWithAllStateVersions("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = @@ -157,7 +161,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - testWithAggrOptions("sort after aggregate in complete mode") { + testWithAllStateVersions("sort after aggregate in complete mode") { val inputData = MemoryStream[Int] val aggregated = @@ -182,7 +186,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("state metrics") { + testWithAllStateVersions("state metrics") { val inputData = MemoryStream[Int] val aggregated = @@ -235,7 +239,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("multiple keys") { + testWithAllStateVersions("multiple keys") { val inputData = MemoryStream[Int] val aggregated = @@ -252,7 +256,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testQuietlyWithAggrOptions("midbatch failure") { + testQuietlyWithAllStateVersions("midbatch failure") { val inputData = MemoryStream[Int] FailureSingleton.firstTime = true val aggregated = @@ -278,7 +282,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("typed aggregators") { + testWithAllStateVersions("typed aggregators") { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) @@ -288,7 +292,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("prune results by current_time, complete mode") { + testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock val inputData = MemoryStream[Long] @@ -340,7 +344,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("prune results by current_date, complete mode") { + testWithAllStateVersions("prune results by current_date, complete mode") { import testImplicits._ val clock = new StreamManualClock val tz = TimeZone.getDefault.getID @@ -389,7 +393,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("SPARK-19690: do not convert batch aggregation in streaming query " + + testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " + "to streaming") { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) @@ -454,7 +458,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest true } - testWithAggrOptions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + + testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " + "repartitioned to 1") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { @@ -493,8 +497,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - testWithAggrOptions("SPARK-21977: coalesce(1) with aggregation should still be repartitioned " + - "when it has non-empty grouping keys") { + testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " + + "repartitioned when it has non-empty grouping keys") { val inputSource = new BlockRDDBackedSource(spark) MockSourceProvider.withMockSources(inputSource) { withTempDir { tempDir => @@ -546,7 +550,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } - testWithAggrOptions("SPARK-22230: last should change with new batches") { + testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] val aggregated = input.toDF().agg(last('value)) @@ -562,7 +566,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } - testWithAggrOptions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + + testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " + "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") { // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error // by ensuring the following. From e84463607bc86403c97ebf9b155b05da86a7aa73 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 20 Jul 2018 15:55:19 +0900 Subject: [PATCH 06/13] Add tests for StatefulOperatorsHelper as well --- .../streaming/state/MemoryStateStore.scala | 53 ++++++++ .../state/StatefulOperatorsHelperSuite.scala | 121 ++++++++++++++++++ .../FlatMapGroupsWithStateSuite.scala | 24 +--- 3 files changed, 175 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala new file mode 100644 index 0000000000000..879e1228c47ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class MemoryStateStore extends StateStore() { + import scala.collection.JavaConverters._ + private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] + + override def iterator(): Iterator[UnsafeRowPair] = { + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + } + + override def get(key: UnsafeRow): UnsafeRow = map.get(key) + + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { + map.put(key.copy(), newValue.copy()) + } + + override def remove(key: UnsafeRow): Unit = { + map.remove(key) + } + + override def commit(): Long = version + 1 + + override def abort(): Unit = {} + + override def id: StateStoreId = null + + override def version: Long = 0 + + override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) + + override def hasCommitted: Boolean = true +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala new file mode 100644 index 0000000000000..6c42c3ce27d7d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class StatefulOperatorsHelperSuite extends StreamTest { + import TestMaterial._ + + test("StateManager v1 - get, put, iter") { + val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 1) + + // in V1, input row is stored as value + testGetPutIterOnStateManager(stateManager, OUTPUT_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, TEST_ROW) + } + + // ============================ StateManagerImplV2 ============================ + test("StateManager v2 - get, put, iter") { + val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 2) + + // in V2, row for values itself (excluding keys from input row) is stored as value + // so that stored value doesn't have key part, but state manager V2 will provide same output + // as V1 when getting row for key + testGetPutIterOnStateManager(stateManager, VALUES_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, + TEST_VALUE_ROW) + } + + private def newStateManager( + keysAttributes: Seq[Attribute], + outputAttributes: Seq[Attribute], + version: Int): StreamingAggregationStateManager = { + StreamingAggregationStateManager.createStateManager(keysAttributes, outputAttributes, version) + } + + private def testGetPutIterOnStateManager( + stateManager: StreamingAggregationStateManager, + expectedValueExpressions: Seq[Attribute], + inputRow: UnsafeRow, + expectedStateKey: UnsafeRow, + expectedStateValue: UnsafeRow): Unit = { + + assert(stateManager.getValueExpressions === expectedValueExpressions) + + val memoryStateStore = new MemoryStateStore() + stateManager.put(memoryStateStore, inputRow) + + assert(memoryStateStore.iterator().size === 1) + + val keyRow = stateManager.extractKey(inputRow) + assert(keyRow === expectedStateKey) + + // iterate state store and verify whether expected format of key and value are stored + val pair = memoryStateStore.iterator().next() + assert(pair.key === keyRow) + assert(pair.value === expectedStateValue) + assert(stateManager.restoreOriginRow(pair) === inputRow) + + // verify the stored value once again via get + assert(memoryStateStore.get(keyRow) === expectedStateValue) + + // state manager should return row which is same as input row regardless of format version + assert(inputRow === stateManager.get(memoryStateStore, keyRow)) + } + +} + +object TestMaterial { + val KEYS: Seq[String] = Seq("key1", "key2") + val VALUES: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val OUTPUT_SCHEMA: StructType = StructType( + KEYS.map(createIntegerField) ++ VALUES.map(createIntegerField)) + + val OUTPUT_ATTRIBUTES: Seq[Attribute] = OUTPUT_SCHEMA.toAttributes + val KEYS_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p => + KEYS.contains(p.name) + } + val VALUES_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p => + VALUES.contains(p.name) + } + + val TEST_ROW: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(OUTPUT_SCHEMA) + val row = unsafeRowProjection(new SpecificInternalRow(OUTPUT_SCHEMA)) + (KEYS ++ VALUES).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val TEST_KEY_ROW: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES) + keyProjector(TEST_ROW) + } + + val TEST_VALUE_ROW: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(VALUES_ATTRIBUTES, OUTPUT_ATTRIBUTES) + valueProjector(TEST_ROW) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 82d7755aef5f0..76511ae2c8362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming import java.io.File import java.sql.Date -import java.util.concurrent.ConcurrentHashMap import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfterAll @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite { var failInTask = true - class MemoryStateStore extends StateStore() { - import scala.collection.JavaConverters._ - private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - - override def iterator(): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } - } - - override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } - override def remove(key: UnsafeRow): Unit = { map.remove(key) } - override def commit(): Long = version + 1 - override def abort(): Unit = { } - override def id: StateStoreId = null - override def version: Long = 0 - override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) - override def hasCommitted: Boolean = true - } - def assertCanGetProcessingTime(predicate: => Boolean): Unit = { if (!predicate) throw new TestFailedException("Could not get processing time", 20) } From 26701a3d35018e1af1574160c2f3091441cee1ce Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 1 Aug 2018 18:09:03 +0900 Subject: [PATCH 07/13] WIP Address a part of review comments from @tdas * TODO list * add docs * move StreamingAggregationStateManager to execution.streaming.state package * add iterator / remove in StreamingAggregationStateManager to remove restoreOriginRow * replace all the usages for direct call of store.xxx whenever state manager is available --- .../apache/spark/sql/internal/SQLConf.scala | 8 +- .../streaming/StatefulOperatorsHelper.scala | 43 ++++---- .../streaming/statefulOperators.scala | 6 +- .../streaming/state/MemoryStateStore.scala | 8 +- .../state/StatefulOperatorsHelperSuite.scala | 99 ++++++++++--------- 5 files changed, 80 insertions(+), 84 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8bd383040a901..f70c5b9797716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -872,11 +872,11 @@ object SQLConf { .createWithDefault(2) val STREAMING_AGGREGATION_STATE_FORMAT_VERSION = - buildConf("spark.sql.streaming.streamingAggregation.stateFormatVersion") + buildConf("spark.sql.streaming.aggregation.stateFormatVersion") .internal() - .doc("State format version used by streaming aggregation operations triggered " + - "explicitly or implicitly via agg() in a streaming query. State between versions are " + - "tend to be incompatible, so state format version shouldn't be modified after running.") + .doc("State format version used by streaming aggregation operations in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") .intConf .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala index 658f5b65531f6..7830e201e0709 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala @@ -30,8 +30,8 @@ object StatefulOperatorsHelper { val legacyVersion = 1 sealed trait StreamingAggregationStateManager extends Serializable { - def extractKey(row: InternalRow): UnsafeRow - def getValueExpressions: Seq[Attribute] + def getKey(row: InternalRow): UnsafeRow + def getStateValueSchema: StructType def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow def get(store: StateStore, key: UnsafeRow): UnsafeRow def put(store: StateStore, row: UnsafeRow): Unit @@ -40,11 +40,11 @@ object StatefulOperatorsHelper { object StreamingAggregationStateManager extends Logging { def createStateManager( keyExpressions: Seq[Attribute], - childOutput: Seq[Attribute], + inputRowAttributes: Seq[Attribute], stateFormatVersion: Int): StreamingAggregationStateManager = { stateFormatVersion match { - case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput) - case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput) + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") } } @@ -52,22 +52,20 @@ object StatefulOperatorsHelper { abstract class StreamingAggregationStateManagerBaseImpl( protected val keyExpressions: Seq[Attribute], - protected val childOutput: Seq[Attribute]) extends StreamingAggregationStateManager { + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { @transient protected lazy val keyProjector = - GenerateUnsafeProjection.generate(keyExpressions, childOutput) + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) - def extractKey(row: InternalRow): UnsafeRow = keyProjector(row) + def getKey(row: InternalRow): UnsafeRow = keyProjector(row) } class StreamingAggregationStateManagerImplV1( keyExpressions: Seq[Attribute], - childOutput: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - override def getValueExpressions: Seq[Attribute] = { - childOutput - } + override def getStateValueSchema: StructType = inputRowAttributes.toStructType override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { rowPair.value @@ -78,31 +76,30 @@ object StatefulOperatorsHelper { } override def put(store: StateStore, row: UnsafeRow): Unit = { - store.put(extractKey(row), row) + store.put(getKey(row), row) } } class StreamingAggregationStateManagerImplV2( keyExpressions: Seq[Attribute], - childOutput: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - private val valueExpressions: Seq[Attribute] = childOutput.diff(keyExpressions) + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions - private val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != childOutput + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes @transient private lazy val valueProjector = - GenerateUnsafeProjection.generate(valueExpressions, childOutput) + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) @transient private lazy val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), StructType.fromAttributes(valueExpressions)) @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( - keyValueJoinedExpressions, childOutput) + keyValueJoinedExpressions, inputRowAttributes) - override def getValueExpressions: Seq[Attribute] = { - valueExpressions - } + override def getStateValueSchema: StructType = valueExpressions.toStructType override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { val joinedRow = joiner.join(rowPair.key, rowPair.value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 048b4a64a84fd..217ced0cc11ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -213,7 +213,7 @@ case class StateStoreRestoreExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - stateManager.getValueExpressions.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => @@ -227,7 +227,7 @@ case class StateStoreRestoreExec( store.iterator().map(_.value) } else { iter.flatMap { row => - val key = stateManager.extractKey(row) + val key = stateManager.getKey(row) val restoredRow = stateManager.get(store, key) numOutputRows += 1 Option(restoredRow).toSeq :+ row @@ -272,7 +272,7 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateInfo, keyExpressions.toStructType, - stateManager.getValueExpressions.toStructType, + stateManager.getStateValueSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 879e1228c47ff..98586d6492c9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -31,13 +31,9 @@ class MemoryStateStore extends StateStore() { override def get(key: UnsafeRow): UnsafeRow = map.get(key) - override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = { - map.put(key.copy(), newValue.copy()) - } + override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key.copy(), newValue.copy()) - override def remove(key: UnsafeRow): Unit = { - map.remove(key) - } + override def remove(key: UnsafeRow): Unit = map.remove(key) override def commit(): Long = version + 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala index 6c42c3ce27d7d..2a8ef514c5779 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala @@ -24,48 +24,88 @@ import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class StatefulOperatorsHelperSuite extends StreamTest { - import TestMaterial._ + // ============================ fields and method for test data ============================ + + val testKeys: Seq[String] = Seq("key1", "key2") + val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)") + + val testOutputSchema: StructType = StructType( + testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) + + val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testKeys.contains(p.name) + } + val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p => + testValues.contains(p.name) + } + val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType + + val testRow: UnsafeRow = { + val unsafeRowProjection = UnsafeProjection.create(testOutputSchema) + val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema)) + (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } + row + } + + val expectedTestKeyRow: UnsafeRow = { + val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, testOutputAttributes) + keyProjector(testRow) + } + + val expectedTestValueRowForV2: UnsafeRow = { + val valueProjector = GenerateUnsafeProjection.generate(testValuesAttributes, + testOutputAttributes) + valueProjector(testRow) + } + + private def createIntegerField(name: String): StructField = { + StructField(name, IntegerType, nullable = false) + } + + // ============================ StateManagerImplV1 ============================ test("StateManager v1 - get, put, iter") { - val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 1) + val stateManager = newStateManager(testKeyAttributes, testOutputAttributes, 1) // in V1, input row is stored as value - testGetPutIterOnStateManager(stateManager, OUTPUT_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, TEST_ROW) + testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, + expectedTestKeyRow, testRow) } // ============================ StateManagerImplV2 ============================ test("StateManager v2 - get, put, iter") { - val stateManager = newStateManager(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES, 2) + val stateManager = newStateManager(testKeyAttributes, testOutputAttributes, 2) // in V2, row for values itself (excluding keys from input row) is stored as value // so that stored value doesn't have key part, but state manager V2 will provide same output // as V1 when getting row for key - testGetPutIterOnStateManager(stateManager, VALUES_ATTRIBUTES, TEST_ROW, TEST_KEY_ROW, - TEST_VALUE_ROW) + testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, testRow, + expectedTestKeyRow, expectedTestValueRowForV2) } private def newStateManager( keysAttributes: Seq[Attribute], - outputAttributes: Seq[Attribute], + inputRowAttributes: Seq[Attribute], version: Int): StreamingAggregationStateManager = { - StreamingAggregationStateManager.createStateManager(keysAttributes, outputAttributes, version) + StreamingAggregationStateManager.createStateManager(keysAttributes, inputRowAttributes, version) } private def testGetPutIterOnStateManager( stateManager: StreamingAggregationStateManager, - expectedValueExpressions: Seq[Attribute], + expectedValueSchema: StructType, inputRow: UnsafeRow, expectedStateKey: UnsafeRow, expectedStateValue: UnsafeRow): Unit = { - assert(stateManager.getValueExpressions === expectedValueExpressions) + assert(stateManager.getStateValueSchema === expectedValueSchema) val memoryStateStore = new MemoryStateStore() stateManager.put(memoryStateStore, inputRow) assert(memoryStateStore.iterator().size === 1) - val keyRow = stateManager.extractKey(inputRow) + val keyRow = stateManager.getKey(inputRow) assert(keyRow === expectedStateKey) // iterate state store and verify whether expected format of key and value are stored @@ -82,40 +122,3 @@ class StatefulOperatorsHelperSuite extends StreamTest { } } - -object TestMaterial { - val KEYS: Seq[String] = Seq("key1", "key2") - val VALUES: Seq[String] = Seq("sum(key1)", "sum(key2)") - - val OUTPUT_SCHEMA: StructType = StructType( - KEYS.map(createIntegerField) ++ VALUES.map(createIntegerField)) - - val OUTPUT_ATTRIBUTES: Seq[Attribute] = OUTPUT_SCHEMA.toAttributes - val KEYS_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p => - KEYS.contains(p.name) - } - val VALUES_ATTRIBUTES: Seq[Attribute] = OUTPUT_ATTRIBUTES.filter { p => - VALUES.contains(p.name) - } - - val TEST_ROW: UnsafeRow = { - val unsafeRowProjection = UnsafeProjection.create(OUTPUT_SCHEMA) - val row = unsafeRowProjection(new SpecificInternalRow(OUTPUT_SCHEMA)) - (KEYS ++ VALUES).zipWithIndex.foreach { case (_, index) => row.setInt(index, index) } - row - } - - val TEST_KEY_ROW: UnsafeRow = { - val keyProjector = GenerateUnsafeProjection.generate(KEYS_ATTRIBUTES, OUTPUT_ATTRIBUTES) - keyProjector(TEST_ROW) - } - - val TEST_VALUE_ROW: UnsafeRow = { - val valueProjector = GenerateUnsafeProjection.generate(VALUES_ATTRIBUTES, OUTPUT_ATTRIBUTES) - valueProjector(TEST_ROW) - } - - private def createIntegerField(name: String): StructField = { - StructField(name, IntegerType, nullable = false) - } -} From 60c231e98a550b0e439827caff75a29c23423a9c Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 1 Aug 2018 19:24:55 +0900 Subject: [PATCH 08/13] WIP Address a part of review comments from @tdas * TODO list * replace all the usages for direct call of store.xxx whenever state manager is available * add iterator / remove in StreamingAggregationStateManager to remove restoreOriginRow * add docs --- .../sql/execution/streaming/OffsetSeq.scala | 4 +- .../streaming/StatefulOperatorsHelper.scala | 134 ------------------ .../execution/streaming/state/package.scala | 110 ++++++++++++++ .../streaming/statefulOperators.scala | 1 - ...reamingAggregationStateManagerSuite.scala} | 3 +- .../streaming/StreamingAggregationSuite.scala | 6 +- 6 files changed, 116 insertions(+), 142 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala rename sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/{StatefulOperatorsHelperSuite.scala => StreamingAggregationStateManagerSuite.scala} (96%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 816e388aceba7..73cf355dbe758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig -import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper +import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} /** @@ -106,7 +106,7 @@ object OffsetSeqMetadata extends Logging { FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> FlatMapGroupsWithStateExecHelper.legacyVersion.toString, STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> - StatefulOperatorsHelper.legacyVersion.toString + StreamingAggregationStateManager.legacyVersion.toString ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala deleted file mode 100644 index 7830e201e0709..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulOperatorsHelper.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} -import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair} -import org.apache.spark.sql.types.StructType - -object StatefulOperatorsHelper { - - val supportedVersions = Seq(1, 2) - val legacyVersion = 1 - - sealed trait StreamingAggregationStateManager extends Serializable { - def getKey(row: InternalRow): UnsafeRow - def getStateValueSchema: StructType - def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow - def get(store: StateStore, key: UnsafeRow): UnsafeRow - def put(store: StateStore, row: UnsafeRow): Unit - } - - object StreamingAggregationStateManager extends Logging { - def createStateManager( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute], - stateFormatVersion: Int): StreamingAggregationStateManager = { - stateFormatVersion match { - case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) - case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) - case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") - } - } - } - - abstract class StreamingAggregationStateManagerBaseImpl( - protected val keyExpressions: Seq[Attribute], - protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { - - @transient protected lazy val keyProjector = - GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) - - def getKey(row: InternalRow): UnsafeRow = keyProjector(row) - } - - class StreamingAggregationStateManagerImplV1( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - - override def getStateValueSchema: StructType = inputRowAttributes.toStructType - - override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { - rowPair.value - } - - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - store.get(key) - } - - override def put(store: StateStore, row: UnsafeRow): Unit = { - store.put(getKey(row), row) - } - } - - class StreamingAggregationStateManagerImplV2( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - - private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) - private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions - private val needToProjectToRestoreValue: Boolean = - keyValueJoinedExpressions != inputRowAttributes - - @transient private lazy val valueProjector = - GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) - - @transient private lazy val joiner = - GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), - StructType.fromAttributes(valueExpressions)) - @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( - keyValueJoinedExpressions, inputRowAttributes) - - override def getStateValueSchema: StructType = valueExpressions.toStructType - - override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { - val joinedRow = joiner.join(rowPair.key, rowPair.value) - if (needToProjectToRestoreValue) { - restoreValueProjector(joinedRow) - } else { - joinedRow - } - } - - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - val savedState = store.get(key) - if (savedState == null) { - return savedState - } - - val joinedRow = joiner.join(key, savedState) - if (needToProjectToRestoreValue) { - restoreValueProjector(joinedRow) - } else { - joinedRow - } - } - - override def put(store: StateStore, row: UnsafeRow): Unit = { - val key = keyProjector(row) - val value = valueProjector(row) - store.put(key, value) - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 0b32327e51dbf..429aeccd8e6d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -81,4 +85,110 @@ package object state { storeCoordinator) } } + + sealed trait StreamingAggregationStateManager extends Serializable { + def getKey(row: InternalRow): UnsafeRow + def getStateValueSchema: StructType + def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow + def get(store: StateStore, key: UnsafeRow): UnsafeRow + def put(store: StateStore, row: UnsafeRow): Unit + } + + object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } + } + + abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + def getKey(row: InternalRow): UnsafeRow = keyProjector(row) + } + + class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { + rowPair.value + } + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + } + + class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + keyValueJoinedExpressions, inputRowAttributes) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { + val joinedRow = joiner.join(rowPair.key, rowPair.value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + val joinedRow = joiner.join(key, savedState) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 217ced0cc11ac..2598f2da9da11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 2a8ef514c5779..b851651951234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulOperatorsHelperSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -class StatefulOperatorsHelperSuite extends StreamTest { +class StreamingAggregationStateManagerSuite extends StreamTest { // ============================ fields and method for test data ============================ val testKeys: Seq[String] = Seq("key1", "key2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 68b20846db3a1..819889971d111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -65,7 +65,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest def testWithAllStateVersions(name: String, confPairs: (String, String)*) (func: => Any): Unit = { - for (version <- StatefulOperatorsHelper.supportedVersions) { + for (version <- StreamingAggregationStateManager.supportedVersions) { test(s"$name - state format version $version") { executeFuncWithStateVersionSQLConf(version, confPairs, func) } @@ -74,7 +74,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) (func: => Any): Unit = { - for (version <- StatefulOperatorsHelper.supportedVersions) { + for (version <- StreamingAggregationStateManager.supportedVersions) { testQuietly(s"$name - state format version $version") { executeFuncWithStateVersionSQLConf(version, confPairs, func) } From b4a3807631cc8e12df367eeca554749fdd81a5ef Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 2 Aug 2018 05:49:03 +0900 Subject: [PATCH 09/13] WIP Address a part of review comments from @tdas * TODO list * add docs --- .../execution/streaming/state/package.scala | 53 ++++++++++++++----- .../streaming/statefulOperators.scala | 28 +++++++--- ...treamingAggregationStateManagerSuite.scala | 11 +++- 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 429aeccd8e6d5..3fa4b28a46921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -89,9 +89,13 @@ package object state { sealed trait StreamingAggregationStateManager extends Serializable { def getKey(row: InternalRow): UnsafeRow def getStateValueSchema: StructType - def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow def get(store: StateStore, key: UnsafeRow): UnsafeRow def put(store: StateStore, row: UnsafeRow): Unit + def commit(store: StateStore): Long + def remove(store: StateStore, key: UnsafeRow): Unit + def iterator(store: StateStore): Iterator[UnsafeRowPair] + def keys(store: StateStore): Iterator[UnsafeRow] + def values(store: StateStore): Iterator[UnsafeRow] } object StreamingAggregationStateManager extends Logging { @@ -118,6 +122,15 @@ package object state { GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) def getKey(row: InternalRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } } class StreamingAggregationStateManagerImplV1( @@ -127,10 +140,6 @@ package object state { override def getStateValueSchema: StructType = inputRowAttributes.toStructType - override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { - rowPair.value - } - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { store.get(key) } @@ -138,6 +147,14 @@ package object state { override def put(store: StateStore, row: UnsafeRow): Unit = { store.put(getKey(row), row) } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } } class StreamingAggregationStateManagerImplV2( @@ -161,15 +178,6 @@ package object state { override def getStateValueSchema: StructType = valueExpressions.toStructType - override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { - val joinedRow = joiner.join(rowPair.key, rowPair.value) - if (needToProjectToRestoreValue) { - restoreValueProjector(joinedRow) - } else { - joinedRow - } - } - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { val savedState = store.get(key) if (savedState == null) { @@ -189,6 +197,23 @@ package object state { val value = valueProjector(row) store.put(key, value) } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginRow(rowPair)) + } + + private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { + val joinedRow = joiner.join(rowPair.key, rowPair.value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2598f2da9da11..4c64754f2c369 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -165,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode { } } } + + protected def removeKeysOlderThanWatermark(storeManager: StreamingAggregationStateManager, + store: StateStore) + : Unit = { + if (watermarkPredicateForKeys.nonEmpty) { + storeManager.keys(store).foreach { keyRow => + if (watermarkPredicateForKeys.get.eval(keyRow)) { + store.remove(keyRow) + } + } + } + } } object WatermarkSupport { @@ -293,12 +305,12 @@ case class StateStoreSaveExec( } allRemovalsTimeMs += 0 commitTimeMs += timeTakenMs { - store.commit() + stateManager.commit(store) } setStoreMetrics(store) - store.iterator().map { rowPair => + stateManager.values(store).map { valueRow => numOutputRows += 1 - stateManager.restoreOriginRow(rowPair) + valueRow } // Update and output only rows being evicted from the StateStore @@ -314,7 +326,7 @@ case class StateStoreSaveExec( } val removalStartTimeNs = System.nanoTime - val rangeIter = store.getRange(None, None) + val rangeIter = stateManager.iterator(store) new NextIterator[InternalRow] { override protected def getNext(): InternalRow = { @@ -322,8 +334,8 @@ case class StateStoreSaveExec( while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() if (watermarkPredicateForKeys.get.eval(rowPair.key)) { - store.remove(rowPair.key) - removedValueRow = stateManager.restoreOriginRow(rowPair) + stateManager.remove(store, rowPair.key) + removedValueRow = rowPair.value } } if (removedValueRow == null) { @@ -336,7 +348,7 @@ case class StateStoreSaveExec( override protected def close(): Unit = { allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } @@ -370,7 +382,7 @@ case class StateStoreSaveExec( // Remove old aggregates if watermark specified allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } - commitTimeMs += timeTakenMs { store.commit() } + commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index b851651951234..1ee5c5df24941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -103,6 +103,7 @@ class StreamingAggregationStateManagerSuite extends StreamTest { stateManager.put(memoryStateStore, inputRow) assert(memoryStateStore.iterator().size === 1) + assert(stateManager.iterator(memoryStateStore).size === memoryStateStore.iterator().size) val keyRow = stateManager.getKey(inputRow) assert(keyRow === expectedStateKey) @@ -111,7 +112,15 @@ class StreamingAggregationStateManagerSuite extends StreamTest { val pair = memoryStateStore.iterator().next() assert(pair.key === keyRow) assert(pair.value === expectedStateValue) - assert(stateManager.restoreOriginRow(pair) === inputRow) + + // iterate with state manager and see whether original rows are returned as values + val pairFromStateManager = stateManager.iterator(memoryStateStore).next() + assert(pairFromStateManager.key === keyRow) + assert(pairFromStateManager.value === inputRow) + + // following as keys and values + assert(stateManager.keys(memoryStateStore).next() === keyRow) + assert(stateManager.values(memoryStateStore).next() === inputRow) // verify the stored value once again via get assert(memoryStateStore.get(keyRow) === expectedStateValue) From e0ee04af4f325db4813b8bf574c0de4cfbbbaed6 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 8 Aug 2018 11:20:34 +0900 Subject: [PATCH 10/13] Address a part of review comments from @tdas * add docs --- .../execution/streaming/state/package.scala | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 3fa4b28a46921..29fce616c2d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -86,15 +86,74 @@ package object state { } } + /** + * Base trait for state manager purposed to be used from streaming aggregations. + */ sealed trait StreamingAggregationStateManager extends Serializable { + + /** + * Extract columns consisting key from input row, and return the new row for key columns. + * + * @param row The input row. + * @return The row instance which only contains key columns. + */ def getKey(row: InternalRow): UnsafeRow + + /** + * Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. + * + * @return An instance of StructType representing schema for the value of state. + */ def getStateValueSchema: StructType + + /** + * Get the current value of a non-null key from the target state store. + * + * @param store The target StateStore instance. + * @param key The key whose associated value is to be returned. + * @return A non-null row if the key exists in the store, otherwise null. + */ def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + * + * @param store The target StateStore instance. + * @param row The input row. + */ def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + * + * @param store The target StateStore instance. + * @return The new state version. + */ def commit(store: StateStore): Long + + /** + * Remove a single non-null key from the target state store. + * + * @param store The target StateStore instance. + * @param key The key whose associated value is to be returned. + */ def remove(store: StateStore, key: UnsafeRow): Unit + + /** + * Return an iterator containing all the key-value pairs in target state store. + */ def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** + * Return an iterator containing all the keys in target state store. + */ def keys(store: StateStore): Iterator[UnsafeRow] + + /** + * Return an iterator containing all the values in target state store. + */ def values(store: StateStore): Iterator[UnsafeRow] } @@ -133,6 +192,18 @@ package object state { } } + /** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * This implementation only works when input row attributes contain all the key attributes. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ class StreamingAggregationStateManagerImplV1( keyExpressions: Seq[Attribute], inputRowAttributes: Seq[Attribute]) @@ -157,6 +228,21 @@ package object state { } } + /** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * This implementation only works when input row attributes contain all the key attributes. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ class StreamingAggregationStateManagerImplV2( keyExpressions: Seq[Attribute], inputRowAttributes: Seq[Attribute]) From 8629f59348a06002742c671c497f9ae73ec67aa9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 9 Aug 2018 09:58:30 +0900 Subject: [PATCH 11/13] WIP address review comments from @tdas TODO list * add test which reads checkpoint from 2.3.1 and runs query --- .../StreamingAggregationStateManager.scala | 205 ++++++++++++++++ .../execution/streaming/state/package.scala | 221 ------------------ .../streaming/statefulOperators.scala | 15 +- ...treamingAggregationStateManagerSuite.scala | 16 +- 4 files changed, 218 insertions(+), 239 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala new file mode 100644 index 0000000000000..9bfb9561b42a1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.types.StructType + +/** + * Base trait for state manager purposed to be used from streaming aggregations. + */ +sealed trait StreamingAggregationStateManager extends Serializable { + + /** Extract columns consisting key from input row, and return the new row for key columns. */ + def getKey(row: UnsafeRow): UnsafeRow + + /** Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. */ + def getStateValueSchema: StructType + + /** Get the current value of a non-null key from the target state store. */ + def get(store: StateStore, key: UnsafeRow): UnsafeRow + + /** + * Put a new value for a non-null key to the target state store. Note that key will be + * extracted from the input row, and the key would be same as the result of getKey(inputRow). + */ + def put(store: StateStore, row: UnsafeRow): Unit + + /** + * Commit all the updates that have been made to the target state store, and return the + * new version. + */ + def commit(store: StateStore): Long + + /** Remove a single non-null key from the target state store. */ + def remove(store: StateStore, key: UnsafeRow): Unit + + /** Return an iterator containing all the key-value pairs in target state store. */ + def iterator(store: StateStore): Iterator[UnsafeRowPair] + + /** Return an iterator containing all the keys in target state store. */ + def keys(store: StateStore): Iterator[UnsafeRow] + + /** Return an iterator containing all the values in target state store. */ + def values(store: StateStore): Iterator[UnsafeRow] +} + +object StreamingAggregationStateManager extends Logging { + val supportedVersions = Seq(1, 2) + val legacyVersion = 1 + + def createStateManager( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute], + stateFormatVersion: Int): StreamingAggregationStateManager = { + stateFormatVersion match { + case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) + case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) + case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") + } + } +} + +abstract class StreamingAggregationStateManagerBaseImpl( + protected val keyExpressions: Seq[Attribute], + protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { + + @transient protected lazy val keyProjector = + GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) + + override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row) + + override def commit(store: StateStore): Long = store.commit() + + override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) + + override def keys(store: StateStore): Iterator[UnsafeRow] = { + // discard and don't convert values to avoid computation + store.getRange(None, None).map(_.key) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 1. + * In state version 1, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: Same as input row attributes. The schema of value contains key expressions as well. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV1( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + override def getStateValueSchema: StructType = inputRowAttributes.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + store.get(key) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + store.put(getKey(row), row) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator() + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(_.value) + } +} + +/** + * The implementation of StreamingAggregationStateManager for state version 2. + * In state version 2, the schema of key and value in state are follow: + * + * - key: Same as key expressions. + * - value: The diff between input row attributes and key expressions. + * + * The schema of value is changed to optimize the memory/space usage in state, via removing + * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. + * + * @param keyExpressions The attributes of keys. + * @param inputRowAttributes The attributes of input row. + */ +class StreamingAggregationStateManagerImplV2( + keyExpressions: Seq[Attribute], + inputRowAttributes: Seq[Attribute]) + extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { + + private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) + private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions + + // flag to check whether the row needs to be project into input row attributes after join + // e.g. if the fields in the joined row are not in the expected order + private val needToProjectToRestoreValue: Boolean = + keyValueJoinedExpressions != inputRowAttributes + + @transient private lazy val valueProjector = + GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) + + @transient private lazy val joiner = + GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), + StructType.fromAttributes(valueExpressions)) + @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( + inputRowAttributes, keyValueJoinedExpressions) + + override def getStateValueSchema: StructType = valueExpressions.toStructType + + override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { + val savedState = store.get(key) + if (savedState == null) { + return savedState + } + + restoreOriginalRow(key, savedState) + } + + override def put(store: StateStore, row: UnsafeRow): Unit = { + val key = keyProjector(row) + val value = valueProjector(row) + store.put(key, value) + } + + override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { + store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))) + } + + override def values(store: StateStore): Iterator[UnsafeRow] = { + store.iterator().map(rowPair => restoreOriginalRow(rowPair)) + } + + private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = { + restoreOriginalRow(rowPair.key, rowPair.value) + } + + private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow = { + val joinedRow = joiner.join(key, value) + if (needToProjectToRestoreValue) { + restoreValueProjector(joinedRow) + } else { + joinedRow + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 29fce616c2d07..0b32327e51dbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -20,12 +20,8 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -85,221 +81,4 @@ package object state { storeCoordinator) } } - - /** - * Base trait for state manager purposed to be used from streaming aggregations. - */ - sealed trait StreamingAggregationStateManager extends Serializable { - - /** - * Extract columns consisting key from input row, and return the new row for key columns. - * - * @param row The input row. - * @return The row instance which only contains key columns. - */ - def getKey(row: InternalRow): UnsafeRow - - /** - * Calculate schema for the value of state. The schema is mainly passed to the StateStoreRDD. - * - * @return An instance of StructType representing schema for the value of state. - */ - def getStateValueSchema: StructType - - /** - * Get the current value of a non-null key from the target state store. - * - * @param store The target StateStore instance. - * @param key The key whose associated value is to be returned. - * @return A non-null row if the key exists in the store, otherwise null. - */ - def get(store: StateStore, key: UnsafeRow): UnsafeRow - - /** - * Put a new value for a non-null key to the target state store. Note that key will be - * extracted from the input row, and the key would be same as the result of getKey(inputRow). - * - * @param store The target StateStore instance. - * @param row The input row. - */ - def put(store: StateStore, row: UnsafeRow): Unit - - /** - * Commit all the updates that have been made to the target state store, and return the - * new version. - * - * @param store The target StateStore instance. - * @return The new state version. - */ - def commit(store: StateStore): Long - - /** - * Remove a single non-null key from the target state store. - * - * @param store The target StateStore instance. - * @param key The key whose associated value is to be returned. - */ - def remove(store: StateStore, key: UnsafeRow): Unit - - /** - * Return an iterator containing all the key-value pairs in target state store. - */ - def iterator(store: StateStore): Iterator[UnsafeRowPair] - - /** - * Return an iterator containing all the keys in target state store. - */ - def keys(store: StateStore): Iterator[UnsafeRow] - - /** - * Return an iterator containing all the values in target state store. - */ - def values(store: StateStore): Iterator[UnsafeRow] - } - - object StreamingAggregationStateManager extends Logging { - val supportedVersions = Seq(1, 2) - val legacyVersion = 1 - - def createStateManager( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute], - stateFormatVersion: Int): StreamingAggregationStateManager = { - stateFormatVersion match { - case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) - case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) - case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") - } - } - } - - abstract class StreamingAggregationStateManagerBaseImpl( - protected val keyExpressions: Seq[Attribute], - protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { - - @transient protected lazy val keyProjector = - GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) - - def getKey(row: InternalRow): UnsafeRow = keyProjector(row) - - override def commit(store: StateStore): Long = store.commit() - - override def remove(store: StateStore, key: UnsafeRow): Unit = store.remove(key) - - override def keys(store: StateStore): Iterator[UnsafeRow] = { - // discard and don't convert values to avoid computation - store.getRange(None, None).map(_.key) - } - } - - /** - * The implementation of StreamingAggregationStateManager for state version 1. - * In state version 1, the schema of key and value in state are follow: - * - * - key: Same as key expressions. - * - value: Same as input row attributes. The schema of value contains key expressions as well. - * - * This implementation only works when input row attributes contain all the key attributes. - * - * @param keyExpressions The attributes of keys. - * @param inputRowAttributes The attributes of input row. - */ - class StreamingAggregationStateManagerImplV1( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - - override def getStateValueSchema: StructType = inputRowAttributes.toStructType - - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - store.get(key) - } - - override def put(store: StateStore, row: UnsafeRow): Unit = { - store.put(getKey(row), row) - } - - override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { - store.iterator() - } - - override def values(store: StateStore): Iterator[UnsafeRow] = { - store.iterator().map(_.value) - } - } - - /** - * The implementation of StreamingAggregationStateManager for state version 2. - * In state version 2, the schema of key and value in state are follow: - * - * - key: Same as key expressions. - * - value: The diff between input row attributes and key expressions. - * - * The schema of value is changed to optimize the memory/space usage in state, via removing - * duplicated columns in key-value pair. Hence key columns are excluded from the schema of value. - * - * This implementation only works when input row attributes contain all the key attributes. - * - * @param keyExpressions The attributes of keys. - * @param inputRowAttributes The attributes of input row. - */ - class StreamingAggregationStateManagerImplV2( - keyExpressions: Seq[Attribute], - inputRowAttributes: Seq[Attribute]) - extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { - - private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) - private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions - private val needToProjectToRestoreValue: Boolean = - keyValueJoinedExpressions != inputRowAttributes - - @transient private lazy val valueProjector = - GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) - - @transient private lazy val joiner = - GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), - StructType.fromAttributes(valueExpressions)) - @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( - keyValueJoinedExpressions, inputRowAttributes) - - override def getStateValueSchema: StructType = valueExpressions.toStructType - - override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { - val savedState = store.get(key) - if (savedState == null) { - return savedState - } - - val joinedRow = joiner.join(key, savedState) - if (needToProjectToRestoreValue) { - restoreValueProjector(joinedRow) - } else { - joinedRow - } - } - - override def put(store: StateStore, row: UnsafeRow): Unit = { - val key = keyProjector(row) - val value = valueProjector(row) - store.put(key, value) - } - - override def iterator(store: StateStore): Iterator[UnsafeRowPair] = { - store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginRow(rowPair))) - } - - override def values(store: StateStore): Iterator[UnsafeRow] = { - store.iterator().map(rowPair => restoreOriginRow(rowPair)) - } - - private def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { - val joinedRow = joiner.join(rowPair.key, rowPair.value) - if (needToProjectToRestoreValue) { - restoreValueProjector(joinedRow) - } else { - joinedRow - } - } - } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 4c64754f2c369..34e26d85ae2ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -166,13 +166,13 @@ trait WatermarkSupport extends UnaryExecNode { } } - protected def removeKeysOlderThanWatermark(storeManager: StreamingAggregationStateManager, - store: StateStore) - : Unit = { + protected def removeKeysOlderThanWatermark( + storeManager: StreamingAggregationStateManager, + store: StateStore): Unit = { if (watermarkPredicateForKeys.nonEmpty) { storeManager.keys(store).foreach { keyRow => if (watermarkPredicateForKeys.get.eval(keyRow)) { - store.remove(keyRow) + storeManager.remove(store, keyRow) } } } @@ -234,11 +234,10 @@ case class StateStoreRestoreExec( // the `HashAggregateExec` will output a 0 value for the partial merge. We need to // restore the value, so that we don't overwrite our state with a 0 value, but rather // merge the 0 with existing state. - // In this case the value should represent origin row, so no need to restore. store.iterator().map(_.value) } else { iter.flatMap { row => - val key = stateManager.getKey(row) + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) val restoredRow = stateManager.get(store, key) numOutputRows += 1 Option(restoredRow).toSeq :+ row @@ -381,7 +380,9 @@ case class StateStoreSaveExec( allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) // Remove old aggregates if watermark specified - allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) } + allRemovalsTimeMs += timeTakenMs { + removeKeysOlderThanWatermark(stateManager, store) + } commitTimeMs += timeTakenMs { stateManager.commit(store) } setStoreMetrics(store) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index 1ee5c5df24941..daacdfd58c7b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -65,16 +65,18 @@ class StreamingAggregationStateManagerSuite extends StreamTest { // ============================ StateManagerImplV1 ============================ test("StateManager v1 - get, put, iter") { - val stateManager = newStateManager(testKeyAttributes, testOutputAttributes, 1) + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 1) // in V1, input row is stored as value testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow, - expectedTestKeyRow, testRow) + expectedTestKeyRow, expectedStateValue = testRow) } // ============================ StateManagerImplV2 ============================ test("StateManager v2 - get, put, iter") { - val stateManager = newStateManager(testKeyAttributes, testOutputAttributes, 2) + val stateManager = StreamingAggregationStateManager.createStateManager(testKeyAttributes, + testOutputAttributes, 2) // in V2, row for values itself (excluding keys from input row) is stored as value // so that stored value doesn't have key part, but state manager V2 will provide same output @@ -83,13 +85,6 @@ class StreamingAggregationStateManagerSuite extends StreamTest { expectedTestKeyRow, expectedTestValueRowForV2) } - private def newStateManager( - keysAttributes: Seq[Attribute], - inputRowAttributes: Seq[Attribute], - version: Int): StreamingAggregationStateManager = { - StreamingAggregationStateManager.createStateManager(keysAttributes, inputRowAttributes, version) - } - private def testGetPutIterOnStateManager( stateManager: StreamingAggregationStateManager, expectedValueSchema: StructType, @@ -128,5 +123,4 @@ class StreamingAggregationStateManagerSuite extends StreamTest { // state manager should return row which is same as input row regardless of format version assert(inputRow === stateManager.get(memoryStateStore, keyRow)) } - } From 65801a60aa35449f45c44f5ee71d32292960cb88 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 9 Aug 2018 13:05:09 +0900 Subject: [PATCH 12/13] Address review comments from @tdas * add test which restores query from Spark 2.3.1 --- .../commits/0 | 2 + .../commits/1 | 2 + .../metadata | 1 + .../offsets/0 | 3 + .../offsets/1 | 3 + .../state/0/0/.1.delta.crc | Bin 0 -> 12 bytes .../state/0/0/.2.delta.crc | Bin 0 -> 12 bytes .../state/0/0/1.delta | Bin 0 -> 46 bytes .../state/0/0/2.delta | Bin 0 -> 46 bytes .../state/0/1/.1.delta.crc | Bin 0 -> 12 bytes .../state/0/1/.2.delta.crc | Bin 0 -> 12 bytes .../state/0/1/1.delta | Bin 0 -> 77 bytes .../state/0/1/2.delta | Bin 0 -> 77 bytes .../state/0/2/.1.delta.crc | Bin 0 -> 12 bytes .../state/0/2/.2.delta.crc | Bin 0 -> 12 bytes .../state/0/2/1.delta | Bin 0 -> 46 bytes .../state/0/2/2.delta | Bin 0 -> 46 bytes .../state/0/3/.1.delta.crc | Bin 0 -> 12 bytes .../state/0/3/.2.delta.crc | Bin 0 -> 12 bytes .../state/0/3/1.delta | Bin 0 -> 46 bytes .../state/0/3/2.delta | Bin 0 -> 46 bytes .../state/0/4/.1.delta.crc | Bin 0 -> 12 bytes .../state/0/4/.2.delta.crc | Bin 0 -> 12 bytes .../state/0/4/1.delta | Bin 0 -> 46 bytes .../state/0/4/2.delta | Bin 0 -> 77 bytes .../streaming/StreamingAggregationSuite.scala | 55 ++++++++++++++++++ 26 files changed, 66 insertions(+) create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta create mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 new file mode 100644 index 0000000000000..83321cd95eb0c --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1 @@ -0,0 +1,2 @@ +v1 +{} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata new file mode 100644 index 0000000000000..c160d737278e1 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata @@ -0,0 +1 @@ +{"id":"2f32aca2-1b97-458f-a48f-109328724f09"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 new file mode 100644 index 0000000000000..acdc6e69e975a --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 new file mode 100644 index 0000000000000..27353e8724507 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}} +1 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..a395dee7224ea430536120f6d5af0b7d37aeb793 GIT binary patch literal 12 TcmYc;N@ieSU}C7h@Vg!W6k!Ae literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..b61bb872e950cb768d25284f23a3be27b3e0e132 GIT binary patch literal 12 TcmYc;N@ieSU}6w0e(D4O5jg_t literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..281b21e96090981faa965b468a31a06f73dc293a GIT binary patch literal 77 zcmeZ?GI7euPtI0VW?*120bDc literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..cf1d68e2acee3bca2b92320c4bafc702a6539ea0 GIT binary patch literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc new file mode 100644 index 0000000000000000000000000000000000000000..efa0266fa0c0a2cca1437e1214e4a31dec3e5833 GIT binary patch literal 12 TcmYc;N@ieSU}88n<9Qwc6r2O@ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta new file mode 100644 index 0000000000000000000000000000000000000000..6352978051846970ca41a0ca97fd79952105726d GIT binary patch literal 46 icmeZ?GI7euPtF!)VPIeY;oA+q9RGp92POd&g989JFAHe^ literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta new file mode 100644 index 0000000000000000000000000000000000000000..f4fb2520a4ac43f7ac9d87544480a1e7bb5053b6 GIT binary patch literal 77 zcmeZ?GI7euPtI0VW?*120b-T3tt`PnT7ZF(L70hy!4b%oU}InxVK~4DWP-qdAn<|e J6NLytNB|3?4Hp0a literal 0 HcmV?d00001 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 819889971d111..1ae6ff3a90989 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.{Locale, TimeZone} +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfterAll} import org.apache.spark.{SparkEnv, SparkException} @@ -38,6 +40,7 @@ import org.apache.spark.sql.streaming.OutputMode._ import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} +import org.apache.spark.util.Utils object FailureSingleton { var firstTime = true @@ -590,6 +593,58 @@ class StreamingAggregationSuite extends StateStoreMetricsTest ) } + + test("simple count, update mode - recovery from checkpoint uses state format version 1") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + val resourceUri = this.getClass.getResource( + "/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI + + val checkpointDir = Utils.createTempDir().getCanonicalFile + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(new File(resourceUri), checkpointDir) + + inputData.addData(3) + inputData.addData(3, 2) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath, + additionalConfs = Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")), + /* + Note: The checkpoint was generated using the following input in Spark version 2.3.1 + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)) + */ + + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + + Execute { query => + // Verify state format = 1 + val stateVersions = query.lastExecution.executedPlan.collect { + case f: StateStoreSaveExec => f.stateFormatVersion + case f: StateStoreRestoreExec => f.stateFormatVersion + } + assert(stateVersions.size == 2) + assert(stateVersions.forall(_ == 1)) + }, + + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { From 19888abc281d7a0689bf57e4c76bda918ad9306b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 21 Aug 2018 16:35:25 +0900 Subject: [PATCH 13/13] Remove .crc files as following up @tdas guide --- .../state/0/0/.1.delta.crc | Bin 12 -> 0 bytes .../state/0/0/.2.delta.crc | Bin 12 -> 0 bytes .../state/0/1/.1.delta.crc | Bin 12 -> 0 bytes .../state/0/1/.2.delta.crc | Bin 12 -> 0 bytes .../state/0/2/.1.delta.crc | Bin 12 -> 0 bytes .../state/0/2/.2.delta.crc | Bin 12 -> 0 bytes .../state/0/3/.1.delta.crc | Bin 12 -> 0 bytes .../state/0/3/.2.delta.crc | Bin 12 -> 0 bytes .../state/0/4/.1.delta.crc | Bin 12 -> 0 bytes .../state/0/4/.2.delta.crc | Bin 12 -> 0 bytes 10 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc delete mode 100644 sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.1.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/.2.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.1.delta.crc deleted file mode 100644 index a395dee7224ea430536120f6d5af0b7d37aeb793..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}C7h@Vg!W6k!Ae diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/.2.delta.crc deleted file mode 100644 index b61bb872e950cb768d25284f23a3be27b3e0e132..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}6w0e(D4O5jg_t diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.1.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/.2.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.1.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/.2.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.1.delta.crc deleted file mode 100644 index cf1d68e2acee3bca2b92320c4bafc702a6539ea0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}A7peP;>)5flQ* diff --git a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/.2.delta.crc deleted file mode 100644 index efa0266fa0c0a2cca1437e1214e4a31dec3e5833..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12 TcmYc;N@ieSU}88n<9Qwc6r2O@