diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 62ab1d7404000..dab54c7becfc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -25,10 +25,13 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.encoders.{RowEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder, encoderFor} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + /** * A framework for implementing tests for streaming queries and sources. * @@ -286,4 +289,54 @@ trait StreamTest extends QueryTest with Timeouts { } } } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. + * @param addData and add data action that adds the given numbers to the stream, encoding them + * as needed + */ + def createStressTest(ds: Dataset[Int], addData: Seq[Int] => StreamAction): Unit = { + implicit val intEncoder = ExpressionEncoder[Int] + var dataPos = 0 + var running = true + val actions = new ArrayBuffer[StreamAction]() + + def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } + + (1 to 500).foreach { i => + val rand = Random.nextDouble() + if(!running) { + rand match { + case r if r < 0.7 => // AddData + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + case _ => // StartStream + actions += StartStream + running = true + } + } else { + rand match { + case r if r < 0.1 => + addCheck() + + case r if r < 0.7 => // AddData + val numItems = Random.nextInt(10) + val data = dataPos until (dataPos + numItems) + dataPos += numItems + actions += addData(data) + + case _ => // StartStream + actions += StopStream + running = false + } + } + } + if(!running) { actions += StartStream } + addCheck() + testStream(ds)(actions: _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala new file mode 100644 index 0000000000000..b6e1fdad8d43f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{Dataset, StreamTest, Row} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +class MemorySourceStressSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("memory stress test") { + val input = MemoryStream[Int] + val mapped = input.toDS().map(_ + 1) + + createStressTest(mapped, AddData(input, _: _*)) + } +}