\" [%s]: " %
distinct_authors[0])
if primary_author == "":
@@ -184,7 +187,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
def cherry_pick(pr_num, merge_hash, default_branch):
- pick_ref = raw_input("Enter a branch name [%s]: " % default_branch)
+ pick_ref = input("Enter a branch name [%s]: " % default_branch)
if pick_ref == "":
pick_ref = default_branch
@@ -231,7 +234,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
asf_jira = jira.client.JIRA({'server': JIRA_API_BASE},
basic_auth=(JIRA_USERNAME, JIRA_PASSWORD))
- jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id)
+ jira_id = input("Enter a JIRA id [%s]: " % default_jira_id)
if jira_id == "":
jira_id = default_jira_id
@@ -276,7 +279,7 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
default_fix_versions = filter(lambda x: x != v, default_fix_versions)
default_fix_versions = ",".join(default_fix_versions)
- fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions)
+ fix_versions = input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions)
if fix_versions == "":
fix_versions = default_fix_versions
fix_versions = fix_versions.replace(" ", "").split(",")
@@ -315,7 +318,7 @@ def choose_jira_assignee(issue, asf_jira):
if author in commentors:
annotations.append("Commentor")
print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations)))
- raw_assignee = raw_input(
+ raw_assignee = input(
"Enter number of user, or userid, to assign to (blank to leave unassigned):")
if raw_assignee == "":
return None
@@ -428,7 +431,7 @@ def main():
# Assumes branch names can be sorted lexicographically
latest_branch = sorted(branch_names, reverse=True)[0]
- pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ")
+ pr_num = input("Which pull request would you like to merge? (e.g. 34): ")
pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num))
pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num))
@@ -440,7 +443,7 @@ def main():
print("I've re-written the title as follows to match the standard format:")
print("Original: %s" % pr["title"])
print("Modified: %s" % modified_title)
- result = raw_input("Would you like to use the modified title? (y/n): ")
+ result = input("Would you like to use the modified title? (y/n): ")
if result.lower() == "y":
title = modified_title
print("Using modified title:")
@@ -491,7 +494,7 @@ def main():
merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc)
pick_prompt = "Would you like to pick %s into another branch?" % merge_hash
- while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y":
+ while input("\n%s (y/n): " % pick_prompt).lower() == "y":
merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)]
if JIRA_IMPORTED:
diff --git a/dev/requirements.txt b/dev/requirements.txt
index 79782279f8fbd..fa833ab96b8e7 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -2,3 +2,4 @@ jira==1.0.3
PyGithub==1.26.0
Unidecode==0.04.19
pypandoc==1.3.3
+sphinx
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 408e446ea4822..7149616e534aa 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -629,6 +629,54 @@ specific to Spark on Kubernetes.
Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key
in the data of the referenced Kubernetes Secret. For example,
spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key
.
+
+
+ spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path |
+ (none) |
+
+ Add the Kubernetes Volume named VolumeName of the VolumeType type to the driver pod on the path specified in the value. For example,
+ spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint .
+ |
+
+
+ spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly |
+ (none) |
+
+ Specify if the mounted volume is read only or not. For example,
+ spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false .
+ |
+
+
+ spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].options.[OptionName] |
+ (none) |
+
+ Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value, must conform with Kubernetes option format. For example,
+ spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim .
+ |
+
+
+ spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path |
+ (none) |
+
+ Add the Kubernetes Volume named VolumeName of the VolumeType type to the executor pod on the path specified in the value. For example,
+ spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint .
+ |
+
+
+ spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.readOnly |
+ false |
+
+ Specify if the mounted volume is read only or not. For example,
+ spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.readOnly=false .
+ |
+
+
+ spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].options.[OptionName] |
+ (none) |
+
+ Configure Kubernetes Volume options passed to the Kubernetes with OptionName as key having specified value. For example,
+ spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.options.claimName=spark-pvc-claim .
+ |
spark.kubernetes.memoryOverheadFactor |
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 575da7205b529..0b265b0cb1b31 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -218,9 +218,10 @@ To use a custom metrics.properties for the application master and executors, upd
spark.yarn.dist.forceDownloadSchemes |
(none) |
- Comma-separated list of schemes for which files will be downloaded to the local disk prior to
+ Comma-separated list of schemes for which resources will be downloaded to the local disk prior to
being added to YARN's distributed cache. For use in cases where the YARN service does not
- support schemes that are supported by Spark, like http, https and ftp.
+ support schemes that are supported by Spark, like http, https and ftp, or jars required to be in the
+ local YARN client's classpath. Wildcard '*' is denoted to download resources for all the schemes.
|
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index cd7329b621122..ad23dae7c6b7c 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1850,6 +1850,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
- Since Spark 2.4, writing a dataframe with an empty or nested empty schema using any file formats (parquet, orc, json, text, csv etc.) is not allowed. An exception is thrown when attempting to write dataframes with empty schema.
- Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after promotes both sides to TIMESTAMP. To set `false` to `spark.sql.hive.compareDateTimestampInTimestamp` restores the previous behavior. This option will be removed in Spark 3.0.
- Since Spark 2.4, creating a managed table with nonempty location is not allowed. An exception is thrown when attempting to create a managed table with nonempty location. To set `true` to `spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the previous behavior. This option will be removed in Spark 3.0.
+ - Since Spark 2.4, renaming a managed table to existing location is not allowed. An exception is thrown when attempting to rename a managed table to existing location.
- Since Spark 2.4, the type coercion rules can automatically promote the argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest common type, no matter how the input arguments order. In prior Spark versions, the promotion could fail in some specific orders (e.g., TimestampType, IntegerType and StringType) and throw an exception.
- Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in addition to the traditional cache invalidation mechanism. The non-cascading cache invalidation mechanism allows users to remove a cache without impacting its dependent caches. This new cache invalidation mechanism is used in scenarios where the data of the cache to be removed is still valid, e.g., calling unpersist() on a Dataset, or dropping a temporary view. This allows users to free up memory and keep the desired caches valid at the same time.
- In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index c30959263cdfa..118b05355c74d 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -2176,6 +2176,8 @@ the input data stream (using `inputStream.repartition()`).
This distributes the received batches of data across the specified number of machines in the cluster
before further processing.
+For direct stream, please refer to [Spark Streaming + Kafka Integration Guide](streaming-kafka-integration.html)
+
### Level of Parallelism in Data Processing
{:.no_toc}
Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
index b6b163fa8b2cd..748bf58f30350 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -26,7 +26,9 @@
import scala.Tuple2;
+import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.serialization.StringDeserializer;
import org.apache.spark.SparkConf;
import org.apache.spark.streaming.api.java.*;
@@ -37,30 +39,33 @@
/**
* Consumes messages from one or more topics in Kafka and does wordcount.
- * Usage: JavaDirectKafkaWordCount
+ * Usage: JavaDirectKafkaWordCount
* is a list of one or more Kafka brokers
+ * is a consumer group name to consume from topics
* is a list of one or more kafka topics to consume from
*
* Example:
* $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \
- * topic1,topic2
+ * consumer-group topic1,topic2
*/
public final class JavaDirectKafkaWordCount {
private static final Pattern SPACE = Pattern.compile(" ");
public static void main(String[] args) throws Exception {
- if (args.length < 2) {
- System.err.println("Usage: JavaDirectKafkaWordCount \n" +
- " is a list of one or more Kafka brokers\n" +
- " is a list of one or more kafka topics to consume from\n\n");
+ if (args.length < 3) {
+ System.err.println("Usage: JavaDirectKafkaWordCount \n" +
+ " is a list of one or more Kafka brokers\n" +
+ " is a consumer group name to consume from topics\n" +
+ " is a list of one or more kafka topics to consume from\n\n");
System.exit(1);
}
StreamingExamples.setStreamingLogLevels();
String brokers = args[0];
- String topics = args[1];
+ String groupId = args[1];
+ String topics = args[2];
// Create context with a 2 seconds batch interval
SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount");
@@ -68,7 +73,10 @@ public static void main(String[] args) throws Exception {
Set topicsSet = new HashSet<>(Arrays.asList(topics.split(",")));
Map kafkaParams = new HashMap<>();
- kafkaParams.put("metadata.broker.list", brokers);
+ kafkaParams.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers);
+ kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId);
+ kafkaParams.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);
+ kafkaParams.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class);
// Create direct kafka stream with brokers and topics
JavaInputDStream> messages = KafkaUtils.createDirectStream(
diff --git a/graphx/pom.xml b/graphx/pom.xml
index fbe77fcb958d5..0f5dc548600b2 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -53,7 +53,7 @@
org.apache.xbean
- xbean-asm5-shaded
+ xbean-asm6-shaded
com.google.guava
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
index d76e84ed8c9ed..a559685b1633c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.HashSet
import scala.language.existentials
-import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor}
-import org.apache.xbean.asm5.Opcodes._
+import org.apache.xbean.asm6.{ClassReader, ClassVisitor, MethodVisitor}
+import org.apache.xbean.asm6.Opcodes._
import org.apache.spark.util.Utils
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
index b7072728d48f0..55b460f1a4524 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
+import scala.collection.mutable
+
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -76,12 +78,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
.setMaxIter(40)
.setWeightCol("weight")
.assignClusters(data)
- val localAssignments = assignments
- .select('id, 'cluster)
- .as[(Long, Int)].collect().toSet
- val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++
- (n1 until n).map(x => (x, 0)).toSet
- assert(localAssignments === expectedResult)
+ .select("id", "cluster")
+ .as[(Long, Int)]
+ .collect()
+
+ val predictions = Array.fill(2)(mutable.Set.empty[Long])
+ assignments.foreach {
+ case (id, cluster) => predictions(cluster) += id
+ }
+ assert(predictions.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
val assignments2 = new PowerIterationClustering()
.setK(2)
@@ -89,10 +94,15 @@ class PowerIterationClusteringSuite extends SparkFunSuite
.setInitMode("degree")
.setWeightCol("weight")
.assignClusters(data)
- val localAssignments2 = assignments2
- .select('id, 'cluster)
- .as[(Long, Int)].collect().toSet
- assert(localAssignments2 === expectedResult)
+ .select("id", "cluster")
+ .as[(Long, Int)]
+ .collect()
+
+ val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
+ assignments2.foreach {
+ case (id, cluster) => predictions2(cluster) += id
+ }
+ assert(predictions2.toSet === Set((0 until n1).toSet, (n1 until n).toSet))
}
test("supported input types") {
diff --git a/pom.xml b/pom.xml
index 90e64ff71d229..cd567e227f331 100644
--- a/pom.xml
+++ b/pom.xml
@@ -313,13 +313,13 @@
chill-java
${chill.version}
-
org.apache.xbean
- xbean-asm5-shaded
- 4.4
+ xbean-asm6-shaded
+ 4.8
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b606f9355e03b..f887e4570c85d 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -464,7 +464,8 @@ object DockerIntegrationTests {
*/
object DependencyOverrides {
lazy val settings = Seq(
- dependencyOverrides += "com.google.guava" % "guava" % "14.0.1")
+ dependencyOverrides += "com.google.guava" % "guava" % "14.0.1",
+ dependencyOverrides += "jline" % "jline" % "2.14.3")
}
/**
diff --git a/python/docs/Makefile b/python/docs/Makefile
index b8e079483c90c..1ed1f33af2326 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -1,19 +1,44 @@
# Makefile for Sphinx documentation
#
+ifndef SPHINXBUILD
+ifndef SPHINXPYTHON
+SPHINXBUILD = sphinx-build
+endif
+endif
+
+ifdef SPHINXBUILD
+# User-friendly check for sphinx-build.
+ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
+$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
+endif
+else
+# Note that there is an issue with Python version and Sphinx in PySpark documentation generation.
+# Please remove this check below when this issue is fixed. See SPARK-24530 for more details.
+PYTHON_VERSION_CHECK = $(shell $(SPHINXPYTHON) -c 'import sys; print(sys.version_info < (3, 0, 0))')
+ifeq ($(PYTHON_VERSION_CHECK), True)
+$(error Note that Python 3 is required to generate PySpark documentation correctly for now. Current Python executable was less than Python 3. See SPARK-24530. To force Sphinx to use a specific Python executable, please set SPHINXPYTHON to point to the Python 3 executable.)
+endif
+# Check if Sphinx is installed.
+ifeq ($(shell $(SPHINXPYTHON) -c 'import sphinx' >/dev/null 2>&1; echo $$?), 1)
+$(error Python executable '$(SPHINXPYTHON)' did not have Sphinx installed. Make sure you have Sphinx installed, then set the SPHINXPYTHON environment variable to point to the Python executable having Sphinx installed. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
+endif
+# Use 'SPHINXPYTHON -msphinx' instead of 'sphinx-build'. See https://github.com/sphinx-doc/sphinx/pull/3523 for more details.
+SPHINXBUILD = $(SPHINXPYTHON) -msphinx
+endif
+
# You can set these variables from the command line.
SPHINXOPTS ?=
-SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
+# You can set SPHINXBUILD to specify Sphinx build executable or SPHINXPYTHON to specify the Python executable used in Sphinx.
+# They follow:
+# 1. if SPHINXPYTHON is set, use Python. If SPHINXBUILD is set, use sphinx-build.
+# 2. If both are set, SPHINXBUILD has a higher priority over SPHINXPYTHON
+# 3. By default, SPHINXBUILD is used as 'sphinx-build'.
export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip)
-# User-friendly check for sphinx-build
-ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
-$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
-endif
-
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 6d77baf7349e4..2f0660040dc7c 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -1345,8 +1345,14 @@ def assignClusters(self, dataset):
if __name__ == "__main__":
import doctest
+ import numpy
import pyspark.ml.clustering
from pyspark.sql import SparkSession
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = pyspark.ml.clustering.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py
index 6a611a2b5b59d..2548fd0f50b33 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -1156,6 +1156,11 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values):
def _test():
import doctest
+ try:
+ # Numpy 1.14+ changed it's string format.
+ np.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index a06ab31a7a56a..370154fc6d62a 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -388,8 +388,14 @@ def summary(self, featuresCol, weightCol=None):
if __name__ == "__main__":
import doctest
+ import numpy
import pyspark.ml.stat
from pyspark.sql import SparkSession
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = pyspark.ml.stat.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 080cd299f4fde..e846834761e49 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -63,7 +63,7 @@ def _randomUID(cls):
Generate a unique unicode id for the object. The default implementation
concatenates the class name, "_", and 12 random hex chars.
"""
- return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:])
+ return unicode(cls.__name__ + "_" + uuid.uuid4().hex[-12:])
@inherit_doc
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 0cbabab13a896..b09469b9f5c2d 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -1042,7 +1042,13 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
def _test():
import doctest
+ import numpy
import pyspark.mllib.clustering
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = pyspark.mllib.clustering.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 36cb03369b8c0..6c65da58e4e2b 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -532,8 +532,14 @@ def accuracy(self):
def _test():
import doctest
+ import numpy
from pyspark.sql import SparkSession
import pyspark.mllib.evaluation
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = pyspark.mllib.evaluation.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 60d96d8d5ceb8..4afd6666400b0 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -1368,6 +1368,12 @@ def R(self):
def _test():
import doctest
+ import numpy
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)
if failure_count:
sys.exit(-1)
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index bba88542167ad..7e8b15056cabe 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -1364,9 +1364,15 @@ def toCoordinateMatrix(self):
def _test():
import doctest
+ import numpy
from pyspark.sql import SparkSession
from pyspark.mllib.linalg import Matrices
import pyspark.mllib.linalg.distributed
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = pyspark.mllib.linalg.distributed.__dict__.copy()
spark = SparkSession.builder\
.master("local[2]")\
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
index 3c75b132ecad2..937bb154c2356 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -303,7 +303,13 @@ def kolmogorovSmirnovTest(data, distName="norm", *params):
def _test():
import doctest
+ import numpy
from pyspark.sql import SparkSession
+ try:
+ # Numpy 1.14+ changed it's string format.
+ numpy.set_printoptions(legacy='1.13')
+ except TypeError:
+ pass
globs = globals().copy()
spark = SparkSession.builder\
.master("local[4]")\
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7e7e5822a6b20..951851804b1d8 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1370,7 +1370,10 @@ def takeUpToNumLeft(iterator):
iterator = iter(iterator)
taken = 0
while taken < left:
- yield next(iterator)
+ try:
+ yield next(iterator)
+ except StopIteration:
+ return
taken += 1
p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9652d3e79b875..9f61e29f9cd42 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1285,11 +1285,21 @@ def from_utc_timestamp(timestamp, tz):
that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
'2017-07-14 03:40:00.0'.
- >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
- >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect()
+ :param timestamp: the column that contains timestamps
+ :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc
+
+ .. versionchanged:: 2.4
+ `tz` can take a :class:`Column` containing timezone ID strings.
+
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
+ >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect()
[Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]
+ >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect()
+ [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))]
"""
sc = SparkContext._active_spark_context
+ if isinstance(tz, Column):
+ tz = _to_java_column(tz)
return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz))
@@ -1300,11 +1310,21 @@ def to_utc_timestamp(timestamp, tz):
zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield
'2017-07-14 01:40:00.0'.
- >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts'])
+ :param timestamp: the column that contains timestamps
+ :param tz: a string that has the ID of timezone, e.g. "GMT", "America/Los_Angeles", etc
+
+ .. versionchanged:: 2.4
+ `tz` can take a :class:`Column` containing timezone ID strings.
+
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
>>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()
[Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]
+ >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect()
+ [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))]
"""
sc = SparkContext._active_spark_context
+ if isinstance(tz, Column):
+ tz = _to_java_column(tz)
return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
@@ -2189,11 +2209,16 @@ def from_json(col, schema, options={}):
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=[Row(a=1)])]
+ >>> schema = schema_of_json(lit('''{"a": 0}'''))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=Row(a=1))]
"""
sc = SparkContext._active_spark_context
if isinstance(schema, DataType):
schema = schema.json()
+ elif isinstance(schema, Column):
+ schema = _to_java_column(schema)
jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options)
return Column(jc)
@@ -2235,6 +2260,28 @@ def to_json(col, options={}):
return Column(jc)
+@ignore_unicode_prefix
+@since(2.4)
+def schema_of_json(col):
+ """
+ Parses a column containing a JSON string and infers its schema in DDL format.
+
+ :param col: string column in json format
+
+ >>> from pyspark.sql.types import *
+ >>> data = [(1, '{"a": 1}')]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(schema_of_json(df.value).alias("json")).collect()
+ [Row(json=u'struct')]
+ >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()
+ [Row(json=u'struct')]
+ """
+
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.schema_of_json(_to_java_column(col))
+ return Column(jc)
+
+
@since(1.5)
def size(col):
"""
@@ -2463,6 +2510,28 @@ def arrays_zip(*cols):
return Column(sc._jvm.functions.arrays_zip(_to_seq(sc, cols, _to_java_column)))
+@since(2.4)
+def map_concat(*cols):
+ """Returns the union of all the given maps.
+
+ :param cols: list of column names (string) or list of :class:`Column` expressions
+
+ >>> from pyspark.sql.functions import map_concat
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
+ >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+ +--------------------------------+
+ |map3 |
+ +--------------------------------+
+ |[1 -> a, 2 -> b, 3 -> c, 1 -> d]|
+ +--------------------------------+
+ """
+ sc = SparkContext._active_spark_context
+ if len(cols) == 1 and isinstance(cols[0], (list, set)):
+ cols = cols[0]
+ jc = sc._jvm.functions.map_concat(_to_seq(sc, cols, _to_java_column))
+ return Column(jc)
+
+
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
diff --git a/python/setup.py b/python/setup.py
index d309e0564530a..45eb74eb87ce7 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -219,6 +219,7 @@ def _supports_symlinks():
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy']
)
diff --git a/repl/pom.xml b/repl/pom.xml
index 6f4a863c48bc7..861bbd7c49654 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -102,7 +102,7 @@
org.apache.xbean
- xbean-asm5-shaded
+ xbean-asm6-shaded
@@ -166,7 +166,7 @@
-
+
scala-2.12
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 4dc399827ffed..42298b06a2c86 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -22,8 +22,8 @@ import java.net.{URI, URL, URLEncoder}
import java.nio.channels.Channels
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.xbean.asm5._
-import org.apache.xbean.asm5.Opcodes._
+import org.apache.xbean.asm6._
+import org.apache.xbean.asm6.Opcodes._
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index bf33179ae3dab..f9a77e71ad618 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -220,11 +220,23 @@ private[spark] object Config extends Logging {
val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation."
val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets."
val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef."
+ val KUBERNETES_DRIVER_VOLUMES_PREFIX = "spark.kubernetes.driver.volumes."
val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label."
val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation."
val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets."
val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef."
+ val KUBERNETES_EXECUTOR_VOLUMES_PREFIX = "spark.kubernetes.executor.volumes."
+
+ val KUBERNETES_VOLUMES_HOSTPATH_TYPE = "hostPath"
+ val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim"
+ val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir"
+ val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path"
+ val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly"
+ val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path"
+ val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName"
+ val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium"
+ val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit"
val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv."
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
index 69bd03d1eda6f..5ecdd3a04d77b 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
@@ -25,9 +25,6 @@ private[spark] object Constants {
val SPARK_POD_DRIVER_ROLE = "driver"
val SPARK_POD_EXECUTOR_ROLE = "executor"
- // Annotations
- val SPARK_APP_NAME_ANNOTATION = "spark-app-name"
-
// Credentials secrets
val DRIVER_CREDENTIALS_SECRETS_BASE_DIR =
"/mnt/secrets/spark-kubernetes-credentials"
@@ -50,17 +47,14 @@ private[spark] object Constants {
val DEFAULT_BLOCKMANAGER_PORT = 7079
val DRIVER_PORT_NAME = "driver-rpc-port"
val BLOCK_MANAGER_PORT_NAME = "blockmanager"
- val EXECUTOR_PORT_NAME = "executor"
// Environment Variables
- val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT"
val ENV_DRIVER_URL = "SPARK_DRIVER_URL"
val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES"
val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY"
val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID"
val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID"
val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP"
- val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH"
val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_"
val ENV_CLASSPATH = "SPARK_CLASSPATH"
val ENV_DRIVER_BIND_ADDRESS = "SPARK_DRIVER_BIND_ADDRESS"
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
index b0ccaa36b01ed..51d205fdb68d1 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
@@ -59,6 +59,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf](
roleSecretNamesToMountPaths: Map[String, String],
roleSecretEnvNamesToKeyRefs: Map[String, String],
roleEnvs: Map[String, String],
+ roleVolumes: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]],
sparkFiles: Seq[String]) {
def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE)
@@ -155,6 +156,12 @@ private[spark] object KubernetesConf {
sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX)
val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs(
sparkConf, KUBERNETES_DRIVER_ENV_PREFIX)
+ val driverVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix(
+ sparkConf, KUBERNETES_DRIVER_VOLUMES_PREFIX).map(_.get)
+ // Also parse executor volumes in order to verify configuration
+ // before the driver pod is created
+ KubernetesVolumeUtils.parseVolumesWithPrefix(
+ sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get)
val sparkFiles = sparkConf
.getOption("spark.files")
@@ -171,6 +178,7 @@ private[spark] object KubernetesConf {
driverSecretNamesToMountPaths,
driverSecretEnvNamesToKeyRefs,
driverEnvs,
+ driverVolumes,
sparkFiles)
}
@@ -203,6 +211,8 @@ private[spark] object KubernetesConf {
val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs(
sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX)
val executorEnv = sparkConf.getExecutorEnv.toMap
+ val executorVolumes = KubernetesVolumeUtils.parseVolumesWithPrefix(
+ sparkConf, KUBERNETES_EXECUTOR_VOLUMES_PREFIX).map(_.get)
KubernetesConf(
sparkConf.clone(),
@@ -214,6 +224,7 @@ private[spark] object KubernetesConf {
executorMountSecrets,
executorEnvSecrets,
executorEnv,
+ executorVolumes,
Seq.empty[String])
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
index 593fb531a004d..66fff267545dc 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala
@@ -16,8 +16,6 @@
*/
package org.apache.spark.deploy.k8s
-import io.fabric8.kubernetes.api.model.LocalObjectReference
-
import org.apache.spark.SparkConf
import org.apache.spark.util.Utils
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
new file mode 100644
index 0000000000000..b1762d1efe2ea
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+private[spark] sealed trait KubernetesVolumeSpecificConf
+
+private[spark] case class KubernetesHostPathVolumeConf(
+ hostPath: String)
+ extends KubernetesVolumeSpecificConf
+
+private[spark] case class KubernetesPVCVolumeConf(
+ claimName: String)
+ extends KubernetesVolumeSpecificConf
+
+private[spark] case class KubernetesEmptyDirVolumeConf(
+ medium: Option[String],
+ sizeLimit: Option[String])
+ extends KubernetesVolumeSpecificConf
+
+private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf](
+ volumeName: String,
+ mountPath: String,
+ mountReadOnly: Boolean,
+ volumeConf: T)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
new file mode 100644
index 0000000000000..713df5fffc3a2
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import java.util.NoSuchElementException
+
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.deploy.k8s.Config._
+
+private[spark] object KubernetesVolumeUtils {
+ /**
+ * Extract Spark volume configuration properties with a given name prefix.
+ *
+ * @param sparkConf Spark configuration
+ * @param prefix the given property name prefix
+ * @return a Map storing with volume name as key and spec as value
+ */
+ def parseVolumesWithPrefix(
+ sparkConf: SparkConf,
+ prefix: String): Iterable[Try[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]] = {
+ val properties = sparkConf.getAllWithPrefix(prefix).toMap
+
+ getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) =>
+ val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY"
+ val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY"
+
+ for {
+ path <- properties.getTry(pathKey)
+ volumeConf <- parseVolumeSpecificConf(properties, volumeType, volumeName)
+ } yield KubernetesVolumeSpec(
+ volumeName = volumeName,
+ mountPath = path,
+ mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean),
+ volumeConf = volumeConf
+ )
+ }
+ }
+
+ /**
+ * Get unique pairs of volumeType and volumeName,
+ * assuming options are formatted in this way:
+ * `volumeType`.`volumeName`.`property` = `value`
+ * @param properties flat mapping of property names to values
+ * @return Set[(volumeType, volumeName)]
+ */
+ private def getVolumeTypesAndNames(
+ properties: Map[String, String]
+ ): Set[(String, String)] = {
+ properties.keys.flatMap { k =>
+ k.split('.').toList match {
+ case tpe :: name :: _ => Some((tpe, name))
+ case _ => None
+ }
+ }.toSet
+ }
+
+ private def parseVolumeSpecificConf(
+ options: Map[String, String],
+ volumeType: String,
+ volumeName: String): Try[KubernetesVolumeSpecificConf] = {
+ volumeType match {
+ case KUBERNETES_VOLUMES_HOSTPATH_TYPE =>
+ val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY"
+ for {
+ path <- options.getTry(pathKey)
+ } yield KubernetesHostPathVolumeConf(path)
+
+ case KUBERNETES_VOLUMES_PVC_TYPE =>
+ val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY"
+ for {
+ claimName <- options.getTry(claimNameKey)
+ } yield KubernetesPVCVolumeConf(claimName)
+
+ case KUBERNETES_VOLUMES_EMPTYDIR_TYPE =>
+ val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY"
+ val sizeLimitKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY"
+ Success(KubernetesEmptyDirVolumeConf(options.get(mediumKey), options.get(sizeLimitKey)))
+
+ case _ =>
+ Failure(new RuntimeException(s"Kubernetes Volume type `$volumeType` is not supported"))
+ }
+ }
+
+ /**
+ * Convenience wrapper to accumulate key lookup errors
+ */
+ implicit private class MapOps[A, B](m: Map[A, B]) {
+ def getTry(key: A): Try[B] = {
+ m
+ .get(key)
+ .fold[Try[B]](Failure(new NoSuchElementException(key.toString)))(Success(_))
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
index 143dc8a12304e..7e67b51de6e04 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
@@ -19,10 +19,10 @@ package org.apache.spark.deploy.k8s.features
import scala.collection.JavaConverters._
import scala.collection.mutable
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder}
+import io.fabric8.kubernetes.api.model._
import org.apache.spark.SparkException
-import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod}
+import org.apache.spark.deploy.k8s._
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.submit._
@@ -103,6 +103,7 @@ private[spark] class BasicDriverFeatureStep(
.addToImagePullSecrets(conf.imagePullSecrets(): _*)
.endSpec()
.build()
+
SparkPod(driverPod, driverContainer)
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
index 91c54a9776982..abaeff0313a79 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
@@ -18,10 +18,10 @@ package org.apache.spark.deploy.k8s.features
import scala.collection.JavaConverters._
-import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder}
+import io.fabric8.kubernetes.api.model._
import org.apache.spark.SparkException
-import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s._
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD}
@@ -173,6 +173,7 @@ private[spark] class BasicExecutorFeatureStep(
.addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*)
.endSpec()
.build()
+
SparkPod(executorPod, containerWithLimitCores)
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
new file mode 100644
index 0000000000000..bb0e2b3128efd
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import io.fabric8.kubernetes.api.model._
+
+import org.apache.spark.deploy.k8s._
+
+private[spark] class MountVolumesFeatureStep(
+ kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ extends KubernetesFeatureConfigStep {
+
+ override def configurePod(pod: SparkPod): SparkPod = {
+ val (volumeMounts, volumes) = constructVolumes(kubernetesConf.roleVolumes).unzip
+
+ val podWithVolumes = new PodBuilder(pod.pod)
+ .editSpec()
+ .addToVolumes(volumes.toSeq: _*)
+ .endSpec()
+ .build()
+
+ val containerWithVolumeMounts = new ContainerBuilder(pod.container)
+ .addToVolumeMounts(volumeMounts.toSeq: _*)
+ .build()
+
+ SparkPod(podWithVolumes, containerWithVolumeMounts)
+ }
+
+ override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty
+
+ private def constructVolumes(
+ volumeSpecs: Iterable[KubernetesVolumeSpec[_ <: KubernetesVolumeSpecificConf]]
+ ): Iterable[(VolumeMount, Volume)] = {
+ volumeSpecs.map { spec =>
+ val volumeMount = new VolumeMountBuilder()
+ .withMountPath(spec.mountPath)
+ .withReadOnly(spec.mountReadOnly)
+ .withName(spec.volumeName)
+ .build()
+
+ val volumeBuilder = spec.volumeConf match {
+ case KubernetesHostPathVolumeConf(hostPath) =>
+ new VolumeBuilder()
+ .withHostPath(new HostPathVolumeSource(hostPath))
+
+ case KubernetesPVCVolumeConf(claimName) =>
+ new VolumeBuilder()
+ .withPersistentVolumeClaim(
+ new PersistentVolumeClaimVolumeSource(claimName, spec.mountReadOnly))
+
+ case KubernetesEmptyDirVolumeConf(medium, sizeLimit) =>
+ new VolumeBuilder()
+ .withEmptyDir(
+ new EmptyDirVolumeSource(medium.getOrElse(""),
+ new Quantity(sizeLimit.orNull)))
+ }
+
+ val volume = volumeBuilder.withName(spec.volumeName).build()
+
+ (volumeMount, volume)
+ }
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
index 5762d8245f778..7208e3d377593 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.k8s.submit
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf}
-import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features._
import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep}
private[spark] class KubernetesDriverBuilder(
@@ -33,10 +33,13 @@ private[spark] class KubernetesDriverBuilder(
new MountSecretsFeatureStep(_),
provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
=> EnvSecretsFeatureStep) =
- new EnvSecretsFeatureStep(_),
- provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
- => LocalDirsFeatureStep) =
+ new EnvSecretsFeatureStep(_),
+ provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ => LocalDirsFeatureStep =
new LocalDirsFeatureStep(_),
+ provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
+ => MountVolumesFeatureStep) =
+ new MountVolumesFeatureStep(_),
provideJavaStep: (
KubernetesConf[KubernetesDriverSpecificConf]
=> JavaDriverFeatureStep) =
@@ -54,22 +57,25 @@ private[spark] class KubernetesDriverBuilder(
provideServiceStep(kubernetesConf),
provideLocalDirsStep(kubernetesConf))
- val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
- Some(provideSecretsStep(kubernetesConf)) } else None
-
- val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
- Some(provideEnvSecretsStep(kubernetesConf)) } else None
+ val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
+ Seq(provideSecretsStep(kubernetesConf))
+ } else Nil
+ val envSecretFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
+ Seq(provideEnvSecretsStep(kubernetesConf))
+ } else Nil
+ val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) {
+ Seq(provideVolumesStep(kubernetesConf))
+ } else Nil
val bindingsStep = kubernetesConf.roleSpecificConf.mainAppResource.map {
case JavaMainAppResource(_) =>
provideJavaStep(kubernetesConf)
case PythonMainAppResource(_) =>
- providePythonStep(kubernetesConf)}.getOrElse(provideJavaStep(kubernetesConf))
+ providePythonStep(kubernetesConf)}
+ .getOrElse(provideJavaStep(kubernetesConf))
- val allFeatures: Seq[KubernetesFeatureConfigStep] =
- (baseFeatures :+ bindingsStep) ++
- maybeRoleSecretNamesStep.toSeq ++
- maybeProvideSecretsStep.toSeq
+ val allFeatures = (baseFeatures :+ bindingsStep) ++
+ secretFeature ++ envSecretFeature ++ volumesFeature
var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap)
for (feature <- allFeatures) {
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
index c6e931a38405f..de2a52bc7a0b8 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
@@ -48,8 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit
sc: SparkContext,
masterURL: String,
scheduler: TaskScheduler): SchedulerBackend = {
- val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs(
- sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX)
val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient(
KUBERNETES_MASTER_INTERNAL_URL,
Some(sc.conf.get(KUBERNETES_NAMESPACE)),
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
index 769a0a5a63047..364b6fb367722 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
@@ -17,37 +17,41 @@
package org.apache.spark.scheduler.cluster.k8s
import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
-import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeatureConfigStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s.features._
+import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep}
private[spark] class KubernetesExecutorBuilder(
- provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep =
+ provideBasicStep: (KubernetesConf [KubernetesExecutorSpecificConf])
+ => BasicExecutorFeatureStep =
new BasicExecutorFeatureStep(_),
- provideSecretsStep:
- (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep =
+ provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
+ => MountSecretsFeatureStep =
new MountSecretsFeatureStep(_),
provideEnvSecretsStep:
(KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) =
new EnvSecretsFeatureStep(_),
provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf])
=> LocalDirsFeatureStep =
- new LocalDirsFeatureStep(_)) {
+ new LocalDirsFeatureStep(_),
+ provideVolumesStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
+ => MountVolumesFeatureStep) =
+ new MountVolumesFeatureStep(_)) {
def buildFromFeatures(
kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = {
- val baseFeatures = Seq(
- provideBasicStep(kubernetesConf),
- provideLocalDirsStep(kubernetesConf))
- val maybeRoleSecretNamesStep = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
- Some(provideSecretsStep(kubernetesConf)) } else None
+ val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf))
+ val secretFeature = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) {
+ Seq(provideSecretsStep(kubernetesConf))
+ } else Nil
+ val secretEnvFeature = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
+ Seq(provideEnvSecretsStep(kubernetesConf))
+ } else Nil
+ val volumesFeature = if (kubernetesConf.roleVolumes.nonEmpty) {
+ Seq(provideVolumesStep(kubernetesConf))
+ } else Nil
- val maybeProvideSecretsStep = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) {
- Some(provideEnvSecretsStep(kubernetesConf)) } else None
-
- val allFeatures: Seq[KubernetesFeatureConfigStep] =
- baseFeatures ++
- maybeRoleSecretNamesStep.toSeq ++
- maybeProvideSecretsStep.toSeq
+ val allFeatures = baseFeatures ++ secretFeature ++ secretEnvFeature ++ volumesFeature
var executorPod = SparkPod.initialPod()
for (feature <- allFeatures) {
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
new file mode 100644
index 0000000000000..d795d159773a8
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+class KubernetesVolumeUtilsSuite extends SparkFunSuite {
+ test("Parses hostPath volumes correctly") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.hostPath.volumeName.mount.path", "/path")
+ sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly === true)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] ===
+ KubernetesHostPathVolumeConf("/hostPath"))
+ }
+
+ test("Parses persistentVolumeClaim volumes correctly") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimeName")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly === true)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] ===
+ KubernetesPVCVolumeConf("claimeName"))
+ }
+
+ test("Parses emptyDir volumes correctly") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.emptyDir.volumeName.mount.path", "/path")
+ sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.emptyDir.volumeName.options.medium", "medium")
+ sparkConf.set("test.emptyDir.volumeName.options.sizeLimit", "5G")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly === true)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] ===
+ KubernetesEmptyDirVolumeConf(Some("medium"), Some("5G")))
+ }
+
+ test("Parses emptyDir volume options can be optional") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.emptyDir.volumeName.mount.path", "/path")
+ sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly === true)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesEmptyDirVolumeConf] ===
+ KubernetesEmptyDirVolumeConf(None, None))
+ }
+
+ test("Defaults optional readOnly to false") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.hostPath.volumeName.mount.path", "/path")
+ sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.mountReadOnly === false)
+ }
+
+ test("Gracefully fails on missing mount key") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.emptyDir.volumeName.mnt.path", "/path")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head
+ assert(volumeSpec.isFailure === true)
+ assert(volumeSpec.failed.get.getMessage === "emptyDir.volumeName.mount.path")
+ }
+
+ test("Gracefully fails on missing option key") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.hostPath.volumeName.mount.path", "/path")
+ sparkConf.set("test.hostPath.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.hostPath.volumeName.options.pth", "/hostPath")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head
+ assert(volumeSpec.isFailure === true)
+ assert(volumeSpec.failed.get.getMessage === "hostPath.volumeName.options.path")
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
index 04b909db9d9f3..165f46a07df2f 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
@@ -50,6 +50,12 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
TEST_IMAGE_PULL_SECRETS.map { secret =>
new LocalObjectReferenceBuilder().withName(secret).build()
}
+ private val emptyDriverSpecificConf = KubernetesDriverSpecificConf(
+ None,
+ APP_NAME,
+ MAIN_CLASS,
+ APP_ARGS)
+
test("Check the pod respects all configurations from the user.") {
val sparkConf = new SparkConf()
@@ -62,11 +68,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
.set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(","))
val kubernetesConf = KubernetesConf(
sparkConf,
- KubernetesDriverSpecificConf(
- Some(JavaMainAppResource("")),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS),
+ emptyDriverSpecificConf,
RESOURCE_NAME_PREFIX,
APP_ID,
DRIVER_LABELS,
@@ -74,6 +76,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
Map.empty,
Map.empty,
DRIVER_ENVS,
+ Nil,
Seq.empty[String])
val featureStep = new BasicDriverFeatureStep(kubernetesConf)
@@ -143,6 +146,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
Map.empty,
Map.empty,
DRIVER_ENVS,
+ Nil,
Seq.empty[String])
val pythonKubernetesConf = KubernetesConf(
pythonSparkConf,
@@ -158,6 +162,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
Map.empty,
Map.empty,
DRIVER_ENVS,
+ Nil,
Seq.empty[String])
val javaFeatureStep = new BasicDriverFeatureStep(javaKubernetesConf)
val pythonFeatureStep = new BasicDriverFeatureStep(pythonKubernetesConf)
@@ -176,11 +181,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
.set(CONTAINER_IMAGE, "spark-driver:latest")
val kubernetesConf = KubernetesConf(
sparkConf,
- KubernetesDriverSpecificConf(
- Some(JavaMainAppResource("")),
- APP_NAME,
- MAIN_CLASS,
- APP_ARGS),
+ emptyDriverSpecificConf,
RESOURCE_NAME_PREFIX,
APP_ID,
DRIVER_LABELS,
@@ -188,7 +189,9 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
Map.empty,
Map.empty,
DRIVER_ENVS,
+ Nil,
allFiles)
+
val step = new BasicDriverFeatureStep(kubernetesConf)
val additionalProperties = step.getAdditionalPodSystemProperties()
val expectedSparkConf = Map(
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
index f06030aa55c0c..a44fa1f2ffc63 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
@@ -89,6 +89,7 @@ class BasicExecutorFeatureStepSuite
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]))
val executor = step.configurePod(SparkPod.initialPod())
@@ -128,6 +129,7 @@ class BasicExecutorFeatureStepSuite
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]))
assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63)
}
@@ -148,6 +150,7 @@ class BasicExecutorFeatureStepSuite
Map.empty,
Map.empty,
Map("qux" -> "quux"),
+ Nil,
Seq.empty[String]))
val executor = step.configurePod(SparkPod.initialPod())
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
index 7cea83591f3e8..7e916b3854404 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala
@@ -61,6 +61,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD)
@@ -92,6 +93,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
@@ -130,6 +132,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf)
val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties()
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala
index 77d38bf19cd10..8b91e93eecd8c 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala
@@ -67,6 +67,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]))
assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod())
assert(configurationStep.getAdditionalKubernetesResources().size === 1)
@@ -98,6 +99,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]))
val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX +
DriverServiceFeatureStep.DRIVER_SVC_POSTFIX
@@ -119,6 +121,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]))
val resolvedService = configurationStep
.getAdditionalKubernetesResources()
@@ -149,6 +152,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]),
clock)
val driverService = configurationStep
@@ -176,6 +180,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]),
clock)
fail("The driver bind address should not be allowed.")
@@ -201,6 +206,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String]),
clock)
fail("The driver host address should not be allowed.")
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala
index af6b35eae484a..1c8d84b76c56b 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala
@@ -45,6 +45,7 @@ class EnvSecretsFeatureStepSuite extends SparkFunSuite{
Map.empty,
envVarsToKeys,
Map.empty,
+ Nil,
Seq.empty[String])
val step = new EnvSecretsFeatureStep(kubernetesConf)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
index bd6ce4b42fc8e..a339827b819a9 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Mockito
import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod}
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf, SparkPod}
class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
private val defaultLocalDir = "/var/data/default-local-dir"
@@ -45,6 +45,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
index eff75b8a15daa..2b49b72dfa569 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala
@@ -43,6 +43,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite {
secretNamesToMountPaths,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
val step = new MountSecretsFeatureStep(kubernetesConf)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
new file mode 100644
index 0000000000000..d309aa94ec115
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s._
+
+class MountVolumesFeatureStepSuite extends SparkFunSuite {
+ private val sparkConf = new SparkConf(false)
+ private val emptyKubernetesConf = KubernetesConf(
+ sparkConf = sparkConf,
+ roleSpecificConf = KubernetesDriverSpecificConf(
+ None,
+ "app-name",
+ "main",
+ Seq.empty),
+ appResourceNamePrefix = "resource",
+ appId = "app-id",
+ roleLabels = Map.empty,
+ roleAnnotations = Map.empty,
+ roleSecretNamesToMountPaths = Map.empty,
+ roleSecretEnvNamesToKeyRefs = Map.empty,
+ roleEnvs = Map.empty,
+ roleVolumes = Nil,
+ sparkFiles = Nil)
+
+ test("Mounts hostPath volumes") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ false,
+ KubernetesHostPathVolumeConf("/hostPath/tmp")
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp")
+ assert(configuredPod.container.getVolumeMounts.size() === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp")
+ assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume")
+ assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false)
+ }
+
+ test("Mounts pesistentVolumeClaims") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ true,
+ KubernetesPVCVolumeConf("pvcClaim")
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim
+ assert(pvcClaim.getClaimName === "pvcClaim")
+ assert(configuredPod.container.getVolumeMounts.size() === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp")
+ assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume")
+ assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === true)
+
+ }
+
+ test("Mounts emptyDir") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ false,
+ KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G"))
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir
+ assert(emptyDir.getMedium === "Memory")
+ assert(emptyDir.getSizeLimit.getAmount === "6G")
+ assert(configuredPod.container.getVolumeMounts.size() === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp")
+ assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume")
+ assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false)
+ }
+
+ test("Mounts emptyDir with no options") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ false,
+ KubernetesEmptyDirVolumeConf(None, None)
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val emptyDir = configuredPod.pod.getSpec.getVolumes.get(0).getEmptyDir
+ assert(emptyDir.getMedium === "")
+ assert(emptyDir.getSizeLimit.getAmount === null)
+ assert(configuredPod.container.getVolumeMounts.size() === 1)
+ assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp")
+ assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume")
+ assert(configuredPod.container.getVolumeMounts.get(0).getReadOnly === false)
+ }
+
+ test("Mounts multiple volumes") {
+ val hpVolumeConf = KubernetesVolumeSpec(
+ "hpVolume",
+ "/tmp",
+ false,
+ KubernetesHostPathVolumeConf("/hostPath/tmp")
+ )
+ val pvcVolumeConf = KubernetesVolumeSpec(
+ "checkpointVolume",
+ "/checkpoints",
+ true,
+ KubernetesPVCVolumeConf("pvcClaim")
+ )
+ val volumesConf = hpVolumeConf :: pvcVolumeConf :: Nil
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumesConf)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 2)
+ assert(configuredPod.container.getVolumeMounts.size() === 2)
+ }
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala
index 0f2bf2fa1d9b5..18874afe6e53a 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/JavaDriverFeatureStepSuite.scala
@@ -42,6 +42,7 @@ class JavaDriverFeatureStepSuite extends SparkFunSuite {
roleSecretNamesToMountPaths = Map.empty,
roleSecretEnvNamesToKeyRefs = Map.empty,
roleEnvs = Map.empty,
+ roleVolumes = Nil,
sparkFiles = Seq.empty[String])
val step = new JavaDriverFeatureStep(kubernetesConf)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala
index a1f9a5d9e264e..a5dac6869327d 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/bindings/PythonDriverFeatureStepSuite.scala
@@ -52,6 +52,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite {
roleSecretNamesToMountPaths = Map.empty,
roleSecretEnvNamesToKeyRefs = Map.empty,
roleEnvs = Map.empty,
+ roleVolumes = Nil,
sparkFiles = Seq.empty[String])
val step = new PythonDriverFeatureStep(kubernetesConf)
@@ -88,6 +89,7 @@ class PythonDriverFeatureStepSuite extends SparkFunSuite {
roleSecretNamesToMountPaths = Map.empty,
roleSecretEnvNamesToKeyRefs = Map.empty,
roleEnvs = Map.empty,
+ roleVolumes = Nil,
sparkFiles = Seq.empty[String])
val step = new PythonDriverFeatureStep(kubernetesConf)
val driverContainerwithPySpark = step.configurePod(baseDriverPod).container
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
index d045d9ae89c07..4d8e79189ff32 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
@@ -141,6 +141,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC)
when(kubernetesClient.pods()).thenReturn(podOperations)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
index 4e8c300543430..046e578b94629 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.deploy.k8s.submit
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf}
+import org.apache.spark.deploy.k8s._
+import org.apache.spark.deploy.k8s.features._
import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
import org.apache.spark.deploy.k8s.features.bindings.{JavaDriverFeatureStep, PythonDriverFeatureStep}
@@ -31,6 +32,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
private val JAVA_STEP_TYPE = "java-bindings"
private val PYSPARK_STEP_TYPE = "pyspark-bindings"
private val ENV_SECRETS_STEP_TYPE = "env-secrets"
+ private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes"
private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep])
@@ -56,6 +58,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
+ private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep])
+
private val builderUnderTest: KubernetesDriverBuilder =
new KubernetesDriverBuilder(
_ => basicFeatureStep,
@@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
_ => secretsStep,
_ => envSecretsStep,
_ => localDirsStep,
+ _ => mountVolumesStep,
_ => javaStep,
_ => pythonStep)
@@ -82,6 +88,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
@@ -107,6 +114,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
Map("secret" -> "secretMountPath"),
Map("EnvName" -> "SecretName:secretKey"),
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
@@ -134,6 +142,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
@@ -159,6 +168,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
@@ -169,6 +179,39 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
PYSPARK_STEP_TYPE)
}
+ test("Apply volumes step if mounts are present.") {
+ val volumeSpec = KubernetesVolumeSpec(
+ "volume",
+ "/tmp",
+ false,
+ KubernetesHostPathVolumeConf("/path"))
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ None,
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ volumeSpec :: Nil,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ MOUNT_VOLUMES_STEP_TYPE,
+ JAVA_STEP_TYPE)
+ }
+
+
private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*)
: Unit = {
assert(resolvedSpec.systemProperties.size === stepTypes.size)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
index a6bc8bce32926..d0b4127065eb7 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
@@ -19,14 +19,15 @@ package org.apache.spark.scheduler.cluster.k8s
import io.fabric8.kubernetes.api.model.PodBuilder
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod}
-import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep}
+import org.apache.spark.deploy.k8s._
+import org.apache.spark.deploy.k8s.features._
class KubernetesExecutorBuilderSuite extends SparkFunSuite {
private val BASIC_STEP_TYPE = "basic"
private val SECRETS_STEP_TYPE = "mount-secrets"
private val ENV_SECRETS_STEP_TYPE = "env-secrets"
private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
+ private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes"
private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep])
@@ -36,12 +37,15 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
+ private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
+ MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep])
private val builderUnderTest = new KubernetesExecutorBuilder(
_ => basicFeatureStep,
_ => mountSecretsStep,
_ => envSecretsStep,
- _ => localDirsStep)
+ _ => localDirsStep,
+ _ => mountVolumesStep)
test("Basic steps are consistently applied.") {
val conf = KubernetesConf(
@@ -55,6 +59,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
Map.empty,
Map.empty,
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE)
@@ -72,6 +77,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
Map("secret" -> "secretMountPath"),
Map("secret-name" -> "secret-key"),
Map.empty,
+ Nil,
Seq.empty[String])
validateStepTypesApplied(
builderUnderTest.buildFromFeatures(conf),
@@ -81,6 +87,32 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
ENV_SECRETS_STEP_TYPE)
}
+ test("Apply volumes step if mounts are present.") {
+ val volumeSpec = KubernetesVolumeSpec(
+ "volume",
+ "/tmp",
+ false,
+ KubernetesHostPathVolumeConf("/checkpoint"))
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesExecutorSpecificConf(
+ "executor-id", new PodBuilder().build()),
+ "prefix",
+ "appId",
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ volumeSpec :: Nil,
+ Seq.empty[String])
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ MOUNT_VOLUMES_STEP_TYPE)
+ }
+
private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = {
assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size)
stepTypes.foreach { stepType =>
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
index 2f4e115e84ecd..8bdb0f7a10795 100755
--- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh
@@ -51,12 +51,10 @@ esac
SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*"
env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt
-readarray -t SPARK_JAVA_OPTS < /tmp/java_opts.txt
-if [ -n "$SPARK_MOUNTED_CLASSPATH" ]; then
- SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_MOUNTED_CLASSPATH"
-fi
-if [ -n "$SPARK_MOUNTED_FILES_DIR" ]; then
- cp -R "$SPARK_MOUNTED_FILES_DIR/." .
+readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt
+
+if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then
+ SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH"
fi
if [ -n "$PYSPARK_FILES" ]; then
@@ -101,7 +99,7 @@ case "$SPARK_K8S_CMD" in
executor)
CMD=(
${JAVA_HOME}/bin/java
- "${SPARK_JAVA_OPTS[@]}"
+ "${SPARK_EXECUTOR_JAVA_OPTS[@]}"
-Xms$SPARK_EXECUTOR_MEMORY
-Xmx$SPARK_EXECUTOR_MEMORY
-cp "$SPARK_CLASSPATH"
diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh
index ea893fa39eede..3acd0f5cd3349 100755
--- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh
+++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh
@@ -27,6 +27,8 @@ IMAGE_TAG="N/A"
SPARK_MASTER=
NAMESPACE=
SERVICE_ACCOUNT=
+INCLUDE_TAGS=
+EXCLUDE_TAGS=
# Parse arguments
while (( "$#" )); do
@@ -59,6 +61,14 @@ while (( "$#" )); do
SERVICE_ACCOUNT="$2"
shift
;;
+ --include-tags)
+ INCLUDE_TAGS="$2"
+ shift
+ ;;
+ --exclude-tags)
+ EXCLUDE_TAGS="$2"
+ shift
+ ;;
*)
break
;;
@@ -90,4 +100,14 @@ then
properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER )
fi
+if [ -n $EXCLUDE_TAGS ];
+then
+ properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS )
+fi
+
+if [ -n $INCLUDE_TAGS ];
+then
+ properties=( ${properties[@]} -Dtest.include.tags=$INCLUDE_TAGS )
+fi
+
../../../build/mvn integration-test ${properties[@]}
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index 520bda89e034d..6a2fff891098b 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -40,6 +40,7 @@
minikube
docker.io/kubespark
+
jar
Spark Project Kubernetes Integration Tests
@@ -102,6 +103,15 @@
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ true
+
+
+
@@ -126,6 +136,7 @@
${spark.kubernetes.test.serviceAccountName}
${test.exclude.tags}
+ ${test.include.tags}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
index 65c513cf241a4..6e334c83fbde8 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
@@ -21,17 +21,17 @@ import java.nio.file.{Path, Paths}
import java.util.UUID
import java.util.regex.Pattern
-import scala.collection.JavaConverters._
-
import com.google.common.io.PatternFilenameFilter
import io.fabric8.kubernetes.api.model.{Container, Pod}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.scalatest.concurrent.{Eventually, PatienceConfiguration}
import org.scalatest.time.{Minutes, Seconds, Span}
+import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory}
import org.apache.spark.deploy.k8s.integrationtest.config._
+import org.apache.spark.launcher.SparkLauncher
private[spark] class KubernetesSuite extends SparkFunSuite
with BeforeAndAfterAll with BeforeAndAfter {
@@ -109,6 +109,12 @@ private[spark] class KubernetesSuite extends SparkFunSuite
runSparkPiAndVerifyCompletion()
}
+ test("Use SparkLauncher.NO_RESOURCE") {
+ sparkAppConf.setJars(Seq(containerLocalSparkDistroExamplesJar))
+ runSparkPiAndVerifyCompletion(
+ appResource = SparkLauncher.NO_RESOURCE)
+ }
+
test("Run SparkPi with a master URL without a scheme.") {
val url = kubernetesTestComponents.kubernetesClient.getMasterUrl
val k8sMasterUrl = if (url.getPort < 0) {
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala
index 48727142dd052..b2471e51116cb 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesTestComponents.scala
@@ -105,16 +105,13 @@ private[spark] object SparkAppLauncher extends Logging {
sparkHomeDir: Path): Unit = {
val sparkSubmitExecutable = sparkHomeDir.resolve(Paths.get("bin", "spark-submit"))
logInfo(s"Launching a spark app with arguments $appArguments and conf $appConf")
- val appArgsArray =
- if (appArguments.appArgs.length > 0) Array(appArguments.appArgs.mkString(" "))
- else Array[String]()
val commandLine = (Array(sparkSubmitExecutable.toFile.getAbsolutePath,
"--deploy-mode", "cluster",
"--class", appArguments.mainClass,
"--master", appConf.get("spark.master")
) ++ appConf.toStringArray :+
appArguments.mainAppResource) ++
- appArgsArray
+ appArguments.appArgs
ProcessUtils.executeProcess(commandLine, timeoutSecs)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 93df73ab1eaf6..6f5fbdd79e668 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -431,6 +431,12 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
+ case (keys: Array[_], values: Array[_]) =>
+ // case for mapdata with duplicate keys
+ new ArrayBasedMapData(
+ new GenericArrayData(keys.map(convertToCatalyst)),
+ new GenericArrayData(values.map(convertToCatalyst))
+ )
case other => other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f9acc208b715e..4543bba8f6ed4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -798,7 +798,12 @@ object ScalaReflection extends ScalaReflection {
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
- tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
+ tpe.dealias match {
+ // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type.
+ case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head)
+ case _ => tpe.dealias <:< localTypeOf[Product] ||
+ tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
+ }
}
private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a574d8a84d4fb..e7517e8c676e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -422,6 +422,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[MapFromEntries]("map_from_entries"),
+ expression[MapConcat]("map_concat"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
@@ -505,6 +506,7 @@ object FunctionRegistry {
// json
expression[StructsToJson]("to_json"),
expression[JsonToStructs]("from_json"),
+ expression[SchemaOfJson]("schema_of_json"),
// cast
expression[Cast]("cast"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 3ebab430ffbcd..e8331c90ea0f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -102,25 +102,7 @@ object TypeCoercion {
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
Some(TimestampType)
- case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
- Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
- // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
- // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
- // - Different names: use f1.name
- // - Different nullabilities: `nullable` is true iff one of them is nullable.
- val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
- StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
- }))
-
- case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
- findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
-
- case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
- val keyType = findTightestCommonType(kt1, kt2)
- val valueType = findTightestCommonType(vt1, vt2)
- Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
-
- case _ => None
+ case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
}
/** Promotes all the way to StringType. */
@@ -166,6 +148,42 @@ object TypeCoercion {
case (l, r) => None
}
+ private def findTypeForComplex(
+ t1: DataType,
+ t2: DataType,
+ findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match {
+ case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
+ findTypeFunc(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
+ case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
+ findTypeFunc(kt1, kt2).flatMap { kt =>
+ findTypeFunc(vt1, vt2).map { vt =>
+ MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
+ }
+ }
+ case (StructType(fields1), StructType(fields2)) if fields1.length == fields2.length =>
+ val resolver = SQLConf.get.resolver
+ fields1.zip(fields2).foldLeft(Option(new StructType())) {
+ case (Some(struct), (field1, field2)) if resolver(field1.name, field2.name) =>
+ findTypeFunc(field1.dataType, field2.dataType).map {
+ dt => struct.add(field1.name, dt, field1.nullable || field2.nullable)
+ }
+ case _ => None
+ }
+ case _ => None
+ }
+
+ /**
+ * The method finds a common type for data types that differ only in nullable, containsNull
+ * and valueContainsNull flags. If the input types are too different, None is returned.
+ */
+ def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = {
+ if (t1 == t2) {
+ Some(t1)
+ } else {
+ findTypeForComplex(t1, t2, findCommonTypeDifferentOnlyInNullFlags)
+ }
+ }
+
/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
@@ -176,11 +194,7 @@ object TypeCoercion {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse(stringPromotion(t1, t2))
- .orElse((t1, t2) match {
- case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
- case _ => None
- })
+ .orElse(findTypeForComplex(t1, t2, findWiderTypeForTwo))
}
/**
@@ -216,12 +230,7 @@ object TypeCoercion {
t2: DataType): Option[DataType] = {
findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
- .orElse((t1, t2) match {
- case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeWithoutStringPromotionForTwo(et1, et2)
- .map(ArrayType(_, containsNull1 || containsNull2))
- case _ => None
- })
+ .orElse(findTypeForComplex(t1, t2, findWiderTypeWithoutStringPromotionForTwo))
}
def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
@@ -551,6 +560,14 @@ object TypeCoercion {
case None => s
}
+ case m @ MapConcat(children) if children.forall(c => MapType.acceptsType(c.dataType)) &&
+ !haveSameType(children) =>
+ val types = children.map(_.dataType)
+ findWiderCommonType(types) match {
+ case Some(finalDataType) => MapConcat(children.map(Cast(_, finalDataType)))
+ case None => m
+ }
+
case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
@@ -655,8 +672,8 @@ object TypeCoercion {
object CaseWhenCoercion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
- val maybeCommonType = findWiderCommonType(c.valueTypes)
+ case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual =>
+ val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
@@ -688,10 +705,10 @@ object TypeCoercion {
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
- case i @ If(pred, left, right) if left.dataType != right.dataType =>
+ case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
- val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
- val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+ val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
+ val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 5ced1ca200daa..f68df5d29b545 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -315,8 +315,10 @@ object UnsupportedOperationChecker {
case GroupingSets(_, _, child, _) if child.isStreaming =>
throwError("GroupingSets is not supported on streaming DataFrames/Datasets")
- case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) =>
- throwError("Limits are not supported on streaming DataFrames/Datasets")
+ case GlobalLimit(_, _) | LocalLimit(_, _)
+ if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update =>
+ throwError("Limits are not supported on streaming DataFrames/Datasets in Update " +
+ "output mode")
case Sort(_, _, _) if !containsCompleteData(subPlan) =>
throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index c390337c03ff5..c26a34528c162 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -619,6 +619,7 @@ class SessionCatalog(
requireTableExists(TableIdentifier(oldTableName, Some(db)))
requireTableNotExists(TableIdentifier(newTableName, Some(db)))
validateName(newTableName)
+ validateNewLocationOfRename(oldName, newName)
externalCatalog.renameTable(db, oldTableName, newTableName)
} else {
if (newName.database.isDefined) {
@@ -1366,4 +1367,23 @@ class SessionCatalog(
// copy over temporary views
tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2))
}
+
+ /**
+ * Validate the new locatoin before renaming a managed table, which should be non-existent.
+ */
+ private def validateNewLocationOfRename(
+ oldName: TableIdentifier,
+ newName: TableIdentifier): Unit = {
+ val oldTable = getTableMetadata(oldName)
+ if (oldTable.tableType == CatalogTableType.MANAGED) {
+ val databaseLocation =
+ externalCatalog.getDatabase(oldName.database.getOrElse(currentDb)).locationUri
+ val newTableLocation = new Path(new Path(databaseLocation), formatTableName(newName.table))
+ val fs = newTableLocation.getFileSystem(hadoopConf)
+ if (fs.exists(newTableLocation)) {
+ throw new AnalysisException(s"Can not rename the managed table('$oldName')" +
+ s". The associated location('$newTableLocation') already exists.")
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 699ea53b5df0f..7971ae602bd37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -134,6 +134,26 @@ object Cast {
toPrecedence > 0 && fromPrecedence > toPrecedence
}
+ /**
+ * Returns true iff we can safely cast the `from` type to `to` type without any truncating or
+ * precision lose, e.g. int -> long, date -> timestamp.
+ */
+ def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match {
+ case _ if from == to => true
+ case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true
+ case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true
+ case (from, to) if legalNumericPrecedence(from, to) => true
+ case (DateType, TimestampType) => true
+ case (_, StringType) => true
+ case _ => false
+ }
+
+ private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+ val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
+ val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
+ fromPrecedence >= 0 && fromPrecedence < toPrecedence
+ }
+
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
case (_, _) if from == to => false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 9b9fa41a47d0f..44c5556ff9ccf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -695,6 +695,41 @@ abstract class TernaryExpression extends Expression {
}
}
+/**
+ * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type.
+ * This logic is usually utilized by expressions combining data from multiple child expressions
+ * of non-primitive types (e.g. [[CaseWhen]]).
+ */
+trait ComplexTypeMergingExpression extends Expression {
+
+ /**
+ * A collection of data types used for resolution the output type of the expression. By default,
+ * data types of all child expressions. The collection must not be empty.
+ */
+ @transient
+ lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
+
+ /**
+ * A method determining whether the input types are equal ignoring nullable, containsNull and
+ * valueContainsNull flags and thus convenient for resolution of the final data type.
+ */
+ def areInputTypesForMergingEqual: Boolean = {
+ inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
+ case Seq(dt1, dt2) => dt1.sameType(dt2)
+ }
+ }
+
+ override def dataType: DataType = {
+ require(
+ inputTypesForMerging.nonEmpty,
+ "The collection of input data types must not be empty.")
+ require(
+ areInputTypesForMergingEqual,
+ "All input types must be the same except nullable, containsNull, valueContainsNull flags.")
+ inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
+ }
+}
+
/**
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4cc0968911cb5..838c045d5bcce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1415,7 +1415,7 @@ object CodeGenerator extends Logging {
* weak keys/values and thus does not respond to memory pressure.
*/
private val cache = CacheBuilder.newBuilder()
- .maximumSize(100)
+ .maximumSize(SQLConf.get.codegenCacheMaxEntries)
.build(
new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() {
override def load(code: CodeAndComment): (GeneratedClass, Int) = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 250ce48d059e0..2f8c853e836ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool}
import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{BooleanType, DataType}
/**
@@ -118,10 +119,8 @@ object JavaCode {
/**
* A trait representing a block of java code.
*/
-trait Block extends JavaCode {
-
- // The expressions to be evaluated inside this block.
- def exprValues: Set[ExprValue]
+trait Block extends TreeNode[Block] with JavaCode {
+ import Block._
// Returns java code string for this code block.
override def toString: String = _marginChar match {
@@ -147,15 +146,48 @@ trait Block extends JavaCode {
this
}
+ /**
+ * Apply a map function to each java expression codes present in this java code, and return a new
+ * java code based on the mapped java expression codes.
+ */
+ def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = {
+ var changed = false
+
+ @inline def transform(e: ExprValue): ExprValue = {
+ val newE = f lift e
+ if (!newE.isDefined || newE.get.equals(e)) {
+ e
+ } else {
+ changed = true
+ newE.get
+ }
+ }
+
+ def doTransform(arg: Any): AnyRef = arg match {
+ case e: ExprValue => transform(e)
+ case Some(value) => Some(doTransform(value))
+ case seq: Traversable[_] => seq.map(doTransform)
+ case other: AnyRef => other
+ }
+
+ val newArgs = mapProductIterator(doTransform)
+ if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
+ }
+
// Concatenates this block with other block.
- def + (other: Block): Block
+ def + (other: Block): Block = other match {
+ case EmptyBlock => this
+ case _ => code"$this\n$other"
+ }
+
+ override def verboseString: String = toString
}
object Block {
val CODE_BLOCK_BUFFER_LENGTH: Int = 512
- implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
+ implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _)
implicit class BlockHelper(val sc: StringContext) extends AnyVal {
def code(args: Any*): Block = {
@@ -190,18 +222,17 @@ object Block {
while (strings.hasNext) {
val input = inputs.next
input match {
- case _: ExprValue | _: Block =>
+ case _: ExprValue | _: CodeBlock =>
codeParts += buf.toString
buf.clear
blockInputs += input.asInstanceOf[JavaCode]
+ case EmptyBlock =>
case _ =>
buf.append(input)
}
buf.append(strings.next)
}
- if (buf.nonEmpty) {
- codeParts += buf.toString
- }
+ codeParts += buf.toString
(codeParts.toSeq, blockInputs.toSeq)
}
@@ -209,15 +240,15 @@ object Block {
/**
* A block of java code. Including a sequence of code parts and some inputs to this block.
- * The actual java code is generated by embedding the inputs into the code parts.
+ * The actual java code is generated by embedding the inputs into the code parts. Here we keep
+ * inputs of `JavaCode` instead of simply folding them as a string of code, because we need to
+ * track expressions (`ExprValue`) in this code block. We need to be able to manipulate the
+ * expressions later without changing the behavior of this code block in some applications, e.g.,
+ * method splitting.
*/
case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
- override lazy val exprValues: Set[ExprValue] = {
- blockInputs.flatMap {
- case b: Block => b.exprValues
- case e: ExprValue => Set(e)
- }.toSet
- }
+ override def children: Seq[Block] =
+ blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]]
override lazy val code: String = {
val strings = codeParts.iterator
@@ -230,30 +261,11 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends
}
buf.toString
}
-
- override def + (other: Block): Block = other match {
- case c: CodeBlock => Blocks(Seq(this, c))
- case b: Blocks => Blocks(Seq(this) ++ b.blocks)
- case EmptyBlock => this
- }
}
-case class Blocks(blocks: Seq[Block]) extends Block {
- override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
- override lazy val code: String = blocks.map(_.toString).mkString("\n")
-
- override def + (other: Block): Block = other match {
- case c: CodeBlock => Blocks(blocks :+ c)
- case b: Blocks => Blocks(blocks ++ b.blocks)
- case EmptyBlock => this
- }
-}
-
-object EmptyBlock extends Block with Serializable {
+case object EmptyBlock extends Block with Serializable {
override val code: String = ""
- override val exprValues: Set[ExprValue] = Set.empty
-
- override def + (other: Block): Block = other
+ override def children: Seq[Block] = Seq.empty
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 8b278f067749e..879603b66b314 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -503,6 +503,237 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
override def prettyName: String = "map_entries"
}
+/**
+ * Returns the union of all the given maps.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
+ [[1 -> "a"], [2 -> "b"], [2 -> "c"], [3 -> "d"]]
+ """, since = "2.4.0")
+case class MapConcat(children: Seq[Expression]) extends Expression {
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ var funcName = s"function $prettyName"
+ if (children.exists(!_.dataType.isInstanceOf[MapType])) {
+ TypeCheckResult.TypeCheckFailure(
+ s"input to $funcName should all be of type map, but it's " +
+ children.map(_.dataType.simpleString).mkString("[", ", ", "]"))
+ } else {
+ TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
+ }
+ }
+
+ override def dataType: MapType = {
+ val dt = children.map(_.dataType.asInstanceOf[MapType]).headOption
+ .getOrElse(MapType(StringType, StringType))
+ val valueContainsNull = children.map(_.dataType.asInstanceOf[MapType])
+ .exists(_.valueContainsNull)
+ if (dt.valueContainsNull != valueContainsNull) {
+ dt.copy(valueContainsNull = valueContainsNull)
+ } else {
+ dt
+ }
+ }
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def eval(input: InternalRow): Any = {
+ val maps = children.map(_.eval(input))
+ if (maps.contains(null)) {
+ return null
+ }
+ val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
+ val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())
+
+ val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements())
+ if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
+ s"elements due to exceeding the map size limit " +
+ s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+ }
+ val finalKeyArray = new Array[AnyRef](numElements.toInt)
+ val finalValueArray = new Array[AnyRef](numElements.toInt)
+ var position = 0
+ for (i <- keyArrayDatas.indices) {
+ val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
+ val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
+ Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
+ Array.copy(valueArray, 0, finalValueArray, position, valueArray.length)
+ position += keyArray.length
+ }
+
+ new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
+ new GenericArrayData(finalValueArray))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val mapCodes = children.map(_.genCode(ctx))
+ val keyType = dataType.keyType
+ val valueType = dataType.valueType
+ val argsName = ctx.freshName("args")
+ val hasNullName = ctx.freshName("hasNull")
+ val mapDataClass = classOf[MapData].getName
+ val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
+ val arrayDataClass = classOf[ArrayData].getName
+
+ val init =
+ s"""
+ |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
+ |boolean ${ev.isNull}, $hasNullName = false;
+ |$mapDataClass ${ev.value} = null;
+ """.stripMargin
+
+ val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
+ s"""
+ |if (!$hasNullName) {
+ | ${m.code}
+ | $argsName[$i] = ${m.value};
+ | if (${m.isNull}) {
+ | $hasNullName = true;
+ | }
+ |}
+ """.stripMargin
+ }
+
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = assignments,
+ funcName = "getMapConcatInputs",
+ extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil,
+ returnType = "boolean",
+ makeSplitFunction = body =>
+ s"""
+ |$body
+ |return $hasNullName;
+ """.stripMargin,
+ foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n")
+ )
+
+ val idxName = ctx.freshName("idx")
+ val numElementsName = ctx.freshName("numElems")
+ val finKeysName = ctx.freshName("finalKeys")
+ val finValsName = ctx.freshName("finalValues")
+
+ val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
+ genCodeForPrimitiveArrays(ctx, keyType, false)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, keyType)
+ }
+
+ val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
+ genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, valueType)
+ }
+
+ val keyArgsName = ctx.freshName("keyArgs")
+ val valArgsName = ctx.freshName("valArgs")
+
+ val mapMerge =
+ s"""
+ |${ev.isNull} = $hasNullName;
+ |if (!${ev.isNull}) {
+ | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}];
+ | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}];
+ | long $numElementsName = 0;
+ | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
+ | $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
+ | $valArgsName[$idxName] = $argsName[$idxName].valueArray();
+ | $numElementsName += $argsName[$idxName].numElements();
+ | }
+ | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | throw new RuntimeException("Unsuccessful attempt to concat maps with " +
+ | $numElementsName + " elements due to exceeding the map size limit " +
+ | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
+ | }
+ | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
+ | (int) $numElementsName);
+ | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
+ | (int) $numElementsName);
+ | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
+ |}
+ """.stripMargin
+
+ ev.copy(
+ code = code"""
+ |$init
+ |$codes
+ |$mapMerge
+ """.stripMargin)
+ }
+
+ private def genCodeForPrimitiveArrays(
+ ctx: CodegenContext,
+ elementType: DataType,
+ checkForNull: Boolean): String = {
+ val counter = ctx.freshName("counter")
+ val arrayData = ctx.freshName("arrayData")
+ val argsName = ctx.freshName("args")
+ val numElemName = ctx.freshName("numElements")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+ val setterCode1 =
+ s"""
+ |$arrayData.set$primitiveValueTypeName(
+ | $counter,
+ | ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}
+ |);""".stripMargin
+
+ val setterCode = if (checkForNull) {
+ s"""
+ |if ($argsName[y].isNullAt(z)) {
+ | $arrayData.setNullAt($counter);
+ |} else {
+ | $setterCode1
+ |}""".stripMargin
+ } else {
+ setterCode1
+ }
+
+ s"""
+ |new Object() {
+ | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
+ | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < $argsName[y].numElements(); z++) {
+ | $setterCode
+ | $counter++;
+ | }
+ | }
+ | return $arrayData;
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayData = ctx.freshName("arrayObjects")
+ val counter = ctx.freshName("counter")
+ val argsName = ctx.freshName("args")
+ val numElemName = ctx.freshName("numElements")
+
+ s"""
+ |new Object() {
+ | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
+ | Object[] $arrayData = new Object[$numElemName];
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < $argsName[y].numElements(); z++) {
+ | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
+ | $counter++;
+ | }
+ | }
+ | return new $genericArrayClass($arrayData);
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ override def prettyName: String = "map_concat"
+}
+
/**
* Returns a map created from the given array of entries.
*/
@@ -1085,7 +1316,7 @@ case class ArrayContains(left: Expression, right: Expression)
if (right.dataType == NullType) {
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments")
} else if (!left.dataType.isInstanceOf[ArrayType]
- || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) {
+ || !left.dataType.asInstanceOf[ArrayType].elementType.sameType(right.dataType)) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 77ac6c088022e..e6377b7d87b53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
@@ -33,7 +33,12 @@ import org.apache.spark.sql.types._
""")
// scalastyle:on line.size.limit
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
- extends Expression {
+ extends ComplexTypeMergingExpression {
+
+ @transient
+ override lazy val inputTypesForMerging: Seq[DataType] = {
+ Seq(trueValue.dataType, falseValue.dataType)
+ }
override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
@@ -43,7 +48,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
"type of predicate expression in If should be boolean, " +
s"not ${predicate.dataType.simpleString}")
- } else if (!trueValue.dataType.sameType(falseValue.dataType)) {
+ } else if (!areInputTypesForMergingEqual) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
@@ -51,8 +56,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}
- override def dataType: DataType = trueValue.dataType
-
override def eval(input: InternalRow): Any = {
if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
trueValue.eval(input)
@@ -118,27 +121,24 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
case class CaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None)
- extends Expression with Serializable {
+ extends ComplexTypeMergingExpression with Serializable {
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
// both then and else expressions should be considered.
- def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
-
- def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
- case Seq(dt1, dt2) => dt1.sameType(dt2)
+ @transient
+ override lazy val inputTypesForMerging: Seq[DataType] = {
+ branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
}
- override def dataType: DataType = branches.head._2.dataType
-
override def nullable: Boolean = {
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
override def checkInputDataTypes(): TypeCheckResult = {
- // Make sure all branch conditions are boolean types.
- if (valueTypesEqual) {
+ if (areInputTypesForMergingEqual) {
+ // Make sure all branch conditions are boolean types.
if (branches.forall(_._1.dataType == BooleanType)) {
TypeCheckResult.TypeCheckSuccess
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index f6d74f5b74c8e..8cd86053a01c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter}
+import java.io._
import scala.util.parsing.combinator.RegexParsers
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -525,17 +526,19 @@ case class JsonToStructs(
override def nullable: Boolean = true
// Used in `FunctionRegistry`
- def this(child: Expression, schema: Expression) =
+ def this(child: Expression, schema: Expression, options: Map[String, String]) =
this(
- schema = JsonExprUtils.validateSchemaLiteral(schema),
- options = Map.empty[String, String],
+ schema = JsonExprUtils.evalSchemaExpr(schema),
+ options = options,
child = child,
timeZoneId = None,
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
+ def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String])
+
def this(child: Expression, schema: Expression, options: Expression) =
this(
- schema = JsonExprUtils.validateSchemaLiteral(schema),
+ schema = JsonExprUtils.evalSchemaExpr(schema),
options = JsonExprUtils.convertToMapData(options),
child = child,
timeZoneId = None,
@@ -744,11 +747,44 @@ case class StructsToJson(
override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil
}
+/**
+ * A function infers schema of JSON string.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('[{"col":0}]');
+ array>
+ """,
+ since = "2.4.0")
+case class SchemaOfJson(child: Expression)
+ extends UnaryExpression with String2StringExpression with CodegenFallback {
+
+ private val jsonOptions = new JSONOptions(Map.empty, "UTC")
+ private val jsonFactory = new JsonFactory()
+ jsonOptions.setJacksonOptions(jsonFactory)
+
+ override def convert(v: UTF8String): UTF8String = {
+ val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser =>
+ parser.nextToken()
+ inferField(parser, jsonOptions)
+ }
+
+ UTF8String.fromString(dt.catalogString)
+ }
+}
+
object JsonExprUtils {
- def validateSchemaLiteral(exp: Expression): DataType = exp match {
+ def evalSchemaExpr(exp: Expression): DataType = exp match {
case Literal(s, StringType) => DataType.fromDDL(s.toString)
- case e => throw new AnalysisException(s"Expected a string literal instead of $e")
+ case e @ SchemaOfJson(_: Literal) =>
+ val ddlSchema = e.eval().asInstanceOf[UTF8String]
+ DataType.fromDDL(ddlSchema.toString)
+ case e => throw new AnalysisException(
+ "Schema should be specified in DDL format as a string literal" +
+ s" or output of the schema_of_json function instead of ${e.sql}")
}
def convertToMapData(exp: Expression): Map[String, String] = exp match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
similarity index 98%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index 8e1b430f4eb33..491ca005877f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.datasources.json
+package org.apache.spark.sql.catalyst.json
import java.util.Comparator
@@ -25,7 +25,6 @@ import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
-import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -103,7 +102,7 @@ private[sql] object JsonInferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
- private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
+ def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8a12a74a73dff..cca3a6188d3e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -883,6 +883,21 @@ object SQLConf {
.stringConf
.createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol")
+ val STREAMING_MULTIPLE_WATERMARK_POLICY =
+ buildConf("spark.sql.streaming.multipleWatermarkPolicy")
+ .doc("Policy to calculate the global watermark value when there are multiple watermark " +
+ "operators in a streaming query. The default value is 'min' which chooses " +
+ "the minimum watermark reported across multiple operators. Other alternative value is" +
+ "'max' which chooses the maximum across multiple operators." +
+ "Note: This configuration cannot be changed between query restarts from the same " +
+ "checkpoint location.")
+ .stringConf
+ .checkValue(
+ str => Set("min", "max").contains(str.toLowerCase),
+ "Invalid value for 'spark.sql.streaming.multipleWatermarkPolicy'. " +
+ "Valid values are 'min' and 'max'")
+ .createWithDefault("min") // must be same as MultipleWatermarkPolicy.DEFAULT_POLICY_NAME
+
val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD =
buildConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold")
.internal()
@@ -1516,6 +1531,8 @@ class SQLConf extends Serializable with Logging {
def tableRelationCacheSize: Int =
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
+ def codegenCacheMaxEntries: Int = getConf(StaticSQLConf.CODEGEN_CACHE_MAX_ENTRIES)
+
def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index 382ef28f49a7a..384b1917a1f79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -66,6 +66,14 @@ object StaticSQLConf {
.checkValue(cacheSize => cacheSize >= 0, "The maximum size of the cache must not be negative")
.createWithDefault(1000)
+ val CODEGEN_CACHE_MAX_ENTRIES = buildStaticConf("spark.sql.codegen.cache.maxEntries")
+ .internal()
+ .doc("When nonzero, enable caching of generated classes for operators and expressions. " +
+ "All jobs share the cache that can use up to the specified number for generated classes.")
+ .intConf
+ .checkValue(maxEntries => maxEntries >= 0, "The maximum must not be negative")
+ .createWithDefault(100)
+
// When enabling the debug, Spark SQL internal table properties are not filtered out; however,
// some related DDL commands (e.g., ANALYZE TABLE and CREATE TABLE LIKE) might not work properly.
val DEBUG_MODE = buildStaticConf("spark.sql.debug")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 0acd3b490447d..8cc5a23779a2a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -54,8 +54,9 @@ class TypeCoercionSuite extends AnalysisTest {
// | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType |
// | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X |
// +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+
- // Note: MapType*, StructType* are castable only when the internal child types also match; otherwise, not castable.
+ // Note: StructType* is castable when all the internal child types are castable according to the table.
// Note: ArrayType* is castable when the element type is castable according to the table.
+ // Note: MapType* is castable when both the key type and the value type are castable according to the table.
// scalastyle:on line.size.limit
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
@@ -396,7 +397,7 @@ class TypeCoercionSuite extends AnalysisTest {
widenTest(
StructType(Seq(StructField("a", IntegerType, nullable = false))),
StructType(Seq(StructField("a", DoubleType, nullable = false))),
- None)
+ Some(StructType(Seq(StructField("a", DoubleType, nullable = false)))))
widenTest(
StructType(Seq(StructField("a", IntegerType, nullable = false))),
@@ -453,15 +454,18 @@ class TypeCoercionSuite extends AnalysisTest {
def widenTestWithStringPromotion(
t1: DataType,
t2: DataType,
- expected: Option[DataType]): Unit = {
- checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected)
+ expected: Option[DataType],
+ isSymmetric: Boolean = true): Unit = {
+ checkWidenType(TypeCoercion.findWiderTypeForTwo, t1, t2, expected, isSymmetric)
}
def widenTestWithoutStringPromotion(
t1: DataType,
t2: DataType,
- expected: Option[DataType]): Unit = {
- checkWidenType(TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected)
+ expected: Option[DataType],
+ isSymmetric: Boolean = true): Unit = {
+ checkWidenType(
+ TypeCoercion.findWiderTypeWithoutStringPromotionForTwo, t1, t2, expected, isSymmetric)
}
// Decimal
@@ -487,12 +491,108 @@ class TypeCoercionSuite extends AnalysisTest {
ArrayType(ArrayType(IntegerType), containsNull = false),
ArrayType(ArrayType(LongType), containsNull = false),
Some(ArrayType(ArrayType(LongType), containsNull = false)))
+ widenTestWithStringPromotion(
+ ArrayType(MapType(IntegerType, FloatType), containsNull = false),
+ ArrayType(MapType(LongType, DoubleType), containsNull = false),
+ Some(ArrayType(MapType(LongType, DoubleType), containsNull = false)))
+ widenTestWithStringPromotion(
+ ArrayType(new StructType().add("num", ShortType), containsNull = false),
+ ArrayType(new StructType().add("num", LongType), containsNull = false),
+ Some(ArrayType(new StructType().add("num", LongType), containsNull = false)))
+
+ // MapType
+ widenTestWithStringPromotion(
+ MapType(ShortType, TimestampType, valueContainsNull = true),
+ MapType(DoubleType, StringType, valueContainsNull = false),
+ Some(MapType(DoubleType, StringType, valueContainsNull = true)))
+ widenTestWithStringPromotion(
+ MapType(IntegerType, ArrayType(TimestampType), valueContainsNull = false),
+ MapType(LongType, ArrayType(StringType), valueContainsNull = true),
+ Some(MapType(LongType, ArrayType(StringType), valueContainsNull = true)))
+ widenTestWithStringPromotion(
+ MapType(IntegerType, MapType(ShortType, TimestampType), valueContainsNull = false),
+ MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false),
+ Some(MapType(LongType, MapType(DoubleType, StringType), valueContainsNull = false)))
+ widenTestWithStringPromotion(
+ MapType(IntegerType, new StructType().add("num", ShortType), valueContainsNull = false),
+ MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false),
+ Some(MapType(LongType, new StructType().add("num", LongType), valueContainsNull = false)))
+
+ // StructType
+ widenTestWithStringPromotion(
+ new StructType()
+ .add("num", ShortType, nullable = true).add("ts", StringType, nullable = false),
+ new StructType()
+ .add("num", DoubleType, nullable = false).add("ts", TimestampType, nullable = true),
+ Some(new StructType()
+ .add("num", DoubleType, nullable = true).add("ts", StringType, nullable = true)))
+ widenTestWithStringPromotion(
+ new StructType()
+ .add("arr", ArrayType(ShortType, containsNull = false), nullable = false),
+ new StructType()
+ .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false),
+ Some(new StructType()
+ .add("arr", ArrayType(DoubleType, containsNull = true), nullable = false)))
+ widenTestWithStringPromotion(
+ new StructType()
+ .add("map", MapType(ShortType, TimestampType, valueContainsNull = true), nullable = false),
+ new StructType()
+ .add("map", MapType(DoubleType, StringType, valueContainsNull = false), nullable = false),
+ Some(new StructType()
+ .add("map", MapType(DoubleType, StringType, valueContainsNull = true), nullable = false)))
+
+ widenTestWithStringPromotion(
+ new StructType().add("num", IntegerType),
+ new StructType().add("num", LongType).add("str", StringType),
+ None)
+ widenTestWithoutStringPromotion(
+ new StructType().add("num", IntegerType),
+ new StructType().add("num", LongType).add("str", StringType),
+ None)
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ widenTestWithStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("A", LongType),
+ None)
+ widenTestWithoutStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("A", LongType),
+ None)
+ }
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ widenTestWithStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("A", LongType),
+ Some(new StructType().add("a", LongType)),
+ isSymmetric = false)
+ widenTestWithoutStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("A", LongType),
+ Some(new StructType().add("a", LongType)),
+ isSymmetric = false)
+ }
// Without string promotion
widenTestWithoutStringPromotion(IntegerType, StringType, None)
widenTestWithoutStringPromotion(StringType, TimestampType, None)
widenTestWithoutStringPromotion(ArrayType(LongType), ArrayType(StringType), None)
widenTestWithoutStringPromotion(ArrayType(StringType), ArrayType(TimestampType), None)
+ widenTestWithoutStringPromotion(
+ MapType(LongType, IntegerType), MapType(StringType, IntegerType), None)
+ widenTestWithoutStringPromotion(
+ MapType(IntegerType, LongType), MapType(IntegerType, StringType), None)
+ widenTestWithoutStringPromotion(
+ MapType(StringType, IntegerType), MapType(TimestampType, IntegerType), None)
+ widenTestWithoutStringPromotion(
+ MapType(IntegerType, StringType), MapType(IntegerType, TimestampType), None)
+ widenTestWithoutStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("a", StringType),
+ None)
+ widenTestWithoutStringPromotion(
+ new StructType().add("a", StringType),
+ new StructType().add("a", IntegerType),
+ None)
// String promotion
widenTestWithStringPromotion(IntegerType, StringType, Some(StringType))
@@ -501,6 +601,30 @@ class TypeCoercionSuite extends AnalysisTest {
ArrayType(LongType), ArrayType(StringType), Some(ArrayType(StringType)))
widenTestWithStringPromotion(
ArrayType(StringType), ArrayType(TimestampType), Some(ArrayType(StringType)))
+ widenTestWithStringPromotion(
+ MapType(LongType, IntegerType),
+ MapType(StringType, IntegerType),
+ Some(MapType(StringType, IntegerType)))
+ widenTestWithStringPromotion(
+ MapType(IntegerType, LongType),
+ MapType(IntegerType, StringType),
+ Some(MapType(IntegerType, StringType)))
+ widenTestWithStringPromotion(
+ MapType(StringType, IntegerType),
+ MapType(TimestampType, IntegerType),
+ Some(MapType(StringType, IntegerType)))
+ widenTestWithStringPromotion(
+ MapType(IntegerType, StringType),
+ MapType(IntegerType, TimestampType),
+ Some(MapType(IntegerType, StringType)))
+ widenTestWithStringPromotion(
+ new StructType().add("a", IntegerType),
+ new StructType().add("a", StringType),
+ Some(new StructType().add("a", StringType)))
+ widenTestWithStringPromotion(
+ new StructType().add("a", StringType),
+ new StructType().add("a", IntegerType),
+ Some(new StructType().add("a", StringType)))
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d7744eb4c7dc7..173c98af323b1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -98,6 +98,132 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapEntries(ms2), null)
}
+ test("Map Concat") {
+ val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType,
+ valueContainsNull = false))
+ val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType,
+ valueContainsNull = false))
+ val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType))
+ val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+ val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType))
+ val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType))
+ val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType))
+ val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2),
+ MapType(ArrayType(IntegerType), IntegerType))
+ val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4),
+ MapType(ArrayType(IntegerType), IntegerType))
+ val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2),
+ MapType(MapType(IntegerType, IntegerType), IntegerType))
+ val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4),
+ MapType(MapType(IntegerType, IntegerType), IntegerType))
+ val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType,
+ valueContainsNull = false))
+ val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
+ valueContainsNull = false))
+ val mNull = Literal.create(null, MapType(StringType, StringType))
+
+ // overlapping maps
+ checkEvaluation(MapConcat(Seq(m0, m1)),
+ (
+ Array("a", "b", "c", "a"), // keys
+ Array("1", "2", "3", "4") // values
+ )
+ )
+
+ // maps with no overlap
+ checkEvaluation(MapConcat(Seq(m0, m2)),
+ Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5"))
+
+ // 3 maps
+ checkEvaluation(MapConcat(Seq(m0, m1, m2)),
+ (
+ Array("a", "b", "c", "a", "d", "e"), // keys
+ Array("1", "2", "3", "4", "4", "5") // values
+ )
+ )
+
+ // null reference values
+ checkEvaluation(MapConcat(Seq(m3, m4)),
+ (
+ Array("a", "b", "a", "c"), // keys
+ Array("1", "2", null, "3") // values
+ )
+ )
+
+ // null primitive values
+ checkEvaluation(MapConcat(Seq(m5, m6)),
+ (
+ Array("a", "b", "a", "c"), // keys
+ Array(1, 2, null, 3) // values
+ )
+ )
+
+ // keys that are primitive
+ checkEvaluation(MapConcat(Seq(m11, m12)),
+ (
+ Array(1, 2, 3, 4), // keys
+ Array("1", "2", "3", "4") // values
+ )
+ )
+
+ // keys that are arrays, with overlap
+ checkEvaluation(MapConcat(Seq(m7, m8)),
+ (
+ Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys
+ Array(1, 2, 3, 4) // values
+ )
+ )
+
+ // keys that are maps, with overlap
+ checkEvaluation(MapConcat(Seq(m9, m10)),
+ (
+ Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12),
+ Map(1 -> 2, 3 -> 4)), // keys
+ Array(1, 2, 3, 4) // values
+ )
+ )
+
+ // null map
+ checkEvaluation(MapConcat(Seq(m0, mNull)), null)
+ checkEvaluation(MapConcat(Seq(mNull, m0)), null)
+ checkEvaluation(MapConcat(Seq(mNull, mNull)), null)
+ checkEvaluation(MapConcat(Seq(mNull)), null)
+
+ // single map
+ checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2"))
+
+ // no map
+ checkEvaluation(MapConcat(Seq.empty), Map.empty)
+
+ // force split expressions for input in generated code
+ val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e")
+ val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5")
+ checkEvaluation(MapConcat(
+ Seq(
+ m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0,
+ m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0,
+ m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2
+ )),
+ (expectedKeys, expectedValues))
+
+ // argument checking
+ assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess)
+ assert(MapConcat(Seq(m5, m6)).checkInputDataTypes().isSuccess)
+ assert(MapConcat(Seq(m0, m5)).checkInputDataTypes().isFailure)
+ assert(MapConcat(Seq(m0, Literal(12))).checkInputDataTypes().isFailure)
+ assert(MapConcat(Seq(m0, m1)).dataType.keyType == StringType)
+ assert(MapConcat(Seq(m0, m1)).dataType.valueType == StringType)
+ assert(!MapConcat(Seq(m0, m1)).dataType.valueContainsNull)
+ assert(MapConcat(Seq(m5, m6)).dataType.keyType == StringType)
+ assert(MapConcat(Seq(m5, m6)).dataType.valueType == IntegerType)
+ assert(MapConcat(Seq.empty).dataType.keyType == StringType)
+ assert(MapConcat(Seq.empty).dataType.valueType == StringType)
+ assert(MapConcat(Seq(m5, m6)).dataType.valueContainsNull)
+ assert(MapConcat(Seq(m6, m5)).dataType.valueContainsNull)
+ assert(!MapConcat(Seq(m1, m2)).nullable)
+ assert(MapConcat(Seq(m1, mNull)).nullable)
+ }
+
test("MapFromEntries") {
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
ArrayType(
@@ -213,6 +339,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))
+ val a4 = Literal.create(Seq(create_row(1)), ArrayType(StructType(Seq(
+ StructField("a", IntegerType, true)))))
checkEvaluation(ArrayContains(a0, Literal(1)), true)
checkEvaluation(ArrayContains(a0, Literal(0)), false)
@@ -228,6 +356,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
+ checkEvaluation(ArrayContains(a4, Literal.create(create_row(1), StructType(Seq(
+ StructField("a", IntegerType, false))))), true)
+ checkEvaluation(ArrayContains(a4, Literal.create(create_row(0), StructType(Seq(
+ StructField("a", IntegerType, false))))), false)
+
// binary
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index a099119732e25..e068c32500cfc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true)
}
+ test("if/case when - null flags of non-primitive types") {
+ val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true))
+ val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false))
+ val structWithNulls = Literal.create(
+ create_row(null, null),
+ StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true))))
+ val structWithoutNulls = Literal.create(
+ create_row(1, "a"),
+ StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false))))
+ val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true))
+ val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false))
+
+ val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls)
+ val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls)
+ val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls)
+ val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls)
+ val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls)
+ val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls)
+ val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls)
+ val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls)
+ val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls)
+ val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls)
+ val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls)
+ val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls)
+
+ val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls)
+ val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls)
+ val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls)
+ val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls)
+ val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls)
+ val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls)
+ val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls)
+ val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls)
+ val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls)
+ val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls)
+ val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls)
+ val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls)
+
+ def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = {
+ assert(expectedType == result.dataType)
+ checkEvaluation(result, expectedValue)
+ }
+
+ checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1)
+ checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2)
+ checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3)
+ checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4)
+ checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1)
+ checkResult(structWithNulls.dataType, structWithNulls.value, structIf2)
+ checkResult(structWithNulls.dataType, structWithNulls.value, structIf3)
+ checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4)
+ checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1)
+ checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2)
+ checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3)
+ checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4)
+
+ checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1)
+ checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2)
+ checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3)
+ checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4)
+ checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1)
+ checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2)
+ checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3)
+ checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4)
+ checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1)
+ checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2)
+ checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3)
+ checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4)
+ }
+
test("case key when") {
val row = create_row(null, 1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 00e97637eee7e..52203b9e337ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -706,4 +706,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
assert(schemaToCompare == schema)
}
}
+
+ test("SPARK-24709: infer schema of json strings") {
+ checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct")
+ checkEvaluation(
+ SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")),
+ "struct,col1:struct>")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
index d2c6420eadb20..55569b6f2933e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite {
|boolean $isNull = false;
|int $value = -1;
""".stripMargin
- val exprValues = code.exprValues
+ val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }.toSet
assert(exprValues.size == 2)
assert(exprValues === Set(value, isNull))
}
@@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite {
assert(code.toString == expected)
- val exprValues = code.exprValues
+ val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }).toSet
assert(exprValues.size == 5)
assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
}
@@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite {
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
}
- test("replace expr values in code block") {
+ test("transform expr in code block") {
val expr = JavaCode.expression("1 + 1", IntegerType)
val isNull = JavaCode.isNullVariable("expr1_isNull")
val exprInFunc = JavaCode.variable("expr1", IntegerType)
@@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite {
|}""".stripMargin
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
- val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
- case _: SimpleExprValue => aliasedParam
- case other => other
+
+ // We want to replace all occurrences of `expr` with the variable `aliasedParam`.
+ val aliasedCode = code.transformExprValues {
+ case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
}
- val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
val expected =
code"""
|callFunc(int $aliasedParam) {
@@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite {
|}""".stripMargin
assert(aliasedCode.toString == expected.toString)
}
+
+ test ("transform expr in nested blocks") {
+ val expr = JavaCode.expression("1 + 1", IntegerType)
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+ val funcs = Seq("callFunc1", "callFunc2", "callFunc3")
+ val subBlocks = funcs.map { funcName =>
+ code"""
+ |$funcName(int $expr) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $expr + 1;
+ |}""".stripMargin
+ }
+
+ val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+
+ val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}"
+ val transformedBlock = block.transform {
+ case b: Block => b.transformExprValues {
+ case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
+ }
+ }.asInstanceOf[CodeBlock]
+
+ val expected1 =
+ code"""
+ |callFunc1(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val expected2 =
+ code"""
+ |callFunc2(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val expected3 =
+ code"""
+ |callFunc3(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val exprValues = transformedBlock.children.flatMap { block =>
+ block.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }
+ }.toSet
+
+ assert(transformedBlock.children(0).toString == expected1.toString)
+ assert(transformedBlock.children(1).toString == expected2.toString)
+ assert(transformedBlock.children(2).toString == expected3.toString)
+ assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString)
+ assert(exprValues === Set(isNull, exprInFunc, aliasedParam))
+ }
}
diff --git a/sql/core/benchmarks/FilterPushdownBenchmark-results.txt b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt
new file mode 100644
index 0000000000000..110669b69a00d
--- /dev/null
+++ b/sql/core/benchmarks/FilterPushdownBenchmark-results.txt
@@ -0,0 +1,580 @@
+================================================================================================
+Pushdown for many distinct value case
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8970 / 9122 1.8 570.3 1.0X
+Parquet Vectorized (Pushdown) 471 / 491 33.4 30.0 19.0X
+Native ORC Vectorized 7661 / 7853 2.1 487.0 1.2X
+Native ORC Vectorized (Pushdown) 1134 / 1161 13.9 72.1 7.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 string row ('7864320' < value < '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 9246 / 9297 1.7 587.8 1.0X
+Parquet Vectorized (Pushdown) 480 / 488 32.8 30.5 19.3X
+Native ORC Vectorized 7838 / 7850 2.0 498.3 1.2X
+Native ORC Vectorized (Pushdown) 1054 / 1118 14.9 67.0 8.8X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 string row (value = '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8989 / 9100 1.7 571.5 1.0X
+Parquet Vectorized (Pushdown) 448 / 467 35.1 28.5 20.1X
+Native ORC Vectorized 7680 / 7768 2.0 488.3 1.2X
+Native ORC Vectorized (Pushdown) 1067 / 1118 14.7 67.8 8.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 string row (value <=> '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 9115 / 9266 1.7 579.5 1.0X
+Parquet Vectorized (Pushdown) 466 / 492 33.7 29.7 19.5X
+Native ORC Vectorized 7800 / 7914 2.0 495.9 1.2X
+Native ORC Vectorized (Pushdown) 1075 / 1102 14.6 68.4 8.5X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 string row ('7864320' <= value <= '7864320'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 9099 / 9237 1.7 578.5 1.0X
+Parquet Vectorized (Pushdown) 462 / 475 34.1 29.3 19.7X
+Native ORC Vectorized 7847 / 7925 2.0 498.9 1.2X
+Native ORC Vectorized (Pushdown) 1078 / 1114 14.6 68.5 8.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select all string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 19303 / 19547 0.8 1227.3 1.0X
+Parquet Vectorized (Pushdown) 19924 / 20089 0.8 1266.7 1.0X
+Native ORC Vectorized 18725 / 19079 0.8 1190.5 1.0X
+Native ORC Vectorized (Pushdown) 19310 / 19492 0.8 1227.7 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 int row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8117 / 8323 1.9 516.1 1.0X
+Parquet Vectorized (Pushdown) 484 / 494 32.5 30.8 16.8X
+Native ORC Vectorized 6811 / 7036 2.3 433.0 1.2X
+Native ORC Vectorized (Pushdown) 1061 / 1082 14.8 67.5 7.6X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 int row (7864320 < value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8105 / 8140 1.9 515.3 1.0X
+Parquet Vectorized (Pushdown) 478 / 505 32.9 30.4 17.0X
+Native ORC Vectorized 6914 / 7211 2.3 439.6 1.2X
+Native ORC Vectorized (Pushdown) 1044 / 1064 15.1 66.4 7.8X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 int row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7983 / 8116 2.0 507.6 1.0X
+Parquet Vectorized (Pushdown) 464 / 487 33.9 29.5 17.2X
+Native ORC Vectorized 6703 / 6774 2.3 426.1 1.2X
+Native ORC Vectorized (Pushdown) 1017 / 1058 15.5 64.6 7.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 int row (value <=> 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7942 / 7983 2.0 504.9 1.0X
+Parquet Vectorized (Pushdown) 468 / 479 33.6 29.7 17.0X
+Native ORC Vectorized 6677 / 6779 2.4 424.5 1.2X
+Native ORC Vectorized (Pushdown) 1021 / 1068 15.4 64.9 7.8X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 int row (7864320 <= value <= 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7909 / 7958 2.0 502.8 1.0X
+Parquet Vectorized (Pushdown) 485 / 494 32.4 30.8 16.3X
+Native ORC Vectorized 6751 / 6846 2.3 429.2 1.2X
+Native ORC Vectorized (Pushdown) 1043 / 1077 15.1 66.3 7.6X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 int row (7864319 < value < 7864321): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8010 / 8033 2.0 509.2 1.0X
+Parquet Vectorized (Pushdown) 472 / 489 33.3 30.0 17.0X
+Native ORC Vectorized 6655 / 6808 2.4 423.1 1.2X
+Native ORC Vectorized (Pushdown) 1015 / 1067 15.5 64.5 7.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 10% int rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8983 / 9035 1.8 571.1 1.0X
+Parquet Vectorized (Pushdown) 2204 / 2231 7.1 140.1 4.1X
+Native ORC Vectorized 7864 / 8011 2.0 500.0 1.1X
+Native ORC Vectorized (Pushdown) 2674 / 2789 5.9 170.0 3.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 50% int rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 12723 / 12903 1.2 808.9 1.0X
+Parquet Vectorized (Pushdown) 9112 / 9282 1.7 579.3 1.4X
+Native ORC Vectorized 12090 / 12230 1.3 768.7 1.1X
+Native ORC Vectorized (Pushdown) 9242 / 9372 1.7 587.6 1.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 90% int rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 16453 / 16678 1.0 1046.1 1.0X
+Parquet Vectorized (Pushdown) 15997 / 16262 1.0 1017.0 1.0X
+Native ORC Vectorized 16652 / 17070 0.9 1058.7 1.0X
+Native ORC Vectorized (Pushdown) 15843 / 16112 1.0 1007.2 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select all int rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 17098 / 17254 0.9 1087.1 1.0X
+Parquet Vectorized (Pushdown) 17302 / 17529 0.9 1100.1 1.0X
+Native ORC Vectorized 16790 / 17098 0.9 1067.5 1.0X
+Native ORC Vectorized (Pushdown) 17329 / 17914 0.9 1101.7 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select all int rows (value > -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 17088 / 17392 0.9 1086.4 1.0X
+Parquet Vectorized (Pushdown) 17609 / 17863 0.9 1119.5 1.0X
+Native ORC Vectorized 18334 / 69831 0.9 1165.7 0.9X
+Native ORC Vectorized (Pushdown) 17465 / 17629 0.9 1110.4 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select all int rows (value != -1): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 16903 / 17233 0.9 1074.6 1.0X
+Parquet Vectorized (Pushdown) 16945 / 17032 0.9 1077.3 1.0X
+Native ORC Vectorized 16377 / 16762 1.0 1041.2 1.0X
+Native ORC Vectorized (Pushdown) 16950 / 17212 0.9 1077.7 1.0X
+
+
+================================================================================================
+Pushdown for few distinct value case (use dictionary encoding)
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 distinct string row (value IS NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7245 / 7322 2.2 460.7 1.0X
+Parquet Vectorized (Pushdown) 378 / 389 41.6 24.0 19.2X
+Native ORC Vectorized 6720 / 6778 2.3 427.2 1.1X
+Native ORC Vectorized (Pushdown) 1009 / 1032 15.6 64.2 7.2X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 0 distinct string row ('100' < value < '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7627 / 7795 2.1 484.9 1.0X
+Parquet Vectorized (Pushdown) 384 / 406 41.0 24.4 19.9X
+Native ORC Vectorized 6724 / 7824 2.3 427.5 1.1X
+Native ORC Vectorized (Pushdown) 968 / 986 16.3 61.5 7.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 distinct string row (value = '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7157 / 7534 2.2 455.0 1.0X
+Parquet Vectorized (Pushdown) 542 / 565 29.0 34.5 13.2X
+Native ORC Vectorized 6716 / 7214 2.3 427.0 1.1X
+Native ORC Vectorized (Pushdown) 1212 / 1288 13.0 77.0 5.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 distinct string row (value <=> '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7368 / 7552 2.1 468.4 1.0X
+Parquet Vectorized (Pushdown) 544 / 556 28.9 34.6 13.5X
+Native ORC Vectorized 6740 / 6867 2.3 428.5 1.1X
+Native ORC Vectorized (Pushdown) 1230 / 1426 12.8 78.2 6.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 distinct string row ('100' <= value <= '100'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7427 / 7734 2.1 472.2 1.0X
+Parquet Vectorized (Pushdown) 556 / 568 28.3 35.4 13.3X
+Native ORC Vectorized 6847 / 7059 2.3 435.3 1.1X
+Native ORC Vectorized (Pushdown) 1226 / 1230 12.8 77.9 6.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select all distinct string rows (value IS NOT NULL): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 16998 / 17311 0.9 1080.7 1.0X
+Parquet Vectorized (Pushdown) 16977 / 17250 0.9 1079.4 1.0X
+Native ORC Vectorized 18447 / 19852 0.9 1172.8 0.9X
+Native ORC Vectorized (Pushdown) 16614 / 17102 0.9 1056.3 1.0X
+
+
+================================================================================================
+Pushdown benchmark for StringStartsWith
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+StringStartsWith filter: (value like '10%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 9705 / 10814 1.6 617.0 1.0X
+Parquet Vectorized (Pushdown) 3086 / 3574 5.1 196.2 3.1X
+Native ORC Vectorized 10094 / 10695 1.6 641.8 1.0X
+Native ORC Vectorized (Pushdown) 9611 / 9999 1.6 611.0 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+StringStartsWith filter: (value like '1000%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8016 / 8183 2.0 509.7 1.0X
+Parquet Vectorized (Pushdown) 444 / 457 35.4 28.2 18.0X
+Native ORC Vectorized 6970 / 7169 2.3 443.2 1.2X
+Native ORC Vectorized (Pushdown) 7447 / 7503 2.1 473.5 1.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+StringStartsWith filter: (value like '786432%'): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7908 / 8046 2.0 502.8 1.0X
+Parquet Vectorized (Pushdown) 408 / 429 38.6 25.9 19.4X
+Native ORC Vectorized 7021 / 7100 2.2 446.4 1.1X
+Native ORC Vectorized (Pushdown) 7310 / 7490 2.2 464.8 1.1X
+
+
+================================================================================================
+Pushdown benchmark for decimal
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 decimal(9, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 3785 / 3867 4.2 240.6 1.0X
+Parquet Vectorized (Pushdown) 3820 / 3928 4.1 242.9 1.0X
+Native ORC Vectorized 3981 / 4049 4.0 253.1 1.0X
+Native ORC Vectorized (Pushdown) 702 / 735 22.4 44.6 5.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 10% decimal(9, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 4694 / 4813 3.4 298.4 1.0X
+Parquet Vectorized (Pushdown) 4839 / 4907 3.3 307.6 1.0X
+Native ORC Vectorized 4943 / 5032 3.2 314.2 0.9X
+Native ORC Vectorized (Pushdown) 2043 / 2085 7.7 129.9 2.3X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 50% decimal(9, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8321 / 8472 1.9 529.0 1.0X
+Parquet Vectorized (Pushdown) 8125 / 8471 1.9 516.6 1.0X
+Native ORC Vectorized 8524 / 8616 1.8 541.9 1.0X
+Native ORC Vectorized (Pushdown) 7961 / 8383 2.0 506.1 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 90% decimal(9, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 9587 / 10112 1.6 609.5 1.0X
+Parquet Vectorized (Pushdown) 9726 / 10370 1.6 618.3 1.0X
+Native ORC Vectorized 10119 / 11147 1.6 643.4 0.9X
+Native ORC Vectorized (Pushdown) 9366 / 9497 1.7 595.5 1.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 decimal(18, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 4060 / 4093 3.9 258.1 1.0X
+Parquet Vectorized (Pushdown) 4037 / 4125 3.9 256.6 1.0X
+Native ORC Vectorized 4756 / 4811 3.3 302.4 0.9X
+Native ORC Vectorized (Pushdown) 824 / 889 19.1 52.4 4.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 10% decimal(18, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 5157 / 5271 3.0 327.9 1.0X
+Parquet Vectorized (Pushdown) 5051 / 5141 3.1 321.1 1.0X
+Native ORC Vectorized 5723 / 6146 2.7 363.9 0.9X
+Native ORC Vectorized (Pushdown) 2198 / 2317 7.2 139.8 2.3X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 50% decimal(18, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8608 / 8647 1.8 547.3 1.0X
+Parquet Vectorized (Pushdown) 8471 / 8584 1.9 538.6 1.0X
+Native ORC Vectorized 9249 / 10048 1.7 588.0 0.9X
+Native ORC Vectorized (Pushdown) 7645 / 8091 2.1 486.1 1.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 90% decimal(18, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 11658 / 11888 1.3 741.2 1.0X
+Parquet Vectorized (Pushdown) 11812 / 12098 1.3 751.0 1.0X
+Native ORC Vectorized 12943 / 13312 1.2 822.9 0.9X
+Native ORC Vectorized (Pushdown) 13139 / 13465 1.2 835.4 0.9X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 decimal(38, 2) row (value = 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 5491 / 5716 2.9 349.1 1.0X
+Parquet Vectorized (Pushdown) 5515 / 5615 2.9 350.6 1.0X
+Native ORC Vectorized 4582 / 4654 3.4 291.3 1.2X
+Native ORC Vectorized (Pushdown) 815 / 861 19.3 51.8 6.7X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 10% decimal(38, 2) rows (value < 1572864): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 6432 / 6527 2.4 409.0 1.0X
+Parquet Vectorized (Pushdown) 6513 / 6607 2.4 414.1 1.0X
+Native ORC Vectorized 5618 / 6085 2.8 357.2 1.1X
+Native ORC Vectorized (Pushdown) 2403 / 2443 6.5 152.8 2.7X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 50% decimal(38, 2) rows (value < 7864320): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 11041 / 11467 1.4 701.9 1.0X
+Parquet Vectorized (Pushdown) 10909 / 11484 1.4 693.5 1.0X
+Native ORC Vectorized 9860 / 10436 1.6 626.9 1.1X
+Native ORC Vectorized (Pushdown) 7908 / 8069 2.0 502.8 1.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 90% decimal(38, 2) rows (value < 14155776): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 14816 / 16877 1.1 942.0 1.0X
+Parquet Vectorized (Pushdown) 15383 / 15740 1.0 978.0 1.0X
+Native ORC Vectorized 14408 / 14771 1.1 916.0 1.0X
+Native ORC Vectorized (Pushdown) 13968 / 14805 1.1 888.1 1.1X
+
+
+================================================================================================
+Pushdown benchmark for InSet -> InFilters
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 5, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7477 / 7587 2.1 475.4 1.0X
+Parquet Vectorized (Pushdown) 7862 / 8346 2.0 499.9 1.0X
+Native ORC Vectorized 6447 / 7021 2.4 409.9 1.2X
+Native ORC Vectorized (Pushdown) 983 / 1003 16.0 62.5 7.6X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 5, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7107 / 7290 2.2 451.9 1.0X
+Parquet Vectorized (Pushdown) 7196 / 7258 2.2 457.5 1.0X
+Native ORC Vectorized 6102 / 6222 2.6 388.0 1.2X
+Native ORC Vectorized (Pushdown) 926 / 958 17.0 58.9 7.7X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 5, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7374 / 7692 2.1 468.8 1.0X
+Parquet Vectorized (Pushdown) 7771 / 7848 2.0 494.1 0.9X
+Native ORC Vectorized 6184 / 6356 2.5 393.2 1.2X
+Native ORC Vectorized (Pushdown) 920 / 963 17.1 58.5 8.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 10, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7073 / 7326 2.2 449.7 1.0X
+Parquet Vectorized (Pushdown) 7304 / 7647 2.2 464.4 1.0X
+Native ORC Vectorized 6222 / 6579 2.5 395.6 1.1X
+Native ORC Vectorized (Pushdown) 958 / 994 16.4 60.9 7.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 10, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7121 / 7501 2.2 452.7 1.0X
+Parquet Vectorized (Pushdown) 7751 / 8334 2.0 492.8 0.9X
+Native ORC Vectorized 6225 / 6680 2.5 395.8 1.1X
+Native ORC Vectorized (Pushdown) 998 / 1020 15.8 63.5 7.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 10, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7157 / 7399 2.2 455.1 1.0X
+Parquet Vectorized (Pushdown) 7806 / 7911 2.0 496.3 0.9X
+Native ORC Vectorized 6548 / 6720 2.4 416.3 1.1X
+Native ORC Vectorized (Pushdown) 1016 / 1050 15.5 64.6 7.0X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 50, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7662 / 7805 2.1 487.1 1.0X
+Parquet Vectorized (Pushdown) 7590 / 7861 2.1 482.5 1.0X
+Native ORC Vectorized 6840 / 8073 2.3 434.9 1.1X
+Native ORC Vectorized (Pushdown) 1041 / 1075 15.1 66.2 7.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 50, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 8230 / 9266 1.9 523.2 1.0X
+Parquet Vectorized (Pushdown) 7735 / 7960 2.0 491.8 1.1X
+Native ORC Vectorized 6945 / 7109 2.3 441.6 1.2X
+Native ORC Vectorized (Pushdown) 1123 / 1144 14.0 71.4 7.3X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 50, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7656 / 8058 2.1 486.7 1.0X
+Parquet Vectorized (Pushdown) 7860 / 8247 2.0 499.7 1.0X
+Native ORC Vectorized 6684 / 7003 2.4 424.9 1.1X
+Native ORC Vectorized (Pushdown) 1085 / 1172 14.5 69.0 7.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 100, distribution: 10): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7594 / 8128 2.1 482.8 1.0X
+Parquet Vectorized (Pushdown) 7845 / 7923 2.0 498.8 1.0X
+Native ORC Vectorized 5859 / 6421 2.7 372.5 1.3X
+Native ORC Vectorized (Pushdown) 1037 / 1054 15.2 66.0 7.3X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 100, distribution: 50): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 6762 / 6775 2.3 429.9 1.0X
+Parquet Vectorized (Pushdown) 6911 / 6970 2.3 439.4 1.0X
+Native ORC Vectorized 5884 / 5960 2.7 374.1 1.1X
+Native ORC Vectorized (Pushdown) 1028 / 1052 15.3 65.4 6.6X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+InSet -> InFilters (values count: 100, distribution: 90): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 6718 / 6767 2.3 427.1 1.0X
+Parquet Vectorized (Pushdown) 6812 / 6909 2.3 433.1 1.0X
+Native ORC Vectorized 5842 / 5883 2.7 371.4 1.1X
+Native ORC Vectorized (Pushdown) 1040 / 1058 15.1 66.1 6.5X
+
+
+================================================================================================
+Pushdown benchmark for tinyint
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 1 tinyint row (value = CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 3461 / 3997 4.5 220.1 1.0X
+Parquet Vectorized (Pushdown) 270 / 315 58.4 17.1 12.8X
+Native ORC Vectorized 4107 / 5372 3.8 261.1 0.8X
+Native ORC Vectorized (Pushdown) 778 / 1553 20.2 49.5 4.4X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 10% tinyint rows (value < CAST(12 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 4771 / 6655 3.3 303.3 1.0X
+Parquet Vectorized (Pushdown) 1322 / 1606 11.9 84.0 3.6X
+Native ORC Vectorized 4437 / 4572 3.5 282.1 1.1X
+Native ORC Vectorized (Pushdown) 1781 / 1976 8.8 113.2 2.7X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 50% tinyint rows (value < CAST(63 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 7433 / 7752 2.1 472.6 1.0X
+Parquet Vectorized (Pushdown) 5863 / 5913 2.7 372.8 1.3X
+Native ORC Vectorized 7986 / 8084 2.0 507.7 0.9X
+Native ORC Vectorized (Pushdown) 6522 / 6608 2.4 414.6 1.1X
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz
+
+Select 90% tinyint rows (value < CAST(114 AS tinyint)): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Parquet Vectorized 11190 / 11519 1.4 711.4 1.0X
+Parquet Vectorized (Pushdown) 10861 / 11206 1.4 690.5 1.0X
+Native ORC Vectorized 11622 / 12196 1.4 738.9 1.0X
+Native ORC Vectorized (Pushdown) 11377 / 11654 1.4 723.3 1.0X
+
+
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index f270c70fbfcf0..18ae314309d7b 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -118,7 +118,7 @@
org.apache.xbean
- xbean-asm5-shaded
+ xbean-asm6-shaded
org.scalacheck
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index c7c4c7b3e7715..c8cf44b51df77 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -20,8 +20,8 @@
import java.io.IOException;
import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
import org.apache.spark.internal.config.package$;
-import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -82,7 +82,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
- * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param taskContext the current task context.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
*/
@@ -90,19 +90,26 @@ public UnsafeFixedWidthAggregationMap(
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
- TaskMemoryManager taskMemoryManager,
+ TaskContext taskContext,
int initialCapacity,
long pageSizeBytes) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length());
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map =
- new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, true);
+ this.map = new BytesToBytesMap(
+ taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true);
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+
+ // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
+ // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
+ // does not fully consume the aggregation map's output (e.g. aggregate followed by limit).
+ taskContext.addTaskCompletionListener(context -> {
+ free();
+ });
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2ec236fc75efc..c97246f30220d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1016,6 +1016,11 @@ class Dataset[T] private[sql](
catalyst.expressions.EqualTo(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
+ case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference)
+ if a.sameRef(b) =>
+ catalyst.expressions.EqualNullSafe(
+ withPlan(plan.left).resolve(a.name),
+ withPlan(plan.right).resolve(b.name))
}}
withPlan {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c2bf40cb22064..0c4ea857fd1d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.sql.types.StructType
/**
@@ -73,12 +74,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), child) =>
- // With whole stage codegen, Spark releases resources only when all the output data of the
- // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little
- // data from child plan and finishes the query without releasing resources. Here we wrap
- // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and
- // trigger the resource releasing work, after we consume `limit` rows.
- CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+ CollectLimitExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case Limit(IntegerLiteral(limit), Sort(order, true, child))
@@ -354,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Used to plan the streaming global limit operator for streams in append mode.
+ * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator,
+ * following the example of the SpecialLimits Strategy above.
+ * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec.
+ * Streams with limit in Complete mode use the stateless CollectLimitExec operator.
+ * Limit is unsupported for streams in Update mode.
+ */
+ case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case ReturnAnswer(rootPlan) => rootPlan match {
+ case Limit(IntegerLiteral(limit), child)
+ if plan.isStreaming && outputMode == InternalOutputModes.Append =>
+ StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+ case _ => Nil
+ }
+ case Limit(IntegerLiteral(limit), child)
+ if plan.isStreaming && outputMode == InternalOutputModes.Append =>
+ StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+ case _ => Nil
+ }
+ }
+
object StreamingJoinStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 8c7b2c187cccd..2cac0cfce28de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -328,7 +328,7 @@ case class HashAggregateExec(
initialBuffer,
bufferSchema,
groupingKeySchema,
- TaskContext.get().taskMemoryManager(),
+ TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9dc334c1ead3c..c1911235f8df3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -166,7 +166,7 @@ class TungstenAggregationIterator(
initialAggregationBuffer,
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
- TaskContext.get().taskMemoryManager(),
+ TaskContext.get(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 3b6df45e949e8..2fee2128ba1f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 93de1faef527a..52a18abb55241 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -353,25 +353,13 @@ class ParquetFileFormat
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
- // Try to push down filters when filter push-down is enabled.
- val pushed = if (enableParquetFilterPushDown) {
- filters
- // Collects all converted Parquet filter predicates. Notice that not all predicates can be
- // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
- // is used here.
- .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith)
- .createFilter(requiredSchema, _))
- .reduceOption(FilterApi.and)
- } else {
- None
- }
-
val fileSplit =
new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty)
+ val filePath = fileSplit.getPath
val split =
new org.apache.parquet.hadoop.ParquetInputSplit(
- fileSplit.getPath,
+ filePath,
fileSplit.getStart,
fileSplit.getStart + fileSplit.getLength,
fileSplit.getLength,
@@ -379,12 +367,28 @@ class ParquetFileFormat
null)
val sharedConf = broadcastedHadoopConf.value.value
+
+ // Try to push down filters when filter push-down is enabled.
+ val pushed = if (enableParquetFilterPushDown) {
+ val parquetSchema = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS)
+ .getFileMetaData.getSchema
+ filters
+ // Collects all converted Parquet filter predicates. Notice that not all predicates can be
+ // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
+ // is used here.
+ .flatMap(new ParquetFilters(pushDownDate, pushDownStringStartWith)
+ .createFilter(parquetSchema, _))
+ .reduceOption(FilterApi.and)
+ } else {
+ None
+ }
+
// PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
// *only* if the file was created by something other than "parquet-mr", so check the actual
// writer here for this file. We have to do this per-file, as each file in the table may
// have different writers.
def isCreatedByParquetMr(): Boolean = {
- val footer = ParquetFileReader.readFooter(sharedConf, fileSplit.getPath, SKIP_ROW_GROUPS)
+ val footer = ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS)
footer.getFileMetaData().getCreatedBy().startsWith("parquet-mr")
}
val convertTz =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 21c9e2e4f82b4..4c9b940db2b30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -19,15 +19,19 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.sql.Date
+import scala.collection.JavaConverters.asScalaBufferConverter
+
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.io.api.Binary
-import org.apache.parquet.schema.PrimitiveComparator
+import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator, PrimitiveType}
+import org.apache.parquet.schema.OriginalType._
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.SQLDate
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -35,171 +39,190 @@ import org.apache.spark.unsafe.types.UTF8String
*/
private[parquet] class ParquetFilters(pushDownDate: Boolean, pushDownStartWith: Boolean) {
+ private case class ParquetSchemaType(
+ originalType: OriginalType,
+ primitiveTypeName: PrimitiveTypeName,
+ decimalMetadata: DecimalMetadata)
+
+ private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, null)
+ private val ParquetByteType = ParquetSchemaType(INT_8, INT32, null)
+ private val ParquetShortType = ParquetSchemaType(INT_16, INT32, null)
+ private val ParquetIntegerType = ParquetSchemaType(null, INT32, null)
+ private val ParquetLongType = ParquetSchemaType(null, INT64, null)
+ private val ParquetFloatType = ParquetSchemaType(null, FLOAT, null)
+ private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, null)
+ private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, null)
+ private val ParquetBinaryType = ParquetSchemaType(null, BINARY, null)
+ private val ParquetDateType = ParquetSchemaType(DATE, INT32, null)
+
private def dateToDays(date: Date): SQLDate = {
DateTimeUtils.fromJavaDate(date)
}
- private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case BooleanType =>
+ private val makeEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
- case IntegerType =>
- (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer])
- case LongType =>
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) => FilterApi.eq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
// Binary.fromString and Binary.fromByteArray don't accept null values
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
- case BinaryType =>
+ case ParquetBinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
- case DateType if pushDownDate =>
+ case ParquetDateType if pushDownDate =>
(n: String, v: Any) => FilterApi.eq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
- private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case BooleanType =>
+ private val makeNotEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
(n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
- case IntegerType =>
- (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer])
- case LongType =>
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) => FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
- case BinaryType =>
+ case ParquetBinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
- case DateType if pushDownDate =>
+ case ParquetDateType if pushDownDate =>
(n: String, v: Any) => FilterApi.notEq(
intColumn(n),
Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
}
- private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case IntegerType =>
- (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer])
- case LongType =>
+ private val makeLt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) =>
+ FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) =>
- FilterApi.lt(binaryColumn(n),
- Binary.fromString(v.asInstanceOf[String]))
- case BinaryType =>
+ FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
- case DateType if pushDownDate =>
- (n: String, v: Any) => FilterApi.lt(
- intColumn(n),
- Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: String, v: Any) =>
+ FilterApi.lt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
}
- private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case IntegerType =>
- (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
- case LongType =>
+ private val makeLtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) =>
+ FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) =>
- FilterApi.ltEq(binaryColumn(n),
- Binary.fromString(v.asInstanceOf[String]))
- case BinaryType =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
- case DateType if pushDownDate =>
- (n: String, v: Any) => FilterApi.ltEq(
- intColumn(n),
- Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: String, v: Any) =>
+ FilterApi.ltEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
}
- private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case IntegerType =>
- (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer])
- case LongType =>
+ private val makeGt: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) =>
+ FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) =>
- FilterApi.gt(binaryColumn(n),
- Binary.fromString(v.asInstanceOf[String]))
- case BinaryType =>
+ FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
- case DateType if pushDownDate =>
- (n: String, v: Any) => FilterApi.gt(
- intColumn(n),
- Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: String, v: Any) =>
+ FilterApi.gt(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
}
- private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
- case IntegerType =>
- (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
- case LongType =>
+ private val makeGtEq: PartialFunction[ParquetSchemaType, (String, Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: String, v: Any) =>
+ FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
(n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long])
- case FloatType =>
+ case ParquetFloatType =>
(n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
- case DoubleType =>
+ case ParquetDoubleType =>
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
- case StringType =>
+ case ParquetStringType =>
(n: String, v: Any) =>
- FilterApi.gtEq(binaryColumn(n),
- Binary.fromString(v.asInstanceOf[String]))
- case BinaryType =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
- case DateType if pushDownDate =>
- (n: String, v: Any) => FilterApi.gtEq(
- intColumn(n),
- Option(v).map(date => dateToDays(date.asInstanceOf[Date]).asInstanceOf[Integer]).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: String, v: Any) =>
+ FilterApi.gtEq(intColumn(n), dateToDays(v.asInstanceOf[Date]).asInstanceOf[Integer])
}
/**
* Returns a map from name of the column to the data type, if predicate push down applies.
*/
- private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match {
- case StructType(fields) =>
+ private def getFieldMap(dataType: MessageType): Map[String, ParquetSchemaType] = dataType match {
+ case m: MessageType =>
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
- fields.map(f => f.name -> f.dataType).toMap
- case _ => Map.empty[String, DataType]
+ m.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
+ f.getName -> ParquetSchemaType(
+ f.getOriginalType, f.getPrimitiveTypeName, f.getDecimalMetadata)
+ }.toMap
+ case _ => Map.empty[String, ParquetSchemaType]
}
/**
* Converts data sources filters to Parquet filter predicates.
*/
- def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
+ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = {
val nameToType = getFieldMap(schema)
// Parquet does not allow dots in the column name because dots are used as a column path
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index c55f9b8f1a7fc..a80673c705f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.exchange
+import java.util.concurrent.TimeoutException
+
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.util.control.NonFatal
@@ -140,7 +142,16 @@ case class BroadcastExchangeExec(
}
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
- ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
+ try {
+ ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
+ } catch {
+ case ex: TimeoutException =>
+ logError(s"Could not execute broadcast in ${timeout.toSeconds} secs.", ex)
+ throw new SparkException(s"Could not execute broadcast in ${timeout.toSeconds} secs. " +
+ s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " +
+ s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1",
+ ex)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index ad95879d86f42..d96ecbaa48029 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -279,13 +279,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan match {
- case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
- right) =>
- val (reorderedLeftKeys, reorderedRightKeys) =
- reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
- BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
- left, right)
-
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index c480b96626f84..6ae7f2869b0f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -59,7 +59,8 @@ class IncrementalExecution(
StatefulAggregationStrategy ::
FlatMapGroupsWithStateStrategy ::
StreamingRelationStrategy ::
- StreamingDeduplicationStrategy :: Nil
+ StreamingDeduplicationStrategy ::
+ StreamingGlobalLimitStrategy(outputMode) :: Nil
}
private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
@@ -134,8 +135,12 @@ class IncrementalExecution(
stateWatermarkPredicates =
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
- Some(offsetSeqMetadata.batchWatermarkMs))
- )
+ Some(offsetSeqMetadata.batchWatermarkMs)))
+
+ case l: StreamingGlobalLimitExec =>
+ l.copy(
+ stateInfo = Some(nextStatefulOperationStateInfo),
+ outputMode = Some(outputMode))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 17ffa2a517312..16651dd060d73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -61,7 +61,7 @@ class MicroBatchExecution(
case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
}
- private val watermarkTracker = new WatermarkTracker()
+ private var watermarkTracker: WatermarkTracker = _
override lazy val logicalPlan: LogicalPlan = {
assert(queryExecutionThread eq Thread.currentThread,
@@ -257,6 +257,7 @@ class MicroBatchExecution(
OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
+ watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
watermarkTracker.setWatermark(metadata.batchWatermarkMs)
}
@@ -295,6 +296,7 @@ class MicroBatchExecution(
case None => // We are starting this stream for the first time.
logInfo(s"Starting new streaming query.")
currentBatchId = 0
+ watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 787174481ff08..1ae3f36c152cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization
import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
-import org.apache.spark.sql.internal.SQLConf.{SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS}
+import org.apache.spark.sql.internal.SQLConf._
/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
@@ -86,7 +86,22 @@ case class OffsetSeqMetadata(
object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
- private val relevantSQLConfs = Seq(SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS)
+ private val relevantSQLConfs = Seq(
+ SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY)
+
+ /**
+ * Default values of relevant configurations that are used for backward compatibility.
+ * As new configurations are added to the metadata, existing checkpoints may not have those
+ * confs. The values in this list ensures that the confs without recovered values are
+ * set to a default value that ensure the same behavior of the streaming query as it was before
+ * the restart.
+ *
+ * Note, that this is optional; set values here if you *have* to override existing session conf
+ * with a specific default value for ensuring same behavior of the query as before.
+ */
+ private val relevantSQLConfDefaultValues = Map[String, String](
+ STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME
+ )
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
@@ -115,8 +130,22 @@ object OffsetSeqMetadata extends Logging {
case None =>
// For backward compatibility, if a config was not recorded in the offset log,
- // then log it, and let the existing conf value in SparkSession prevail.
- logWarning (s"Conf '$confKey' was not found in the offset log, using existing value")
+ // then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or
+ // let the existing conf value in SparkSession prevail.
+ relevantSQLConfDefaultValues.get(confKey) match {
+
+ case Some(defaultValue) =>
+ sessionConf.set(confKey, defaultValue)
+ logWarning(s"Conf '$confKey' was not found in the offset log, " +
+ s"using default value '$defaultValue'")
+
+ case None =>
+ val valueStr = sessionConf.getOption(confKey).map { v =>
+ s" Using existing session conf value '$v'."
+ }.getOrElse { " No value set in session conf." }
+ logWarning(s"Conf '$confKey' was not found in the offset log. $valueStr")
+
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
new file mode 100644
index 0000000000000..bf4af60c8cf03
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.streaming.state.StateStoreOps
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * A physical operator for executing a streaming limit, which makes sure no more than streamLimit
+ * rows are returned. This operator is meant for streams in Append mode only.
+ */
+case class StreamingGlobalLimitExec(
+ streamLimit: Long,
+ child: SparkPlan,
+ stateInfo: Option[StatefulOperatorStateInfo] = None,
+ outputMode: Option[OutputMode] = None)
+ extends UnaryExecNode with StateStoreWriter {
+
+ private val keySchema = StructType(Array(StructField("key", NullType)))
+ private val valueSchema = StructType(Array(StructField("value", LongType)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ metrics // force lazy init at driver
+
+ assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append,
+ "StreamingGlobalLimitExec is only valid for streams in Append output mode")
+
+ child.execute().mapPartitionsWithStateStore(
+ getStateInfo,
+ keySchema,
+ valueSchema,
+ indexOrdinal = None,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+ val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
+ val numOutputRows = longMetric("numOutputRows")
+ val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+ val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+ val commitTimeMs = longMetric("commitTimeMs")
+ val updatesStartTimeNs = System.nanoTime
+
+ val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L)
+ var cumulativeRowCount = preBatchRowCount
+
+ val result = iter.filter { r =>
+ val x = cumulativeRowCount < streamLimit
+ if (x) {
+ cumulativeRowCount += 1
+ }
+ x
+ }
+
+ CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
+ if (cumulativeRowCount > preBatchRowCount) {
+ numUpdatedStateRows += 1
+ numOutputRows += cumulativeRowCount - preBatchRowCount
+ store.put(key, getValueRow(cumulativeRowCount))
+ }
+ allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+ commitTimeMs += timeTakenMs { store.commit() }
+ setStoreMetrics(store)
+ })
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil
+
+ private def getValueRow(value: Long): UnsafeRow = {
+ UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value)))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
index 80865669558dd..7b30db44a2090 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
@@ -20,15 +20,68 @@ package org.apache.spark.sql.execution.streaming
import scala.collection.mutable
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.internal.SQLConf
-class WatermarkTracker extends Logging {
+/**
+ * Policy to define how to choose a new global watermark value if there are
+ * multiple watermark operators in a streaming query.
+ */
+sealed trait MultipleWatermarkPolicy {
+ def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long
+}
+
+object MultipleWatermarkPolicy {
+ val DEFAULT_POLICY_NAME = "min"
+
+ def apply(policyName: String): MultipleWatermarkPolicy = {
+ policyName.toLowerCase match {
+ case DEFAULT_POLICY_NAME => MinWatermark
+ case "max" => MaxWatermark
+ case _ =>
+ throw new IllegalArgumentException(s"Could not recognize watermark policy '$policyName'")
+ }
+ }
+}
+
+/**
+ * Policy to choose the *min* of the operator watermark values as the global watermark value.
+ * Note that this is the safe (hence default) policy as the global watermark will advance
+ * only if all the individual operator watermarks have advanced. In other words, in a
+ * streaming query with multiple input streams and watermarks defined on all of them,
+ * the global watermark will advance as slowly as the slowest input. So if there is watermark
+ * based state cleanup or late-data dropping, then this policy is the most conservative one.
+ */
+case object MinWatermark extends MultipleWatermarkPolicy {
+ def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = {
+ assert(operatorWatermarks.nonEmpty)
+ operatorWatermarks.min
+ }
+}
+
+/**
+ * Policy to choose the *min* of the operator watermark values as the global watermark value. So the
+ * global watermark will advance if any of the individual operator watermarks has advanced.
+ * In other words, in a streaming query with multiple input streams and watermarks defined on all
+ * of them, the global watermark will advance as fast as the fastest input. So if there is watermark
+ * based state cleanup or late-data dropping, then this policy is the most aggressive one and
+ * may lead to unexpected behavior if the data of the slow stream is delayed.
+ */
+case object MaxWatermark extends MultipleWatermarkPolicy {
+ def chooseGlobalWatermark(operatorWatermarks: Seq[Long]): Long = {
+ assert(operatorWatermarks.nonEmpty)
+ operatorWatermarks.max
+ }
+}
+
+/** Tracks the watermark value of a streaming query based on a given `policy` */
+case class WatermarkTracker(policy: MultipleWatermarkPolicy) extends Logging {
private val operatorToWatermarkMap = mutable.HashMap[Int, Long]()
- private var watermarkMs: Long = 0
- private var updated = false
+ private var globalWatermarkMs: Long = 0
def setWatermark(newWatermarkMs: Long): Unit = synchronized {
- watermarkMs = newWatermarkMs
+ globalWatermarkMs = newWatermarkMs
}
def updateWatermark(executedPlan: SparkPlan): Unit = synchronized {
@@ -37,7 +90,6 @@ class WatermarkTracker extends Logging {
}
if (watermarkOperators.isEmpty) return
-
watermarkOperators.zipWithIndex.foreach {
case (e, index) if e.eventTimeStats.value.count > 0 =>
logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}")
@@ -58,16 +110,28 @@ class WatermarkTracker extends Logging {
// This is the safest option, because only the global watermark is fault-tolerant. Making
// it the minimum of all individual watermarks guarantees it will never advance past where
// any individual watermark operator would be if it were in a plan by itself.
- val newWatermarkMs = operatorToWatermarkMap.minBy(_._2)._2
- if (newWatermarkMs > watermarkMs) {
- logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms")
- watermarkMs = newWatermarkMs
- updated = true
+ val chosenGlobalWatermark = policy.chooseGlobalWatermark(operatorToWatermarkMap.values.toSeq)
+ if (chosenGlobalWatermark > globalWatermarkMs) {
+ logInfo(s"Updating event-time watermark from $globalWatermarkMs to $chosenGlobalWatermark ms")
+ globalWatermarkMs = chosenGlobalWatermark
} else {
- logDebug(s"Event time didn't move: $newWatermarkMs < $watermarkMs")
- updated = false
+ logDebug(s"Event time watermark didn't move: $chosenGlobalWatermark < $globalWatermarkMs")
}
}
- def currentWatermark: Long = synchronized { watermarkMs }
+ def currentWatermark: Long = synchronized { globalWatermarkMs }
+}
+
+object WatermarkTracker {
+ def apply(conf: RuntimeConfig): WatermarkTracker = {
+ // If the session has been explicitly configured to use non-default policy then use it,
+ // otherwise use the default `min` policy as thats the safe thing to do.
+ // When recovering from a checkpoint location, it is expected that the `conf` will already
+ // be configured with the value present in the checkpoint. If there is no policy explicitly
+ // saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced
+ // through defaults specified in OffsetSeqMetadata.setSessionConf().
+ val policyName = conf.get(
+ SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME)
+ new WatermarkTracker(MultipleWatermarkPolicy(policyName))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 7fa13c4aa2c01..b137f98045c5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
@@ -222,60 +221,19 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow])
}
/** A common trait for MemorySinks with methods used for testing */
-trait MemorySinkBase extends BaseStreamingSink with Logging {
+trait MemorySinkBase extends BaseStreamingSink {
def allData: Seq[Row]
def latestBatchData: Seq[Row]
def dataSinceBatch(sinceBatchId: Long): Seq[Row]
def latestBatchId: Option[Long]
-
- /**
- * Truncates the given rows to return at most maxRows rows.
- * @param rows The data that may need to be truncated.
- * @param batchLimit Number of rows to keep in this batch; the rest will be truncated
- * @param sinkLimit Total number of rows kept in this sink, for logging purposes.
- * @param batchId The ID of the batch that sent these rows, for logging purposes.
- * @return Truncated rows.
- */
- protected def truncateRowsIfNeeded(
- rows: Array[Row],
- batchLimit: Int,
- sinkLimit: Int,
- batchId: Long): Array[Row] = {
- if (rows.length > batchLimit && batchLimit >= 0) {
- logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
- rows.take(batchLimit)
- } else {
- rows
- }
- }
-}
-
-/**
- * Companion object to MemorySinkBase.
- */
-object MemorySinkBase {
- val MAX_MEMORY_SINK_ROWS = "maxRows"
- val MAX_MEMORY_SINK_ROWS_DEFAULT = -1
-
- /**
- * Gets the max number of rows a MemorySink should store. This number is based on the memory
- * sink row limit option if it is set. If not, we use a large value so that data truncates
- * rather than causing out of memory errors.
- * @param options Options for writing from which we get the max rows option
- * @return The maximum number of rows a memorySink should store.
- */
- def getMemorySinkCapacity(options: DataSourceOptions): Int = {
- val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
- if (maxRows >= 0) maxRows else Int.MaxValue - 10
- }
}
/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
- extends Sink with MemorySinkBase with Logging {
+class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
+ with MemorySinkBase with Logging {
private case class AddedData(batchId: Long, data: Array[Row])
@@ -283,12 +241,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()
- /** The number of rows in this MemorySink. */
- private var numRows = 0
-
- /** The capacity in rows of this sink. */
- val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
@@ -321,23 +273,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
- var rowsToAdd = data.collect()
- synchronized {
- rowsToAdd =
- truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
- val rows = AddedData(batchId, rowsToAdd)
- batches += rows
- numRows += rowsToAdd.length
- }
+ val rows = AddedData(batchId, data.collect())
+ synchronized { batches += rows }
case Complete =>
- var rowsToAdd = data.collect()
+ val rows = AddedData(batchId, data.collect())
synchronized {
- rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
- val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
- numRows = rowsToAdd.length
}
case _ =>
@@ -351,7 +294,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSo
def clear(): Unit = synchronized {
batches.clear()
- numRows = 0
}
override def toString(): String = "MemorySink"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 29f8cca476722..f2a35a90af24a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
schema: StructType,
mode: OutputMode,
options: DataSourceOptions): StreamWriter = {
- new MemoryStreamWriter(this, mode, options)
+ new MemoryStreamWriter(this, mode)
}
private case class AddedData(batchId: Long, data: Array[Row])
@@ -55,9 +55,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
@GuardedBy("this")
private val batches = new ArrayBuffer[AddedData]()
- /** The number of rows in this MemorySink. */
- private var numRows = 0
-
/** Returns all rows that are stored in this [[Sink]]. */
def allData: Seq[Row] = synchronized {
batches.flatMap(_.data)
@@ -84,11 +81,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
}.mkString("\n")
}
- def write(
- batchId: Long,
- outputMode: OutputMode,
- newRows: Array[Row],
- sinkCapacity: Int): Unit = {
+ def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
val notCommitted = synchronized {
latestBatchId.isEmpty || batchId > latestBatchId.get
}
@@ -96,21 +89,14 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
logDebug(s"Committing batch $batchId to $this")
outputMode match {
case Append | Update =>
- synchronized {
- val rowsToAdd =
- truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
- val rows = AddedData(batchId, rowsToAdd)
- batches += rows
- numRows += rowsToAdd.length
- }
+ val rows = AddedData(batchId, newRows)
+ synchronized { batches += rows }
case Complete =>
+ val rows = AddedData(batchId, newRows)
synchronized {
- val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
- val rows = AddedData(batchId, rowsToAdd)
batches.clear()
batches += rows
- numRows = rowsToAdd.length
}
case _ =>
@@ -124,7 +110,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
def clear(): Unit = synchronized {
batches.clear()
- numRows = 0
}
override def toString(): String = "MemorySinkV2"
@@ -132,22 +117,16 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {}
-class MemoryWriter(
- sink: MemorySinkV2,
- batchId: Long,
- outputMode: OutputMode,
- options: DataSourceOptions)
+class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
extends DataSourceWriter with Logging {
- val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
def commit(messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
- sink.write(batchId, outputMode, newRows, sinkCapacity)
+ sink.write(batchId, outputMode, newRows)
}
override def abort(messages: Array[WriterCommitMessage]): Unit = {
@@ -155,21 +134,16 @@ class MemoryWriter(
}
}
-class MemoryStreamWriter(
- val sink: MemorySinkV2,
- outputMode: OutputMode,
- options: DataSourceOptions)
+class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
extends StreamWriter {
- val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
- sink.write(epochId, outputMode, newRows, sinkCapacity)
+ sink.write(epochId, outputMode, newRows)
}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index acca9572cb14c..89dbba10a6bf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2934,6 +2934,17 @@ object functions {
FromUTCTimestamp(ts.expr, Literal(tz))
}
+ /**
+ * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
+ * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
+ * '2017-07-14 03:40:00.0'.
+ * @group datetime_funcs
+ * @since 2.4.0
+ */
+ def from_utc_timestamp(ts: Column, tz: Column): Column = withExpr {
+ FromUTCTimestamp(ts.expr, tz.expr)
+ }
+
/**
* Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time
* zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield
@@ -2945,6 +2956,17 @@ object functions {
ToUTCTimestamp(ts.expr, Literal(tz))
}
+ /**
+ * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time
+ * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield
+ * '2017-07-14 01:40:00.0'.
+ * @group datetime_funcs
+ * @since 2.4.0
+ */
+ def to_utc_timestamp(ts: Column, tz: Column): Column = withExpr {
+ ToUTCTimestamp(ts.expr, tz.expr)
+ }
+
/**
* Bucketize rows into one or more time windows given a timestamp specifying column. Window
* starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
@@ -3381,6 +3403,48 @@ object functions {
from_json(e, dataType, options)
}
+ /**
+ * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
+ *
+ * @param e a string column containing JSON data.
+ * @param schema the schema to use when parsing the json string
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def from_json(e: Column, schema: Column): Column = {
+ from_json(e, schema, Map.empty[String, String].asJava)
+ }
+
+ /**
+ * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
+ *
+ * @param e a string column containing JSON data.
+ * @param schema the schema to use when parsing the json string
+ * @param options options to control how the json is parsed. accepts the same options and the
+ * json data source.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = {
+ withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap))
+ }
+
+ /**
+ * Parses a column containing a JSON string and infers its schema.
+ *
+ * @param e a string column containing JSON data.
+ *
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr))
+
/**
* (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s,
* a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema.
@@ -3563,6 +3627,14 @@ object functions {
@scala.annotation.varargs
def arrays_zip(e: Column*): Column = withExpr { ArraysZip(e.map(_.expr)) }
+ /**
+ * Returns the union of all the given maps.
+ * @group collection_funcs
+ * @since 2.4.0
+ */
+ @scala.annotation.varargs
+ def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) }
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Mask functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 2499e9b604f3e..bdd8c4da6bd30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -199,7 +199,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
/**
* A filter that evaluates to `true` iff the attribute evaluates to
- * a string that starts with `value`.
+ * a string that ends with `value`.
*
* @since 1.3.1
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 926c0b69a03fd..3b9a56ffdde4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.StreamWriteSupport
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -250,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
(s, r)
case _ =>
- val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
+ val s = new MemorySink(df.schema, outputMode)
val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
(s, r)
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
index dc15d13cd1dd3..79fdd5895e691 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
@@ -35,3 +35,7 @@ DROP VIEW IF EXISTS jsonTable;
-- from_json - complex types
select from_json('{"a":1, "b":2}', 'map');
select from_json('{"a":1, "b":"2"}', 'struct');
+
+-- infer schema of json literal
+select schema_of_json('{"c1":0, "c2":[1]}');
+select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}'));
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
new file mode 100644
index 0000000000000..fc26397b881b5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
@@ -0,0 +1,94 @@
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+ map(true, false), map(false, true),
+ map(1Y, 2Y), map(3Y, 4Y),
+ map(1S, 2S), map(3S, 4S),
+ map(4, 6), map(7, 8),
+ map(6L, 7L), map(8L, 9L),
+ map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809),
+ map(1.0D, 2.0D), map(3.0D, 4.0D),
+ map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)),
+ map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'),
+ map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+ map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'),
+ map('a', 'b'), map('c', 'd'),
+ map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
+ map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
+ map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
+ map('a', 1), map('c', 2),
+ map(1, 'a'), map(2, 'c')
+) AS various_maps (
+ boolean_map1, boolean_map2,
+ tinyint_map1, tinyint_map2,
+ smallint_map1, smallint_map2,
+ int_map1, int_map2,
+ bigint_map1, bigint_map2,
+ decimal_map1, decimal_map2,
+ double_map1, double_map2,
+ float_map1, float_map2,
+ date_map1, date_map2,
+ timestamp_map1,
+ timestamp_map2,
+ string_map1, string_map2,
+ array_map1, array_map2,
+ struct_map1, struct_map2,
+ map_map1, map_map2,
+ string_int_map1, string_int_map2,
+ int_string_map1, int_string_map2
+);
+
+-- Concatenate maps of the same type
+SELECT
+ map_concat(boolean_map1, boolean_map2) boolean_map,
+ map_concat(tinyint_map1, tinyint_map2) tinyint_map,
+ map_concat(smallint_map1, smallint_map2) smallint_map,
+ map_concat(int_map1, int_map2) int_map,
+ map_concat(bigint_map1, bigint_map2) bigint_map,
+ map_concat(decimal_map1, decimal_map2) decimal_map,
+ map_concat(float_map1, float_map2) float_map,
+ map_concat(double_map1, double_map2) double_map,
+ map_concat(date_map1, date_map2) date_map,
+ map_concat(timestamp_map1, timestamp_map2) timestamp_map,
+ map_concat(string_map1, string_map2) string_map,
+ map_concat(array_map1, array_map2) array_map,
+ map_concat(struct_map1, struct_map2) struct_map,
+ map_concat(map_map1, map_map2) map_map,
+ map_concat(string_int_map1, string_int_map2) string_int_map,
+ map_concat(int_string_map1, int_string_map2) int_string_map
+FROM various_maps;
+
+-- Concatenate maps of different types
+SELECT
+ map_concat(tinyint_map1, smallint_map2) ts_map,
+ map_concat(smallint_map1, int_map2) si_map,
+ map_concat(int_map1, bigint_map2) ib_map,
+ map_concat(decimal_map1, float_map2) df_map,
+ map_concat(string_map1, date_map2) std_map,
+ map_concat(timestamp_map1, string_map2) tst_map,
+ map_concat(string_map1, int_map2) sti_map,
+ map_concat(int_string_map1, tinyint_map2) istt_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 1
+SELECT
+ map_concat(tinyint_map1, map_map2) tm_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 2
+SELECT
+ map_concat(boolean_map1, int_map2) bi_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 3
+SELECT
+ map_concat(int_map1, struct_map2) is_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 4
+SELECT
+ map_concat(map_map1, array_map2) ma_map
+FROM various_maps;
+
+-- Concatenate map of incompatible types 5
+SELECT
+ map_concat(map_map1, struct_map2) ms_map
+FROM various_maps;
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
index 2b3288dc5a137..3d49323751a10 100644
--- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 28
+-- Number of queries: 30
-- !query 0
@@ -183,7 +183,7 @@ select from_json('{"a":1}', 1)
struct<>
-- !query 17 output
org.apache.spark.sql.AnalysisException
-Expected a string literal instead of 1;; line 1 pos 7
+Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7
-- !query 18
@@ -274,3 +274,19 @@ select from_json('{"a":1, "b":"2"}', 'struct')
struct>
-- !query 27 output
{"a":1,"b":"2"}
+
+
+-- !query 28
+select schema_of_json('{"c1":0, "c2":[1]}')
+-- !query 28 schema
+struct
+-- !query 28 output
+struct>
+
+
+-- !query 29
+select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}'))
+-- !query 29 schema
+struct>>
+-- !query 29 output
+{"c1":[1,2,3]}
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
new file mode 100644
index 0000000000000..d352b7284ae87
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
@@ -0,0 +1,143 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 8
+
+
+-- !query 0
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+ map(true, false), map(false, true),
+ map(1Y, 2Y), map(3Y, 4Y),
+ map(1S, 2S), map(3S, 4S),
+ map(4, 6), map(7, 8),
+ map(6L, 7L), map(8L, 9L),
+ map(9223372036854775809, 9223372036854775808), map(9223372036854775808, 9223372036854775809),
+ map(1.0D, 2.0D), map(3.0D, 4.0D),
+ map(float(1.0D), float(2.0D)), map(float(3.0D), float(4.0D)),
+ map(date '2016-03-14', date '2016-03-13'), map(date '2016-03-12', date '2016-03-11'),
+ map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+ map(timestamp '2016-11-11 20:54:00.000', timestamp '2016-11-09 20:54:00.000'),
+ map('a', 'b'), map('c', 'd'),
+ map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
+ map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
+ map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
+ map('a', 1), map('c', 2),
+ map(1, 'a'), map(2, 'c')
+) AS various_maps (
+ boolean_map1, boolean_map2,
+ tinyint_map1, tinyint_map2,
+ smallint_map1, smallint_map2,
+ int_map1, int_map2,
+ bigint_map1, bigint_map2,
+ decimal_map1, decimal_map2,
+ double_map1, double_map2,
+ float_map1, float_map2,
+ date_map1, date_map2,
+ timestamp_map1,
+ timestamp_map2,
+ string_map1, string_map2,
+ array_map1, array_map2,
+ struct_map1, struct_map2,
+ map_map1, map_map2,
+ string_int_map1, string_int_map2,
+ int_string_map1, int_string_map2
+)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT
+ map_concat(boolean_map1, boolean_map2) boolean_map,
+ map_concat(tinyint_map1, tinyint_map2) tinyint_map,
+ map_concat(smallint_map1, smallint_map2) smallint_map,
+ map_concat(int_map1, int_map2) int_map,
+ map_concat(bigint_map1, bigint_map2) bigint_map,
+ map_concat(decimal_map1, decimal_map2) decimal_map,
+ map_concat(float_map1, float_map2) float_map,
+ map_concat(double_map1, double_map2) double_map,
+ map_concat(date_map1, date_map2) date_map,
+ map_concat(timestamp_map1, timestamp_map2) timestamp_map,
+ map_concat(string_map1, string_map2) string_map,
+ map_concat(array_map1, array_map2) array_map,
+ map_concat(struct_map1, struct_map2) struct_map,
+ map_concat(map_map1, map_map2) map_map,
+ map_concat(string_int_map1, string_int_map2) string_int_map,
+ map_concat(int_string_map1, int_string_map2) int_string_map
+FROM various_maps
+-- !query 1 schema
+struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map