diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index 5eb00c4aba0f9..30f3272c8b933 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -16,10 +16,15 @@ jobs:
matrix:
java: [ '1.8', '11' ]
hadoop: [ 'hadoop-2.7', 'hadoop-3.2' ]
+ hive: [ 'hive-1.2', 'hive-2.3' ]
exclude:
- java: '11'
hadoop: 'hadoop-2.7'
- name: Build Spark with JDK ${{ matrix.java }} and ${{ matrix.hadoop }}
+ - java: '11'
+ hive: 'hive-1.2'
+ - hadoop: 'hadoop-3.2'
+ hive: 'hive-1.2'
+ name: Build Spark - JDK${{ matrix.java }}/${{ matrix.hadoop }}/${{ matrix.hive }}
steps:
- uses: actions/checkout@master
@@ -36,6 +41,18 @@ jobs:
key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-org-${{ hashFiles('**/pom.xml') }}
restore-keys: |
${{ matrix.java }}-${{ matrix.hadoop }}-maven-org-
+ - uses: actions/cache@v1
+ with:
+ path: ~/.m2/repository/net
+ key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-net-${{ hashFiles('**/pom.xml') }}
+ restore-keys: |
+ ${{ matrix.java }}-${{ matrix.hadoop }}-maven-net-
+ - uses: actions/cache@v1
+ with:
+ path: ~/.m2/repository/io
+ key: ${{ matrix.java }}-${{ matrix.hadoop }}-maven-io-${{ hashFiles('**/pom.xml') }}
+ restore-keys: |
+ ${{ matrix.java }}-${{ matrix.hadoop }}-maven-io-
- name: Set up JDK ${{ matrix.java }}
uses: actions/setup-java@v1
with:
@@ -44,13 +61,13 @@ jobs:
run: |
export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g -Dorg.slf4j.simpleLogger.defaultLogLevel=WARN"
export MAVEN_CLI_OPTS="--no-transfer-progress"
- ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -P${{ matrix.hadoop }} -Phadoop-cloud -Djava.version=${{ matrix.java }} install
+ ./build/mvn $MAVEN_CLI_OPTS -DskipTests -Pyarn -Pmesos -Pkubernetes -Phive -P${{ matrix.hive }} -Phive-thriftserver -P${{ matrix.hadoop }} -Phadoop-cloud -Djava.version=${{ matrix.java }} install
rm -rf ~/.m2/repository/org/apache/spark
lint:
runs-on: ubuntu-latest
- name: Linters
+ name: Linters (Java/Scala/Python), licenses, dependencies
steps:
- uses: actions/checkout@master
- uses: actions/setup-java@v1
@@ -72,3 +89,26 @@ jobs:
run: ./dev/check-license
- name: Dependencies
run: ./dev/test-dependencies.sh
+
+ lintr:
+ runs-on: ubuntu-latest
+ name: Linter (R)
+ steps:
+ - uses: actions/checkout@master
+ - uses: actions/setup-java@v1
+ with:
+ java-version: '11'
+ - name: install R
+ run: |
+ echo 'deb https://cloud.r-project.org/bin/linux/ubuntu bionic-cran35/' | sudo tee -a /etc/apt/sources.list
+ curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0xE298A3A825C0D65DFD57CBB651716619E084DAB9" | sudo apt-key add
+ sudo apt-get update
+ sudo apt-get install -y r-base r-base-dev libcurl4-openssl-dev
+ - name: install R packages
+ run: |
+ sudo Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='https://cloud.r-project.org/')"
+ sudo Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')"
+ - name: package and install SparkR
+ run: ./R/install-dev.sh
+ - name: lint-r
+ run: ./dev/lint-r
diff --git a/R/pkg/.lintr b/R/pkg/.lintr
index c83ad2adfe0ef..67dc1218ea551 100644
--- a/R/pkg/.lintr
+++ b/R/pkg/.lintr
@@ -1,2 +1,2 @@
-linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE))
+linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), object_usage_linter = NULL, cyclocomp_linter = NULL)
exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 6f3c7c120ba3c..593d3ca16220d 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2252,7 +2252,7 @@ setMethod("mutate",
# The last column of the same name in the specific columns takes effect
deDupCols <- list()
- for (i in 1:length(cols)) {
+ for (i in seq_len(length(cols))) {
deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]])
}
@@ -2416,7 +2416,7 @@ setMethod("arrange",
# builds a list of columns of type Column
# example: [[1]] Column Species ASC
# [[2]] Column Petal_Length DESC
- jcols <- lapply(seq_len(length(decreasing)), function(i){
+ jcols <- lapply(seq_len(length(decreasing)), function(i) {
if (decreasing[[i]]) {
desc(getColumn(x, by[[i]]))
} else {
@@ -2749,7 +2749,7 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) {
col <- getColumn(x, colName)
if (colName %in% intersectedColNames) {
newJoin <- paste(colName, suffix, sep = "")
- if (newJoin %in% allColNames){
+ if (newJoin %in% allColNames) {
stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.",
"Please use different suffixes for the intersected columns.")
}
@@ -3475,7 +3475,7 @@ setMethod("str",
cat(paste0("'", class(object), "': ", length(names), " variables:\n"))
if (nrow(localDF) > 0) {
- for (i in 1 : ncol(localDF)) {
+ for (i in seq_len(ncol(localDF))) {
# Get the first elements for each column
firstElements <- if (types[i] == "character") {
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index f27ef4ee28f16..f48a334ed6766 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -166,9 +166,9 @@ writeToFileInArrow <- function(fileName, rdf, numPartitions) {
for (rdf_slice in rdf_slices) {
batch <- arrow::record_batch(rdf_slice)
if (is.null(stream_writer)) {
- stream <- arrow::FileOutputStream(fileName)
+ stream <- arrow::FileOutputStream$create(fileName)
schema <- batch$schema
- stream_writer <- arrow::RecordBatchStreamWriter(stream, schema)
+ stream_writer <- arrow::RecordBatchStreamWriter$create(stream, schema)
}
stream_writer$write_batch(batch)
@@ -197,7 +197,7 @@ getSchema <- function(schema, firstRow = NULL, rdd = NULL) {
as.list(schema)
}
if (is.null(names)) {
- names <- lapply(1:length(firstRow), function(x) {
+ names <- lapply(seq_len(length(firstRow)), function(x) {
paste0("_", as.character(x))
})
}
@@ -213,7 +213,7 @@ getSchema <- function(schema, firstRow = NULL, rdd = NULL) {
})
types <- lapply(firstRow, infer_type)
- fields <- lapply(1:length(firstRow), function(i) {
+ fields <- lapply(seq_len(length(firstRow)), function(i) {
structField(names[[i]], types[[i]], TRUE)
})
schema <- do.call(structType, fields)
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 93ba1307043a3..d96a287f818a2 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -416,7 +416,7 @@ spark.getSparkFiles <- function(fileName) {
#' @examples
#'\dontrun{
#' sparkR.session()
-#' doubled <- spark.lapply(1:10, function(x){2 * x})
+#' doubled <- spark.lapply(1:10, function(x) {2 * x})
#'}
#' @note spark.lapply since 2.0.0
spark.lapply <- function(list, func) {
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index a6febb1cbd132..ca4a6e342d772 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -242,7 +242,7 @@ readDeserializeInArrow <- function(inputCon) {
# for now.
dataLen <- readInt(inputCon)
arrowData <- readBin(inputCon, raw(), as.integer(dataLen), endian = "big")
- batches <- arrow::RecordBatchStreamReader(arrowData)$batches()
+ batches <- arrow::RecordBatchStreamReader$create(arrowData)$batches()
if (useAsTibble) {
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 6e8f4dc3a7907..2b7995e1e37f6 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -162,7 +162,7 @@ methods <- c("avg", "max", "mean", "min", "sum")
#' @note pivot since 2.0.0
setMethod("pivot",
signature(x = "GroupedData", colname = "character"),
- function(x, colname, values = list()){
+ function(x, colname, values = list()) {
stopifnot(length(colname) == 1)
if (length(values) == 0) {
result <- callJMethod(x@sgd, "pivot", colname)
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index c3501977e64bc..a8c1ddb3dd20b 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -131,7 +131,7 @@ hashCode <- function(key) {
} else {
asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) })
hashC <- 0
- for (k in 1:length(asciiVals)) {
+ for (k in seq_len(length(asciiVals))) {
hashC <- mult31AndAdd(hashC, asciiVals[k])
}
as.integer(hashC)
@@ -543,10 +543,14 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
- ifelse(identical(func, obj), TRUE, FALSE)
+ ifelse(
+ identical(func, obj) &&
+ # Also check if the parent environment is identical to current parent
+ identical(parent.env(environment(func)), func.env),
+ TRUE, FALSE)
})
if (sum(found) > 0) {
- # If function has been examined, ignore.
+ # If function has been examined ignore
break
}
# Function has not been examined, record it and recursively clean its closure.
@@ -724,7 +728,7 @@ assignNewEnv <- function(data) {
stopifnot(length(cols) > 0)
env <- new.env()
- for (i in 1:length(cols)) {
+ for (i in seq_len(length(cols))) {
assign(x = cols[i], value = data[, cols[i], drop = F], envir = env)
}
env
@@ -750,7 +754,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr
if (.Platform$OS.type == "windows") {
scriptWithArgs <- paste(script, combinedArgs, sep = " ")
# on Windows, intern = F seems to mean output to the console. (documentation on this is missing)
- shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait) # nolint
+ shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait)
} else {
# http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html
# stdout = F means discard output
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index dfe69b7f4f1fb..1ef05ea621e83 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -194,7 +194,7 @@ if (isEmpty != 0) {
} else {
# gapply mode
outputs <- list()
- for (i in 1:length(data)) {
+ for (i in seq_len(length(data))) {
# Timing reading input data for execution
inputElap <- elapsedSecs()
output <- compute(mode, partition, serializer, deserializer, keys[[i]],
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index c2b2458ec064b..cb47353d600db 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -172,7 +172,7 @@ test_that("structField type strings", {
typeList <- c(primitiveTypes, complexTypes)
typeStrings <- names(typeList)
- for (i in seq_along(typeStrings)){
+ for (i in seq_along(typeStrings)) {
typeString <- typeStrings[i]
expected <- typeList[[i]]
testField <- structField("_col", typeString)
@@ -203,7 +203,7 @@ test_that("structField type strings", {
errorList <- c(primitiveErrors, complexErrors)
typeStrings <- names(errorList)
- for (i in seq_along(typeStrings)){
+ for (i in seq_along(typeStrings)) {
typeString <- typeStrings[i]
expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]])
expect_error(structField("_col", typeString), expected)
diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R
index b2b6f34aaa085..c4fcbecee18e9 100644
--- a/R/pkg/tests/fulltests/test_utils.R
+++ b/R/pkg/tests/fulltests/test_utils.R
@@ -110,6 +110,15 @@ test_that("cleanClosure on R functions", {
actual <- get("y", envir = env, inherits = FALSE)
expect_equal(actual, y)
+ # Test for combination for nested and sequenctial functions in a closure
+ f1 <- function(x) x + 1
+ f2 <- function(x) f1(x) + 2
+ userFunc <- function(x) { f1(x); f2(x) }
+ cUserFuncEnv <- environment(cleanClosure(userFunc))
+ expect_equal(length(cUserFuncEnv), 2)
+ innerCUserFuncEnv <- environment(cUserFuncEnv$f2)
+ expect_equal(length(innerCUserFuncEnv), 1)
+
# Test for function (and variable) definitions.
f <- function(x) {
g <- function(y) { y * 2 }
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 86bd8aad5f113..51ca7d600caf0 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)"
diff --git a/appveyor.yml b/appveyor.yml
index b36175a787ae9..325fd67abc674 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -42,10 +42,7 @@ install:
# Install maven and dependencies
- ps: .\dev\appveyor-install-dependencies.ps1
# Required package for R unit tests
- - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival'), repos='https://cloud.r-project.org/')"
- # Use Arrow R 0.14.1 for now. 0.15.0 seems not working for now. See SPARK-29378.
- - cmd: R -e "install.packages(c('assertthat', 'bit64', 'fs', 'purrr', 'R6', 'tidyselect'), repos='https://cloud.r-project.org/')"
- - cmd: R -e "install.packages('https://cran.r-project.org/src/contrib/Archive/arrow/arrow_0.14.1.tar.gz', repos=NULL, type='source')"
+ - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')"
# Here, we use the fixed version of testthat. For more details, please see SPARK-22817.
# As of devtools 2.1.0, it requires testthat higher then 2.1.1 as a dependency. SparkR test requires testthat 1.0.2.
# Therefore, we don't use devtools but installs it directly from the archive including its dependencies.
@@ -56,7 +53,7 @@ install:
build_script:
# '-Djna.nosys=true' is required to avoid kernel32.dll load failure.
# See SPARK-28759.
- - cmd: mvn -DskipTests -Psparkr -Phive -Djna.nosys=true package
+ - cmd: mvn -DskipTests -Psparkr -Phive -Phive-1.2 -Djna.nosys=true package
environment:
NOT_CRAN: true
diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index 8995bbc940f63..36ca73f6ac0f0 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -90,7 +90,8 @@ CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException
return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}
- private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
+ @VisibleForTesting
+ CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv));
}
@@ -166,34 +167,45 @@ private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
- if (!isCipherValid) {
- throw new IOException("Cipher is in invalid state.");
- }
- byteChannel.feedData((ByteBuf) data);
-
- byte[] decryptedData = new byte[byteChannel.readableBytes()];
- int offset = 0;
- while (offset < decryptedData.length) {
- // SPARK-25535: workaround for CRYPTO-141.
- try {
- offset += cis.read(decryptedData, offset, decryptedData.length - offset);
- } catch (InternalError ie) {
- isCipherValid = false;
- throw ie;
+ ByteBuf buffer = (ByteBuf) data;
+
+ try {
+ if (!isCipherValid) {
+ throw new IOException("Cipher is in invalid state.");
+ }
+ byte[] decryptedData = new byte[buffer.readableBytes()];
+ byteChannel.feedData(buffer);
+
+ int offset = 0;
+ while (offset < decryptedData.length) {
+ // SPARK-25535: workaround for CRYPTO-141.
+ try {
+ offset += cis.read(decryptedData, offset, decryptedData.length - offset);
+ } catch (InternalError ie) {
+ isCipherValid = false;
+ throw ie;
+ }
}
- }
- ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
+ ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length));
+ } finally {
+ buffer.release();
+ }
}
@Override
- public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ // We do the closing of the stream / channel in handlerRemoved(...) as
+ // this method will be called in all cases:
+ //
+ // - when the Channel becomes inactive
+ // - when the handler is removed from the ChannelPipeline
try {
if (isCipherValid) {
cis.close();
}
} finally {
- super.channelInactive(ctx);
+ super.handlerRemoved(ctx);
}
}
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
index 25d103d0e316f..fe461d0b39862 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java
@@ -19,23 +19,27 @@
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import io.netty.buffer.ByteBuf;
public class ByteArrayReadableChannel implements ReadableByteChannel {
private ByteBuf data;
+ private boolean closed;
- public int readableBytes() {
- return data.readableBytes();
- }
-
- public void feedData(ByteBuf buf) {
+ public void feedData(ByteBuf buf) throws ClosedChannelException {
+ if (closed) {
+ throw new ClosedChannelException();
+ }
data = buf;
}
@Override
public int read(ByteBuffer dst) throws IOException {
+ if (closed) {
+ throw new ClosedChannelException();
+ }
int totalRead = 0;
while (data.readableBytes() > 0 && dst.remaining() > 0) {
int bytesToRead = Math.min(data.readableBytes(), dst.remaining());
@@ -43,20 +47,16 @@ public int read(ByteBuffer dst) throws IOException {
totalRead += bytesToRead;
}
- if (data.readableBytes() == 0) {
- data.release();
- }
-
return totalRead;
}
@Override
- public void close() throws IOException {
+ public void close() {
+ closed = true;
}
@Override
public boolean isOpen() {
- return true;
+ return !closed;
}
-
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 1980361a15523..cef0e415aa40a 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -184,8 +184,12 @@ private ByteBuf decodeNext() {
return null;
}
- // Reset buf and size for next frame.
+ return consumeCurrentFrameBuf();
+ }
+
+ private ByteBuf consumeCurrentFrameBuf() {
ByteBuf frame = frameBuf;
+ // Reset buf and size for next frame.
frameBuf = null;
consolidatedFrameBufSize = 0;
consolidatedNumComponents = 0;
@@ -215,13 +219,9 @@ private ByteBuf nextBufferForFrame(int bytesToRead) {
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
- for (ByteBuf b : buffers) {
- b.release();
- }
if (interceptor != null) {
interceptor.channelInactive();
}
- frameLenBuf.release();
super.channelInactive(ctx);
}
@@ -233,6 +233,24 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
super.exceptionCaught(ctx, cause);
}
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ // Release all buffers that are still in our ownership.
+ // Doing this in handlerRemoved(...) guarantees that this will happen in all cases:
+ // - When the Channel becomes inactive
+ // - When the decoder is removed from the ChannelPipeline
+ for (ByteBuf b : buffers) {
+ b.release();
+ }
+ buffers.clear();
+ frameLenBuf.release();
+ ByteBuf frame = consumeCurrentFrameBuf();
+ if (frame != null) {
+ frame.release();
+ }
+ super.handlerRemoved(ctx);
+ }
+
public void setInterceptor(Interceptor interceptor) {
Preconditions.checkState(this.interceptor == null, "Already have an interceptor.");
this.interceptor = interceptor;
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
new file mode 100644
index 0000000000000..6b2186f73cd0c
--- /dev/null
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.crypto;
+
+import javax.crypto.spec.SecretKeySpec;
+import java.io.IOException;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.apache.commons.crypto.stream.CryptoInputStream;
+import org.apache.commons.crypto.stream.CryptoOutputStream;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.hamcrest.CoreMatchers;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class TransportCipherSuite {
+
+ @Test
+ public void testBufferNotLeaksOnInternalError() throws IOException {
+ String algorithm = "TestAlgorithm";
+ TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY);
+ TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(),
+ new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) {
+
+ @Override
+ CryptoOutputStream createOutputStream(WritableByteChannel ch) {
+ return null;
+ }
+
+ @Override
+ CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException {
+ CryptoInputStream mockInputStream = mock(CryptoInputStream.class);
+ when(mockInputStream.read(any(byte[].class), anyInt(), anyInt()))
+ .thenThrow(new InternalError());
+ return mockInputStream;
+ }
+ };
+
+ EmbeddedChannel channel = new EmbeddedChannel();
+ cipher.addToChannel(channel);
+
+ ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
+ ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
+
+ try {
+ channel.writeInbound(buffer);
+ fail("Should have raised InternalError");
+ } catch (InternalError expected) {
+ // expected
+ assertEquals(0, buffer.refCnt());
+ }
+
+ try {
+ channel.writeInbound(buffer2);
+ fail("Should have raised an exception");
+ } catch (Throwable expected) {
+ assertThat(expected, CoreMatchers.instanceOf(IOException.class));
+ assertEquals(0, buffer2.refCnt());
+ }
+
+ // Simulate closing the connection
+ assertFalse(channel.finish());
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
index 457805feeac45..fb67d7220a0b4 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java
@@ -28,6 +28,7 @@
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.JavaUtils;
+import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -122,7 +123,7 @@ private void insertFile(String filename) throws IOException {
private void insertFile(String filename, byte[] block) throws IOException {
OutputStream dataStream = null;
File file = ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename);
- assert(!file.exists()) : "this test file has been already generated";
+ Assert.assertFalse("this test file has been already generated", file.exists());
try {
dataStream = new FileOutputStream(
ExecutorDiskUtils.getFile(localDirs, subDirsPerLocalDir, filename));
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index d7a498d1c1c2f..deecd4f015824 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -1063,7 +1063,7 @@ public static class IntWrapper implements Serializable {
}
/**
- * Parses this UTF8String to long.
+ * Parses this UTF8String(trimmed if needed) to long.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
@@ -1077,18 +1077,20 @@ public static class IntWrapper implements Serializable {
* @return true if the parsing was successful else false
*/
public boolean toLong(LongWrapper toLongResult) {
- if (numBytes == 0) {
- return false;
- }
+ int offset = 0;
+ while (offset < this.numBytes && getByte(offset) <= ' ') offset++;
+ if (offset == this.numBytes) return false;
- byte b = getByte(0);
+ int end = this.numBytes - 1;
+ while (end > offset && getByte(end) <= ' ') end--;
+
+ byte b = getByte(offset);
final boolean negative = b == '-';
- int offset = 0;
if (negative || b == '+') {
- offset++;
- if (numBytes == 1) {
+ if (end - offset == 0) {
return false;
}
+ offset++;
}
final byte separator = '.';
@@ -1096,7 +1098,7 @@ public boolean toLong(LongWrapper toLongResult) {
final long stopValue = Long.MIN_VALUE / radix;
long result = 0;
- while (offset < numBytes) {
+ while (offset <= end) {
b = getByte(offset);
offset++;
if (b == separator) {
@@ -1131,7 +1133,7 @@ public boolean toLong(LongWrapper toLongResult) {
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
- while (offset < numBytes) {
+ while (offset <= end) {
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
@@ -1151,7 +1153,7 @@ public boolean toLong(LongWrapper toLongResult) {
}
/**
- * Parses this UTF8String to int.
+ * Parses this UTF8String(trimmed if needed) to int.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
@@ -1168,18 +1170,20 @@ public boolean toLong(LongWrapper toLongResult) {
* @return true if the parsing was successful else false
*/
public boolean toInt(IntWrapper intWrapper) {
- if (numBytes == 0) {
- return false;
- }
+ int offset = 0;
+ while (offset < this.numBytes && getByte(offset) <= ' ') offset++;
+ if (offset == this.numBytes) return false;
- byte b = getByte(0);
+ int end = this.numBytes - 1;
+ while (end > offset && getByte(end) <= ' ') end--;
+
+ byte b = getByte(offset);
final boolean negative = b == '-';
- int offset = 0;
if (negative || b == '+') {
- offset++;
- if (numBytes == 1) {
+ if (end - offset == 0) {
return false;
}
+ offset++;
}
final byte separator = '.';
@@ -1187,7 +1191,7 @@ public boolean toInt(IntWrapper intWrapper) {
final int stopValue = Integer.MIN_VALUE / radix;
int result = 0;
- while (offset < numBytes) {
+ while (offset <= end) {
b = getByte(offset);
offset++;
if (b == separator) {
@@ -1222,7 +1226,7 @@ public boolean toInt(IntWrapper intWrapper) {
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
- while (offset < numBytes) {
+ while (offset <= end) {
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
index 6397f26c02f3a..01bf7eb2438ad 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
@@ -46,36 +46,6 @@ public void equalsTest() {
assertEquals(i1, i6);
}
- @Test
- public void toStringTest() {
- CalendarInterval i;
-
- i = new CalendarInterval(0, 0, 0);
- assertEquals("0 seconds", i.toString());
-
- i = new CalendarInterval(34, 0, 0);
- assertEquals("2 years 10 months", i.toString());
-
- i = new CalendarInterval(-34, 0, 0);
- assertEquals("-2 years -10 months", i.toString());
-
- i = new CalendarInterval(0, 31, 0);
- assertEquals("31 days", i.toString());
-
- i = new CalendarInterval(0, -31, 0);
- assertEquals("-31 days", i.toString());
-
- i = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123);
- assertEquals("3 hours 13 minutes 0.000123 seconds", i.toString());
-
- i = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123);
- assertEquals("-3 hours -13 minutes -0.000123 seconds", i.toString());
-
- i = new CalendarInterval(34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123);
- assertEquals("2 years 10 months 31 days 3 hours 13 minutes 0.000123 seconds",
- i.toString());
- }
-
@Test
public void periodAndDurationTest() {
CalendarInterval interval = new CalendarInterval(120, -40, 123456);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 1a9453a8b3e80..e14964d68119b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -205,6 +205,10 @@ public long getSortTimeNanos() {
}
public long getMemoryUsage() {
+ if (array == null) {
+ return 0L;
+ }
+
return array.size() * 8;
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
index 705a08f0293d3..b2cd616791734 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
@@ -83,8 +83,9 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime, offset) {
});
}
-$(function (){
- if (window.localStorage.getItem("expand-application-timeline") == "true") {
+$(function () {
+ if ($("span.expand-application-timeline").length &&
+ window.localStorage.getItem("expand-application-timeline") == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem("expand-application-timeline", "false");
$("span.expand-application-timeline").trigger('click');
@@ -159,8 +160,9 @@ function drawJobTimeline(groupArray, eventObjArray, startTime, offset) {
});
}
-$(function (){
- if (window.localStorage.getItem("expand-job-timeline") == "true") {
+$(function () {
+ if ($("span.expand-job-timeline").length &&
+ window.localStorage.getItem("expand-job-timeline") == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem("expand-job-timeline", "false");
$("span.expand-job-timeline").trigger('click');
@@ -226,8 +228,9 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma
});
}
-$(function (){
- if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") {
+$(function () {
+ if ($("span.expand-task-assignment-timeline").length &&
+ window.localStorage.getItem("expand-task-assignment-timeline") == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem("expand-task-assignment-timeline", "false");
$("span.expand-task-assignment-timeline").trigger('click');
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 873efa76468ed..3c6c181f9428c 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -701,8 +701,7 @@ private[spark] class MapOutputTrackerMaster(
if (shuffleStatus != null) {
shuffleStatus.withMapStatuses { statuses =>
if (mapId >= 0 && mapId < statuses.length) {
- Seq( ExecutorCacheTaskLocation(statuses(mapId).location.host,
- statuses(mapId).location.executorId).toString)
+ Seq(statuses(mapId).location.host)
} else {
Nil
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 1926a5268227c..df236ba8926c1 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -212,8 +212,13 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
daemonPort = in.readInt()
} catch {
+ case _: EOFException if daemon.isAlive =>
+ throw new SparkException("EOFException occurred while reading the port number " +
+ s"from $daemonModule's stdout")
case _: EOFException =>
- throw new SparkException(s"No port number in $daemonModule's stdout")
+ throw new SparkException(
+ s"EOFException occurred while reading the port number from $daemonModule's" +
+ s" stdout and terminated with code: ${daemon.exitValue}.")
}
// test that the returned port number is within a valid range.
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
index 76537afd81ce0..c6eb461ad601c 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala
@@ -39,15 +39,20 @@ private[ui] class EnvironmentPage(
"Scala Version" -> appEnv.runtime.scalaVersion)
val runtimeInformationTable = UIUtils.listingTable(
- propertyHeader, jvmRow, jvmInformation.toSeq.sorted, fixedWidth = true)
+ propertyHeader, jvmRow, jvmInformation.toSeq.sorted, fixedWidth = true,
+ headerClasses = headerClasses)
val sparkPropertiesTable = UIUtils.listingTable(propertyHeader, propertyRow,
- Utils.redact(conf, appEnv.sparkProperties.sorted), fixedWidth = true)
+ Utils.redact(conf, appEnv.sparkProperties.sorted), fixedWidth = true,
+ headerClasses = headerClasses)
val hadoopPropertiesTable = UIUtils.listingTable(propertyHeader, propertyRow,
- Utils.redact(conf, appEnv.hadoopProperties.sorted), fixedWidth = true)
+ Utils.redact(conf, appEnv.hadoopProperties.sorted), fixedWidth = true,
+ headerClasses = headerClasses)
val systemPropertiesTable = UIUtils.listingTable(propertyHeader, propertyRow,
- Utils.redact(conf, appEnv.systemProperties.sorted), fixedWidth = true)
+ Utils.redact(conf, appEnv.systemProperties.sorted), fixedWidth = true,
+ headerClasses = headerClasses)
val classpathEntriesTable = UIUtils.listingTable(
- classPathHeaders, classPathRow, appEnv.classpathEntries.sorted, fixedWidth = true)
+ classPathHeader, classPathRow, appEnv.classpathEntries.sorted, fixedWidth = true,
+ headerClasses = headerClasses)
val content =
{kv._1} | {kv._2} |
private def propertyRow(kv: (String, String)) = {kv._1} | {kv._2} |
private def classPathRow(data: (String, String)) = {data._1} | {data._2} |
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index a13037b5e24db..77564f48015f1 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -89,7 +89,12 @@ private[ui] class ExecutorThreadDumpPage(
Thread ID |
Thread Name |
Thread State |
- Thread Locks |
+
+
+ Thread Locks
+
+ |
{dumpRows}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 8173a8e545ebb..3f309819065be 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import org.scalatest.Assertions._
import org.scalatest.Matchers
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.time.{Millis, Span}
diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
index 5dbef88e73a9e..78f1246295bf8 100644
--- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import org.scalatest.Assertions._
+
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.benchmark.BenchmarkBase
import org.apache.spark.scheduler.CompressedMapStatus
diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
index 022fcbb25b0af..9629f5ab1a3dd 100644
--- a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
+++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala
@@ -26,6 +26,7 @@ import scala.util.Try
import org.apache.commons.io.output.TeeOutputStream
import org.apache.commons.lang3.SystemUtils
+import org.scalatest.Assertions._
import org.apache.spark.util.Utils
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 2115ee8b1b723..7272a98c9770b 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -33,6 +33,7 @@ import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.scalatest.Assertions._
import org.scalatest.PrivateMethodTester
import org.scalatest.concurrent.Eventually
import org.scalatestplus.mockito.MockitoSugar
diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
index 5e8da3e205ab0..7c65f3b126e3d 100644
--- a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rpc
import scala.collection.mutable.ArrayBuffer
import org.scalactic.TripleEquals
+import org.scalatest.Assertions._
class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 4f737c9499ad6..dff8975a4fe49 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -26,6 +26,7 @@ import scala.concurrent.duration.{Duration, SECONDS}
import scala.reflect.ClassTag
import org.scalactic.TripleEquals
+import org.scalatest.Assertions
import org.scalatest.Assertions.AssertionsHelper
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
@@ -463,7 +464,7 @@ class MockRDD(
override def toString: String = "MockRDD " + id
}
-object MockRDD extends AssertionsHelper with TripleEquals {
+object MockRDD extends AssertionsHelper with TripleEquals with Assertions {
/**
* make sure all the shuffle dependencies have a consistent number of output partitions
* (mostly to make sure the test setup makes sense, not that Spark itself would get this wrong)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 8439be955c738..406bd9244870e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -29,6 +29,7 @@ import com.google.common.util.concurrent.MoreExecutors
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, anyLong}
import org.mockito.Mockito.{spy, times, verify}
+import org.scalatest.Assertions._
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 89df5de97c444..34bcae8abd512 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.FileAlreadyExistsException
import org.mockito.ArgumentMatchers.{any, anyBoolean, anyInt, anyString}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
+import org.scalatest.Assertions._
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -128,7 +129,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
def removeExecutor(execId: String): Unit = {
executors -= execId
val host = executorIdToHost.get(execId)
- assert(host != None)
+ assert(host.isDefined)
val hostId = host.get
val executorsOnHost = hostToExecutors(hostId)
executorsOnHost -= execId
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala
index d8657ecdff676..3d70ff1fed29f 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala
@@ -21,6 +21,7 @@ import java.util.{Map => JMap}
import java.util.concurrent.atomic.AtomicBoolean
import com.google.common.collect.ImmutableMap
+import org.scalatest.Assertions._
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
diff --git a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
index 06c2ceb68bd79..f14ec175232be 100644
--- a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import org.apache.hadoop.fs.Path
+import org.scalatest.Assertions._
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index f1069d4490b5a..99c4b20102929 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -138,7 +138,8 @@ fi
# Hive-specific profiles for some builds
HIVE_PROFILES="-Phive -Phive-thriftserver"
# Profiles for publishing snapshots and release to Maven Central
-PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl"
+# We use Apache Hive 2.3 for publishing
+PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Phive-2.3 -Pspark-ganglia-lgpl -Pkinesis-asl"
# Profiles for building binary releases
BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr"
diff --git a/dev/lint-r b/dev/lint-r
index bfda0bca15eb7..b08f5efecd5d3 100755
--- a/dev/lint-r
+++ b/dev/lint-r
@@ -17,6 +17,9 @@
# limitations under the License.
#
+set -o pipefail
+set -e
+
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"
LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log"
@@ -24,7 +27,7 @@ LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log"
if ! type "Rscript" > /dev/null; then
echo "ERROR: You should install R"
- exit
+ exit 1
fi
`which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME"
diff --git a/dev/lint-r.R b/dev/lint-r.R
index a4261d266bbc0..7e165319e316a 100644
--- a/dev/lint-r.R
+++ b/dev/lint-r.R
@@ -27,7 +27,7 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) {
# Installs lintr from Github in a local directory.
# NOTE: The CRAN's version is too old to adapt to our rules.
if ("lintr" %in% row.names(installed.packages()) == FALSE) {
- devtools::install_github("jimhester/lintr@5431140")
+ devtools::install_github("jimhester/lintr@v2.0.0")
}
library(lintr)
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 82277720bb52f..2d52ead06a041 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -43,15 +43,20 @@ def determine_modules_for_files(filenames):
"""
Given a list of filenames, return the set of modules that contain those files.
If a file is not associated with a more specific submodule, then this method will consider that
- file to belong to the 'root' module.
+ file to belong to the 'root' module. GitHub Action and Appveyor files are ignored.
>>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"]))
['pyspark-core', 'sql']
>>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])]
['root']
+ >>> [x.name for x in determine_modules_for_files( \
+ [".github/workflows/master.yml", "appveyor.yml"])]
+ []
"""
changed_modules = set()
for filename in filenames:
+ if filename in (".github/workflows/master.yml", "appveyor.yml"):
+ continue
matched_at_least_one_module = False
for module in modules.all_modules:
if module.contains_file(filename):
@@ -278,8 +283,8 @@ def get_hadoop_profiles(hadoop_version):
"""
sbt_maven_hadoop_profiles = {
- "hadoop2.7": ["-Phadoop-2.7"],
- "hadoop3.2": ["-Phadoop-3.2"],
+ "hadoop2.7": ["-Phadoop-2.7", "-Phive-1.2"],
+ "hadoop3.2": ["-Phadoop-3.2", "-Phive-2.3"],
}
if hadoop_version in sbt_maven_hadoop_profiles:
diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh
index cc0292e9c2ea5..7d5725aaf137e 100755
--- a/dev/test-dependencies.sh
+++ b/dev/test-dependencies.sh
@@ -67,15 +67,20 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de
# Generate manifests for each Hadoop profile:
for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do
+ if [[ $HADOOP_PROFILE == **hadoop-3** ]]; then
+ HIVE_PROFILE=hive-2.3
+ else
+ HIVE_PROFILE=hive-1.2
+ fi
echo "Performing Maven install for $HADOOP_PROFILE"
- $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install clean -q
+ $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE jar:jar jar:test-jar install:install clean -q
echo "Performing Maven validate for $HADOOP_PROFILE"
- $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE validate -q
+ $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE validate -q
echo "Generating dependency manifest for $HADOOP_PROFILE"
mkdir -p dev/pr-deps
- $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly -am \
+ $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE dependency:build-classpath -pl assembly -am \
| grep "Dependencies classpath:" -A 1 \
| tail -n 1 | tr ":" "\n" | rev | cut -d "/" -f 1 | rev | sort \
| grep -v spark > dev/pr-deps/spark-deps-$HADOOP_PROFILE
diff --git a/docs/configuration.md b/docs/configuration.md
index 97ea1fb4ba041..0c7cc6022eb09 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1857,6 +1857,51 @@ Apart from these, the following properties are also available, and may be useful
driver using more memory.
+
+ spark.scheduler.listenerbus.eventqueue.shared.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for shared event queue in Spark listener bus, which hold events for external listener(s)
+ that register to the listener bus. Consider increasing value, if the listener events corresponding
+ to shared queue are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.appStatus.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for appStatus event queue, which hold events for internal application status listeners.
+ Consider increasing value, if the listener events corresponding to appStatus queue are dropped.
+ Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.executorManagement.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for executorManagement event queue in Spark listener bus, which hold events for internal
+ executor management listeners. Consider increasing value if the listener events corresponding to
+ executorManagement queue are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.eventLog.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for eventLog queue in Spark listener bus, which hold events for Event logging listeners
+ that write events to eventLogs. Consider increasing value if the listener events corresponding to eventLog queue
+ are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.streams.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for streams queue in Spark listener bus, which hold events for internal streaming listener.
+ Consider increasing value if the listener events corresponding to streams queue are dropped. Increasing
+ this value may result in the driver using more memory.
+ |
+
spark.scheduler.blacklist.unschedulableTaskSetTimeout |
120s |
diff --git a/docs/img/webui-sql-dag.png b/docs/img/webui-sql-dag.png
index 4ca21092e8b39..1c83c176da325 100644
Binary files a/docs/img/webui-sql-dag.png and b/docs/img/webui-sql-dag.png differ
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index b83b4ba08a5fd..05c688960f04c 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -478,16 +478,18 @@ it computes the conditional probability distribution of each feature given each
For prediction, it applies Bayes' theorem to compute the conditional probability distribution
of each label given an observation.
-MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
-and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+MLlib supports [Multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes),
+[Complement naive Bayes](https://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf),
+[Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html)
+and [Gaussian naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes).
*Input data*:
-These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+These Multinomial, Complement and Bernoulli models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
Within that context, each observation is a document and each feature represents a term.
-A feature's value is the frequency of the term (in multinomial Naive Bayes) or
+A feature's value is the frequency of the term (in Multinomial or Complement Naive Bayes) or
a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes).
-Feature values must be *non-negative*. The model type is selected with an optional parameter
-"multinomial" or "bernoulli" with "multinomial" as the default.
+Feature values for Multinomial and Bernoulli models must be *non-negative*. The model type is selected with an optional parameter
+"multinomial", "complement", "bernoulli" or "gaussian", with "multinomial" as the default.
For document classification, the input feature vectors should usually be sparse vectors.
Since the training data is only used once, it is not necessary to cache it.
diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md
index 81d7ce37af178..3117ee40a8c9b 100644
--- a/docs/sql-keywords.md
+++ b/docs/sql-keywords.md
@@ -19,15 +19,16 @@ license: |
limitations under the License.
---
-When `spark.sql.ansi.enabled` is true, Spark SQL has two kinds of keywords:
+When `spark.sql.dialect=PostgreSQL` or keep default `spark.sql.dialect=Spark` with setting `spark.sql.dialect.spark.ansi.enabled` to true, Spark SQL will use the ANSI mode parser.
+In this mode, Spark SQL has two kinds of keywords:
* Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc.
* Non-reserved keywords: Keywords that have a special meaning only in particular contexts and can be used as identifiers in other contexts. For example, `SELECT 1 WEEK` is an interval literal, but WEEK can be used as identifiers in other places.
-When `spark.sql.ansi.enabled` is false, Spark SQL has two kinds of keywords:
-* Non-reserved keywords: Same definition as the one when `spark.sql.ansi.enabled=true`.
+When the ANSI mode is disabled, Spark SQL has two kinds of keywords:
+* Non-reserved keywords: Same definition as the one when the ANSI mode enabled.
* Strict-non-reserved keywords: A strict version of non-reserved keywords, which can not be used as table alias.
-By default `spark.sql.ansi.enabled` is false.
+By default `spark.sql.dialect.spark.ansi.enabled` is false.
Below is a list of all the keywords in Spark SQL.
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 2d5afa919e668..6fc78893e6881 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -222,6 +222,8 @@ license: |
- Since Spark 3.0, when casting interval values to string type, there is no "interval" prefix, e.g. `1 days 2 hours`. In Spark version 2.4 and earlier, the string contains the "interval" prefix like `interval 1 days 2 hours`.
+ - Since Spark 3.0, when casting string value to integral types, including tinyint, smallint, int and bigint type, the leading and trailing white spaces(<= ACSII 32) will be trimmed before convert to integral values, e.g. `cast(' 1 ' as int)` results `1`. In Spark version 2.4 and earlier, the result will be `null`.
+
## Upgrading from Spark SQL 2.4 to 2.4.1
- The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was
diff --git a/docs/sql-ref-syntax-aux-show-partitions.md b/docs/sql-ref-syntax-aux-show-partitions.md
index c6499de9cbb9e..216f3f0d679ec 100644
--- a/docs/sql-ref-syntax-aux-show-partitions.md
+++ b/docs/sql-ref-syntax-aux-show-partitions.md
@@ -18,5 +18,86 @@ license: |
See the License for the specific language governing permissions and
limitations under the License.
---
+### Description
-**This page is under construction**
+The `SHOW PARTITIONS` statement is used to list partitions of a table. An optional
+partition spec may be specified to return the partitions matching the supplied
+partition spec.
+
+### Syntax
+{% highlight sql %}
+SHOW PARTITIONS table_name
+ [ PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] ) ]
+{% endhighlight %}
+
+### Parameters
+
+ table_name
+ - The name of an existing table.
+
+
+ PARTITION ( partition_col_name [ = partition_col_val ] [ , ... ] )
+ - An optional parameter that specifies a comma separated list of key and value pairs for
+ partitions. When specified, the partitions that match the partition spec are returned.
+
+
+### Examples
+{% highlight sql %}
+-- create a partitioned table and insert a few rows.
+USE salesdb;
+CREATE TABLE customer(id INT, name STRING) PARTITIONED BY (state STRING, city STRING);
+INSERT INTO customer PARTITION (state = 'CA', city = 'Fremont') VALUES (100, 'John');
+INSERT INTO customer PARTITION (state = 'CA', city = 'San Jose') VALUES (200, 'Marry');
+INSERT INTO customer PARTITION (state = 'AZ', city = 'Peoria') VALUES (300, 'Daniel');
+
+-- Lists all partitions for table `customer`
+SHOW PARTITIONS customer;
+ +----------------------+
+ |partition |
+ +----------------------+
+ |state=AZ/city=Peoria |
+ |state=CA/city=Fremont |
+ |state=CA/city=San Jose|
+ +----------------------+
+
+-- Lists all partitions for the qualified table `customer`
+SHOW PARTITIONS salesdb.customer;
+ +----------------------+
+ |partition |
+ +----------------------+
+ |state=AZ/city=Peoria |
+ |state=CA/city=Fremont |
+ |state=CA/city=San Jose|
+ +----------------------+
+
+-- Specify a full partition spec to list specific partition
+SHOW PARTITIONS customer PARTITION (state = 'CA', city = 'Fremont');
+ +---------------------+
+ |partition |
+ +---------------------+
+ |state=CA/city=Fremont|
+ +---------------------+
+
+-- Specify a partial partition spec to list the specific partitions
+SHOW PARTITIONS customer PARTITION (state = 'CA');
+ +----------------------+
+ |partition |
+ +----------------------+
+ |state=CA/city=Fremont |
+ |state=CA/city=San Jose|
+ +----------------------+
+
+-- Specify a partial spec to list specific partition
+SHOW PARTITIONS customer PARTITION (city = 'San Jose');
+ +----------------------+
+ |partition |
+ +----------------------+
+ |state=CA/city=San Jose|
+ +----------------------+
+{% endhighlight %}
+
+### Related statements
+- [CREATE TABLE](sql-ref-syntax-ddl-create-table.html)
+- [INSERT STATEMENT](sql-ref-syntax-dml-insert.html)
+- [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html)
+- [SHOW TABLE](sql-ref-syntax-aux-show-table.html)
diff --git a/docs/web-ui.md b/docs/web-ui.md
index e6025370e6796..f94e81ca67961 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -336,7 +336,7 @@ scala> spark.sql("select name,sum(count) from global_temp.df group by name").sho
Now the above three dataframe/SQL operators are shown in the list. If we click the
-'show at \: 24' link of the last query, we will see the DAG of the job.
+'show at \: 24' link of the last query, we will see the DAG and details of the query execution.
-We can see that details information of each stage. The first block 'WholeStageCodegen'
-compile multiple operator ('LocalTableScan' and 'HashAggregate') together into a single Java
-function to improve performance, and metrics like number of rows and spill size are listed in
-the block. The second block 'Exchange' shows the metrics on the shuffle exchange, including
+The query details page displays information about the query execution time, its duration,
+the list of associated jobs, and the query execution DAG.
+The first block 'WholeStageCodegen (1)' compiles multiple operators ('LocalTableScan' and 'HashAggregate') together into a single Java
+function to improve performance, and metrics like number of rows and spill size are listed in the block.
+The annotation '(1)' in the block name is the code generation id.
+The second block 'Exchange' shows the metrics on the shuffle exchange, including
number of written shuffle records, total data size, etc.
@@ -362,6 +364,8 @@ number of written shuffle records, total data size, etc.
Clicking the 'Details' link on the bottom displays the logical plans and the physical plan, which
illustrate how Spark parses, analyzes, optimizes and performs the query.
+Steps in the physical plan subject to whole stage code generation optimization, are prefixed by a star followed by
+the code generation id, for example: '*(1) LocalTableScan'
### SQL metrics
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index f1cd3343b7925..efd7ca74c796b 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -59,7 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO numbers VALUES (
|0,
- |127, 32767, 2147483647, 9223372036854775807,
+ |255, 32767, 2147483647, 9223372036854775807,
|123456789012345.123456789012345, 123456789012345.123456789012345,
|123456789012345.123456789012345,
|123, 12345.12,
@@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
val types = row.toSeq.map(x => x.getClass.toString)
assert(types.length == 12)
assert(types(0).equals("class java.lang.Boolean"))
- assert(types(1).equals("class java.lang.Byte"))
+ assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Short"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Long"))
@@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(10).equals("class java.math.BigDecimal"))
assert(types(11).equals("class java.math.BigDecimal"))
assert(row.getBoolean(0) == false)
- assert(row.getByte(1) == 127)
+ assert(row.getInt(1) == 255)
assert(row.getShort(2) == 32767)
assert(row.getInt(3) == 2147483647)
assert(row.getLong(4) == 9223372036854775807L)
@@ -202,46 +202,4 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
}
-
- test("SPARK-29644: Write tables with ShortType") {
- import testImplicits._
- val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a")
- val tablename = "shorttable"
- df.write
- .format("jdbc")
- .mode("overwrite")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .save()
- val df2 = spark.read
- .format("jdbc")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .load()
- assert(df.count == df2.count)
- val rows = df2.collect()
- val colType = rows(0).toSeq.map(x => x.getClass.toString)
- assert(colType(0) == "class java.lang.Short")
- }
-
- test("SPARK-29644: Write tables with ByteType") {
- import testImplicits._
- val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a")
- val tablename = "bytetable"
- df.write
- .format("jdbc")
- .mode("overwrite")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .save()
- val df2 = spark.read
- .format("jdbc")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .load()
- assert(df.count == df2.count)
- val rows = df2.collect()
- val colType = rows(0).toSeq.map(x => x.getClass.toString)
- assert(colType(0) == "class java.lang.Byte")
- }
}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index 8401b0a8a752f..bba1b5275269b 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -84,7 +84,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types.length == 9)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Long"))
- assert(types(2).equals("class java.lang.Short"))
+ assert(types(2).equals("class java.lang.Integer"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Integer"))
assert(types(5).equals("class java.lang.Long"))
@@ -93,7 +93,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(8).equals("class java.lang.Double"))
assert(rows(0).getBoolean(0) == false)
assert(rows(0).getLong(1) == 0x225)
- assert(rows(0).getShort(2) == 17)
+ assert(rows(0).getInt(2) == 17)
assert(rows(0).getInt(3) == 77777)
assert(rows(0).getInt(4) == 123456789)
assert(rows(0).getLong(5) == 123456789012345L)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
index 8e29e38b2a644..56c0fdd7c35b7 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.types.StructType
@@ -40,7 +40,7 @@ private[kafka010] class KafkaBatchWrite(
validateQuery(schema.toAttributes, producerParams, topic)
- override def createBatchWriterFactory(): KafkaBatchWriterFactory =
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): KafkaBatchWriterFactory =
KafkaBatchWriterFactory(topic, producerParams, schema)
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
index 2b50b771e694e..bcf9e3416f843 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.types.StructType
@@ -41,7 +41,8 @@ private[kafka010] class KafkaStreamingWrite(
validateQuery(schema.toAttributes, producerParams, topic)
- override def createStreamingWriterFactory(): KafkaStreamWriterFactory =
+ override def createStreamingWriterFactory(
+ info: PhysicalWriteInfo): KafkaStreamWriterFactory =
KafkaStreamWriterFactory(topic, producerParams, schema)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index 7e78dfdd0dae6..28b33faabec1f 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -18,10 +18,9 @@
package org.apache.spark.sql.kafka010
import java.io.{File, IOException}
-import java.lang.{Integer => JInt}
import java.net.{InetAddress, InetSocketAddress}
import java.nio.charset.StandardCharsets
-import java.util.{Collections, Map => JMap, Properties, UUID}
+import java.util.{Collections, Properties, UUID}
import java.util.concurrent.TimeUnit
import javax.security.auth.login.Configuration
@@ -41,13 +40,12 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.SaslConfigs
-import org.apache.kafka.common.header.Header
-import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
import org.apache.zookeeper.server.auth.SASLAuthenticationProvider
+import org.scalatest.Assertions._
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java
index 03becd73d1a06..7af0abe0e8d90 100644
--- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java
+++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java
@@ -18,12 +18,14 @@
package org.apache.spark.streaming.kinesis;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
+import org.junit.Assert;
+import org.junit.Test;
+
import org.apache.spark.streaming.kinesis.KinesisInitialPositions.TrimHorizon;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.Seconds;
-import org.junit.Test;
public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext {
/**
@@ -49,13 +51,14 @@ public void testJavaKinesisDStreamBuilder() {
.checkpointInterval(checkpointInterval)
.storageLevel(storageLevel)
.build();
- assert(kinesisDStream.streamName() == streamName);
- assert(kinesisDStream.endpointUrl() == endpointUrl);
- assert(kinesisDStream.regionName() == region);
- assert(kinesisDStream.initialPosition().getPosition() == initialPosition.getPosition());
- assert(kinesisDStream.checkpointAppName() == appName);
- assert(kinesisDStream.checkpointInterval() == checkpointInterval);
- assert(kinesisDStream._storageLevel() == storageLevel);
+ Assert.assertEquals(streamName, kinesisDStream.streamName());
+ Assert.assertEquals(endpointUrl, kinesisDStream.endpointUrl());
+ Assert.assertEquals(region, kinesisDStream.regionName());
+ Assert.assertEquals(initialPosition.getPosition(),
+ kinesisDStream.initialPosition().getPosition());
+ Assert.assertEquals(appName, kinesisDStream.checkpointAppName());
+ Assert.assertEquals(checkpointInterval, kinesisDStream.checkpointInterval());
+ Assert.assertEquals(storageLevel, kinesisDStream._storageLevel());
ssc.stop();
}
@@ -83,13 +86,14 @@ public void testJavaKinesisDStreamBuilderOldApi() {
.checkpointInterval(checkpointInterval)
.storageLevel(storageLevel)
.build();
- assert(kinesisDStream.streamName() == streamName);
- assert(kinesisDStream.endpointUrl() == endpointUrl);
- assert(kinesisDStream.regionName() == region);
- assert(kinesisDStream.initialPosition().getPosition() == InitialPositionInStream.LATEST);
- assert(kinesisDStream.checkpointAppName() == appName);
- assert(kinesisDStream.checkpointInterval() == checkpointInterval);
- assert(kinesisDStream._storageLevel() == storageLevel);
+ Assert.assertEquals(streamName, kinesisDStream.streamName());
+ Assert.assertEquals(endpointUrl, kinesisDStream.endpointUrl());
+ Assert.assertEquals(region, kinesisDStream.regionName());
+ Assert.assertEquals(InitialPositionInStream.LATEST,
+ kinesisDStream.initialPosition().getPosition());
+ Assert.assertEquals(appName, kinesisDStream.checkpointAppName());
+ Assert.assertEquals(checkpointInterval, kinesisDStream.checkpointInterval());
+ Assert.assertEquals(storageLevel, kinesisDStream._storageLevel());
ssc.stop();
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
index e0c65e6940f66..e3471759b3a70 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.graphx.util
import org.apache.hadoop.fs.Path
+import org.scalatest.Assertions
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
@@ -88,7 +89,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContex
}
}
-private object PeriodicGraphCheckpointerSuite {
+private object PeriodicGraphCheckpointerSuite extends Assertions {
private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER
case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index bcca40d159c9b..52a0f4d9b9828 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -18,18 +18,22 @@
package org.apache.spark.ml.classification
import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.VersionUtils
/**
* Params for Naive Bayes Classifiers.
@@ -49,12 +53,13 @@ private[classification] trait NaiveBayesParams extends PredictorParams with HasW
/**
* The model type which is a string (case-sensitive).
- * Supported options: "multinomial" and "bernoulli".
+ * Supported options: "multinomial", "complement", "bernoulli", "gaussian".
* (default = multinomial)
* @group param
*/
final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
- "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ "which is a string (case-sensitive). Supported options: multinomial (default), complement, " +
+ "bernoulli and gaussian.",
ParamValidators.inArray[String](NaiveBayes.supportedModelTypes.toArray))
/** @group getParam */
@@ -72,7 +77,16 @@ private[classification] trait NaiveBayesParams extends PredictorParams with HasW
* binary (0/1) data, it can also be used as Bernoulli NB
* (see
* here).
- * The input feature values must be nonnegative.
+ * The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
+ * Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB. Specifically,
+ * Complement NB uses statistics from the complement of each class to compute the model's coefficients
+ * The inventors of Complement NB show empirically that the parameter estimates for CNB are more stable
+ * than those for Multinomial NB. Like Multinomial NB, the input feature values for Complement NB must
+ * be nonnegative.
+ * Since 3.0.0, it also supports Gaussian NB
+ * (see
+ * here)
+ * which can handle continuous data.
*/
// scalastyle:on line.size.limit
@Since("1.5.0")
@@ -97,13 +111,13 @@ class NaiveBayes @Since("1.5.0") (
/**
* Set the model type using a string (case-sensitive).
- * Supported options: "multinomial" and "bernoulli".
+ * Supported options: "multinomial", "complement", "bernoulli", and "gaussian".
* Default is "multinomial"
* @group setParam
*/
@Since("1.5.0")
def setModelType(value: String): this.type = set(modelType, value)
- setDefault(modelType -> NaiveBayes.Multinomial)
+ setDefault(modelType -> Multinomial)
/**
* Sets the value of param [[weightCol]].
@@ -130,6 +144,9 @@ class NaiveBayes @Since("1.5.0") (
positiveLabel: Boolean): NaiveBayesModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
+ instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
+ probabilityCol, modelType, smoothing, thresholds)
+
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
instr.logNumClasses(numClasses)
@@ -138,65 +155,174 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val validateInstance = $(modelType) match {
- case Multinomial =>
- (instance: Instance) => requireNonnegativeValues(instance.features)
- case Bernoulli =>
- (instance: Instance) => requireZeroOneBernoulliValues(instance.features)
+ $(modelType) match {
+ case Bernoulli | Multinomial | Complement =>
+ trainDiscreteImpl(dataset, instr)
+ case Gaussian =>
+ trainGaussianImpl(dataset, instr)
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
+ }
- instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
- probabilityCol, modelType, smoothing, thresholds)
+ private def trainDiscreteImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
- val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
- instr.logNumFeatures(numFeatures)
+ val validateUDF = $(modelType) match {
+ case Multinomial | Complement =>
+ udf { vector: Vector => requireNonnegativeValues(vector); vector }
+ case Bernoulli =>
+ udf { vector: Vector => requireZeroOneBernoulliValues(vector); vector }
+ }
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
// Aggregates term frequencies per label.
- // TODO: Calling aggregateByKey and collect creates two stages, we can implement something
- // TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = extractInstances(dataset, validateInstance).map { instance =>
- (instance.label, (instance.weight, instance.features))
- }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
- seqOp = {
- case ((weightSum, featureSum, count), (weight, features)) =>
- BLAS.axpy(weight, features, featureSum)
- (weightSum + weight, featureSum, count + 1)
- },
- combOp = {
- case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2, count2)) =>
- BLAS.axpy(1.0, featureSum2, featureSum1)
- (weightSum1 + weightSum2, featureSum1, count1 + count2)
- }).collect().sortBy(_._1)
-
- val numSamples = aggregated.map(_._2._3).sum
+ // TODO: Summarizer directly returns sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "count")
+ .summary(validateUDF(col($(featuresCol))), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.count")
+ .as[(Double, Double, Vector, Long)]
+ .map { case (label, weightSum, mean, count) =>
+ BLAS.scal(weightSum, mean)
+ (label, weightSum, mean, count)
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+ val numSamples = aggregated.map(_._4).sum
instr.logNumExamples(numSamples)
val numLabels = aggregated.length
instr.logNumClasses(numLabels)
- val numDocuments = aggregated.map(_._2._1).sum
+ val numDocuments = aggregated.map(_._2).sum
val labelArray = new Array[Double](numLabels)
val piArray = new Array[Double](numLabels)
val thetaArray = new Array[Double](numLabels * numFeatures)
+ val aggIter = $(modelType) match {
+ case Multinomial | Bernoulli => aggregated.iterator
+ case Complement =>
+ val featureSum = Vectors.zeros(numFeatures)
+ aggregated.foreach { case (_, _, sumTermFreqs, _) =>
+ BLAS.axpy(1.0, sumTermFreqs, featureSum)
+ }
+ aggregated.iterator.map { case (label, n, sumTermFreqs, count) =>
+ val comp = featureSum.copy
+ BLAS.axpy(-1.0, sumTermFreqs, comp)
+ (label, n, comp, count)
+ }
+ }
+
val lambda = $(smoothing)
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
- aggregated.foreach { case (label, (n, sumTermFreqs, _)) =>
+ aggIter.foreach { case (label, n, sumTermFreqs, _) =>
labelArray(i) = label
piArray(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = $(modelType) match {
- case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
+ case Multinomial | Complement =>
+ math.log(sumTermFreqs.toArray.sum + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
- case _ =>
- // This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
var j = 0
+ val offset = i * numFeatures
while (j < numFeatures) {
- thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ thetaArray(offset + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ j += 1
+ }
+ i += 1
+ }
+
+ val pi = Vectors.dense(piArray)
+ $(modelType) match {
+ case Multinomial | Bernoulli =>
+ val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
+ .setOldLabels(labelArray)
+ case Complement =>
+ // Since the CNB compute the coefficient in a complement way.
+ val theta = new DenseMatrix(numLabels, numFeatures, thetaArray.map(v => -v), true)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
+ }
+ }
+
+ private def trainGaussianImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+
+ // Aggregates mean vector and square-sum vector per label.
+ // TODO: Summarizer directly returns square-sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "normL2")
+ .summary(col($(featuresCol)), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.normL2")
+ .as[(Double, Double, Vector, Vector)]
+ .map { case (label, weightSum, mean, normL2) =>
+ (label, weightSum, mean, Vectors.dense(normL2.toArray.map(v => v * v)))
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+
+ val numLabels = aggregated.length
+ instr.logNumClasses(numLabels)
+
+ val numInstances = aggregated.map(_._2).sum
+
+ // If the ratio of data variance between dimensions is too small, it
+ // will cause numerical errors. To address this, we artificially
+ // boost the variance by epsilon, a small fraction of the standard
+ // deviation of the largest dimension.
+ // Refer to scikit-learn's implementation
+ // [https://github.com/scikit-learn/scikit-learn/blob/0.21.X/sklearn/naive_bayes.py#L348]
+ // and discussion [https://github.com/scikit-learn/scikit-learn/pull/5349] for detail.
+ val epsilon = Iterator.range(0, numFeatures).map { j =>
+ var globalSum = 0.0
+ var globalSqrSum = 0.0
+ aggregated.foreach { case (_, weightSum, mean, squareSum) =>
+ globalSum += mean(j) * weightSum
+ globalSqrSum += squareSum(j)
+ }
+ globalSqrSum / numInstances -
+ globalSum * globalSum / numInstances / numInstances
+ }.max * 1e-9
+
+ val piArray = new Array[Double](numLabels)
+
+ // thetaArray in Gaussian NB store the means of features per label
+ val thetaArray = new Array[Double](numLabels * numFeatures)
+
+ // thetaArray in Gaussian NB store the variances of features per label
+ val sigmaArray = new Array[Double](numLabels * numFeatures)
+
+ var i = 0
+ val logNumInstances = math.log(numInstances)
+ aggregated.foreach { case (_, weightSum, mean, squareSum) =>
+ piArray(i) = math.log(weightSum) - logNumInstances
+ var j = 0
+ val offset = i * numFeatures
+ while (j < numFeatures) {
+ val m = mean(j)
+ thetaArray(offset + j) = m
+ sigmaArray(offset + j) = epsilon + squareSum(j) / weightSum - m * m
j += 1
}
i += 1
@@ -204,7 +330,8 @@ class NaiveBayes @Since("1.5.0") (
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
- new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
+ val sigma = new DenseMatrix(numLabels, numFeatures, sigmaArray, true)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed)
}
@Since("1.5.0")
@@ -219,10 +346,17 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/** String name for Bernoulli model type. */
private[classification] val Bernoulli: String = "bernoulli"
+ /** String name for Gaussian model type. */
+ private[classification] val Gaussian: String = "gaussian"
+
+ /** String name for Complement model type. */
+ private[classification] val Complement: String = "complement"
+
/* Set of modelTypes that NaiveBayes supports */
- private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[classification] val supportedModelTypes =
+ Set(Multinomial, Bernoulli, Gaussian, Complement)
- private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
+ private[ml] def requireNonnegativeValues(v: Vector): Unit = {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
@@ -232,7 +366,7 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
s"Naive Bayes requires nonnegative feature values but found $v.")
}
- private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
+ private[ml] def requireZeroOneBernoulliValues(v: Vector): Unit = {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
@@ -248,19 +382,24 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/**
* Model produced by [[NaiveBayes]]
- * @param pi log of class priors, whose dimension is C (number of classes)
+ *
+ * @param pi log of class priors, whose dimension is C (number of classes)
* @param theta log of class conditional probabilities, whose dimension is C (number of classes)
* by D (number of features)
+ * @param sigma variance of each feature, whose dimension is C (number of classes)
+ * by D (number of features). This matrix is only available when modelType
+ * is set Gaussian.
*/
@Since("1.5.0")
class NaiveBayesModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("2.0.0") val pi: Vector,
- @Since("2.0.0") val theta: Matrix)
+ @Since("2.0.0") val theta: Matrix,
+ @Since("3.0.0") val sigma: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {
- import NaiveBayes.{Bernoulli, Multinomial}
+ import NaiveBayes._
/**
* mllib NaiveBayes is a wrapper of ml implementation currently.
@@ -280,18 +419,36 @@ class NaiveBayesModel private[ml] (
* This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
* application of this condition (in predict function).
*/
- private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
- case Multinomial => (None, None)
+ @transient private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
case Bernoulli =>
val negTheta = theta.map(value => math.log1p(-math.exp(value)))
val ones = new DenseVector(Array.fill(theta.numCols) {1.0})
val thetaMinusNegTheta = theta.map { value =>
value - math.log1p(-math.exp(value))
}
- (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
+ (thetaMinusNegTheta, negTheta.multiply(ones))
case _ =>
// This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " +
+ "Variables thetaMinusNegTheta and negThetaSum should only be precomputed in Bernoulli NB.")
+ }
+
+ /**
+ * Gaussian scoring requires sum of log(Variance).
+ * This precomputes sum of log(Variance) which are used for the linear algebra
+ * application of this condition (in predict function).
+ */
+ @transient private lazy val logVarSum = $(modelType) match {
+ case Gaussian =>
+ Array.tabulate(numClasses) { i =>
+ Iterator.range(0, numFeatures).map { j =>
+ math.log(sigma(i, j))
+ }.sum
+ }
+ case _ =>
+ // This should never happen.
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " +
+ "Variables logVarSum should only be precomputed in Gaussian NB.")
}
@Since("1.6.0")
@@ -301,34 +458,77 @@ class NaiveBayesModel private[ml] (
override val numClasses: Int = pi.size
private def multinomialCalculation(features: Vector) = {
+ requireNonnegativeValues(features)
val prob = theta.multiply(features)
BLAS.axpy(1.0, pi, prob)
prob
}
+ private def complementCalculation(features: Vector) = {
+ requireNonnegativeValues(features)
+ val probArray = theta.multiply(features).toArray
+ // the following lines equal to:
+ // val logSumExp = math.log(probArray.map(math.exp).sum)
+ // However, it easily returns Infinity/NaN values.
+ // Here follows 'scipy.special.logsumexp' (which is used in Scikit-Learn's ComplementNB)
+ // to compute the log of the sum of exponentials of elements in a numeric-stable way.
+ val max = probArray.max
+ var sumExp = 0.0
+ var j = 0
+ while (j < probArray.length) {
+ sumExp += math.exp(probArray(j) - max)
+ j += 1
+ }
+ val logSumExp = math.log(sumExp) + max
+
+ j = 0
+ while (j < probArray.length) {
+ probArray(j) = probArray(j) - logSumExp
+ j += 1
+ }
+ Vectors.dense(probArray)
+ }
+
private def bernoulliCalculation(features: Vector) = {
- features.foreachActive((_, value) =>
- require(value == 0.0 || value == 1.0,
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
- )
- val prob = thetaMinusNegTheta.get.multiply(features)
+ requireZeroOneBernoulliValues(features)
+ val prob = thetaMinusNegTheta.multiply(features)
BLAS.axpy(1.0, pi, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
+ BLAS.axpy(1.0, negThetaSum, prob)
prob
}
- override protected def predictRaw(features: Vector): Vector = {
+ private def gaussianCalculation(features: Vector) = {
+ val prob = Array.ofDim[Double](numClasses)
+ var i = 0
+ while (i < numClasses) {
+ var s = 0.0
+ var j = 0
+ while (j < numFeatures) {
+ val d = features(j) - theta(i, j)
+ s += d * d / sigma(i, j)
+ j += 1
+ }
+ prob(i) = pi(i) - (s + logVarSum(i)) / 2
+ i += 1
+ }
+ Vectors.dense(prob)
+ }
+
+ @transient private lazy val predictRawFunc = {
$(modelType) match {
case Multinomial =>
- multinomialCalculation(features)
+ features: Vector => multinomialCalculation(features)
+ case Complement =>
+ features: Vector => complementCalculation(features)
case Bernoulli =>
- bernoulliCalculation(features)
- case _ =>
- // This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
+ features: Vector => bernoulliCalculation(features)
+ case Gaussian =>
+ features: Vector => gaussianCalculation(features)
}
}
+ override protected def predictRaw(features: Vector): Vector = predictRawFunc(features)
+
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
@@ -354,7 +554,7 @@ class NaiveBayesModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayesModel = {
- copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
+ copyValues(new NaiveBayesModel(uid, pi, theta, sigma).setParent(this.parent), extra)
}
@Since("1.5.0")
@@ -378,34 +578,61 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
/** [[MLWriter]] instance for [[NaiveBayesModel]] */
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
+ import NaiveBayes._
private case class Data(pi: Vector, theta: Matrix)
+ private case class GaussianData(pi: Vector, theta: Matrix, sigma: Matrix)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: pi, theta
- val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+
+ instance.getModelType match {
+ case Multinomial | Bernoulli | Complement =>
+ // Save model data: pi, theta
+ require(instance.sigma == null)
+ val data = Data(instance.pi, instance.theta)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+
+ case Gaussian =>
+ require(instance.sigma != null)
+ val data = GaussianData(instance.pi, instance.theta, instance.sigma)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
}
}
private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
+ import NaiveBayes._
/** Checked against metadata when loading model */
private val className = classOf[NaiveBayesModel].getName
override def load(path: String): NaiveBayesModel = {
+ implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ val modelTypeJson = metadata.getParamValue("modelType")
+ val modelType = Param.jsonDecode[String](compact(render(modelTypeJson)))
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
- val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
- .select("pi", "theta")
- .head()
- val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+ val model = if (major.toInt < 3 || modelType != Gaussian) {
+ val Row(pi: Vector, theta: Matrix) =
+ MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+ .select("pi", "theta")
+ .head()
+ new NaiveBayesModel(metadata.uid, pi, theta, null)
+ } else {
+ val Row(pi: Vector, theta: Matrix, sigma: Matrix) =
+ MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma")
+ .select("pi", "theta", "sigma")
+ .head()
+ new NaiveBayesModel(metadata.uid, pi, theta, sigma)
+ }
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index ab14227f06be1..435708186242f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -46,7 +46,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
* `"weightedPrecision"`, `"weightedRecall"`, `"weightedTruePositiveRate"`,
* `"weightedFalsePositiveRate"`, `"weightedFMeasure"`, `"truePositiveRateByLabel"`,
* `"falsePositiveRateByLabel"`, `"precisionByLabel"`, `"recallByLabel"`,
- * `"fMeasureByLabel"`, `"logLoss"`)
+ * `"fMeasureByLabel"`, `"logLoss"`, `"hammingLoss"`)
*
* @group param
*/
@@ -172,13 +172,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
case "precisionByLabel" => metrics.precision($(metricLabel))
case "recallByLabel" => metrics.recall($(metricLabel))
case "fMeasureByLabel" => metrics.fMeasure($(metricLabel), $(beta))
+ case "hammingLoss" => metrics.hammingLoss
case "logLoss" => metrics.logLoss($(eps))
}
}
@Since("1.5.0")
override def isLargerBetter: Boolean = $(metricName) match {
- case "weightedFalsePositiveRate" | "falsePositiveRateByLabel" | "logLoss" => false
+ case "weightedFalsePositiveRate" | "falsePositiveRateByLabel" | "logLoss" | "hammingLoss" =>
+ false
case _ => true
}
@@ -199,7 +201,7 @@ object MulticlassClassificationEvaluator
private val supportedMetricNames = Array("f1", "accuracy", "weightedPrecision", "weightedRecall",
"weightedTruePositiveRate", "weightedFalsePositiveRate", "weightedFMeasure",
"truePositiveRateByLabel", "falsePositiveRateByLabel", "precisionByLabel", "recallByLabel",
- "fMeasureByLabel", "logLoss")
+ "fMeasureByLabel", "logLoss", "hammingLoss")
@Since("1.6.0")
override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index b20852383a6ff..4885d03220e95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -112,7 +112,9 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
numNearestNeighbors: Int,
singleProbe: Boolean,
distCol: String): Dataset[_] = {
- require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
+ val count = dataset.count()
+ require(numNearestNeighbors > 0 && numNearestNeighbors <= count, "The number of" +
+ " nearest neighbors cannot be less than 1 or greater than the number of elements in dataset")
// Get Hash Value of the key
val keyHash = hashFunction(key)
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
@@ -137,14 +139,21 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
val hashDistCol = hashDistUDF(col($(outputCol)))
- // Compute threshold to get exact k elements.
- // TODO: SPARK-18409: Use approxQuantile to get the threshold
- val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
- val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
- val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
-
- // Filter the dataset where the hash value is less than the threshold.
- modelDataset.filter(hashDistCol <= hashThreshold)
+ // Compute threshold to get around k elements.
+ // To guarantee to have enough neighbors in one pass, we need (p - err) * N >= M
+ // so we pick quantile p = M / N + err
+ // M: the number of nearest neighbors; N: the number of elements in dataset
+ val relativeError = 0.05
+ val approxQuantile = numNearestNeighbors.toDouble / count + relativeError
+ val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)
+ if (approxQuantile >= 1) {
+ modelDatasetWithDist
+ } else {
+ val hashThreshold = modelDatasetWithDist.stat
+ .approxQuantile(distCol, Array(approxQuantile), relativeError)
+ // Filter the dataset where the hash value is less than the threshold.
+ modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0))
+ }
}
// Get the top k nearest neighbor by their distance to the key
@@ -169,11 +178,11 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
* to show the distance between each row and the key.
*/
def approxNearestNeighbors(
- dataset: Dataset[_],
- key: Vector,
- numNearestNeighbors: Int,
- distCol: String): Dataset[_] = {
- approxNearestNeighbors(dataset, key, numNearestNeighbors, true, distCol)
+ dataset: Dataset[_],
+ key: Vector,
+ numNearestNeighbors: Int,
+ distCol: String): Dataset[_] = {
+ approxNearestNeighbors(dataset, key, numNearestNeighbors, true, distCol)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 9518f7e6828cf..050ebb0fa4fbd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -240,6 +240,23 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product])
@Since("1.1.0")
lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted
+ /**
+ * Returns Hamming-loss
+ */
+ @Since("3.0.0")
+ lazy val hammingLoss: Double = {
+ var numerator = 0.0
+ var denominator = 0.0
+ confusions.iterator.foreach {
+ case ((label, prediction), weight) =>
+ if (label != prediction) {
+ numerator += weight
+ }
+ denominator += weight
+ }
+ numerator / denominator
+ }
+
/**
* Returns the log-loss, aka logistic loss or cross-entropy loss.
* @param eps log-loss is undefined for p=0 or p=1, so probabilities are
diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java
index 830f668fe07b8..9037f6b854724 100644
--- a/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaKolmogorovSmirnovTestSuite.java
@@ -23,6 +23,7 @@
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.spark.sql.Encoders;
+import org.junit.Assert;
import org.junit.Test;
import org.apache.spark.SharedSparkSession;
@@ -60,7 +61,7 @@ public void testKSTestCDF() {
.test(dataset, "sample", stdNormalCDF).head();
double pValue1 = results.getDouble(0);
// Cannot reject null hypothesis
- assert(pValue1 > pThreshold);
+ Assert.assertTrue(pValue1 > pThreshold);
}
@Test
@@ -72,6 +73,6 @@ public void testKSTestNamedDistribution() {
.test(dataset, "sample", "norm", 0.0, 1.0).head();
double pValue1 = results.getDouble(0);
// Cannot reject null hypothesis
- assert(pValue1 > pThreshold);
+ Assert.assertTrue(pValue1 > pThreshold);
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index cb9b8f9b6b472..dc38f17d296f2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import scala.util.Random
import breeze.linalg.{DenseVector => BDV}
+import org.scalatest.Assertions._
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 07116606dfb52..60c9cce6a4879 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -21,6 +21,8 @@ import scala.collection.JavaConverters._
import scala.util.Random
import scala.util.control.Breaks._
+import org.scalatest.Assertions._
+
import org.apache.spark.SparkException
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 9100ef1db6e12..4a555ad3ed071 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -22,15 +22,15 @@ import scala.util.Random
import breeze.linalg.{DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis}
-import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
+import org.apache.spark.SparkException
+import org.apache.spark.ml.classification.NaiveBayes._
import org.apache.spark.ml.classification.NaiveBayesSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Dataset, Row}
class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
@@ -38,6 +38,9 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@transient var bernoulliDataset: Dataset[_] = _
+ @transient var gaussianDataset: Dataset[_] = _
+ @transient var gaussianDataset2: Dataset[_] = _
+ @transient var complementDataset: Dataset[_] = _
private val seed = 42
@@ -53,6 +56,27 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
dataset = generateNaiveBayesInput(pi, theta, 100, seed).toDF()
bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF()
+
+ // theta for gaussian nb
+ val theta2 = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0: mean
+ Array(0.10, 0.70, 0.10, 0.10), // label 1: mean
+ Array(0.10, 0.10, 0.70, 0.10) // label 2: mean
+ )
+
+ // sigma for gaussian nb
+ val sigma = Array(
+ Array(0.10, 0.10, 0.50, 0.10), // label 0: variance
+ Array(0.50, 0.10, 0.10, 0.10), // label 1: variance
+ Array(0.10, 0.10, 0.10, 0.50) // label 2: variance
+ )
+ gaussianDataset = generateGaussianNaiveBayesInput(pi, theta2, sigma, 1000, seed).toDF()
+
+ gaussianDataset2 = spark.read.format("libsvm")
+ .load("../data/mllib/sample_multiclass_classification_data.txt")
+
+ complementDataset = spark.read.format("libsvm")
+ .load("../data/mllib/sample_libsvm_data.txt")
}
def validatePrediction(predictionAndLabels: Seq[Row]): Unit = {
@@ -67,10 +91,17 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
def validateModelFit(
piData: Vector,
thetaData: Matrix,
+ sigmaData: Matrix,
model: NaiveBayesModel): Unit = {
assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
+ if (sigmaData == null) {
+ assert(model.sigma == null, "sigma mismatch")
+ } else {
+ assert(model.sigma.map(math.exp) ~== sigmaData.map(math.exp) absTol 0.05,
+ "sigma mismatch")
+ }
}
def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
@@ -90,6 +121,19 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
Vectors.dense(classProbs.map(_ / classProbsSum))
}
+ def expectedGaussianProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val pi = model.pi.toArray.map(math.exp)
+ val classProbs = pi.indices.map { i =>
+ feature.toArray.zipWithIndex.map { case (v, j) =>
+ val mean = model.theta(i, j)
+ val variance = model.sigma(i, j)
+ math.exp(- (v - mean) * (v - mean) / variance / 2) / math.sqrt(variance * math.Pi * 2)
+ }.product * pi(i)
+ }.toArray
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
def validateProbabilities(
featureAndProbabilities: Seq[Row],
model: NaiveBayesModel,
@@ -102,6 +146,8 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
expectedMultinomialProbabilities(model, features)
case Bernoulli =>
expectedBernoulliProbabilities(model, features)
+ case Gaussian =>
+ expectedGaussianProbabilities(model, features)
case _ =>
throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
@@ -112,12 +158,15 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
test("model types") {
assert(Multinomial === "multinomial")
assert(Bernoulli === "bernoulli")
+ assert(Gaussian === "gaussian")
+ assert(Complement === "complement")
}
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
- theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
+ theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)),
+ sigma = null)
ParamsSuite.checkParams(model)
}
@@ -146,7 +195,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
- validateModelFit(pi, theta, model)
+ validateModelFit(pi, theta, null, model)
assert(model.hasParent)
MLTestingUtils.checkCopyAndUids(nb, model)
@@ -175,8 +224,6 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
Array(0.10, 0.70, 0.10, 0.10), // label 1
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
- val pi = Vectors.dense(piArray)
- val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
val trainDataset =
generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF()
@@ -192,12 +239,18 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
test("Naive Bayes with weighted samples") {
val numClasses = 3
def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = {
+ assert(m1.getModelType === m2.getModelType)
assert(m1.pi ~== m2.pi relTol 0.01)
assert(m1.theta ~== m2.theta relTol 0.01)
+ if (m1.getModelType == Gaussian) {
+ assert(m1.sigma ~== m2.sigma relTol 0.01)
+ }
}
val testParams = Seq[(String, Dataset[_])](
("bernoulli", bernoulliDataset),
- ("multinomial", dataset)
+ ("multinomial", dataset),
+ ("complement", dataset),
+ ("gaussian", gaussianDataset)
)
testParams.foreach { case (family, dataset) =>
// NaiveBayes is sensitive to constant scaling of the weights unless smoothing is set to 0
@@ -228,7 +281,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
- validateModelFit(pi, theta, model)
+ validateModelFit(pi, theta, null, model)
assert(model.hasParent)
val validationDataset =
@@ -308,14 +361,168 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("Naive Bayes Gaussian") {
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0: mean
+ Array(0.10, 0.70, 0.10, 0.10), // label 1: mean
+ Array(0.10, 0.10, 0.70, 0.10) // label 2: mean
+ )
+
+ val sigmaArray = Array(
+ Array(0.10, 0.10, 0.50, 0.10), // label 0: variance
+ Array(0.50, 0.10, 0.10, 0.10), // label 1: variance
+ Array(0.10, 0.10, 0.10, 0.50) // label 2: variance
+ )
+
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+ val sigma = new DenseMatrix(3, 4, sigmaArray.flatten, true)
+
+ val nPoints = 10000
+ val testDataset =
+ generateGaussianNaiveBayesInput(piArray, thetaArray, sigmaArray, nPoints, 42).toDF()
+ val gnb = new NaiveBayes().setModelType("gaussian")
+ val model = gnb.fit(testDataset)
+
+ validateModelFit(pi, theta, sigma, model)
+ assert(model.hasParent)
+
+ val validationDataset =
+ generateGaussianNaiveBayesInput(piArray, thetaArray, sigmaArray, nPoints, 17).toDF()
+
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ validatePrediction(predictionAndLabels.collect())
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities.collect(), model, "gaussian")
+ }
+
+ test("Naive Bayes Gaussian - Model Coefficients") {
+ /*
+ Using the following Python code to verify the correctness.
+
+ import numpy as np
+ from sklearn.naive_bayes import GaussianNB
+ from sklearn.datasets import load_svmlight_file
+
+ path = "./data/mllib/sample_multiclass_classification_data.txt"
+ X, y = load_svmlight_file(path)
+ X = X.toarray()
+ clf = GaussianNB()
+ clf.fit(X, y)
+
+ >>> clf.class_prior_
+ array([0.33333333, 0.33333333, 0.33333333])
+ >>> clf.theta_
+ array([[ 0.27111101, -0.18833335, 0.54305072, 0.60500005],
+ [-0.60777778, 0.18166667, -0.84271174, -0.88000014],
+ [-0.09111114, -0.35833336, 0.10508474, 0.0216667 ]])
+ >>> clf.sigma_
+ array([[0.12230125, 0.07078052, 0.03430001, 0.05133607],
+ [0.03758145, 0.0988028 , 0.0033903 , 0.00782224],
+ [0.08058764, 0.06701387, 0.02486641, 0.02661392]])
+ */
+
+ val gnb = new NaiveBayes().setModelType(Gaussian)
+ val model = gnb.fit(gaussianDataset2)
+ assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~=
+ Vectors.dense(0.33333333, 0.33333333, 0.33333333) relTol 1E-5)
+
+ val thetaRows = model.theta.rowIter.toArray
+ assert(thetaRows(0) ~=
+ Vectors.dense(0.27111101, -0.18833335, 0.54305072, 0.60500005) relTol 1E-5)
+ assert(thetaRows(1) ~=
+ Vectors.dense(-0.60777778, 0.18166667, -0.84271174, -0.88000014) relTol 1E-5)
+ assert(thetaRows(2) ~=
+ Vectors.dense(-0.09111114, -0.35833336, 0.10508474, 0.0216667) relTol 1E-5)
+
+ val sigmaRows = model.sigma.rowIter.toArray
+ assert(sigmaRows(0) ~=
+ Vectors.dense(0.12230125, 0.07078052, 0.03430001, 0.05133607) relTol 1E-5)
+ assert(sigmaRows(1) ~=
+ Vectors.dense(0.03758145, 0.0988028, 0.0033903, 0.00782224) relTol 1E-5)
+ assert(sigmaRows(2) ~=
+ Vectors.dense(0.08058764, 0.06701387, 0.02486641, 0.02661392) relTol 1E-5)
+ }
+
+ test("Naive Bayes Complement") {
+ /*
+ Using the following Python code to verify the correctness.
+
+ import numpy as np
+ from sklearn.naive_bayes import ComplementNB
+ from sklearn.datasets import load_svmlight_file
+
+ path = "./data/mllib/sample_libsvm_data.txt"
+ X, y = load_svmlight_file(path)
+ X = X.toarray()
+ clf = ComplementNB()
+ clf.fit(X, y)
+
+ >>> clf.feature_log_prob_[:, -5:]
+ array([[ 7.2937608 , 10.26577655, 13.73151245, 13.73151245, 13.73151245],
+ [ 6.99678043, 7.51387415, 7.74399483, 8.32904552, 9.53119848]])
+ >>> clf.predict_log_proba(X[:5])
+ array([[ 0. , -74732.70765355],
+ [-36018.30169185, 0. ],
+ [-37126.4015229 , 0. ],
+ [-27649.81038619, 0. ],
+ [-28767.84075587, 0. ]])
+ >>> clf.predict_proba(X[:5])
+ array([[1., 0.],
+ [0., 1.],
+ [0., 1.],
+ [0., 1.],
+ [0., 1.]])
+ */
+
+ val cnb = new NaiveBayes().setModelType(Complement)
+ val model = cnb.fit(complementDataset)
+
+ val thetaRows = model.theta.rowIter.map(vec => Vectors.dense(vec.toArray.takeRight(5))).toArray
+ assert(thetaRows(0) ~=
+ Vectors.dense(7.2937608, 10.26577655, 13.73151245, 13.73151245, 13.73151245) relTol 1E-5)
+ assert(thetaRows(1) ~=
+ Vectors.dense(6.99678043, 7.51387415, 7.74399483, 8.32904552, 9.53119848) relTol 1E-5)
+
+ val preds = model.transform(complementDataset)
+ .select("rawPrediction", "probability")
+ .as[(Vector, Vector)]
+ .take(5)
+ assert(preds(0)._1 ~= Vectors.dense(0.0, -74732.70765355) relTol 1E-5)
+ assert(preds(0)._2 ~= Vectors.dense(1.0, 0.0) relTol 1E-5)
+ assert(preds(1)._1 ~= Vectors.dense(-36018.30169185, 0.0) relTol 1E-5)
+ assert(preds(1)._2 ~= Vectors.dense(0.0, 1.0) relTol 1E-5)
+ assert(preds(2)._1 ~= Vectors.dense(-37126.4015229, 0.0) relTol 1E-5)
+ assert(preds(2)._2 ~= Vectors.dense(0.0, 1.0) relTol 1E-5)
+ assert(preds(3)._1 ~= Vectors.dense(-27649.81038619, 0.0) relTol 1E-5)
+ assert(preds(3)._2 ~= Vectors.dense(0.0, 1.0) relTol 1E-5)
+ assert(preds(4)._1 ~= Vectors.dense(-28767.84075587, 0.0) relTol 1E-5)
+ assert(preds(4)._2 ~= Vectors.dense(0.0, 1.0) relTol 1E-5)
+ }
+
test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
+ assert(model.getModelType === model2.getModelType)
assert(model.pi === model2.pi)
assert(model.theta === model2.theta)
+ if (model.getModelType == "gaussian") {
+ assert(model.sigma === model2.sigma)
+ } else {
+ assert(model.sigma === null && model2.sigma === null)
+ }
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
NaiveBayesSuite.allParamSettings, checkModelData)
+
+ val gnb = new NaiveBayes().setModelType("gaussian")
+ testEstimatorAndModelReadWrite(gnb, gaussianDataset,
+ NaiveBayesSuite.allParamSettingsForGaussian,
+ NaiveBayesSuite.allParamSettingsForGaussian, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
@@ -324,6 +531,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
nb, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
+ assert(expected.sigma === null && actual.sigma === null)
}
}
}
@@ -340,6 +548,16 @@ object NaiveBayesSuite {
"smoothing" -> 0.1
)
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettingsForGaussian: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "modelType" -> "gaussian"
+ )
+
private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0
for (j <- 0 until pi.length) {
@@ -384,4 +602,26 @@ object NaiveBayesSuite {
LabeledPoint(y, Vectors.dense(xi))
}
}
+
+ // Generate input
+ def generateGaussianNaiveBayesInput(
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
+ sigma: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val D = theta(0).length
+ val rnd = new Random(seed)
+ val _pi = pi.map(math.exp)
+
+ for (i <- 0 until nPoints) yield {
+ val y = calcLabel(rnd.nextDouble(), _pi)
+ val xi = Array.tabulate[Double] (D) { j =>
+ val mean = theta(y)(j)
+ val variance = sigma(y)(j)
+ mean + rnd.nextGaussian() * math.sqrt(variance)
+ }
+ LabeledPoint(y, Vectors.dense(xi))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index b6e8c927403ad..adffd83ab1bd1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import org.scalatest.Assertions._
+
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 1c8c9829f18d1..87a8b345a65a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import org.scalatest.Assertions._
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index db4f56ed60d32..76a4acd798e34 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import org.scalatest.Assertions._
+
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils}
import org.apache.spark.sql.Dataset
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index d4e9da3c6263e..d96a4da46a630 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import scala.util.Random
import scala.util.control.Breaks._
+import org.scalatest.Assertions._
import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 8906e52faebe5..321df05e272db 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import java.util.{ArrayList => JArrayList}
import breeze.linalg.{argmax, argtopk, max, DenseMatrix => BDM}
+import org.scalatest.Assertions
import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index e10295c905cdb..a8c6339ba6824 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -254,4 +254,35 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val metrics2 = new MulticlassMetrics(rdd2)
assert(metrics2.logLoss() ~== 0.9682005730687164 relTol delta)
}
+
+ test("MulticlassMetrics supports hammingLoss") {
+ /*
+ Using the following Python code to verify the correctness.
+
+ from sklearn.metrics import hamming_loss
+ y_true = [2, 2, 3, 4]
+ y_pred = [1, 2, 3, 4]
+ weights = [1.5, 2.0, 1.0, 0.5]
+
+ >>> hamming_loss(y_true, y_pred)
+ 0.25
+ >>> hamming_loss(y_true, y_pred, sample_weight=weights)
+ 0.3
+ */
+
+ val preds = Seq(1.0, 2.0, 3.0, 4.0)
+ val labels = Seq(2.0, 2.0, 3.0, 4.0)
+ val weights = Seq(1.5, 2.0, 1.0, 0.5)
+
+ val rdd = sc.parallelize(preds.zip(labels))
+ val metrics = new MulticlassMetrics(rdd)
+ assert(metrics.hammingLoss ~== 0.25 relTol delta)
+
+ val rdd2 = sc.parallelize(preds.zip(labels).zip(weights))
+ .map { case ((pred, label), weight) =>
+ (pred, label, weight)
+ }
+ val metrics2 = new MulticlassMetrics(rdd2)
+ assert(metrics2.hammingLoss ~== 0.3 relTol delta)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index e04d7b7c327a8..5458a43b4f2c6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
import scala.collection.mutable
+import org.scalatest.Assertions._
+
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
diff --git a/pom.xml b/pom.xml
index 9ec7833427765..b0193a8ae5030 100644
--- a/pom.xml
+++ b/pom.xml
@@ -128,19 +128,19 @@
3.4.14
2.7.1
0.4.2
- org.spark-project.hive
-
+ org.apache.hive
+ core
- 1.2.1.spark2
+ 2.3.6
2.3.6
- 1.2.1
+ 2.3
2.3.1
10.12.1.1
1.10.1
1.5.7
- nohive
+
com.twitter
1.6.0
9.4.18.v20190429
@@ -181,7 +181,7 @@
3.8.1
2.6.2
- 3.2.10
+ 4.1.17
3.0.15
2.29
2.10.5
@@ -228,7 +228,7 @@
-->
compile
compile
- ${hive.deps.scope}
+ provided
compile
compile
test
@@ -2326,7 +2326,7 @@
**/*Suite.java
${project.build.directory}/surefire-reports
- -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} -Dio.netty.tryReflectionSetAccessible=true
- provided
-
- 4.1.17
+ 3.2.0
+ 2.13.0
+
+
+
+ hive-1.2
+
+ org.spark-project.hive
+
+
+ 1.2.1.spark2
+
+ 1.2
+ ${hive.deps.scope}
+ nohive
+ 3.2.10
+
+
+
+
+ hive-2.3
+
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8830061c0d9ed..617eb173f4f49 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -118,6 +118,9 @@ object MimaExcludes {
// [SPARK-26632][Core] Separate Thread Configurations of Driver and Executor
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"),
+ // [SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"),
+
// [SPARK-25765][ML] Add training cost to BisectingKMeans summary
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.this"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 91d3a75849b0c..8dda5809fa374 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -978,6 +978,7 @@ object TestSettings {
javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=false",
javaOptions in Test += "-Dderby.system.durability=test",
+ javaOptions in Test += "-Dio.netty.tryReflectionSetAccessible=true",
javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark"))
.map { case (k,v) => s"-D$k=$v" }.toSeq,
javaOptions in Test += "-ea",
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a6aa3a65568e9..6cc343e3e495c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -40,6 +40,7 @@
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
+from pyspark.util import _warn_pin_thread
if sys.version > '3':
xrange = range
@@ -1008,30 +1009,20 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
ensure that the tasks are actually stopped in a timely manner, but is off by default due
to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead.
- .. note:: Currently, setting a group ID (set to local properties) with a thread does
- not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
- can be reused for multiple threads on PVM, which fails to isolate local properties
- for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
+ .. note:: Currently, setting a group ID (set to local properties) with multiple threads
+ does not properly work. Internally threads on PVM and JVM are not synced, and JVM
+ thread can be reused for multiple threads on PVM, which fails to isolate local
+ properties for each thread on PVM.
+
+ To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
- local properties. To work around this, you should manually copy and set the local
+ local properties.
+
+ To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
- warnings.warn(
- "Currently, setting a group ID (set to local properties) with a thread does "
- "not properly work. "
- "\n"
- "Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
- "for multiple threads on PVM, which fails to isolate local properties for each "
- "thread on PVM. "
- "\n"
- "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
- "However, note that it cannot inherit the local properties from the parent thread "
- "although it isolates each thread on PVM and JVM with its own local properties. "
- "\n"
- "To work around this, you should manually copy and set the local properties from "
- "the parent thread to the child thread when you create another thread.",
- UserWarning)
+ _warn_pin_thread("setJobGroup")
self._jsc.setJobGroup(groupId, description, interruptOnCancel)
def setLocalProperty(self, key, value):
@@ -1039,29 +1030,20 @@ def setLocalProperty(self, key, value):
Set a local property that affects jobs submitted from this thread, such as the
Spark fair scheduler pool.
- .. note:: Currently, setting a local property with a thread does
- not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
+ .. note:: Currently, setting a local property with multiple threads does not properly work.
+ Internally threads on PVM and JVM are not synced, and JVM thread
can be reused for multiple threads on PVM, which fails to isolate local properties
- for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
+ for each thread on PVM.
+
+ To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
- local properties. To work around this, you should manually copy and set the local
+ local properties.
+
+ To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
- warnings.warn(
- "Currently, setting a local property with a thread does not properly work. "
- "\n"
- "Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
- "for multiple threads on PVM, which fails to isolate local properties for each "
- "thread on PVM. "
- "\n"
- "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
- "However, note that it cannot inherit the local properties from the parent thread "
- "although it isolates each thread on PVM and JVM with its own local properties. "
- "\n"
- "To work around this, you should manually copy and set the local properties from "
- "the parent thread to the child thread when you create another thread.",
- UserWarning)
+ _warn_pin_thread("setLocalProperty")
self._jsc.setLocalProperty(key, value)
def getLocalProperty(self, key):
@@ -1075,30 +1057,20 @@ def setJobDescription(self, value):
"""
Set a human readable description of the current job.
- .. note:: Currently, setting a job description (set to local properties) with a thread does
- not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
- can be reused for multiple threads on PVM, which fails to isolate local properties
- for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
+ .. note:: Currently, setting a job description (set to local properties) with multiple
+ threads does not properly work. Internally threads on PVM and JVM are not synced,
+ and JVM thread can be reused for multiple threads on PVM, which fails to isolate
+ local properties for each thread on PVM.
+
+ To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
- local properties. To work around this, you should manually copy and set the local
+ local properties.
+
+ To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
- warnings.warn(
- "Currently, setting a job description (set to local properties) with a thread does "
- "not properly work. "
- "\n"
- "Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
- "for multiple threads on PVM, which fails to isolate local properties for each "
- "thread on PVM. "
- "\n"
- "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
- "However, note that it cannot inherit the local properties from the parent thread "
- "although it isolates each thread on PVM and JVM with its own local properties. "
- "\n"
- "To work around this, you should manually copy and set the local properties from "
- "the parent thread to the child thread when you create another thread.",
- UserWarning)
+ _warn_pin_thread("setJobDescription")
self._jsc.setJobDescription(value)
def sparkUser(self):
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index f9465bffc9a1a..d6fe26dc69da8 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -192,11 +192,11 @@ class LinearSVC(JavaClassifier, _LinearSVCParams, JavaMLWritable, JavaMLReadable
0.01
>>> model = svm.fit(df)
>>> model.setPredictionCol("newPrediction")
- LinearSVC...
+ LinearSVCModel...
>>> model.getPredictionCol()
'newPrediction'
>>> model.setThreshold(0.5)
- LinearSVC...
+ LinearSVCModel...
>>> model.getThreshold()
0.5
>>> model.coefficients
@@ -812,9 +812,6 @@ def evaluate(self, dataset):
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)
- def __repr__(self):
- return self._call_java("toString")
-
class LogisticRegressionSummary(JavaWrapper):
"""
@@ -1881,7 +1878,8 @@ class _NaiveBayesParams(_JavaPredictorParams, HasWeightCol):
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
"default is 1.0", typeConverter=TypeConverters.toFloat)
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
- "(case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ "(case-sensitive). Supported options: multinomial (default), bernoulli " +
+ "and gaussian.",
typeConverter=TypeConverters.toString)
@since("1.5.0")
@@ -1910,7 +1908,15 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
TF-IDF vectors, it can be used for document classification. By making every vector a
binary (0/1) data, it can also be used as `Bernoulli NB
`_.
- The input feature values must be nonnegative.
+ The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
+ Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.
+ Specifically, Complement NB uses statistics from the complement of each class to compute
+ the model's coefficients. The inventors of Complement NB show empirically that the parameter
+ estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the
+ input feature values for Complement NB must be nonnegative.
+ Since 3.0.0, it also supports Gaussian NB
+ `_.
+ which can handle continuous data.
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
@@ -1921,13 +1927,15 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
>>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
>>> model = nb.fit(df)
>>> model.setFeaturesCol("features")
- NaiveBayes_...
+ NaiveBayesModel...
>>> model.getSmoothing()
1.0
>>> model.pi
DenseVector([-0.81..., -0.58...])
>>> model.theta
DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
+ >>> model.sigma == None
+ True
>>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
>>> model.predict(test0.head().features)
1.0
@@ -1958,6 +1966,20 @@ class NaiveBayes(JavaProbabilisticClassifier, _NaiveBayesParams, HasThresholds,
>>> result = model3.transform(test0).head()
>>> result.prediction
0.0
+ >>> nb3 = NaiveBayes().setModelType("gaussian")
+ >>> model4 = nb3.fit(df)
+ >>> model4.getModelType()
+ 'gaussian'
+ >>> model4.sigma
+ DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)
+ >>> nb5 = NaiveBayes(smoothing=1.0, modelType="complement", weightCol="weight")
+ >>> model5 = nb5.fit(df)
+ >>> model5.getModelType()
+ 'complement'
+ >>> model5.theta
+ DenseMatrix(2, 2, [...], 1)
+ >>> model5.sigma == None
+ True
.. versionadded:: 1.5.0
"""
@@ -2040,6 +2062,14 @@ def theta(self):
"""
return self._call_java("theta")
+ @property
+ @since("3.0.0")
+ def sigma(self):
+ """
+ variance of each feature.
+ """
+ return self._call_java("sigma")
+
class _MultilayerPerceptronParams(_JavaProbabilisticClassifierParams, HasSeed, HasMaxIter,
HasTol, HasStepSize, HasSolver):
@@ -2114,7 +2144,7 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
100
>>> model = mlp.fit(df)
>>> model.setFeaturesCol("features")
- MultilayerPerceptronClassifier...
+ MultilayerPerceptronClassificationModel...
>>> model.layers
[2, 2, 2]
>>> model.weights.size
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 39cc62670ae88..5aab7a3f5077b 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -234,7 +234,7 @@ class GaussianMixture(JavaEstimator, _GaussianMixtureParams, JavaMLWritable, Jav
>>> model.getFeaturesCol()
'features'
>>> model.setPredictionCol("newPrediction")
- GaussianMixture...
+ GaussianMixtureModel...
>>> model.predict(df.head().features)
2
>>> model.predictProbability(df.head().features)
@@ -532,7 +532,7 @@ class KMeans(JavaEstimator, _KMeansParams, JavaMLWritable, JavaMLReadable):
>>> model.getDistanceMeasure()
'euclidean'
>>> model.setPredictionCol("newPrediction")
- KMeans...
+ KMeansModel...
>>> model.predict(df.head().features)
0
>>> centers = model.clusterCenters()
@@ -794,7 +794,7 @@ class BisectingKMeans(JavaEstimator, _BisectingKMeansParams, JavaMLWritable, Jav
>>> model.getMaxIter()
20
>>> model.setPredictionCol("newPrediction")
- BisectingKMeans...
+ BisectingKMeansModel...
>>> model.predict(df.head().features)
0
>>> centers = model.clusterCenters()
@@ -1265,6 +1265,8 @@ class LDA(JavaEstimator, _LDAParams, JavaMLReadable, JavaMLWritable):
10
>>> lda.clear(lda.maxIter)
>>> model = lda.fit(df)
+ >>> model.setSeed(1)
+ DistributedLDAModel...
>>> model.getTopicDistributionCol()
'topicDistribution'
>>> model.isDistributed()
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 6539e2abaed12..556a2f85c708d 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -374,6 +374,10 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
>>> evaluator.evaluate(dataset, {evaluator.metricName: "truePositiveRateByLabel",
... evaluator.metricLabel: 1.0})
0.75...
+ >>> evaluator.setMetricName("hammingLoss")
+ MulticlassClassificationEvaluator...
+ >>> evaluator.evaluate(dataset)
+ 0.33...
>>> mce_path = temp_path + "/mce"
>>> evaluator.save(mce_path)
>>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
@@ -408,7 +412,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
"(f1|accuracy|weightedPrecision|weightedRecall|weightedTruePositiveRate|"
"weightedFalsePositiveRate|weightedFMeasure|truePositiveRateByLabel|"
"falsePositiveRateByLabel|precisionByLabel|recallByLabel|fMeasureByLabel|"
- "logLoss)",
+ "logLoss|hammingLoss)",
typeConverter=TypeConverters.toString)
metricLabel = Param(Params._dummy(), "metricLabel",
"The class whose metric will be computed in truePositiveRateByLabel|"
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f6e531302317b..e771221d5f06d 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -337,6 +337,8 @@ class BucketedRandomProjectionLSH(_LSH, _BucketedRandomProjectionLSHParams,
>>> model = brp.fit(df)
>>> model.getBucketLength()
1.0
+ >>> model.setOutputCol("hashes")
+ BucketedRandomProjectionLSHModel...
>>> model.transform(df).head()
Row(id=0, features=DenseVector([-1.0, -1.0]), hashes=[DenseVector([-1.0])])
>>> data2 = [(4, Vectors.dense([2.0, 2.0 ]),),
@@ -733,6 +735,8 @@ class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, Jav
>>> cv.setOutputCol("vectors")
CountVectorizer...
>>> model = cv.fit(df)
+ >>> model.setInputCol("raw")
+ CountVectorizerModel...
>>> model.transform(df).show(truncate=False)
+-----+---------------+-------------------------+
|label|raw |vectors |
@@ -1345,6 +1349,8 @@ class IDF(JavaEstimator, _IDFParams, JavaMLReadable, JavaMLWritable):
>>> idf.setOutputCol("idf")
IDF...
>>> model = idf.fit(df)
+ >>> model.setOutputCol("idf")
+ IDFModel...
>>> model.getMinDocFreq()
3
>>> model.idf
@@ -1519,6 +1525,8 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable):
>>> imputer.getRelativeError()
0.001
>>> model = imputer.fit(df)
+ >>> model.setInputCols(["a", "b"])
+ ImputerModel...
>>> model.getStrategy()
'mean'
>>> model.surrogateDF.show()
@@ -1810,7 +1818,7 @@ class MaxAbsScaler(JavaEstimator, _MaxAbsScalerParams, JavaMLReadable, JavaMLWri
MaxAbsScaler...
>>> model = maScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
- MaxAbsScaler...
+ MaxAbsScalerModel...
>>> model.transform(df).show()
+-----+------------+
| a|scaledOutput|
@@ -1928,6 +1936,8 @@ class MinHashLSH(_LSH, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, JavaM
>>> mh.setSeed(12345)
MinHashLSH...
>>> model = mh.fit(df)
+ >>> model.setInputCol("features")
+ MinHashLSHModel...
>>> model.transform(df).head()
Row(id=0, features=SparseVector(6, {0: 1.0, 1: 1.0, 2: 1.0}), hashes=[DenseVector([6179668...
>>> data2 = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
@@ -2056,7 +2066,7 @@ class MinMaxScaler(JavaEstimator, _MinMaxScalerParams, JavaMLReadable, JavaMLWri
MinMaxScaler...
>>> model = mmScaler.fit(df)
>>> model.setOutputCol("scaledOutput")
- MinMaxScaler...
+ MinMaxScalerModel...
>>> model.originalMin
DenseVector([0.0])
>>> model.originalMax
@@ -2421,6 +2431,8 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW
>>> ohe.setOutputCols(["output"])
OneHotEncoder...
>>> model = ohe.fit(df)
+ >>> model.setOutputCols(["output"])
+ OneHotEncoderModel...
>>> model.getHandleInvalid()
'error'
>>> model.transform(df).head().output
@@ -2935,7 +2947,7 @@ class RobustScaler(JavaEstimator, _RobustScalerParams, JavaMLReadable, JavaMLWri
RobustScaler...
>>> model = scaler.fit(df)
>>> model.setOutputCol("output")
- RobustScaler...
+ RobustScalerModel...
>>> model.median
DenseVector([2.0, -2.0])
>>> model.range
@@ -3330,7 +3342,7 @@ class StandardScaler(JavaEstimator, _StandardScalerParams, JavaMLReadable, JavaM
>>> model.getInputCol()
'a'
>>> model.setOutputCol("output")
- StandardScaler...
+ StandardScalerModel...
>>> model.mean
DenseVector([1.0])
>>> model.std
@@ -3490,6 +3502,8 @@ class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, JavaMLW
>>> stringIndexer.setHandleInvalid("error")
StringIndexer...
>>> model = stringIndexer.fit(stringIndDf)
+ >>> model.setHandleInvalid("error")
+ StringIndexerModel...
>>> td = model.transform(stringIndDf)
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
... key=lambda x: x[0])
@@ -4166,7 +4180,7 @@ class VectorIndexer(JavaEstimator, _VectorIndexerParams, JavaMLReadable, JavaMLW
>>> indexer.getHandleInvalid()
'error'
>>> model.setOutputCol("output")
- VectorIndexer...
+ VectorIndexerModel...
>>> model.transform(df).head().output
DenseVector([1.0, 0.0])
>>> model.numFeatures
@@ -4487,6 +4501,8 @@ class Word2Vec(JavaEstimator, _Word2VecParams, JavaMLReadable, JavaMLWritable):
>>> model = word2Vec.fit(doc)
>>> model.getMinCount()
5
+ >>> model.setInputCol("sentence")
+ Word2VecModel...
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
@@ -4714,7 +4730,7 @@ class PCA(JavaEstimator, _PCAParams, JavaMLReadable, JavaMLWritable):
>>> model.getK()
2
>>> model.setOutputCol("output")
- PCA...
+ PCAModel...
>>> model.transform(df).collect()[0].output
DenseVector([1.648..., -4.013...])
>>> model.explainedVariance
@@ -5139,6 +5155,8 @@ class ChiSqSelector(JavaEstimator, _ChiSqSelectorParams, JavaMLReadable, JavaMLW
>>> model = selector.fit(df)
>>> model.getFeaturesCol()
'features'
+ >>> model.setFeaturesCol("features")
+ ChiSqSelectorModel...
>>> model.transform(df).head().selectedFeatures
DenseVector([18.0])
>>> model.selectedFeatures
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 5b34d555484d1..7d933daf9e032 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -166,7 +166,7 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7)
>>> fpm = fp.fit(data)
>>> fpm.setPredictionCol("newPrediction")
- FPGrowth...
+ FPGrowthModel...
>>> fpm.freqItemsets.show(5)
+---------+----+
| items|freq|
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 57ad1e6dfb3e6..fe61f9f0fffd6 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -484,8 +484,16 @@ def _copyValues(self, to, extra=None):
:return: the target instance with param values copied
"""
paramMap = self._paramMap.copy()
- if extra is not None:
- paramMap.update(extra)
+ if isinstance(extra, dict):
+ for param, value in extra.items():
+ if isinstance(param, Param):
+ paramMap[param] = value
+ else:
+ raise TypeError("Expecting a valid instance of Param, but received: {}"
+ .format(param))
+ elif extra is not None:
+ raise TypeError("Expecting a dict, but received an object of type {}."
+ .format(type(extra)))
for param in self.params:
# copy default params
if param in self._defaultParamMap and to.hasParam(param.name):
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 3ebd0ac2765f3..ee276962c898b 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -225,6 +225,8 @@ class ALS(JavaEstimator, _ALSParams, JavaMLWritable, JavaMLReadable):
>>> model = als.fit(df)
>>> model.getUserCol()
'user'
+ >>> model.setUserCol("user")
+ ALSModel...
>>> model.getItemCol()
'item'
>>> model.setPredictionCol("newPrediction")
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 84e39a035d80b..fdb04bb5115c5 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -105,9 +105,9 @@ class LinearRegression(JavaPredictor, _LinearRegressionParams, JavaMLWritable, J
LinearRegression...
>>> model = lr.fit(df)
>>> model.setFeaturesCol("features")
- LinearRegression...
+ LinearRegressionModel...
>>> model.setPredictionCol("newPrediction")
- LinearRegression...
+ LinearRegressionModel...
>>> model.getMaxIter()
5
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -591,7 +591,7 @@ class IsotonicRegression(JavaEstimator, _IsotonicRegressionParams, HasWeightCol,
>>> ir = IsotonicRegression()
>>> model = ir.fit(df)
>>> model.setFeaturesCol("features")
- IsotonicRegression...
+ IsotonicRegressionModel...
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -1546,7 +1546,7 @@ class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
>>> aftsr.clear(aftsr.maxIter)
>>> model = aftsr.fit(df)
>>> model.setFeaturesCol("features")
- AFTSurvivalRegression...
+ AFTSurvivalRegressionModel...
>>> model.predict(Vectors.dense(6.3))
1.0
>>> model.predictQuantiles(Vectors.dense(6.3))
@@ -1881,7 +1881,7 @@ class GeneralizedLinearRegression(JavaPredictor, _GeneralizedLinearRegressionPar
>>> glr.clear(glr.maxIter)
>>> model = glr.fit(df)
>>> model.setFeaturesCol("features")
- GeneralizedLinearRegression...
+ GeneralizedLinearRegressionModel...
>>> model.getMaxIter()
25
>>> model.getAggregationDepth()
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
index 75cd903b5d6d7..777b4930ce8c9 100644
--- a/python/pyspark/ml/tests/test_param.py
+++ b/python/pyspark/ml/tests/test_param.py
@@ -307,6 +307,10 @@ def test_copy_param_extras(self):
copied_no_extra[k] = v
self.assertEqual(tp._paramMap, copied_no_extra)
self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
+ with self.assertRaises(TypeError):
+ tp.copy(extra={"unknown_parameter": None})
+ with self.assertRaises(TypeError):
+ tp.copy(extra=["must be a dict"])
def test_logistic_regression_check_thresholds(self):
self.assertIsInstance(
diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py
index 176e99d052d30..9d8ba37c60da4 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -63,6 +63,15 @@ def _fit(self, dataset):
return model
+class ParamGridBuilderTests(SparkSessionTestCase):
+
+ def test_addGrid(self):
+ with self.assertRaises(TypeError):
+ grid = (ParamGridBuilder()
+ .addGrid("must be an instance of Param", ["not", "string"])
+ .build())
+
+
class CrossValidatorTests(SparkSessionTestCase):
def test_copy(self):
diff --git a/python/pyspark/ml/tree.py b/python/pyspark/ml/tree.py
index f38a7375c2c54..d97a950c9276e 100644
--- a/python/pyspark/ml/tree.py
+++ b/python/pyspark/ml/tree.py
@@ -56,9 +56,6 @@ def predictLeaf(self, value):
"""
return self._call_java("predictLeaf", value)
- def __repr__(self):
- return self._call_java("toString")
-
class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
"""
@@ -208,9 +205,6 @@ def predictLeaf(self, value):
"""
return self._call_java("predictLeaf", value)
- def __repr__(self):
- return self._call_java("toString")
-
class _TreeEnsembleParams(_DecisionTreeParams):
"""
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 16c376296c20d..5eb8ae44d3d66 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -88,8 +88,14 @@ def __init__(self):
def addGrid(self, param, values):
"""
Sets the given parameters in this grid to fixed values.
+
+ param must be an instance of Param associated with an instance of Params
+ (such as Estimator or Transformer).
"""
- self._param_grid[param] = values
+ if isinstance(param, Param):
+ self._param_grid[param] = values
+ else:
+ raise TypeError("param must be an instance of Param")
return self
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 47e4921541ea2..ae3a6ba24ffa5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -372,6 +372,9 @@ def __init__(self, java_model=None):
self._resetUid(java_model.uid())
+ def __repr__(self):
+ return self._call_java("toString")
+
@inherit_doc
class _JavaPredictorParams(HasLabelCol, HasFeaturesCol, HasPredictionCol):
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index ad4cf7c9b1c8b..93137560de25e 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -19,6 +19,8 @@
import re
import sys
import traceback
+import os
+import warnings
import inspect
from py4j.protocol import Py4JJavaError
@@ -112,6 +114,33 @@ def wrapper(*args, **kwargs):
return wrapper
+def _warn_pin_thread(name):
+ if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
+ msg = (
+ "PYSPARK_PIN_THREAD feature is enabled. "
+ "However, note that it cannot inherit the local properties from the parent thread "
+ "although it isolates each thread on PVM and JVM with its own local properties. "
+ "\n"
+ "To work around this, you should manually copy and set the local properties from "
+ "the parent thread to the child thread when you create another thread.")
+ else:
+ msg = (
+ "Currently, '%s' (set to local properties) with multiple threads does "
+ "not properly work. "
+ "\n"
+ "Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
+ "for multiple threads on PVM, which fails to isolate local properties for each "
+ "thread on PVM. "
+ "\n"
+ "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
+ "However, note that it cannot inherit the local properties from the parent thread "
+ "although it isolates each thread on PVM and JVM with its own local properties. "
+ "\n"
+ "To work around this, you should manually copy and set the local properties from "
+ "the parent thread to the child thread when you create another thread." % name)
+ warnings.warn(msg, UserWarning)
+
+
def _print_missing_jar(lib_name, pkg_name, jar_name, spark_version):
print("""
________________________________________________________________________________________________
diff --git a/python/run-tests.py b/python/run-tests.py
index 5bcf8b0669129..88b148c6587d5 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -86,9 +86,10 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
env["TMPDIR"] = tmp_dir
# Also override the JVM's temp directory by setting driver and executor options.
+ java_options = "-Djava.io.tmpdir={0} -Dio.netty.tryReflectionSetAccessible=true".format(tmp_dir)
spark_args = [
- "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
- "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
+ "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
"pyspark-shell"
]
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
index 65e595e3cf2bf..5a4bf1dd2d409 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
@@ -28,6 +28,7 @@ import org.apache.mesos.protobuf.ByteString
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{times, verify}
+import org.scalatest.Assertions._
import org.apache.spark.deploy.mesos.config.MesosSecretConfig
@@ -161,12 +162,14 @@ object Utils {
val variableOne = envVars.filter(_.getName == "USER").head
assert(variableOne.getSecret.isInitialized)
assert(variableOne.getSecret.getType == Secret.Type.VALUE)
- assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes))
+ assert(variableOne.getSecret.getValue.getData ==
+ ByteString.copyFrom("user".getBytes))
assert(variableOne.getType == Environment.Variable.Type.SECRET)
val variableTwo = envVars.filter(_.getName == "PASSWORD").head
assert(variableTwo.getSecret.isInitialized)
assert(variableTwo.getSecret.getType == Secret.Type.VALUE)
- assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes))
+ assert(variableTwo.getSecret.getValue.getData ==
+ ByteString.copyFrom("password".getBytes))
assert(variableTwo.getType == Environment.Variable.Type.SECRET)
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index a1b5d53b91416..696afaacb0e79 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -538,7 +538,11 @@ private[spark] class Client(
if (!Utils.isLocalUri(jar)) {
val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf)
val pathFs = FileSystem.get(path.toUri(), hadoopConf)
- pathFs.globStatus(path).filter(_.isFile()).foreach { entry =>
+ val fss = pathFs.globStatus(path)
+ if (fss == null) {
+ throw new FileNotFoundException(s"Path ${path.toString} does not exist")
+ }
+ fss.filter(_.isFile()).foreach { entry =>
val uri = entry.getPath().toUri()
statCache.update(uri, entry)
distribute(uri.toString(), targetDir = Some(LOCALIZED_LIB_DIR))
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index ba7c2dd8a1cdf..7cce908cd5fb7 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.yarn
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream}
import java.net.URI
import java.util.Properties
@@ -473,6 +473,18 @@ class ClientSuite extends SparkFunSuite with Matchers {
assert(allResourceInfo.get(yarnMadeupResource).get === 5)
}
+ test("test yarn jars path not exists") {
+ withTempDir { dir =>
+ val conf = new SparkConf().set(SPARK_JARS, Seq(dir.getAbsolutePath + "/test"))
+ val client = new Client(new ClientArguments(Array()), conf, null)
+ withTempDir { distDir =>
+ intercept[FileNotFoundException] {
+ client.prepareLocalResources(new Path(distDir.getAbsolutePath), Nil)
+ }
+ }
+ }
+ }
+
private val matching = Seq(
("files URI match test1", "file:///file1", "file:///file2"),
("files URI match test2", "file:///c:file1", "file://c:file2"),
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index bcebb225dfaca..b416994195d01 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -148,7 +148,7 @@
org.scalatest
scalatest-maven-plugin
- -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} -Dio.netty.tryReflectionSetAccessible=true
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 49fba6b7f35df..cc273fd36011e 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -46,9 +46,9 @@ grammar SqlBase;
}
/**
- * When true, ANSI SQL parsing mode is enabled.
+ * When true, the behavior of keywords follows ANSI SQL standard.
*/
- public boolean ansi = false;
+ public boolean SQL_standard_keyword_behavior = false;
}
singleStatement
@@ -79,10 +79,6 @@ singleTableSchema
: colTypeList EOF
;
-singleInterval
- : INTERVAL? multiUnitsInterval EOF
- ;
-
statement
: query #statementDefault
| ctes? dmlStatementNoWith #dmlStatement
@@ -91,10 +87,10 @@ statement
((COMMENT comment=STRING) |
locationSpec |
(WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace
- | ALTER database db=errorCapturingIdentifier
- SET DBPROPERTIES tablePropertyList #setDatabaseProperties
- | ALTER database db=errorCapturingIdentifier
- SET locationSpec #setDatabaseLocation
+ | ALTER (database | NAMESPACE) multipartIdentifier
+ SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties
+ | ALTER (database | NAMESPACE) multipartIdentifier
+ SET locationSpec #setNamespaceLocation
| DROP (database | NAMESPACE) (IF EXISTS)? multipartIdentifier
(RESTRICT | CASCADE)? #dropNamespace
| SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)?
@@ -144,8 +140,8 @@ statement
'(' columns=multipartIdentifierList ')' #dropTableColumns
| ALTER TABLE multipartIdentifier
DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns
- | ALTER (TABLE | VIEW) from=tableIdentifier
- RENAME TO to=tableIdentifier #renameTable
+ | ALTER (TABLE | VIEW) from=multipartIdentifier
+ RENAME TO to=multipartIdentifier #renameTable
| ALTER (TABLE | VIEW) multipartIdentifier
SET TBLPROPERTIES tablePropertyList #setTableProperties
| ALTER (TABLE | VIEW) multipartIdentifier
@@ -187,7 +183,7 @@ statement
statement #explain
| SHOW TABLES ((FROM | IN) multipartIdentifier)?
(LIKE? pattern=STRING)? #showTables
- | SHOW TABLE EXTENDED ((FROM | IN) db=errorCapturingIdentifier)?
+ | SHOW TABLE EXTENDED ((FROM | IN) namespace=multipartIdentifier)?
LIKE pattern=STRING partitionSpec? #showTable
| SHOW TBLPROPERTIES table=multipartIdentifier
('(' key=tablePropertyKey ')')? #showTblProperties
@@ -199,7 +195,8 @@ statement
| SHOW CREATE TABLE multipartIdentifier #showCreateTable
| SHOW CURRENT NAMESPACE #showCurrentNamespace
| (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction
- | (DESC | DESCRIBE) database EXTENDED? db=errorCapturingIdentifier #describeDatabase
+ | (DESC | DESCRIBE) (database | NAMESPACE) EXTENDED?
+ multipartIdentifier #describeNamespace
| (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)?
multipartIdentifier partitionSpec? describeColName? #describeTable
| (DESC | DESCRIBE) QUERY? query #describeQuery
@@ -217,14 +214,6 @@ statement
| SET ROLE .*? #failNativeCommand
| SET .*? #setConfiguration
| RESET #resetConfiguration
- | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
- | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
- | MERGE INTO target=multipartIdentifier targetAlias=tableAlias
- USING (source=multipartIdentifier |
- '(' sourceQuery=query')') sourceAlias=tableAlias
- ON mergeCondition=booleanExpression
- matchedClause*
- notMatchedClause* #mergeIntoTable
| unsupportedHiveNativeCommands .*? #failNativeCommand
;
@@ -401,6 +390,14 @@ resource
dmlStatementNoWith
: insertInto queryTerm queryOrganization #singleInsertQuery
| fromClause multiInsertQueryBody+ #multiInsertQuery
+ | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
+ | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
+ | MERGE INTO target=multipartIdentifier targetAlias=tableAlias
+ USING (source=multipartIdentifier |
+ '(' sourceQuery=query')') sourceAlias=tableAlias
+ ON mergeCondition=booleanExpression
+ matchedClause*
+ notMatchedClause* #mergeIntoTable
;
queryOrganization
@@ -747,7 +744,7 @@ primaryExpression
| qualifiedName '.' ASTERISK #star
| '(' namedExpression (',' namedExpression)+ ')' #rowConstructor
| '(' query ')' #subqueryExpression
- | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
+ | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
(OVER windowSpec)? #functionCall
| identifier '->' expression #lambda
| '(' identifier (',' identifier)+ ')' '->' expression #lambda
@@ -767,7 +764,7 @@ primaryExpression
constant
: NULL #nullLiteral
| interval #intervalLiteral
- | negativeSign=MINUS? identifier STRING #typeConstructor
+ | identifier STRING #typeConstructor
| number #numericLiteral
| booleanValue #booleanLiteral
| STRING+ #stringLiteral
@@ -790,8 +787,8 @@ booleanValue
;
interval
- : negativeSign=MINUS? INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)?
- | {ansi}? (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)
+ : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)?
+ | {SQL_standard_keyword_behavior}? (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)
;
errorCapturingMultiUnitsInterval
@@ -911,6 +908,12 @@ qualifiedNameList
: qualifiedName (',' qualifiedName)*
;
+functionName
+ : qualifiedName
+ | LEFT
+ | RIGHT
+ ;
+
qualifiedName
: identifier ('.' identifier)*
;
@@ -930,14 +933,14 @@ errorCapturingIdentifierExtra
identifier
: strictIdentifier
- | {!ansi}? strictNonReserved
+ | {!SQL_standard_keyword_behavior}? strictNonReserved
;
strictIdentifier
: IDENTIFIER #unquotedIdentifier
| quotedIdentifier #quotedIdentifierAlternative
- | {ansi}? ansiNonReserved #unquotedIdentifier
- | {!ansi}? nonReserved #unquotedIdentifier
+ | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier
+ | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier
;
quotedIdentifier
@@ -954,7 +957,7 @@ number
| MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral
;
-// When `spark.sql.ansi.enabled=true`, there are 2 kinds of keywords in Spark SQL.
+// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
// - Reserved keywords:
// Keywords that are reserved and can't be used as identifiers for table, view, column,
// function, alias, etc.
@@ -1154,9 +1157,9 @@ ansiNonReserved
| YEARS
;
-// When `spark.sql.ansi.enabled=false`, there are 2 kinds of keywords in Spark SQL.
+// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL.
// - Non-reserved keywords:
-// Same definition as the one when `spark.sql.ansi.enabled=true`.
+// Same definition as the one when `SQL_standard_keyword_behavior=true`.
// - Strict-non-reserved keywords:
// A strict version of non-reserved keywords, which can not be used as table alias.
// You can find the full keywords list by searching "Start of the keywords list" in this file.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
index 37c5539d2518f..3e8b14172d6b2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java
@@ -23,8 +23,8 @@
* An interface that defines how to write the data to data source for batch processing.
*
* The writing procedure is:
- * 1. Create a writer factory by {@link #createBatchWriterFactory()}, serialize and send it to all
- * the partitions of the input data(RDD).
+ * 1. Create a writer factory by {@link #createBatchWriterFactory(PhysicalWriteInfo)}, serialize
+ * and send it to all the partitions of the input data(RDD).
* 2. For each partition, create the data writer, and write the data of the partition with this
* writer. If all the data are written successfully, call {@link DataWriter#commit()}. If
* exception happens during the writing, call {@link DataWriter#abort()}.
@@ -45,8 +45,10 @@ public interface BatchWrite {
*
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
+ *
+ * @param info Physical information about the input data that will be written to this table.
*/
- DataWriterFactory createBatchWriterFactory();
+ DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info);
/**
* Returns whether Spark should use the commit coordinator to ensure that at most one task for
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriterFactory.java
index bcf8d8a59e5e5..310575df05d97 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriterFactory.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriterFactory.java
@@ -24,8 +24,9 @@
import org.apache.spark.sql.catalyst.InternalRow;
/**
- * A factory of {@link DataWriter} returned by {@link BatchWrite#createBatchWriterFactory()},
- * which is responsible for creating and initializing the actual data writer at executor side.
+ * A factory of {@link DataWriter} returned by
+ * {@link BatchWrite#createBatchWriterFactory(PhysicalWriteInfo)}, which is responsible for
+ * creating and initializing the actual data writer at executor side.
*
* Note that, the writer factory will be serialized and sent to executors, then the data writer
* will be created on executors and do the actual writing. So this interface must be
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PhysicalWriteInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PhysicalWriteInfo.java
new file mode 100644
index 0000000000000..55a092e39970e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/PhysicalWriteInfo.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
+
+/**
+ * This interface contains physical write information that data sources can use when
+ * generating a {@link DataWriterFactory} or a {@link StreamingDataWriterFactory}.
+ */
+@Evolving
+public interface PhysicalWriteInfo {
+ /**
+ * The number of partitions of the input data that is going to be written.
+ */
+ int numPartitions();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java
index daaa18d5bc4e7..9946867e8ea65 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java
@@ -23,11 +23,12 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.write.DataWriter;
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
/**
* A factory of {@link DataWriter} returned by
- * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating
- * and initializing the actual data writer at executor side.
+ * {@link StreamingWrite#createStreamingWriterFactory(PhysicalWriteInfo)}, which is responsible for
+ * creating and initializing the actual data writer at executor side.
*
* Note that, the writer factory will be serialized and sent to executors, then the data writer
* will be created on executors and do the actual writing. So this interface must be
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
index 0821b34891654..4f930e1c158e5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java
@@ -19,14 +19,15 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.write.DataWriter;
+import org.apache.spark.sql.connector.write.PhysicalWriteInfo;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
/**
* An interface that defines how to write the data to data source in streaming queries.
*
* The writing procedure is:
- * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to
- * all the partitions of the input data(RDD).
+ * 1. Create a writer factory by {@link #createStreamingWriterFactory(PhysicalWriteInfo)},
+ * serialize and send it to all the partitions of the input data(RDD).
* 2. For each epoch in each partition, create the data writer, and write the data of the epoch in
* the partition with this writer. If all the data are written successfully, call
* {@link DataWriter#commit()}. If exception happens during the writing, call
@@ -48,8 +49,10 @@ public interface StreamingWrite {
*
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
+ *
+ * @param info Information about the RDD that will be written to this data writer
*/
- StreamingDataWriterFactory createStreamingWriterFactory();
+ StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo info);
/**
* Commits this writing job for the specified epoch with a list of commit messages. The commit
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a7443e71c0ca3..625ef2153c711 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -198,7 +198,6 @@ class Analyzer(
ResolveTableValuedFunctions ::
new ResolveCatalogs(catalogManager) ::
ResolveInsertInto ::
- ResolveTables ::
ResolveRelations ::
ResolveReferences ::
ResolveCreateNamedStruct ::
@@ -666,12 +665,26 @@ class Analyzer(
}
/**
- * Resolve table relations with concrete relations from v2 catalog.
+ * Resolve relations to temp views. This is not an actual rule, and is only called by
+ * [[ResolveTables]].
+ */
+ object ResolveTempViews extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ case u @ UnresolvedRelation(Seq(part1)) =>
+ v1SessionCatalog.lookupTempView(part1).getOrElse(u)
+ case u @ UnresolvedRelation(Seq(part1, part2)) =>
+ v1SessionCatalog.lookupGlobalTempView(part1, part2).getOrElse(u)
+ }
+ }
+
+ /**
+ * Resolve table relations with concrete relations from v2 catalog. This is not an actual rule,
+ * and is only called by [[ResolveRelations]].
*
* [[ResolveRelations]] still resolves v1 tables.
*/
object ResolveTables extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case u: UnresolvedRelation =>
lookupV2Relation(u.multipartIdentifier)
.getOrElse(u)
@@ -733,10 +746,6 @@ class Analyzer(
// Note this is compatible with the views defined by older versions of Spark(before 2.2), which
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
- case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
- if v1SessionCatalog.isTemporaryTable(ident) =>
- resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase))
-
case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
val defaultDatabase = AnalysisContext.get.defaultDatabase
val foundRelation = lookupTableFromCatalog(ident, u, defaultDatabase)
@@ -767,7 +776,7 @@ class Analyzer(
case _ => plan
}
- def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ def apply(plan: LogicalPlan): LogicalPlan = ResolveTables(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(u @ UnresolvedRelation(AsTableIdentifier(ident)), _, child, _, _)
if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(ident, u)) match {
@@ -2839,7 +2848,6 @@ class Analyzer(
private def lookupV2RelationAndCatalog(
identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] =
identifier match {
- case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index cb18aa1a9479b..7cc64d43858c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -490,7 +490,8 @@ object FunctionRegistry {
expression[CurrentDatabase]("current_database"),
expression[CallMethodViaReflection]("reflect"),
expression[CallMethodViaReflection]("java_method"),
- expression[Version]("version"),
+ expression[SparkVersion]("version"),
+ expression[TypeOf]("typeof"),
// grouping sets
expression[Cube]("cube"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
index f1a8e5bfda4a9..2f2e4e619eb4a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
@@ -93,6 +93,19 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
s"Can not specify catalog `${catalog.name}` for view ${tableName.quoted} " +
s"because view support in catalog has not been implemented yet")
+ case AlterNamespaceSetPropertiesStatement(NonSessionCatalog(catalog, nameParts), properties) =>
+ AlterNamespaceSetProperties(catalog.asNamespaceCatalog, nameParts, properties)
+
+ case AlterNamespaceSetLocationStatement(NonSessionCatalog(catalog, nameParts), location) =>
+ AlterNamespaceSetProperties(
+ catalog.asNamespaceCatalog, nameParts, Map("location" -> location))
+
+ case RenameTableStatement(NonSessionCatalog(catalog, oldName), newNameParts, isView) =>
+ if (isView) {
+ throw new AnalysisException("Renaming view is not supported in v2 catalogs.")
+ }
+ RenameTable(catalog.asTableCatalog, oldName.asIdentifier, newNameParts.asIdentifier)
+
case DescribeTableStatement(
nameParts @ NonSessionCatalog(catalog, tableName), partitionSpec, isExtended) =>
if (partitionSpec.nonEmpty) {
@@ -172,6 +185,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
case DropNamespaceStatement(NonSessionCatalog(catalog, nameParts), ifExists, cascade) =>
DropNamespace(catalog, nameParts, ifExists, cascade)
+ case DescribeNamespaceStatement(NonSessionCatalog(catalog, nameParts), extended) =>
+ DescribeNamespace(catalog.asNamespaceCatalog, nameParts, extended)
+
case ShowNamespacesStatement(Some(CatalogAndNamespace(catalog, namespace)), pattern) =>
ShowNamespaces(catalog.asNamespaceCatalog, namespace, pattern)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index d904ba3aca5d5..5b77d67bd1340 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -223,7 +223,7 @@ object ResolveHints {
createRepartition(shuffle = false, hint)
case "REPARTITION_BY_RANGE" =>
createRepartitionByRange(hint)
- case _ => plan
+ case _ => hint
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index b27d6ed0efed8..83c76c2d4e2bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -132,12 +132,9 @@ object TypeCoercion {
case (NullType, StringType) => Some(StringType)
// Cast to TimestampType when we compare DateType with TimestampType
- // if conf.compareDateTimestampInTimestamp is true
// i.e. TimeStamp('2017-03-01 00:00:00') eq Date('2017-03-01') = true
- case (TimestampType, DateType)
- => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType)
- case (DateType, TimestampType)
- => if (conf.compareDateTimestampInTimestamp) Some(TimestampType) else Some(StringType)
+ case (TimestampType, DateType) => Some(TimestampType)
+ case (DateType, TimestampType) => Some(TimestampType)
// There is no proper decimal type we can pick,
// using double type is the best we can do.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
index 4cff162c116a4..ae3b75dc3334b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.catalog
import java.net.URI
-import java.util.Locale
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell
@@ -26,7 +25,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, Predicate}
object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
@@ -148,7 +147,7 @@ object ExternalCatalogUtils {
}
val boundPredicate =
- InterpretedPredicate.create(predicates.reduce(And).transform {
+ Predicate.createInterpreted(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index be8526454f9f1..96ca1ac73e043 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -327,8 +327,7 @@ class SessionCatalog(
def validateTableLocation(table: CatalogTable): Unit = {
// SPARK-19724: the default location of a managed table should be non-existent or empty.
- if (table.tableType == CatalogTableType.MANAGED &&
- !conf.allowCreatingManagedTableUsingNonemptyLocation) {
+ if (table.tableType == CatalogTableType.MANAGED) {
val tableLocation =
new Path(table.storage.locationUri.getOrElse(defaultTablePath(table.identifier)))
val fs = tableLocation.getFileSystem(hadoopConf)
@@ -576,6 +575,10 @@ class SessionCatalog(
tempViews.get(formatTableName(name))
}
+ def getTempViewNames(): Seq[String] = synchronized {
+ tempViews.keySet.toSeq
+ }
+
/**
* Return a global temporary view exactly as it was stored.
*/
@@ -764,6 +767,25 @@ class SessionCatalog(
}
}
+ def lookupTempView(table: String): Option[SubqueryAlias] = {
+ val formattedTable = formatTableName(table)
+ getTempView(formattedTable).map { view =>
+ SubqueryAlias(formattedTable, view)
+ }
+ }
+
+ def lookupGlobalTempView(db: String, table: String): Option[SubqueryAlias] = {
+ val formattedDB = formatDatabaseName(db)
+ if (formattedDB == globalTempViewManager.database) {
+ val formattedTable = formatTableName(table)
+ getGlobalTempView(formattedTable).map { view =>
+ SubqueryAlias(formattedTable, formattedDB, view)
+ }
+ } else {
+ None
+ }
+ }
+
/**
* Return whether a table with the specified name is a temporary view.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f3b58fa3137b1..8d11f4663a3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -30,7 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
+import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.IntervalStyle._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -281,6 +283,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// UDFToString
private[this] def castToString(from: DataType): Any => Any = from match {
+ case CalendarIntervalType => SQLConf.get.intervalOutputStyle match {
+ case SQL_STANDARD =>
+ buildCast[CalendarInterval](_, i => UTF8String.fromString(toSqlStandardString(i)))
+ case ISO_8601 =>
+ buildCast[CalendarInterval](_, i => UTF8String.fromString(toIso8601String(i)))
+ case MULTI_UNITS =>
+ buildCast[CalendarInterval](_, i => UTF8String.fromString(toMultiUnitsString(i)))
+ }
case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d)))
case TimestampType => buildCast[Long](_,
@@ -467,7 +477,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => IntervalUtils.stringToInterval(s))
+ buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
}
// LongConverter
@@ -592,7 +602,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
* Change the precision / scale in a given decimal to those set in `decimalType` (if any),
* modifying `value` in-place and returning it if successful. If an overflow occurs, it
* either returns null or throws an exception according to the value set for
- * `spark.sql.ansi.enabled`.
+ * `spark.sql.dialect.spark.ansi.enabled`.
*
* NOTE: this modifies `value` in-place, so don't call it on external data.
*/
@@ -611,7 +621,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
/**
* Create new `Decimal` with precision and scale given in `decimalType` (if any).
- * If overflow occurs, if `spark.sql.ansi.enabled` is false, null is returned;
+ * If overflow occurs, if `spark.sql.dialect.spark.ansi.enabled` is false, null is returned;
* otherwise, an `ArithmeticException` is thrown.
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
@@ -985,6 +995,14 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
timestampFormatter.getClass)
(c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($tf, $c));"""
+ case CalendarIntervalType =>
+ val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
+ val funcName = SQLConf.get.intervalOutputStyle match {
+ case SQL_STANDARD => "toSqlStandardString"
+ case ISO_8601 => "toIso8601String"
+ case MULTI_UNITS => "toMultiUnitsString"
+ }
+ (c, evPrim, _) => code"""$evPrim = UTF8String.fromString($iu.$funcName($c));"""
case ArrayType(et, _) =>
(c, evPrim, evNull) => {
val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
@@ -1216,7 +1234,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case StringType =>
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, evNull) =>
- code"""$evPrim = $util.stringToInterval($c);
+ code"""$evPrim = $util.safeStringToInterval($c);
if(${evPrim} == null) {
${evNull} = true;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 300f075d32763..b4a85e3e50bec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -89,14 +89,14 @@ object MutableProjection
}
/**
- * Returns an MutableProjection for given sequence of bound Expressions.
+ * Returns a MutableProjection for given sequence of bound Expressions.
*/
def create(exprs: Seq[Expression]): MutableProjection = {
createObject(exprs)
}
/**
- * Returns an MutableProjection for given sequence of Expressions, which will be bound to
+ * Returns a MutableProjection for given sequence of Expressions, which will be bound to
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 69badb9562dc3..caacb71814f17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
case class TimeWindow(
timeColumn: Expression,
@@ -103,7 +104,7 @@ object TimeWindow {
* precision.
*/
private def getIntervalInMicroSeconds(interval: String): Long = {
- val cal = IntervalUtils.fromString(interval)
+ val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(
s"Intervals greater than a month is not supported ($interval).")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 82a8e6d80a0bd..7650fb07a61cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -150,7 +150,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")
// Name of the function for the exact version of this expression in [[Math]].
- // If the option "spark.sql.ansi.enabled" is enabled and there is corresponding
+ // If the option "spark.sql.dialect.spark.ansi.enabled" is enabled and there is corresponding
// function in [[Math]], the exact function will be called instead of evaluation with [[symbol]].
def exactMathMethod: Option[String] = None
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index b66b80ad31dc2..63bd59e7628b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -29,19 +29,11 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
-/**
- * Inherits some default implementation for Java from `Ordering[Row]`
- */
-class BaseOrdering extends Ordering[InternalRow] {
- def compare(a: InternalRow, b: InternalRow): Int = {
- throw new UnsupportedOperationException
- }
-}
/**
* Generates bytecode for an [[Ordering]] of rows for a given set of expressions.
*/
-object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging {
+object GenerateOrdering extends CodeGenerator[Seq[SortOrder], BaseOrdering] with Logging {
protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index e0fabad6d089a..6ba646d360d2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -20,31 +20,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-/**
- * Interface for generated predicate
- */
-abstract class Predicate {
- def eval(r: InternalRow): Boolean
-
- /**
- * Initializes internal states given the current partition index.
- * This is used by nondeterministic expressions to set initial states.
- * The default implementation does nothing.
- */
- def initialize(partitionIndex: Int): Unit = {}
-}
-
/**
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
*/
-object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
+object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
BindReferences.bindReference(in, inputSchema)
- protected def create(predicate: Expression): Predicate = {
+ protected def create(predicate: Expression): BasePredicate = {
val ctx = newCodeGenContext()
val eval = predicate.genCode(ctx)
@@ -53,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
return new SpecificPredicate(references);
}
- class SpecificPredicate extends ${classOf[Predicate].getName} {
+ class SpecificPredicate extends ${classOf[BasePredicate].getName} {
private final Object[] references;
${ctx.declareMutableStates()}
@@ -79,6 +65,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
val (clazz, _) = CodeGenerator.compile(code)
- clazz.generate(ctx.references.toArray).asInstanceOf[Predicate]
+ clazz.generate(ctx.references.toArray).asInstanceOf[BasePredicate]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 5d964b602e634..d5d42510842ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -900,54 +900,6 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}
-
-/**
- * Sorts the input array in ascending order according to the natural ordering of
- * the array elements and returns it.
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = """
- _FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
- be orderable. Null elements will be placed at the end of the returned array.
- """,
- examples = """
- Examples:
- > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
- ["a","b","c","d",null]
- """,
- since = "2.4.0")
-// scalastyle:on line.size.limit
-case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike {
-
- override def dataType: DataType = child.dataType
- override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
-
- override def arrayExpression: Expression = child
- override def nullOrder: NullOrder = NullOrder.Greatest
-
- override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
- case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
- TypeCheckResult.TypeCheckSuccess
- case ArrayType(dt, _) =>
- val dtSimple = dt.catalogString
- TypeCheckResult.TypeCheckFailure(
- s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
- case _ =>
- TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
- }
-
- override def nullSafeEval(array: Any): Any = {
- sortEval(array, true)
- }
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
- }
-
- override def prettyName: String = "array_sort"
-}
-
/**
* Returns a random permutation of the given array.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 317ebb62c07ec..adeda0981fe8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import java.util.Comparator
import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
@@ -285,6 +286,113 @@ case class ArrayTransform(
override def prettyName: String = "transform"
}
+/**
+ * Sorts elements in an array using a comparator function.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(expr, func) - Sorts the input array in ascending order. The elements of the
+ input array must be orderable. Null elements will be placed at the end of the returned
+ array. Since 3.0.0 this function also sorts and returns the array based on the given
+ comparator function. The comparator will take two arguments
+ representing two elements of the array.
+ It returns -1, 0, or 1 as the first element is less than, equal to, or greater
+ than the second element. If the comparator function returns other
+ values (including null), the function will fail and raise an error.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(5, 6, 1), (left, right) -> case when left < right then -1 when left > right then 1 else 0 end);
+ [1,5,6]
+ > SELECT _FUNC_(array('bc', 'ab', 'dc'), (left, right) -> case when left is null and right is null then 0 when left is null then -1 when right is null then 1 when left < right then 1 when left > right then -1 else 0 end);
+ ["dc","bc","ab"]
+ > SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
+ ["a","b","c","d",null]
+ """,
+ since = "2.4.0")
+// scalastyle:on line.size.limit
+case class ArraySort(
+ argument: Expression,
+ function: Expression)
+ extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
+
+ def this(argument: Expression) = this(argument, ArraySort.defaultComparator)
+
+ @transient lazy val elementType: DataType =
+ argument.dataType.asInstanceOf[ArrayType].elementType
+
+ override def dataType: ArrayType = argument.dataType.asInstanceOf[ArrayType]
+ override def checkInputDataTypes(): TypeCheckResult = {
+ checkArgumentDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess =>
+ argument.dataType match {
+ case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
+ if (function.dataType == IntegerType) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " +
+ "IntegerType")
+ }
+ case ArrayType(dt, _) =>
+ val dtSimple = dt.catalogString
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName does not support sorting array of type $dtSimple which is not " +
+ "orderable")
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
+ }
+ case failure => failure
+ }
+ }
+
+ override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraySort = {
+ val ArrayType(elementType, containsNull) = argument.dataType
+ copy(function =
+ f(function, (elementType, containsNull) :: (elementType, containsNull) :: Nil))
+ }
+
+ @transient lazy val LambdaFunction(_,
+ Seq(firstElemVar: NamedLambdaVariable, secondElemVar: NamedLambdaVariable), _) = function
+
+ def comparator(inputRow: InternalRow): Comparator[Any] = {
+ val f = functionForEval
+ (o1: Any, o2: Any) => {
+ firstElemVar.value.set(o1)
+ secondElemVar.value.set(o2)
+ f.eval(inputRow).asInstanceOf[Int]
+ }
+ }
+
+ override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+ val arr = argumentValue.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
+ if (elementType != NullType) {
+ java.util.Arrays.sort(arr, comparator(inputRow))
+ }
+ new GenericArrayData(arr.asInstanceOf[Array[Any]])
+ }
+
+ override def prettyName: String = "array_sort"
+}
+
+object ArraySort {
+
+ def comparator(left: Expression, right: Expression): Expression = {
+ val lit0 = Literal(0)
+ val lit1 = Literal(1)
+ val litm1 = Literal(-1)
+
+ If(And(IsNull(left), IsNull(right)), lit0,
+ If(IsNull(left), lit1, If(IsNull(right), litm1,
+ If(LessThan(left, right), litm1, If(GreaterThan(left, right), lit1, lit0)))))
+ }
+
+ val defaultComparator: LambdaFunction = {
+ val left = UnresolvedNamedLambdaVariable(Seq("left"))
+ val right = UnresolvedNamedLambdaVariable(Seq("right"))
+ LambdaFunction(comparator(left, right), Seq(left, right))
+ }
+}
+
/**
* Filters entries in a map using the provided function.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index a13a6836c6be6..de7e1160185dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -515,12 +515,10 @@ case class JsonToStructs(
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
- val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)
-
// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
// can generate incorrect files if values are missing in columns declared as non-nullable.
- val nullableSchema = if (forceNullableSchema) schema.asNullable else schema
+ val nullableSchema = schema.asNullable
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 5a5d7a17acd99..48b8c9c0fbf8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -407,7 +407,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
case (v: Long, TimestampType) =>
val formatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
- s"TIMESTAMP('${formatter.format(v)}')"
+ s"TIMESTAMP '${formatter.format(v)}'"
+ case (i: CalendarInterval, CalendarIntervalType) =>
+ s"INTERVAL '${IntervalUtils.toMultiUnitsString(i)}'"
case (v: Array[Byte], BinaryType) => s"X'${DatatypeConverter.printHexBinary(v)}'"
case _ => value.toString
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index b8c23a1f08912..f576873829f27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -169,7 +169,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
usage = """_FUNC_() - Returns the Spark version. The string contains 2 fields, the first being a release version and the second being a git revision.""",
since = "3.0.0")
// scalastyle:on line.size.limit
-case class Version() extends LeafExpression with CodegenFallback {
+case class SparkVersion() extends LeafExpression with CodegenFallback {
override def nullable: Boolean = false
override def foldable: Boolean = true
override def dataType: DataType = StringType
@@ -177,3 +177,24 @@ case class Version() extends LeafExpression with CodegenFallback {
UTF8String.fromString(SPARK_VERSION_SHORT + " " + SPARK_REVISION)
}
}
+
+@ExpressionDescription(
+ usage = """_FUNC_(expr) - Return DDL-formatted type string for the data type of the input.""",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(1);
+ int
+ > SELECT _FUNC_(array(1));
+ array
+ """,
+ since = "3.0.0")
+case class TypeOf(child: Expression) extends UnaryExpression {
+ override def nullable: Boolean = false
+ override def foldable: Boolean = true
+ override def dataType: DataType = StringType
+ override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString)
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ defineCodeGen(ctx, ev, _ => s"""UTF8String.fromString(${child.dataType.catalogString})""")
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index c9706c09f6949..8867a03a4633b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -19,18 +19,28 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.types._
+/**
+ * A base class for generated/interpreted row ordering.
+ */
+class BaseOrdering extends Ordering[InternalRow] {
+ def compare(a: InternalRow, b: InternalRow): Int = {
+ throw new UnsupportedOperationException
+ }
+}
+
/**
* An interpreted row ordering comparator.
*/
-class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
+class InterpretedOrdering(ordering: Seq[SortOrder]) extends BaseOrdering {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(bindReferences(ordering, inputSchema))
- def compare(a: InternalRow, b: InternalRow): Int = {
+ override def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
val size = ordering.size
while (i < size) {
@@ -67,7 +77,7 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
}
i += 1
}
- return 0
+ 0
}
}
@@ -83,7 +93,7 @@ object InterpretedOrdering {
}
}
-object RowOrdering {
+object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], BaseOrdering] {
/**
* Returns true iff the data type can be ordered (i.e. can be sorted).
@@ -102,4 +112,26 @@ object RowOrdering {
* Returns true iff outputs from the expressions can be ordered.
*/
def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType))
+
+ override protected def createCodeGeneratedObject(in: Seq[SortOrder]): BaseOrdering = {
+ GenerateOrdering.generate(in)
+ }
+
+ override protected def createInterpretedObject(in: Seq[SortOrder]): BaseOrdering = {
+ new InterpretedOrdering(in)
+ }
+
+ def create(order: Seq[SortOrder], inputSchema: Seq[Attribute]): BaseOrdering = {
+ createObject(bindReferences(order, inputSchema))
+ }
+
+ /**
+ * Creates a row ordering for the given schema, in natural ascending order.
+ */
+ def createNaturalAscendingOrdering(dataTypes: Seq[DataType]): BaseOrdering = {
+ val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
+ case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending)
+ }
+ create(order, Seq.empty)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 4c0998412f729..bcd442ad3cc35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -21,8 +21,9 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -30,11 +31,18 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-object InterpretedPredicate {
- def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
- create(BindReferences.bindReference(expression, inputSchema))
+/**
+ * A base class for generated/interpreted predicate
+ */
+abstract class BasePredicate {
+ def eval(r: InternalRow): Boolean
- def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression)
+ /**
+ * Initializes internal states given the current partition index.
+ * This is used by nondeterministic expressions to set initial states.
+ * The default implementation does nothing.
+ */
+ def initialize(partitionIndex: Int): Unit = {}
}
case class InterpretedPredicate(expression: Expression) extends BasePredicate {
@@ -56,6 +64,35 @@ trait Predicate extends Expression {
override def dataType: DataType = BooleanType
}
+/**
+ * The factory object for `BasePredicate`.
+ */
+object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePredicate] {
+
+ override protected def createCodeGeneratedObject(in: Expression): BasePredicate = {
+ GeneratePredicate.generate(in)
+ }
+
+ override protected def createInterpretedObject(in: Expression): BasePredicate = {
+ InterpretedPredicate(in)
+ }
+
+ def createInterpreted(e: Expression): InterpretedPredicate = InterpretedPredicate(e)
+
+ /**
+ * Returns a BasePredicate for an Expression, which will be bound to `inputSchema`.
+ */
+ def create(e: Expression, inputSchema: Seq[Attribute]): BasePredicate = {
+ createObject(bindReference(e, inputSchema))
+ }
+
+ /**
+ * Returns a BasePredicate for a given bound Expression.
+ */
+ def create(e: Expression): BasePredicate = {
+ createObject(e)
+ }
+}
trait PredicateHelper {
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b78bdf082f333..9d0bd358aa24c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1002,12 +1002,11 @@ object EliminateSorts extends Rule[LogicalPlan] {
private def isOrderIrrelevantAggs(aggs: Seq[NamedExpression]): Boolean = {
def isOrderIrrelevantAggFunction(func: AggregateFunction): Boolean = func match {
- case _: Sum => true
- case _: Min => true
- case _: Max => true
- case _: Count => true
- case _: Average => true
- case _: CentralMomentAgg => true
+ case _: Min | _: Max | _: Count => true
+ // Arithmetic operations for floating-point values are order-sensitive
+ // (they are not associative).
+ case _: Sum | _: Average | _: CentralMomentAgg =>
+ !Seq(FloatType, DoubleType).exists(_.sameType(func.children.head.dataType))
case _ => false
}
@@ -1507,7 +1506,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
case Filter(condition, LocalRelation(output, data, isStreaming))
if !hasUnevaluableExpr(condition) =>
- val predicate = InterpretedPredicate.create(condition, output)
+ val predicate = Predicate.create(condition, output)
predicate.initialize(0)
LocalRelation(output, data.filter(row => predicate.eval(row)), isStreaming)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index c623b5c4d36a5..7bec46678f58d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -102,10 +102,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
}
- override def visitSingleInterval(ctx: SingleIntervalContext): CalendarInterval = {
- withOrigin(ctx)(visitMultiUnitsInterval(ctx.multiUnitsInterval))
- }
-
/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
@@ -1585,7 +1581,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
// Create the function call.
- val name = ctx.qualifiedName.getText
+ val name = ctx.functionName.getText
val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
val arguments = ctx.argument.asScala.map(expression) match {
case Seq(UnresolvedStar(None))
@@ -1595,7 +1591,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case expressions =>
expressions
}
- val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct)
+ val function = UnresolvedFunction(
+ getFunctionIdentifier(ctx.functionName), arguments, isDistinct)
// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
@@ -1635,6 +1632,17 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
+ /**
+ * Get a function identifier consist by database (optional) and name.
+ */
+ protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = {
+ if (ctx.qualifiedName != null) {
+ visitFunctionName(ctx.qualifiedName)
+ } else {
+ FunctionIdentifier(ctx.getText, None)
+ }
+ }
+
/**
* Create an [[LambdaFunction]].
*/
@@ -1854,7 +1862,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
val value = string(ctx.STRING)
val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT)
- val isNegative = ctx.negativeSign != null
def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = {
f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse {
@@ -1863,23 +1870,22 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
try {
valueType match {
- case "DATE" if !isNegative =>
+ case "DATE" =>
toLiteral(stringToDate(_, getZoneId(SQLConf.get.sessionLocalTimeZone)), DateType)
- case "TIMESTAMP" if !isNegative =>
+ case "TIMESTAMP" =>
val zoneId = getZoneId(SQLConf.get.sessionLocalTimeZone)
toLiteral(stringToTimestamp(_, zoneId), TimestampType)
case "INTERVAL" =>
val interval = try {
- IntervalUtils.fromString(value)
+ IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx)
ex.setStackTrace(e.getStackTrace)
throw ex
}
- val signedInterval = if (isNegative) IntervalUtils.negate(interval) else interval
- Literal(signedInterval, CalendarIntervalType)
- case "X" if !isNegative =>
+ Literal(interval, CalendarIntervalType)
+ case "X" =>
val padding = if (value.length % 2 != 0) "0" else ""
Literal(DatatypeConverter.parseHexBinary(padding + value))
case "INTEGER" =>
@@ -1891,10 +1897,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
ex.setStackTrace(e.getStackTrace)
throw ex
}
- Literal(if (isNegative) -i else i, IntegerType)
+ Literal(i, IntegerType)
case other =>
- val negativeSign: String = if (isNegative) "-" else ""
- throw new ParseException(s"Literals of type '$negativeSign$other' are currently not" +
+ throw new ParseException(s"Literals of type '$other' are currently not" +
" supported.", ctx)
}
} catch {
@@ -2024,14 +2029,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
- private def applyNegativeSign(sign: Token, interval: CalendarInterval): CalendarInterval = {
- if (sign != null) {
- IntervalUtils.negate(interval)
- } else {
- interval
- }
- }
-
/**
* Create a [[CalendarInterval]] literal expression. Two syntaxes are supported:
* - multiple unit value pairs, for instance: interval 2 months 2 days.
@@ -2045,10 +2042,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
"Can only have a single from-to unit in the interval literal syntax",
innerCtx.unitToUnitInterval)
}
- val interval = applyNegativeSign(
- ctx.negativeSign,
- visitMultiUnitsInterval(innerCtx.multiUnitsInterval))
- Literal(interval, CalendarIntervalType)
+ Literal(visitMultiUnitsInterval(innerCtx.multiUnitsInterval), CalendarIntervalType)
} else if (ctx.errorCapturingUnitToUnitInterval != null) {
val innerCtx = ctx.errorCapturingUnitToUnitInterval
if (innerCtx.error1 != null || innerCtx.error2 != null) {
@@ -2057,8 +2051,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
"Can only have a single from-to unit in the interval literal syntax",
errorCtx)
}
- val interval = applyNegativeSign(ctx.negativeSign, visitUnitToUnitInterval(innerCtx.body))
- Literal(interval, CalendarIntervalType)
+ Literal(visitUnitToUnitInterval(innerCtx.body), CalendarIntervalType)
} else {
throw new ParseException("at least one time unit should be given for interval literal", ctx)
}
@@ -2069,22 +2062,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
- val units = ctx.intervalUnit().asScala.map { unit =>
- val u = unit.getText.toLowerCase(Locale.ROOT)
- // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
- if (u.endsWith("s")) u.substring(0, u.length - 1) else u
- }.map(IntervalUtils.IntervalUnit.withName).toArray
-
- val values = ctx.intervalValue().asScala.map { value =>
- if (value.STRING() != null) {
- string(value.STRING())
- } else {
- value.getText
- }
- }.toArray
-
+ val units = ctx.intervalUnit().asScala
+ val values = ctx.intervalValue().asScala
try {
- IntervalUtils.fromUnitStrings(units, values)
+ assert(units.length == values.length)
+ val kvs = units.indices.map { i =>
+ val u = units(i).getText
+ val v = if (values(i).STRING() != null) {
+ string(values(i).STRING())
+ } else {
+ values(i).getText
+ }
+ UTF8String.fromString(" " + v + " " + u)
+ }
+ IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
} catch {
case i: IllegalArgumentException =>
val e = new ParseException(i.getMessage, ctx)
@@ -2159,12 +2150,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case ("date", Nil) => DateType
case ("timestamp", Nil) => TimestampType
case ("string", Nil) => StringType
- case ("char", length :: Nil) => CharType(length.getText.toInt)
+ case ("character" | "char", length :: Nil) => CharType(length.getText.toInt)
case ("varchar", length :: Nil) => VarcharType(length.getText.toInt)
case ("binary", Nil) => BinaryType
- case ("decimal", Nil) => DecimalType.USER_DEFAULT
- case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
- case ("decimal", precision :: scale :: Nil) =>
+ case ("decimal" | "dec", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal" | "dec", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal" | "dec", precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
case ("interval", Nil) => CalendarIntervalType
case (dt, params) =>
@@ -2528,6 +2519,39 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
ctx.CASCADE != null)
}
+ /**
+ * Create an [[AlterNamespaceSetPropertiesStatement]] logical plan.
+ *
+ * For example:
+ * {{{
+ * ALTER (DATABASE|SCHEMA|NAMESPACE) database
+ * SET (DBPROPERTIES|PROPERTIES) (property_name=property_value, ...);
+ * }}}
+ */
+ override def visitSetNamespaceProperties(ctx: SetNamespacePropertiesContext): LogicalPlan = {
+ withOrigin(ctx) {
+ AlterNamespaceSetPropertiesStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ visitPropertyKeyValues(ctx.tablePropertyList))
+ }
+ }
+
+ /**
+ * Create an [[AlterNamespaceSetLocationStatement]] logical plan.
+ *
+ * For example:
+ * {{{
+ * ALTER (DATABASE|SCHEMA|NAMESPACE) namespace SET LOCATION path;
+ * }}}
+ */
+ override def visitSetNamespaceLocation(ctx: SetNamespaceLocationContext): LogicalPlan = {
+ withOrigin(ctx) {
+ AlterNamespaceSetLocationStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier),
+ visitLocationSpec(ctx.locationSpec))
+ }
+ }
+
/**
* Create a [[ShowNamespacesStatement]] command.
*/
@@ -2541,6 +2565,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
Option(ctx.pattern).map(string))
}
+ /**
+ * Create a [[DescribeNamespaceStatement]].
+ *
+ * For example:
+ * {{{
+ * DESCRIBE (DATABASE|SCHEMA|NAMESPACE) [EXTENDED] database;
+ * }}}
+ */
+ override def visitDescribeNamespace(ctx: DescribeNamespaceContext): LogicalPlan =
+ withOrigin(ctx) {
+ DescribeNamespaceStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier()),
+ ctx.EXTENDED != null)
+ }
+
/**
* Create a table, returning a [[CreateTableStatement]] logical plan.
*
@@ -2719,6 +2758,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
Option(ctx.pattern).map(string))
}
+ /**
+ * Create a [[ShowTableStatement]] command.
+ */
+ override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) {
+ ShowTableStatement(
+ Option(ctx.namespace).map(visitMultipartIdentifier),
+ string(ctx.pattern),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
+ }
+
/**
* Parse new column info from ADD COLUMN into a QualifiedColType.
*/
@@ -3193,6 +3242,22 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
query = plan(ctx.query))
}
+ /**
+ * Create a [[RenameTableStatement]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE multi_part_name1 RENAME TO multi_part_name2;
+ * ALTER VIEW multi_part_name1 RENAME TO multi_part_name2;
+ * }}}
+ */
+ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) {
+ RenameTableStatement(
+ visitMultipartIdentifier(ctx.from),
+ visitMultipartIdentifier(ctx.to),
+ ctx.VIEW != null)
+ }
+
/**
* A command for users to list the properties for a table. If propertyKey is specified, the value
* for the propertyKey is returned. If propertyKey is not specified, all the keys and their
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index b66cae7979416..30c36598d81d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -28,22 +28,14 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.Dialect
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.unsafe.types.CalendarInterval
/**
* Base SQL parsing infrastructure.
*/
abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Logging {
- /**
- * Creates [[CalendarInterval]] for a given SQL String. Throws [[ParseException]] if the SQL
- * string is not a valid interval format.
- */
- def parseInterval(sqlText: String): CalendarInterval = parse(sqlText) { parser =>
- astBuilder.visitSingleInterval(parser.singleInterval())
- }
-
/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
@@ -97,11 +89,18 @@ abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Log
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
+ // When we use PostgreSQL dialect or use Spark dialect with setting
+ // `spark.sql.dialect.spark.ansi.enabled=true`, the parser will use ANSI SQL standard keywords.
+ val SQLStandardKeywordBehavior = conf.dialect match {
+ case Dialect.POSTGRESQL => true
+ case Dialect.SPARK => conf.dialectSparkAnsiEnabled
+ }
+
val lexer = new SqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command)))
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
lexer.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
- lexer.ansi = conf.ansiEnabled
+ lexer.SQL_standard_keyword_behavior = SQLStandardKeywordBehavior
val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
@@ -109,7 +108,7 @@ abstract class AbstractSqlParser(conf: SQLConf) extends ParserInterface with Log
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.legacy_setops_precedence_enbled = conf.setOpsPrecedenceEnforced
- parser.ansi = conf.ansiEnabled
+ parser.SQL_standard_keyword_behavior = SQLStandardKeywordBehavior
try {
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 51d2a73ea97b7..c2a12eda19137 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -56,7 +56,7 @@ object PhysicalOperation extends PredicateHelper {
* }}}
*/
private def collectProjectsAndFilters(plan: LogicalPlan):
- (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) =
+ (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression]) =
plan match {
case Project(fields, child) if fields.forall(_.deterministic) =>
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
@@ -72,14 +72,15 @@ object PhysicalOperation extends PredicateHelper {
collectProjectsAndFilters(h.child)
case other =>
- (None, Nil, other, Map.empty)
+ (None, Nil, other, AttributeMap(Seq()))
}
- private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
- case a @ Alias(child, _) => a.toAttribute -> child
- }.toMap
+ private def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] =
+ AttributeMap(fields.collect {
+ case a: Alias => (a.toAttribute, a.child)
+ })
- private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
+ private def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = {
expr.transform {
case a @ Alias(ref: AttributeReference, name) =>
aliases.get(ref)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
index ec373d95fad88..7d7d6bdbfdd2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala
@@ -246,6 +246,14 @@ case class AlterViewAsStatement(
originalText: String,
query: LogicalPlan) extends ParsedStatement
+/**
+ * ALTER TABLE ... RENAME TO command, as parsed from SQL.
+ */
+case class RenameTableStatement(
+ oldName: Seq[String],
+ newName: Seq[String],
+ isView: Boolean) extends ParsedStatement
+
/**
* A DROP TABLE statement, as parsed from SQL.
*/
@@ -269,6 +277,13 @@ case class DescribeTableStatement(
partitionSpec: TablePartitionSpec,
isExtended: Boolean) extends ParsedStatement
+/**
+ * A DESCRIBE NAMESPACE statement, as parsed from SQL.
+ */
+case class DescribeNamespaceStatement(
+ namespace: Seq[String],
+ extended: Boolean) extends ParsedStatement
+
/**
* A DESCRIBE TABLE tbl_name col_name statement, as parsed from SQL.
*/
@@ -313,6 +328,15 @@ case class InsertIntoStatement(
case class ShowTablesStatement(namespace: Option[Seq[String]], pattern: Option[String])
extends ParsedStatement
+/**
+ * A SHOW TABLE EXTENDED statement, as parsed from SQL.
+ */
+case class ShowTableStatement(
+ namespace: Option[Seq[String]],
+ pattern: String,
+ partitionSpec: Option[TablePartitionSpec])
+ extends ParsedStatement
+
/**
* A CREATE NAMESPACE statement, as parsed from SQL.
*/
@@ -334,6 +358,20 @@ case class DropNamespaceStatement(
ifExists: Boolean,
cascade: Boolean) extends ParsedStatement
+/**
+ * ALTER (DATABASE|SCHEMA|NAMESPACE) ... SET (DBPROPERTIES|PROPERTIES) command, as parsed from SQL.
+ */
+case class AlterNamespaceSetPropertiesStatement(
+ namespace: Seq[String],
+ properties: Map[String, String]) extends ParsedStatement
+
+/**
+ * ALTER (DATABASE|SCHEMA|NAMESPACE) ... SET LOCATION command, as parsed from SQL.
+ */
+case class AlterNamespaceSetLocationStatement(
+ namespace: Seq[String],
+ location: String) extends ParsedStatement
+
/**
* A SHOW NAMESPACES statement, as parsed from SQL.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 7d8e9a0c18f65..d87758a7df7b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.DescribeTableSchema
import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, SupportsNamespaces, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange}
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.types.{DataType, StringType, StructType}
+import org.apache.spark.sql.types.{DataType, MetadataBuilder, StringType, StructType}
/**
* Base trait for DataSourceV2 write commands
@@ -255,6 +255,30 @@ case class DropNamespace(
ifExists: Boolean,
cascade: Boolean) extends Command
+/**
+ * The logical plan of the DESCRIBE NAMESPACE command that works for v2 catalogs.
+ */
+case class DescribeNamespace(
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ extended: Boolean) extends Command {
+
+ override def output: Seq[Attribute] = Seq(
+ AttributeReference("name", StringType, nullable = false,
+ new MetadataBuilder().putString("comment", "name of the column").build())(),
+ AttributeReference("value", StringType, nullable = true,
+ new MetadataBuilder().putString("comment", "value of the column").build())())
+}
+
+/**
+ * The logical plan of the ALTER (DATABASE|SCHEMA|NAMESPACE) ... SET (DBPROPERTIES|PROPERTIES)
+ * command that works for v2 catalogs.
+ */
+case class AlterNamespaceSetProperties(
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ properties: Map[String, String]) extends Command
+
/**
* The logical plan of the SHOW NAMESPACES command that works for v2 catalogs.
*/
@@ -376,6 +400,14 @@ case class AlterTable(
}
}
+/**
+ * The logical plan of the ALTER TABLE RENAME command that works for v2 tables.
+ */
+case class RenameTable(
+ catalog: TableCatalog,
+ oldIdent: Identifier,
+ newIdent: Identifier) extends Command
+
/**
* The logical plan of the SHOW TABLE command that works for v2 catalogs.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 882c1d85267e4..9418d8eec3376 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst.util
+import java.math.BigDecimal
import java.util.concurrent.TimeUnit
import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -101,34 +101,6 @@ object IntervalUtils {
Decimal(result, 18, 6)
}
- /**
- * Converts a string to [[CalendarInterval]] case-insensitively.
- *
- * @throws IllegalArgumentException if the input string is not in valid interval format.
- */
- def fromString(str: String): CalendarInterval = {
- if (str == null) throw new IllegalArgumentException("Interval string cannot be null")
- try {
- CatalystSqlParser.parseInterval(str)
- } catch {
- case e: ParseException =>
- val ex = new IllegalArgumentException(s"Invalid interval string: $str\n" + e.message)
- ex.setStackTrace(e.getStackTrace)
- throw ex
- }
- }
-
- /**
- * A safe version of `fromString`. It returns null for invalid input string.
- */
- def safeFromString(str: String): CalendarInterval = {
- try {
- fromString(str)
- } catch {
- case _: IllegalArgumentException => null
- }
- }
-
private def toLongWithRange(
fieldName: IntervalUnit,
s: String,
@@ -250,46 +222,6 @@ object IntervalUtils {
}
}
- def fromUnitStrings(units: Array[IntervalUnit], values: Array[String]): CalendarInterval = {
- assert(units.length == values.length)
- var months: Int = 0
- var days: Int = 0
- var microseconds: Long = 0
- var i = 0
- while (i < units.length) {
- try {
- units(i) match {
- case YEAR =>
- months = Math.addExact(months, Math.multiplyExact(values(i).toInt, 12))
- case MONTH =>
- months = Math.addExact(months, values(i).toInt)
- case WEEK =>
- days = Math.addExact(days, Math.multiplyExact(values(i).toInt, 7))
- case DAY =>
- days = Math.addExact(days, values(i).toInt)
- case HOUR =>
- val hoursUs = Math.multiplyExact(values(i).toLong, MICROS_PER_HOUR)
- microseconds = Math.addExact(microseconds, hoursUs)
- case MINUTE =>
- val minutesUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MINUTE)
- microseconds = Math.addExact(microseconds, minutesUs)
- case SECOND =>
- microseconds = Math.addExact(microseconds, parseSecondNano(values(i)))
- case MILLISECOND =>
- val millisUs = Math.multiplyExact(values(i).toLong, MICROS_PER_MILLIS)
- microseconds = Math.addExact(microseconds, millisUs)
- case MICROSECOND =>
- microseconds = Math.addExact(microseconds, values(i).toLong)
- }
- } catch {
- case e: Exception =>
- throw new IllegalArgumentException(s"Error parsing interval string: ${e.getMessage}", e)
- }
- i += 1
- }
- new CalendarInterval(months, days, microseconds)
- }
-
// Parses a string with nanoseconds, truncates the result and returns microseconds
private def parseNanos(nanosStr: String, isNegative: Boolean): Long = {
if (nanosStr != null) {
@@ -305,30 +237,6 @@ object IntervalUtils {
}
}
- /**
- * Parse second_nano string in ss.nnnnnnnnn format to microseconds
- */
- private def parseSecondNano(secondNano: String): Long = {
- def parseSeconds(secondsStr: String): Long = {
- toLongWithRange(
- SECOND,
- secondsStr,
- Long.MinValue / MICROS_PER_SECOND,
- Long.MaxValue / MICROS_PER_SECOND) * MICROS_PER_SECOND
- }
-
- secondNano.split("\\.") match {
- case Array(secondsStr) => parseSeconds(secondsStr)
- case Array("", nanosStr) => parseNanos(nanosStr, false)
- case Array(secondsStr, nanosStr) =>
- val seconds = parseSeconds(secondsStr)
- Math.addExact(seconds, parseNanos(nanosStr, seconds < 0))
- case _ =>
- throw new IllegalArgumentException(
- "Interval string does not match second-nano format of ss.nnnnnnnnn")
- }
- }
-
/**
* Gets interval duration
*
@@ -424,6 +332,85 @@ object IntervalUtils {
fromDoubles(interval.months / num, interval.days / num, interval.microseconds / num)
}
+ // `toString` implementation in CalendarInterval is the multi-units format currently.
+ def toMultiUnitsString(interval: CalendarInterval): String = interval.toString
+
+ def toSqlStandardString(interval: CalendarInterval): String = {
+ val yearMonthPart = if (interval.months < 0) {
+ val ma = math.abs(interval.months)
+ "-" + ma / 12 + "-" + ma % 12
+ } else if (interval.months > 0) {
+ "+" + interval.months / 12 + "-" + interval.months % 12
+ } else {
+ ""
+ }
+
+ val dayPart = if (interval.days < 0) {
+ interval.days.toString
+ } else if (interval.days > 0) {
+ "+" + interval.days
+ } else {
+ ""
+ }
+
+ val timePart = if (interval.microseconds != 0) {
+ val sign = if (interval.microseconds > 0) "+" else "-"
+ val sb = new StringBuilder(sign)
+ var rest = math.abs(interval.microseconds)
+ sb.append(rest / MICROS_PER_HOUR)
+ sb.append(':')
+ rest %= MICROS_PER_HOUR
+ val minutes = rest / MICROS_PER_MINUTE;
+ if (minutes < 10) {
+ sb.append(0)
+ }
+ sb.append(minutes)
+ sb.append(':')
+ rest %= MICROS_PER_MINUTE
+ val bd = BigDecimal.valueOf(rest, 6)
+ if (bd.compareTo(new BigDecimal(10)) < 0) {
+ sb.append(0)
+ }
+ val s = bd.stripTrailingZeros().toPlainString
+ sb.append(s)
+ sb.toString()
+ } else {
+ ""
+ }
+
+ val intervalList = Seq(yearMonthPart, dayPart, timePart).filter(_.nonEmpty)
+ if (intervalList.nonEmpty) intervalList.mkString(" ") else "0"
+ }
+
+ def toIso8601String(interval: CalendarInterval): String = {
+ val sb = new StringBuilder("P")
+
+ val year = interval.months / 12
+ if (year != 0) sb.append(year + "Y")
+ val month = interval.months % 12
+ if (month != 0) sb.append(month + "M")
+
+ if (interval.days != 0) sb.append(interval.days + "D")
+
+ if (interval.microseconds != 0) {
+ sb.append('T')
+ var rest = interval.microseconds
+ val hour = rest / MICROS_PER_HOUR
+ if (hour != 0) sb.append(hour + "H")
+ rest %= MICROS_PER_HOUR
+ val minute = rest / MICROS_PER_MINUTE
+ if (minute != 0) sb.append(minute + "M")
+ rest %= MICROS_PER_MINUTE
+ if (rest != 0) {
+ val bd = BigDecimal.valueOf(rest, 6)
+ sb.append(bd.stripTrailingZeros().toPlainString + "S")
+ }
+ } else if (interval.days == 0 && interval.months == 0) {
+ sb.append("T0S")
+ }
+ sb.toString()
+ }
+
private object ParseState extends Enumeration {
type ParseState = Value
@@ -452,18 +439,37 @@ object IntervalUtils {
private final val millisStr = unitToUtf8(MILLISECOND)
private final val microsStr = unitToUtf8(MICROSECOND)
+ /**
+ * A safe version of `stringToInterval`. It returns null for invalid input string.
+ */
+ def safeStringToInterval(input: UTF8String): CalendarInterval = {
+ try {
+ stringToInterval(input)
+ } catch {
+ case _: IllegalArgumentException => null
+ }
+ }
+
+ /**
+ * Converts a string to [[CalendarInterval]] case-insensitively.
+ *
+ * @throws IllegalArgumentException if the input string is not in valid interval format.
+ */
def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
+ def throwIAE(msg: String, e: Exception = null) = {
+ throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e)
+ }
if (input == null) {
- return null
+ throwIAE("interval string cannot be null")
}
// scalastyle:off caselocale .toLowerCase
val s = input.trim.toLowerCase
// scalastyle:on
val bytes = s.getBytes
if (bytes.isEmpty) {
- return null
+ throwIAE("interval string cannot be empty")
}
var state = PREFIX
var i = 0
@@ -473,7 +479,9 @@ object IntervalUtils {
var days: Int = 0
var microseconds: Long = 0
var fractionScale: Int = 0
+ val initialFractionScale = (NANOS_PER_SECOND / 10).toInt
var fraction: Int = 0
+ var pointPrefixed: Boolean = false
def trimToNextState(b: Byte, next: ParseState): Unit = {
b match {
@@ -482,13 +490,19 @@ object IntervalUtils {
}
}
+ def currentWord: UTF8String = {
+ val strings = s.split(UTF8String.blankString(1), -1)
+ val lenRight = s.substring(i, s.numBytes()).split(UTF8String.blankString(1), -1).length
+ strings(strings.length - lenRight)
+ }
+
while (i < bytes.length) {
val b = bytes(i)
state match {
case PREFIX =>
if (s.startsWith(intervalStr)) {
if (s.numBytes() == intervalStr.numBytes()) {
- return null
+ throwIAE("interval string cannot be empty")
} else {
i += intervalStr.numBytes()
}
@@ -496,6 +510,18 @@ object IntervalUtils {
state = TRIM_BEFORE_SIGN
case TRIM_BEFORE_SIGN => trimToNextState(b, SIGN)
case SIGN =>
+ currentValue = 0
+ fraction = 0
+ // We preset next state from SIGN to TRIM_BEFORE_VALUE. If we meet '.' in the SIGN state,
+ // it means that the interval value we deal with here is a numeric with only fractional
+ // part, such as '.11 second', which can be parsed to 0.11 seconds. In this case, we need
+ // to reset next state to `VALUE_FRACTIONAL_PART` to go parse the fraction part of the
+ // interval value.
+ state = TRIM_BEFORE_VALUE
+ // We preset the scale to an invalid value to track fraction presence in the UNIT_BEGIN
+ // state. If we meet '.', the scale become valid for the VALUE_FRACTIONAL_PART state.
+ fractionScale = -1
+ pointPrefixed = false
b match {
case '-' =>
isNegative = true
@@ -505,14 +531,14 @@ object IntervalUtils {
i += 1
case _ if '0' <= b && b <= '9' =>
isNegative = false
- case _ => return null
+ case '.' =>
+ isNegative = false
+ fractionScale = initialFractionScale
+ pointPrefixed = true
+ i += 1
+ state = VALUE_FRACTIONAL_PART
+ case _ => throwIAE( s"unrecognized number '$currentWord'")
}
- currentValue = 0
- fraction = 0
- // Sets the scale to an invalid value to track fraction presence
- // in the BEGIN_UNIT_NAME state
- fractionScale = -1
- state = TRIM_BEFORE_VALUE
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
case VALUE =>
b match {
@@ -520,13 +546,13 @@ object IntervalUtils {
try {
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
} catch {
- case _: ArithmeticException => return null
+ case e: ArithmeticException => throwIAE(e.getMessage, e)
}
case ' ' => state = TRIM_BEFORE_UNIT
case '.' =>
- fractionScale = (NANOS_PER_SECOND / 10).toInt
+ fractionScale = initialFractionScale
state = VALUE_FRACTIONAL_PART
- case _ => return null
+ case _ => throwIAE(s"invalid value '$currentWord'")
}
i += 1
case VALUE_FRACTIONAL_PART =>
@@ -534,17 +560,20 @@ object IntervalUtils {
case _ if '0' <= b && b <= '9' && fractionScale > 0 =>
fraction += (b - '0') * fractionScale
fractionScale /= 10
- case ' ' =>
+ case ' ' if !pointPrefixed || fractionScale < initialFractionScale =>
fraction /= NANOS_PER_MICROS.toInt
state = TRIM_BEFORE_UNIT
- case _ => return null
+ case _ if '0' <= b && b <= '9' =>
+ throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" +
+ s" of range")
+ case _ => throwIAE(s"invalid value '$currentWord'")
}
i += 1
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
case UNIT_BEGIN =>
// Checks that only seconds can have the fractional part
if (b != 's' && fractionScale >= 0) {
- return null
+ throwIAE(s"'$currentWord' cannot have fractional part")
}
if (isNegative) {
currentValue = -currentValue
@@ -588,18 +617,18 @@ object IntervalUtils {
} else if (s.matchAt(microsStr, i)) {
microseconds = Math.addExact(microseconds, currentValue)
i += microsStr.numBytes()
- } else return null
- case _ => return null
+ } else throwIAE(s"invalid unit '$currentWord'")
+ case _ => throwIAE(s"invalid unit '$currentWord'")
}
} catch {
- case _: ArithmeticException => return null
+ case e: ArithmeticException => throwIAE(e.getMessage, e)
}
state = UNIT_SUFFIX
case UNIT_SUFFIX =>
b match {
case 's' => state = UNIT_END
case ' ' => state = TRIM_BEFORE_SIGN
- case _ => return null
+ case _ => throwIAE(s"invalid unit '$currentWord'")
}
i += 1
case UNIT_END =>
@@ -607,7 +636,7 @@ object IntervalUtils {
case ' ' =>
i += 1
state = TRIM_BEFORE_SIGN
- case _ => return null
+ case _ => throwIAE(s"invalid unit '$currentWord'")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
index d62148b2bbe45..135c180ef4000 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala
@@ -71,7 +71,7 @@ class CatalogManager(
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
- private def v2SessionCatalog: CatalogPlugin = {
+ private[sql] def v2SessionCatalog: CatalogPlugin = {
conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).map { customV2SessionCatalog =>
try {
catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
index 26ba93e57fc64..613c0d1797cc6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
/**
* A trait to encapsulate catalog lookup function and helpful extractors.
@@ -120,10 +121,22 @@ private[sql] trait LookupCatalog extends Logging {
* Extract catalog and the rest name parts from a multi-part identifier.
*/
object CatalogAndIdentifierParts {
- def unapply(nameParts: Seq[String]): Some[(CatalogPlugin, Seq[String])] = {
+ private val globalTempDB = SQLConf.get.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
+
+ def unapply(nameParts: Seq[String]): Option[(CatalogPlugin, Seq[String])] = {
assert(nameParts.nonEmpty)
try {
- Some((catalogManager.catalog(nameParts.head), nameParts.tail))
+ // Conceptually global temp views are in a special reserved catalog. However, the v2 catalog
+ // API does not support view yet, and we have to use v1 commands to deal with global temp
+ // views. To simplify the implementation, we put global temp views in a special namespace
+ // in the session catalog. The special namespace has higher priority during name resolution.
+ // For example, if the name of a custom catalog is the same with `GLOBAL_TEMP_DATABASE`,
+ // this custom catalog can't be accessed.
+ if (nameParts.head.equalsIgnoreCase(globalTempDB)) {
+ Some((catalogManager.v2SessionCatalog, nameParts))
+ } else {
+ Some((catalogManager.catalog(nameParts.head), nameParts.tail))
+ }
} catch {
case _: CatalogNotFoundException =>
Some((currentCatalog, nameParts))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PhysicalWriteInfoImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PhysicalWriteInfoImpl.scala
new file mode 100644
index 0000000000000..a663822f3eb45
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/PhysicalWriteInfoImpl.scala
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.write
+
+private[sql] case class PhysicalWriteInfoImpl(numPartitions: Int) extends PhysicalWriteInfo
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 759586a2936fd..33f91d045f7d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -393,8 +393,8 @@ object SQLConf {
"must be a positive integer.")
.createOptional
- val OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED =
- buildConf("spark.sql.adaptive.shuffle.optimizedLocalShuffleReader.enabled")
+ val LOCAL_SHUFFLE_READER_ENABLED =
+ buildConf("spark.sql.adaptive.shuffle.localShuffleReader.enabled")
.doc("When true and adaptive execution is enabled, this enables the optimization of" +
" converting the shuffle reader to local shuffle reader for the shuffle exchange" +
" of the broadcast hash join in probe side.")
@@ -720,14 +720,6 @@ object SQLConf {
.stringConf
.createWithDefault("_corrupt_record")
- val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema")
- .internal()
- .doc("When true, force the output schema of the from_json() function to be nullable " +
- "(including all the fields). Otherwise, the schema might not be compatible with" +
- "actual data, which leads to corruptions. This config will be removed in Spark 3.0.")
- .booleanConf
- .createWithDefault(true)
-
val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout")
.doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
.timeConf(TimeUnit.SECONDS)
@@ -1673,14 +1665,20 @@ object SQLConf {
.checkValues(Dialect.values.map(_.toString))
.createWithDefault(Dialect.SPARK.toString)
- val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
- buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
+ val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
.internal()
- .doc("When this option is set to true, creating managed tables with nonempty location " +
- "is allowed. Otherwise, an analysis exception is thrown. ")
+ .doc("This configuration is deprecated and will be removed in the future releases." +
+ "It is replaced by spark.sql.dialect.spark.ansi.enabled.")
.booleanConf
.createWithDefault(false)
+ val DIALECT_SPARK_ANSI_ENABLED = buildConf("spark.sql.dialect.spark.ansi.enabled")
+ .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
+ "throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
+ "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
+ "the SQL parser.")
+ .fallbackConf(ANSI_ENABLED)
+
val VALIDATE_PARTITION_COLUMNS =
buildConf("spark.sql.sources.validatePartitionColumns")
.internal()
@@ -1784,13 +1782,22 @@ object SQLConf {
.checkValues(StoreAssignmentPolicy.values.map(_.toString))
.createWithDefault(StoreAssignmentPolicy.ANSI.toString)
- val ANSI_ENABLED = buildConf("spark.sql.ansi.enabled")
- .doc("When true, Spark tries to conform to the ANSI SQL specification: 1. Spark will " +
- "throw a runtime exception if an overflow occurs in any operation on integral/decimal " +
- "field. 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
- "the SQL parser.")
- .booleanConf
- .createWithDefault(false)
+ object IntervalStyle extends Enumeration {
+ type IntervalStyle = Value
+ val SQL_STANDARD, ISO_8601, MULTI_UNITS = Value
+ }
+
+ val INTERVAL_STYLE = buildConf("spark.sql.intervalOutputStyle")
+ .doc("When converting interval values to strings (i.e. for display), this config decides the" +
+ " interval string format. The value SQL_STANDARD will produce output matching SQL standard" +
+ " interval literals (i.e. '+3-2 +10 -00:00:01'). The value ISO_8601 will produce output" +
+ " matching the ISO 8601 standard (i.e. 'P3Y2M10DT-1S'). The value MULTI_UNITS (which is the" +
+ " default) will produce output in form of value unit pairs, (i.e. '3 year 2 months 10 days" +
+ " -1 seconds'")
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValues(IntervalStyle.values.map(_.toString))
+ .createWithDefault(IntervalStyle.MULTI_UNITS.toString)
val SORT_BEFORE_REPARTITION =
buildConf("spark.sql.execution.sortBeforeRepartition")
@@ -1907,16 +1914,6 @@ object SQLConf {
.checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION)
.createWithDefault(Deflater.DEFAULT_COMPRESSION)
- val COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP =
- buildConf("spark.sql.legacy.compareDateTimestampInTimestamp")
- .internal()
- .doc("When true (default), compare Date with Timestamp after converting both sides to " +
- "Timestamp. This behavior is compatible with Hive 2.2 or later. See HIVE-15236. " +
- "When false, restore the behavior prior to Spark 2.4. Compare Date with Timestamp after " +
- "converting both sides to string. This config will be removed in Spark 3.0.")
- .booleanConf
- .createWithDefault(true)
-
val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull")
.doc("If it is set to true, size of null returns -1. This behavior was inherited from Hive. " +
"The size function returns null for null input if the flag is disabled.")
@@ -2230,8 +2227,6 @@ class SQLConf extends Serializable with Logging {
def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value =
HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE))
- def compareDateTimestampInTimestamp : Boolean = getConf(COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP)
-
def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT)
def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY)
@@ -2510,9 +2505,6 @@ class SQLConf extends Serializable with Logging {
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
- def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
- getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
-
def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS)
def partitionOverwriteMode: PartitionOverwriteMode.Value =
@@ -2521,9 +2513,15 @@ class SQLConf extends Serializable with Logging {
def storeAssignmentPolicy: StoreAssignmentPolicy.Value =
StoreAssignmentPolicy.withName(getConf(STORE_ASSIGNMENT_POLICY))
- def ansiEnabled: Boolean = getConf(ANSI_ENABLED)
+ def intervalOutputStyle: IntervalStyle.Value = IntervalStyle.withName(getConf(INTERVAL_STYLE))
+
+ def dialect: Dialect.Value = Dialect.withName(getConf(DIALECT))
+
+ def usePostgreSQLDialect: Boolean = dialect == Dialect.POSTGRESQL
+
+ def dialectSparkAnsiEnabled: Boolean = getConf(DIALECT_SPARK_ANSI_ENABLED)
- def usePostgreSQLDialect: Boolean = getConf(DIALECT) == Dialect.POSTGRESQL.toString()
+ def ansiEnabled: Boolean = usePostgreSQLDialect || dialectSparkAnsiEnabled
def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index d665d16ae4195..d2f27da239016 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -161,4 +161,11 @@ object StaticSQLConf {
"defaults, dropping any overrides in its parent SparkSession.")
.booleanConf
.createWithDefault(false)
+
+ val DEFAULT_URL_STREAM_HANDLER_FACTORY_ENABLED =
+ buildStaticConf("spark.sql.defaultUrlStreamHandlerFactory.enabled")
+ .doc("When true, set FsUrlStreamHandlerFactory to support ADD JAR against HDFS locations")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
}
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java
index d8845e0c838ff..ca2b18b8eed49 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java
@@ -19,6 +19,7 @@
import java.util.Locale;
+import org.junit.Assert;
import org.junit.Test;
public class JavaOutputModeSuite {
@@ -26,8 +27,8 @@ public class JavaOutputModeSuite {
@Test
public void testOutputModes() {
OutputMode o1 = OutputMode.Append();
- assert(o1.toString().toLowerCase(Locale.ROOT).contains("append"));
+ Assert.assertTrue(o1.toString().toLowerCase(Locale.ROOT).contains("append"));
OutputMode o2 = OutputMode.Complete();
- assert (o2.toString().toLowerCase(Locale.ROOT).contains("complete"));
+ Assert.assertTrue(o2.toString().toLowerCase(Locale.ROOT).contains("complete"));
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 3dabbca9deeee..e0fa1f2ecb88e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import org.scalatest.Assertions._
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index cddcddd51e38d..49ab34d2ea3a0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -245,4 +245,11 @@ class ResolveHintsSuite extends AnalysisTest {
e => e.getLevel == Level.WARN &&
e.getRenderedMessage.contains("Unrecognized hint: unknown_hint")))
}
+
+ test("SPARK-30003: Do not throw stack overflow exception in non-root unknown hint resolution") {
+ checkAnalysis(
+ Project(testRelation.output, UnresolvedHint("unknown_hint", Seq("TaBlE"), table("TaBlE"))),
+ Project(testRelation.output, testRelation),
+ caseSensitive = false)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index c7371a7911df5..567cf5ec8ebe6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -1526,26 +1526,15 @@ class TypeCoercionSuite extends AnalysisTest {
GreaterThan(Literal("1.5"), Literal(BigDecimal("0.5"))),
GreaterThan(Cast(Literal("1.5"), DoubleType), Cast(Literal(BigDecimal("0.5")),
DoubleType)))
- Seq(true, false).foreach { convertToTS =>
- withSQLConf(
- SQLConf.COMPARE_DATE_TIMESTAMP_IN_TIMESTAMP.key -> convertToTS.toString) {
- val date0301 = Literal(java.sql.Date.valueOf("2017-03-01"))
- val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00"))
- val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01"))
- if (convertToTS) {
- // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549
- ruleTest(rule, EqualTo(date0301, timestamp0301000000),
- EqualTo(Cast(date0301, TimestampType), timestamp0301000000))
- ruleTest(rule, LessThan(date0301, timestamp0301000001),
- LessThan(Cast(date0301, TimestampType), timestamp0301000001))
- } else {
- ruleTest(rule, LessThan(date0301, timestamp0301000000),
- LessThan(Cast(date0301, StringType), Cast(timestamp0301000000, StringType)))
- ruleTest(rule, LessThan(date0301, timestamp0301000001),
- LessThan(Cast(date0301, StringType), Cast(timestamp0301000001, StringType)))
- }
- }
- }
+ // Checks that dates/timestamps are not promoted to strings
+ val date0301 = Literal(java.sql.Date.valueOf("2017-03-01"))
+ val timestamp0301000000 = Literal(Timestamp.valueOf("2017-03-01 00:00:00"))
+ val timestamp0301000001 = Literal(Timestamp.valueOf("2017-03-01 00:00:01"))
+ // `Date` should be treated as timestamp at 00:00:00 See SPARK-23549
+ ruleTest(rule, EqualTo(date0301, timestamp0301000000),
+ EqualTo(Cast(date0301, TimestampType), timestamp0301000000))
+ ruleTest(rule, LessThan(date0301, timestamp0301000001),
+ LessThan(Cast(date0301, TimestampType), timestamp0301000001))
}
test("cast WindowFrame boundaries to the type they operate upon") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index c1f1be3b30e4b..62e688e4d4bd6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -436,7 +436,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
testAndVerifyNotLeakingReflectionObjects(
s"overflowing $testName, ansiEnabled=$ansiEnabled") {
withSQLConf(
- SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> ansiEnabled.toString
) {
// Need to construct Encoder here rather than implicitly resolving it
// so that SQLConf changes are respected.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 1a1cab823d4f3..fe068f7a5f6c2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -169,7 +169,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
private def testDecimalOverflow(schema: StructType, row: Row): Unit = {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
val encoder = RowEncoder(schema).resolveAndBind()
intercept[Exception] {
encoder.toRow(row)
@@ -182,7 +182,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index ad8b1a1673679..6e3fc438e41ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -61,7 +61,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L)
Seq("true", "false").foreach { checkOverflow =>
- withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe)
}
@@ -80,7 +80,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal(Long.MinValue)), "overflow")
checkExceptionInExpression[ArithmeticException](
@@ -122,7 +122,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong)
Seq("true", "false").foreach { checkOverflow =>
- withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericAndInterval.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe)
}
@@ -144,7 +144,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong)
Seq("true", "false").foreach { checkOverflow =>
- withSQLConf(SQLConf.ANSI_ENABLED.key -> checkOverflow) {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> checkOverflow) {
DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe)
}
@@ -445,12 +445,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minLongLiteral, minLongLiteral)
val e5 = Subtract(minLongLiteral, maxLongLiteral)
val e6 = Multiply(minLongLiteral, minLongLiteral)
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Long.MinValue)
checkEvaluation(e2, Long.MinValue)
checkEvaluation(e3, -2L)
@@ -469,12 +469,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minIntLiteral, minIntLiteral)
val e5 = Subtract(minIntLiteral, maxIntLiteral)
val e6 = Multiply(minIntLiteral, minIntLiteral)
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Int.MinValue)
checkEvaluation(e2, Int.MinValue)
checkEvaluation(e3, -2)
@@ -493,12 +493,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minShortLiteral, minShortLiteral)
val e5 = Subtract(minShortLiteral, maxShortLiteral)
val e6 = Multiply(minShortLiteral, minShortLiteral)
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Short.MinValue)
checkEvaluation(e2, Short.MinValue)
checkEvaluation(e3, (-2).toShort)
@@ -517,12 +517,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val e4 = Add(minByteLiteral, minByteLiteral)
val e5 = Subtract(minByteLiteral, maxByteLiteral)
val e6 = Multiply(minByteLiteral, minByteLiteral)
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
Seq(e1, e2, e3, e4, e5, e6).foreach { e =>
checkExceptionInExpression[ArithmeticException](e, "overflow")
}
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(e1, Byte.MinValue)
checkEvaluation(e2, Byte.MinValue)
checkEvaluation(e3, (-2).toByte)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index fb99fc805c45b..12ca3e798b13d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -891,7 +891,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("Throw exception on casting out-of-range value to decimal type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
+ withSQLConf(
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
checkExceptionInExpression[ArithmeticException](
cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
checkExceptionInExpression[ArithmeticException](
@@ -957,7 +958,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("Throw exception on casting out-of-range value to byte type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
+ withSQLConf(
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(ByteType)
Seq(Byte.MaxValue + 1, Byte.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ByteType), "overflow")
@@ -982,7 +984,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("Throw exception on casting out-of-range value to short type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
+ withSQLConf(
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(ShortType)
Seq(Short.MaxValue + 1, Short.MinValue - 1).foreach { value =>
checkExceptionInExpression[ArithmeticException](cast(value, ShortType), "overflow")
@@ -1007,7 +1010,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("Throw exception on casting out-of-range value to int type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
+ withSQLConf(
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key ->requiredAnsiEnabledForOverflowTestCases.toString) {
testIntMaxAndMin(IntegerType)
testLongMaxAndMin(IntegerType)
@@ -1024,7 +1028,8 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
test("Throw exception on casting out-of-range value to long type") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
+ withSQLConf(
+ SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> requiredAnsiEnabledForOverflowTestCases.toString) {
testLongMaxAndMin(LongType)
Seq(Long.MaxValue, 0, Long.MinValue).foreach { value =>
@@ -1201,7 +1206,7 @@ class CastSuite extends CastSuiteBase {
}
test("SPARK-28470: Cast should honor nullOnOverflow property") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(Cast(Literal("134.12"), DecimalType(3, 2)), null)
checkEvaluation(
Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 3287c83b1dd87..b4343b648110f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -32,9 +32,12 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+import org.apache.spark.unsafe.types.UTF8String
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ implicit def stringToUTF8Str(str: String): UTF8String = UTF8String.fromString(str)
+
def testSize(sizeOfNull: Any): Unit = {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
@@ -364,16 +367,6 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS)
checkEvaluation(new SortArray(arrayArrayStruct), Seq(aas1, aas2))
-
- checkEvaluation(ArraySort(a0), Seq(1, 2, 3))
- checkEvaluation(ArraySort(a1), Seq[Integer]())
- checkEvaluation(ArraySort(a2), Seq("a", "b"))
- checkEvaluation(ArraySort(a3), Seq("a", "b", null))
- checkEvaluation(ArraySort(a4), Seq(d1, d2))
- checkEvaluation(ArraySort(a5), Seq(null, null))
- checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2)))
- checkEvaluation(ArraySort(arrayArray), Seq(aa1, aa2))
- checkEvaluation(ArraySort(arrayArrayStruct), Seq(aas1, aas2))
}
test("Array contains") {
@@ -721,7 +714,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
- Literal(fromString("interval 12 hours"))),
+ Literal(stringToInterval("interval 12 hours"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -730,7 +723,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-02 00:00:01")),
- Literal(fromString("interval 12 hours"))),
+ Literal(stringToInterval("interval 12 hours"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -739,7 +732,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(negate(fromString("interval 12 hours")))),
+ Literal(negate(stringToInterval("interval 12 hours")))),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -748,7 +741,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
- Literal(negate(fromString("interval 12 hours")))),
+ Literal(negate(stringToInterval("interval 12 hours")))),
Seq(
Timestamp.valueOf("2018-01-02 00:00:00"),
Timestamp.valueOf("2018-01-01 12:00:00"),
@@ -757,7 +750,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
- Literal(fromString("interval 1 month"))),
+ Literal(stringToInterval("interval 1 month"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:00"),
@@ -766,7 +759,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(negate(fromString("interval 1 month")))),
+ Literal(negate(stringToInterval("interval 1 month")))),
Seq(
Timestamp.valueOf("2018-03-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:00"),
@@ -775,7 +768,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-03 00:00:00")),
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
- Literal(negate(fromString("interval 1 month 1 day")))),
+ Literal(negate(stringToInterval("interval 1 month 1 day")))),
Seq(
Timestamp.valueOf("2018-03-03 00:00:00"),
Timestamp.valueOf("2018-02-02 00:00:00"),
@@ -784,7 +777,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-31 00:00:00")),
Literal(Timestamp.valueOf("2018-04-30 00:00:00")),
- Literal(fromString("interval 1 month"))),
+ Literal(stringToInterval("interval 1 month"))),
Seq(
Timestamp.valueOf("2018-01-31 00:00:00"),
Timestamp.valueOf("2018-02-28 00:00:00"),
@@ -794,7 +787,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
- Literal(fromString("interval 1 month 1 second"))),
+ Literal(stringToInterval("interval 1 month 1 second"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:00:01")))
@@ -802,7 +795,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
Literal(Timestamp.valueOf("2018-03-01 00:04:06")),
- Literal(fromString("interval 1 month 2 minutes 3 seconds"))),
+ Literal(stringToInterval("interval 1 month 2 minutes 3 seconds"))),
Seq(
Timestamp.valueOf("2018-01-01 00:00:00"),
Timestamp.valueOf("2018-02-01 00:02:03"),
@@ -840,7 +833,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-03-25 01:30:00")),
Literal(Timestamp.valueOf("2018-03-25 03:30:00")),
- Literal(fromString("interval 30 minutes"))),
+ Literal(stringToInterval("interval 30 minutes"))),
Seq(
Timestamp.valueOf("2018-03-25 01:30:00"),
Timestamp.valueOf("2018-03-25 03:00:00"),
@@ -850,7 +843,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Timestamp.valueOf("2018-10-28 01:30:00")),
Literal(Timestamp.valueOf("2018-10-28 03:30:00")),
- Literal(fromString("interval 30 minutes"))),
+ Literal(stringToInterval("interval 30 minutes"))),
Seq(
Timestamp.valueOf("2018-10-28 01:30:00"),
noDST(Timestamp.valueOf("2018-10-28 02:00:00")),
@@ -867,7 +860,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-01")),
Literal(Date.valueOf("2018-01-05")),
- Literal(fromString("interval 2 days"))),
+ Literal(stringToInterval("interval 2 days"))),
Seq(
Date.valueOf("2018-01-01"),
Date.valueOf("2018-01-03"),
@@ -876,7 +869,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-01")),
Literal(Date.valueOf("2018-03-01")),
- Literal(fromString("interval 1 month"))),
+ Literal(stringToInterval("interval 1 month"))),
Seq(
Date.valueOf("2018-01-01"),
Date.valueOf("2018-02-01"),
@@ -885,7 +878,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(new Sequence(
Literal(Date.valueOf("2018-01-31")),
Literal(Date.valueOf("2018-04-30")),
- Literal(fromString("interval 1 month"))),
+ Literal(stringToInterval("interval 1 month"))),
Seq(
Date.valueOf("2018-01-31"),
Date.valueOf("2018-02-28"),
@@ -906,14 +899,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
new Sequence(
Literal(Date.valueOf("1970-01-02")),
Literal(Date.valueOf("1970-01-01")),
- Literal(fromString("interval 1 day"))),
+ Literal(stringToInterval("interval 1 day"))),
EmptyRow, "sequence boundaries: 1 to 0 by 1")
checkExceptionInExpression[IllegalArgumentException](
new Sequence(
Literal(Date.valueOf("1970-01-01")),
Literal(Date.valueOf("1970-02-01")),
- Literal(negate(fromString("interval 1 month")))),
+ Literal(negate(stringToInterval("interval 1 month")))),
EmptyRow,
s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 5f043ce972bed..5cd4d11e32f7a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -1090,17 +1090,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SubtractTimestamps(Literal(end), Literal(end)),
new CalendarInterval(0, 0, 0))
checkEvaluation(SubtractTimestamps(Literal(end), Literal(Instant.EPOCH)),
- IntervalUtils.fromString("interval " +
- "436163 hours 4 minutes 1 seconds 123 milliseconds 456 microseconds"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval " +
+ "436163 hours 4 minutes 1 seconds 123 milliseconds 456 microseconds")))
checkEvaluation(SubtractTimestamps(Literal(Instant.EPOCH), Literal(end)),
- IntervalUtils.fromString("interval " +
- "-436163 hours -4 minutes -1 seconds -123 milliseconds -456 microseconds"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval " +
+ "-436163 hours -4 minutes -1 seconds -123 milliseconds -456 microseconds")))
checkEvaluation(
SubtractTimestamps(
Literal(Instant.parse("9999-12-31T23:59:59.999999Z")),
Literal(Instant.parse("0001-01-01T00:00:00Z"))),
- IntervalUtils.fromString("interval " +
- "87649415 hours 59 minutes 59 seconds 999 milliseconds 999 microseconds"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval " +
+ "87649415 hours 59 minutes 59 seconds 999 milliseconds 999 microseconds")))
}
test("subtract dates") {
@@ -1108,18 +1108,18 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SubtractDates(Literal(end), Literal(end)),
new CalendarInterval(0, 0, 0))
checkEvaluation(SubtractDates(Literal(end.plusDays(1)), Literal(end)),
- IntervalUtils.fromString("interval 1 days"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval 1 days")))
checkEvaluation(SubtractDates(Literal(end.minusDays(1)), Literal(end)),
- IntervalUtils.fromString("interval -1 days"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval -1 days")))
val epochDate = Literal(LocalDate.ofEpochDay(0))
checkEvaluation(SubtractDates(Literal(end), epochDate),
- IntervalUtils.fromString("interval 49 years 9 months 4 days"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval 49 years 9 months 4 days")))
checkEvaluation(SubtractDates(epochDate, Literal(end)),
- IntervalUtils.fromString("interval -49 years -9 months -4 days"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval -49 years -9 months -4 days")))
checkEvaluation(
SubtractDates(
Literal(LocalDate.of(10000, 1, 1)),
Literal(LocalDate.of(1, 1, 1))),
- IntervalUtils.fromString("interval 9999 years"))
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval 9999 years")))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index 36bc3db580400..8609d888b7bc9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -32,7 +32,7 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("MakeDecimal") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
@@ -41,7 +41,7 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
evaluateWithoutCodegen(overflowExpr, null)
checkEvaluationWithUnsafeProjection(overflowExpr, null)
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
index ada3f7abd7e3a..492d97ba9d524 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
@@ -89,7 +89,7 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
val timestamp = LocalDateTime.of(2016, 1, 1, 0, 0, 0, 987654321)
.atZone(DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
.toInstant
- checkSQL(Literal(timestamp), "TIMESTAMP('2016-01-01 00:00:00.987654')")
+ checkSQL(Literal(timestamp), "TIMESTAMP '2016-01-01 00:00:00.987654'")
// TODO tests for decimals
}
@@ -169,12 +169,12 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
checkSQL(
TimeAdd('a, interval),
- "`a` + 1 hours"
+ "`a` + INTERVAL '1 hours'"
)
checkSQL(
TimeSub('a, interval),
- "`a` - 1 hours"
+ "`a` - INTERVAL '1 hours'"
)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 4b2da73abe562..3a68847ecb1f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.unsafe.types.UTF8String
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val random = new scala.util.Random
+ implicit def stringToUTF8Str(str: String): UTF8String = UTF8String.fromString(str)
test("md5") {
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
@@ -252,7 +253,8 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("hive-hash for CalendarInterval type") {
def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = {
- checkHiveHash(IntervalUtils.fromString(interval), CalendarIntervalType, expected)
+ checkHiveHash(IntervalUtils.stringToInterval(UTF8String.fromString(interval)),
+ CalendarIntervalType, expected)
}
// ----- MICROSEC -----
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index 4cdee447fa45a..e7b713840b884 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -84,6 +84,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}
+ def arraySort(expr: Expression): Expression = {
+ arraySort(expr, ArraySort.comparator)
+ }
+
+ def arraySort(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
+ val ArrayType(et, cn) = expr.dataType
+ ArraySort(expr, createLambda(et, cn, et, cn, f)).bind(validateBinding)
+ }
+
def filter(expr: Expression, f: Expression => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
@@ -167,6 +176,47 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq("[1, 3, 5]", null, "[4, 6]"))
}
+ test("ArraySort") {
+ val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
+ val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
+ val d1 = new Decimal().set(10)
+ val d2 = new Decimal().set(100)
+ val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0)))
+ val a5 = Literal.create(Seq(null, null), ArrayType(NullType))
+
+ val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
+
+ val typeAA = ArrayType(ArrayType(IntegerType))
+ val aa1 = Array[java.lang.Integer](1, 2)
+ val aa2 = Array[java.lang.Integer](3, null, 4)
+ val arrayArray = Literal.create(Seq(aa2, aa1), typeAA)
+
+ val typeAAS = ArrayType(ArrayType(StructType(StructField("a", IntegerType) :: Nil)))
+ val aas1 = Array(create_row(1))
+ val aas2 = Array(create_row(2))
+ val arrayArrayStruct = Literal.create(Seq(aas2, aas1), typeAAS)
+
+ checkEvaluation(arraySort(a0), Seq(1, 2, 3))
+ checkEvaluation(arraySort(a1), Seq[Integer]())
+ checkEvaluation(arraySort(a2), Seq("a", "b"))
+ checkEvaluation(arraySort(a3), Seq("a", "b", null))
+ checkEvaluation(arraySort(a4), Seq(d1, d2))
+ checkEvaluation(arraySort(a5), Seq(null, null))
+ checkEvaluation(arraySort(arrayStruct), Seq(create_row(1), create_row(2)))
+ checkEvaluation(arraySort(arrayArray), Seq(aa1, aa2))
+ checkEvaluation(arraySort(arrayArrayStruct), Seq(aas1, aas2))
+
+ checkEvaluation(arraySort(a0, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
+ Seq(3, 2, 1))
+ checkEvaluation(arraySort(a3, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
+ Seq(null, "b", "a"))
+ checkEvaluation(arraySort(a4, (left, right) => UnaryMinus(ArraySort.comparator(left, right))),
+ Seq(d2, d1))
+ }
+
test("MapFilter") {
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index e483f028ffff3..ddcb6a66832af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -21,13 +21,15 @@ import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
-import org.apache.spark.sql.catalyst.util.IntervalUtils.fromString
+import org.apache.spark.sql.catalyst.util.IntervalUtils.stringToInterval
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ implicit def stringToUTF8Str(str: String): UTF8String = UTF8String.fromString(str)
+
implicit def interval(s: String): Literal = {
- Literal(fromString("interval " + s))
+ Literal(stringToInterval( "interval " + s))
}
test("millenniums") {
@@ -197,8 +199,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("multiply") {
def check(interval: String, num: Double, expected: String): Unit = {
checkEvaluation(
- MultiplyInterval(Literal(fromString(interval)), Literal(num)),
- if (expected == null) null else fromString(expected))
+ MultiplyInterval(Literal(stringToInterval(interval)), Literal(num)),
+ if (expected == null) null else stringToInterval(expected))
}
check("0 seconds", 10, "0 seconds")
@@ -215,8 +217,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("divide") {
def check(interval: String, num: Double, expected: String): Unit = {
checkEvaluation(
- DivideInterval(Literal(fromString(interval)), Literal(num)),
- if (expected == null) null else fromString(expected))
+ DivideInterval(Literal(stringToInterval(interval)), Literal(num)),
+ if (expected == null) null else stringToInterval(expected))
}
check("0 seconds", 10, "0 seconds")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index f8400a590606a..d5cc1d4f0fdde 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -702,26 +702,22 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
}
test("from_json missing fields") {
- for (forceJsonNullableSchema <- Seq(false, true)) {
- withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) {
- val input =
- """{
- | "a": 1,
- | "c": "foo"
- |}
- |""".stripMargin
- val jsonSchema = new StructType()
- .add("a", LongType, nullable = false)
- .add("b", StringType, nullable = !forceJsonNullableSchema)
- .add("c", StringType, nullable = false)
- val output = InternalRow(1L, null, UTF8String.fromString("foo"))
- val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
- checkEvaluation(expr, output)
- val schema = expr.dataType
- val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
- assert(schemaToCompare == schema)
- }
- }
+ val input =
+ """{
+ | "a": 1,
+ | "c": "foo"
+ |}
+ |""".stripMargin
+ val jsonSchema = new StructType()
+ .add("a", LongType, nullable = false)
+ .add("b", StringType, nullable = false)
+ .add("c", StringType, nullable = false)
+ val output = InternalRow(1L, null, UTF8String.fromString("foo"))
+ val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
+ checkEvaluation(expr, output)
+ val schema = expr.dataType
+ val schemaToCompare = jsonSchema.asNullable
+ assert(schemaToCompare == schema)
}
test("SPARK-24709: infer schema of json strings") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 03c9cf9c8a94d..4714635a3370b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -302,7 +302,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
val timestamp = LocalDateTime.of(2019, 3, 21, 0, 2, 3, 456000000)
.atZone(ZoneOffset.UTC)
.toInstant
- val expected = "TIMESTAMP('2019-03-21 01:02:03.456')"
+ val expected = "TIMESTAMP '2019-03-21 01:02:03.456'"
val literalStr = Literal.create(timestamp).sql
assert(literalStr === expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
index a89937068a87d..d92eb01b69bf0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala
@@ -22,6 +22,7 @@ import java.time.{Duration, Instant, LocalDate}
import java.util.concurrent.TimeUnit
import org.scalacheck.{Arbitrary, Gen}
+import org.scalatest.Assertions._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index 23ba9c6ec7388..63700a1e94a3e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,7 +57,8 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
testBothCodegenAndInterpreted("variable-length types") {
val proj = createMutableProjection(variableLengthTypes)
- val scalaValues = Seq("abc", BigDecimal(10), IntervalUtils.fromString("interval 1 day"),
+ val scalaValues = Seq("abc", BigDecimal(10),
+ IntervalUtils.stringToInterval(UTF8String.fromString("interval 1 day")),
Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"),
new java.lang.Integer(5))
val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 4ccd4f7ce798d..ef7764dba1e9e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -485,7 +485,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
("abcd".getBytes, BinaryType),
("abcd", StringType),
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
- (IntervalUtils.fromString("interval 3 day"), CalendarIntervalType),
+ (IntervalUtils.stringToInterval(UTF8String.fromString("interval 3 day")),
+ CalendarIntervalType),
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 52cdd988caa2e..67a41e7cc2767 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -510,7 +510,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("Interpreted Predicate should initialize nondeterministic expressions") {
- val interpreted = InterpretedPredicate.create(LessThan(Rand(7), Literal(1.0)))
+ val interpreted = Predicate.create(LessThan(Rand(7), Literal(1.0)))
interpreted.initialize(0)
assert(interpreted.eval(new UnsafeRow()))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index c5ffc381b58e2..cf6ebfb0ecefb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -57,7 +57,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
@@ -69,7 +69,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
}
assert(e2.getCause.isInstanceOf[ArithmeticException])
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
val udf = ScalaUDF(
(a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
DecimalType.SYSTEM_DEFAULT,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 20e77254ecdad..b80b30a4e07ae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -531,7 +531,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
// Simple tests
val inputRow = InternalRow.fromSeq(Seq(
false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"),
- Decimal(255), IntervalUtils.fromString("interval 1 day"), Array[Byte](1, 2)
+ Decimal(255), IntervalUtils.stringToInterval(UTF8String.fromString( "interval 1 day")),
+ Array[Byte](1, 2)
))
val fields1 = Array(
BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index 94171feba2ac7..d2575dabf847c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -623,6 +623,15 @@ class DDLParserSuite extends AnalysisTest {
}
}
+ test("alter table/view: rename table/view") {
+ comparePlans(
+ parsePlan("ALTER TABLE a.b.c RENAME TO x.y.z"),
+ RenameTableStatement(Seq("a", "b", "c"), Seq("x", "y", "z"), isView = false))
+ comparePlans(
+ parsePlan("ALTER VIEW a.b.c RENAME TO x.y.z"),
+ RenameTableStatement(Seq("a", "b", "c"), Seq("x", "y", "z"), isView = true))
+ }
+
test("describe table column") {
comparePlans(parsePlan("DESCRIBE t col"),
DescribeColumnStatement(
@@ -653,6 +662,13 @@ class DDLParserSuite extends AnalysisTest {
"DESC TABLE COLUMN for a specific partition is not supported"))
}
+ test("describe database") {
+ val sql1 = "DESCRIBE DATABASE EXTENDED a.b"
+ val sql2 = "DESCRIBE DATABASE a.b"
+ comparePlans(parsePlan(sql1), DescribeNamespaceStatement(Seq("a", "b"), extended = true))
+ comparePlans(parsePlan(sql2), DescribeNamespaceStatement(Seq("a", "b"), extended = false))
+ }
+
test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") {
comparePlans(parsePlan("describe t"),
DescribeTableStatement(Seq("t"), Map.empty, isExtended = false))
@@ -1022,6 +1038,31 @@ class DDLParserSuite extends AnalysisTest {
ShowTablesStatement(Some(Seq("tbl")), Some("*dog*")))
}
+ test("show table extended") {
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED LIKE '*test*'"),
+ ShowTableStatement(None, "*test*", None))
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED FROM testcat.ns1.ns2 LIKE '*test*'"),
+ ShowTableStatement(Some(Seq("testcat", "ns1", "ns2")), "*test*", None))
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED IN testcat.ns1.ns2 LIKE '*test*'"),
+ ShowTableStatement(Some(Seq("testcat", "ns1", "ns2")), "*test*", None))
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED LIKE '*test*' PARTITION(ds='2008-04-09', hr=11)"),
+ ShowTableStatement(None, "*test*", Some(Map("ds" -> "2008-04-09", "hr" -> "11"))))
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED FROM testcat.ns1.ns2 LIKE '*test*' " +
+ "PARTITION(ds='2008-04-09')"),
+ ShowTableStatement(Some(Seq("testcat", "ns1", "ns2")), "*test*",
+ Some(Map("ds" -> "2008-04-09"))))
+ comparePlans(
+ parsePlan("SHOW TABLE EXTENDED IN testcat.ns1.ns2 LIKE '*test*' " +
+ "PARTITION(ds='2008-04-09')"),
+ ShowTableStatement(Some(Seq("testcat", "ns1", "ns2")), "*test*",
+ Some(Map("ds" -> "2008-04-09"))))
+ }
+
test("create namespace -- backward compatibility with DATABASE/DBPROPERTIES") {
val expected = CreateNamespaceStatement(
Seq("a", "b", "c"),
@@ -1128,6 +1169,52 @@ class DDLParserSuite extends AnalysisTest {
DropNamespaceStatement(Seq("a", "b", "c"), ifExists = false, cascade = true))
}
+ test("set namespace properties") {
+ comparePlans(
+ parsePlan("ALTER DATABASE a.b.c SET PROPERTIES ('a'='a', 'b'='b', 'c'='c')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("a" -> "a", "b" -> "b", "c" -> "c")))
+
+ comparePlans(
+ parsePlan("ALTER SCHEMA a.b.c SET PROPERTIES ('a'='a')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("a" -> "a")))
+
+ comparePlans(
+ parsePlan("ALTER NAMESPACE a.b.c SET PROPERTIES ('b'='b')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("b" -> "b")))
+
+ comparePlans(
+ parsePlan("ALTER DATABASE a.b.c SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("a" -> "a", "b" -> "b", "c" -> "c")))
+
+ comparePlans(
+ parsePlan("ALTER SCHEMA a.b.c SET DBPROPERTIES ('a'='a')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("a" -> "a")))
+
+ comparePlans(
+ parsePlan("ALTER NAMESPACE a.b.c SET DBPROPERTIES ('b'='b')"),
+ AlterNamespaceSetPropertiesStatement(
+ Seq("a", "b", "c"), Map("b" -> "b")))
+ }
+
+ test("set namespace location") {
+ comparePlans(
+ parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"),
+ AlterNamespaceSetLocationStatement(Seq("a", "b", "c"), "/home/user/db"))
+
+ comparePlans(
+ parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"),
+ AlterNamespaceSetLocationStatement(Seq("a", "b", "c"), "/home/user/db"))
+
+ comparePlans(
+ parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"),
+ AlterNamespaceSetLocationStatement(Seq("a", "b", "c"), "/home/user/db"))
+ }
+
test("show databases: basic") {
comparePlans(
parsePlan("SHOW DATABASES"),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 1a6286067a618..d519fdf378786 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -51,10 +51,13 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("dOUBle", DoubleType)
checkDataType("decimal(10, 5)", DecimalType(10, 5))
checkDataType("decimal", DecimalType.USER_DEFAULT)
+ checkDataType("Dec(10, 5)", DecimalType(10, 5))
+ checkDataType("deC", DecimalType.USER_DEFAULT)
checkDataType("DATE", DateType)
checkDataType("timestamp", TimestampType)
checkDataType("string", StringType)
checkDataType("ChaR(5)", StringType)
+ checkDataType("ChaRacter(5)", StringType)
checkDataType("varchAr(20)", StringType)
checkDataType("cHaR(27)", StringType)
checkDataType("BINARY", BinaryType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index a707b456c6bd1..371b702722a69 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* Test basic expression parsing.
@@ -43,6 +43,8 @@ class ExpressionParserSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+ implicit def stringToUTF8Str(str: String): UTF8String = UTF8String.fromString(str)
+
val defaultParser = CatalystSqlParser
def assertEqual(
@@ -434,13 +436,13 @@ class ExpressionParserSuite extends AnalysisTest {
intercept("timestamP '2016-33-11 20:54:00.000'", "Cannot parse the TIMESTAMP value")
// Interval.
- val intervalLiteral = Literal(IntervalUtils.fromString("interval 3 month 1 hour"))
+ val intervalLiteral = Literal(IntervalUtils.stringToInterval("interval 3 month 1 hour"))
assertEqual("InterVal 'interval 3 month 1 hour'", intervalLiteral)
assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral)
intercept("Interval 'interval 3 monthsss 1 hoursss'", "Cannot parse the INTERVAL value")
assertEqual(
"-interval '3 month 1 hour'",
- Literal(IntervalUtils.fromString("interval -3 month -1 hour")))
+ UnaryMinus(Literal(IntervalUtils.stringToInterval("interval 3 month 1 hour"))))
// Binary.
assertEqual("X'A'", Literal(Array(0x0a).map(_.toByte)))
@@ -602,20 +604,19 @@ class ExpressionParserSuite extends AnalysisTest {
MICROSECOND)
def intervalLiteral(u: IntervalUnit, s: String): Literal = {
- Literal(IntervalUtils.fromUnitStrings(Array(u), Array(s)))
+ Literal(IntervalUtils.stringToInterval(s + " " + u.toString))
}
test("intervals") {
def checkIntervals(intervalValue: String, expected: Literal): Unit = {
Seq(
"" -> expected,
- "-" -> expected.copy(
- value = IntervalUtils.negate(expected.value.asInstanceOf[CalendarInterval]))
+ "-" -> UnaryMinus(expected)
).foreach { case (sign, expectedLiteral) =>
assertEqual(s"${sign}interval $intervalValue", expectedLiteral)
// SPARK-23264 Support interval values without INTERVAL clauses if ANSI SQL enabled
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
assertEqual(intervalValue, expected)
}
}
@@ -651,7 +652,8 @@ class ExpressionParserSuite extends AnalysisTest {
0,
0,
13 * MICROS_PER_SECOND + 123 * MICROS_PER_MILLIS + 456)))
- checkIntervals("1.001 second", Literal(IntervalUtils.fromString("1 second 1 millisecond")))
+ checkIntervals("1.001 second",
+ Literal(IntervalUtils.stringToInterval("1 second 1 millisecond")))
// Non Existing unit
intercept("interval 10 nanoseconds",
@@ -701,12 +703,12 @@ class ExpressionParserSuite extends AnalysisTest {
test("SPARK-23264 Interval Compatibility tests") {
def checkIntervals(intervalValue: String, expected: Literal): Unit = {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
assertEqual(intervalValue, expected)
}
// Compatibility tests: If ANSI SQL disabled, `intervalValue` should be parsed as an alias
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
val aliases = defaultParser.parseExpression(intervalValue).collect {
case a @ Alias(_: Literal, name)
if intervalUnits.exists { unit => name.startsWith(unit.toString) } => a
@@ -804,12 +806,12 @@ class ExpressionParserSuite extends AnalysisTest {
}
test("current date/timestamp braceless expressions") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
assertEqual("current_date", CurrentDate())
assertEqual("current_timestamp", CurrentTimestamp())
}
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "false") {
assertEqual("current_date", UnresolvedAttribute.quoted("current_date"))
assertEqual("current_timestamp", UnresolvedAttribute.quoted("current_timestamp"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
index a9216174804d0..9560aec944d9a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -658,7 +658,7 @@ class TableIdentifierParserSuite extends SparkFunSuite with SQLHelper {
}
test("table identifier - reserved/non-reserved keywords if ANSI mode enabled") {
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ withSQLConf(SQLConf.DIALECT_SPARK_ANSI_ENABLED.key -> "true") {
reservedKeywordsInAnsiMode.foreach { keyword =>
val errMsg = intercept[ParseException] {
parseTableIdentifier(keyword)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index 8c84eb107cd30..ee3db0391ed00 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -28,20 +28,31 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class IntervalUtilsSuite extends SparkFunSuite {
private def checkFromString(input: String, expected: CalendarInterval): Unit = {
- assert(fromString(input) === expected)
assert(stringToInterval(UTF8String.fromString(input)) === expected)
+ assert(safeStringToInterval(UTF8String.fromString(input)) === expected)
+ }
+
+ private def checkFromStringWithFunc(
+ input: String,
+ months: Int,
+ days: Int,
+ us: Long,
+ func: CalendarInterval => CalendarInterval): Unit = {
+ val expected = new CalendarInterval(months, days, us)
+ assert(func(stringToInterval(UTF8String.fromString(input))) === expected)
+ assert(func(safeStringToInterval(UTF8String.fromString(input))) === expected)
}
private def checkFromInvalidString(input: String, errorMsg: String): Unit = {
try {
- fromString(input)
+ stringToInterval(UTF8String.fromString(input))
fail("Expected to throw an exception for the invalid input")
} catch {
case e: IllegalArgumentException =>
val msg = e.getMessage
assert(msg.contains(errorMsg))
}
- assert(stringToInterval(UTF8String.fromString(input)) === null)
+ assert(safeStringToInterval(UTF8String.fromString(input)) === null)
}
private def testSingleUnit(
@@ -69,7 +80,7 @@ class IntervalUtilsSuite extends SparkFunSuite {
checkFromInvalidString(null, "cannot be null")
for (input <- Seq("", " ", "interval", "interval1 day", "foo", "foo 1 day")) {
- checkFromInvalidString(input, "Invalid interval string")
+ checkFromInvalidString(input, "Error parsing")
}
}
@@ -93,8 +104,18 @@ class IntervalUtilsSuite extends SparkFunSuite {
// Allow duplicated units and summarize their values
checkFromString("1 day 10 day", new CalendarInterval(0, 11, 0))
// Only the seconds units can have the fractional part
- checkFromInvalidString("1.5 days", "Error parsing interval string")
- checkFromInvalidString("1. hour", "Error parsing interval string")
+ checkFromInvalidString("1.5 days", "'days' cannot have fractional part")
+ checkFromInvalidString("1. hour", "'hour' cannot have fractional part")
+ checkFromInvalidString("1 hourX", "invalid unit 'hourx'")
+ checkFromInvalidString("~1 hour", "unrecognized number '~1'")
+ checkFromInvalidString("1 Mour", "invalid unit 'mour'")
+ checkFromInvalidString("1 aour", "invalid unit 'aour'")
+ checkFromInvalidString("1a1 hour", "invalid value '1a1'")
+ checkFromInvalidString("1.1a1 seconds", "invalid value '1.1a1'")
+ checkFromInvalidString("2234567890 days", "integer overflow")
+ checkFromInvalidString("\n", "Error parsing '\n' to interval")
+ checkFromInvalidString("\t", "Error parsing '\t' to interval")
+ checkFromInvalidString(". seconds", "invalid value '.'")
}
test("string to interval: seconds with fractional part") {
@@ -106,7 +127,8 @@ class IntervalUtilsSuite extends SparkFunSuite {
checkFromString("-1.5 seconds", new CalendarInterval(0, 0, -1500000))
// truncate nanoseconds to microseconds
checkFromString("0.999999999 seconds", new CalendarInterval(0, 0, 999999))
- checkFromInvalidString("0.123456789123 seconds", "Error parsing interval string")
+ checkFromString(".999999999 seconds", new CalendarInterval(0, 0, 999999))
+ checkFromInvalidString("0.123456789123 seconds", "'0.123456789123' is out of range")
}
test("from year-month string") {
@@ -173,7 +195,7 @@ class IntervalUtilsSuite extends SparkFunSuite {
test("interval duration") {
def duration(s: String, unit: TimeUnit, daysPerMonth: Int): Long = {
- IntervalUtils.getDuration(fromString(s), unit, daysPerMonth)
+ IntervalUtils.getDuration(stringToInterval(UTF8String.fromString(s)), unit, daysPerMonth)
}
assert(duration("0 seconds", TimeUnit.MILLISECONDS, 31) === 0)
@@ -192,7 +214,7 @@ class IntervalUtilsSuite extends SparkFunSuite {
test("negative interval") {
def isNegative(s: String, daysPerMonth: Int): Boolean = {
- IntervalUtils.isNegative(fromString(s), daysPerMonth)
+ IntervalUtils.isNegative(stringToInterval(UTF8String.fromString(s)), daysPerMonth)
}
assert(isNegative("-1 months", 28))
@@ -268,33 +290,91 @@ class IntervalUtilsSuite extends SparkFunSuite {
}
test("justify days") {
- assert(justifyDays(fromString("1 month 35 day")) === new CalendarInterval(2, 5, 0))
- assert(justifyDays(fromString("-1 month 35 day")) === new CalendarInterval(0, 5, 0))
- assert(justifyDays(fromString("1 month -35 day")) === new CalendarInterval(0, -5, 0))
- assert(justifyDays(fromString("-1 month -35 day")) === new CalendarInterval(-2, -5, 0))
- assert(justifyDays(fromString("-1 month 2 day")) === new CalendarInterval(0, -28, 0))
+ checkFromStringWithFunc("1 month 35 day", 2, 5, 0, justifyDays)
+ checkFromStringWithFunc("-1 month 35 day", 0, 5, 0, justifyDays)
+ checkFromStringWithFunc("1 month -35 day", 0, -5, 0, justifyDays)
+ checkFromStringWithFunc("-1 month -35 day", -2, -5, 0, justifyDays)
+ checkFromStringWithFunc("-1 month 2 day", 0, -28, 0, justifyDays)
}
test("justify hours") {
- assert(justifyHours(fromString("29 day 25 hour")) ===
- new CalendarInterval(0, 30, 1 * MICROS_PER_HOUR))
- assert(justifyHours(fromString("29 day -25 hour")) ===
- new CalendarInterval(0, 27, 23 * MICROS_PER_HOUR))
- assert(justifyHours(fromString("-29 day 25 hour")) ===
- new CalendarInterval(0, -27, -23 * MICROS_PER_HOUR))
- assert(justifyHours(fromString("-29 day -25 hour")) ===
- new CalendarInterval(0, -30, -1 * MICROS_PER_HOUR))
+ checkFromStringWithFunc("29 day 25 hour", 0, 30, 1 * MICROS_PER_HOUR, justifyHours)
+ checkFromStringWithFunc("29 day -25 hour", 0, 27, 23 * MICROS_PER_HOUR, justifyHours)
+ checkFromStringWithFunc("-29 day 25 hour", 0, -27, -23 * MICROS_PER_HOUR, justifyHours)
+ checkFromStringWithFunc("-29 day -25 hour", 0, -30, -1 * MICROS_PER_HOUR, justifyHours)
}
test("justify interval") {
- assert(justifyInterval(fromString("1 month 29 day 25 hour")) ===
- new CalendarInterval(2, 0, 1 * MICROS_PER_HOUR))
- assert(justifyInterval(fromString("-1 month 29 day -25 hour")) ===
- new CalendarInterval(0, -2, -1 * MICROS_PER_HOUR))
- assert(justifyInterval(fromString("1 month -29 day -25 hour")) ===
- new CalendarInterval(0, 0, -1 * MICROS_PER_HOUR))
- assert(justifyInterval(fromString("-1 month -29 day -25 hour")) ===
- new CalendarInterval(-2, 0, -1 * MICROS_PER_HOUR))
+ checkFromStringWithFunc("1 month 29 day 25 hour", 2, 0, 1 * MICROS_PER_HOUR, justifyInterval)
+ checkFromStringWithFunc("-1 month 29 day -25 hour", 0, -2, -1 * MICROS_PER_HOUR,
+ justifyInterval)
+ checkFromStringWithFunc("1 month -29 day -25 hour", 0, 0, -1 * MICROS_PER_HOUR, justifyInterval)
+ checkFromStringWithFunc("-1 month -29 day -25 hour", -2, 0, -1 * MICROS_PER_HOUR,
+ justifyInterval)
intercept[ArithmeticException](justifyInterval(new CalendarInterval(2, 0, Long.MaxValue)))
}
+
+ test("to ansi sql standard string") {
+ val i1 = new CalendarInterval(0, 0, 0)
+ assert(IntervalUtils.toSqlStandardString(i1) === "0")
+ val i2 = new CalendarInterval(34, 0, 0)
+ assert(IntervalUtils.toSqlStandardString(i2) === "+2-10")
+ val i3 = new CalendarInterval(-34, 0, 0)
+ assert(IntervalUtils.toSqlStandardString(i3) === "-2-10")
+ val i4 = new CalendarInterval(0, 31, 0)
+ assert(IntervalUtils.toSqlStandardString(i4) === "+31")
+ val i5 = new CalendarInterval(0, -31, 0)
+ assert(IntervalUtils.toSqlStandardString(i5) === "-31")
+ val i6 = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toSqlStandardString(i6) === "+3:13:00.000123")
+ val i7 = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123)
+ assert(IntervalUtils.toSqlStandardString(i7) === "-3:13:00.000123")
+ val i8 = new CalendarInterval(-34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toSqlStandardString(i8) === "-2-10 +31 +3:13:00.000123")
+ val i9 = new CalendarInterval(0, 0, -3000 * MICROS_PER_HOUR)
+ assert(IntervalUtils.toSqlStandardString(i9) === "-3000:00:00")
+ }
+
+ test("to iso 8601 string") {
+ val i1 = new CalendarInterval(0, 0, 0)
+ assert(IntervalUtils.toIso8601String(i1) === "PT0S")
+ val i2 = new CalendarInterval(34, 0, 0)
+ assert(IntervalUtils.toIso8601String(i2) === "P2Y10M")
+ val i3 = new CalendarInterval(-34, 0, 0)
+ assert(IntervalUtils.toIso8601String(i3) === "P-2Y-10M")
+ val i4 = new CalendarInterval(0, 31, 0)
+ assert(IntervalUtils.toIso8601String(i4) === "P31D")
+ val i5 = new CalendarInterval(0, -31, 0)
+ assert(IntervalUtils.toIso8601String(i5) === "P-31D")
+ val i6 = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toIso8601String(i6) === "PT3H13M0.000123S")
+ val i7 = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123)
+ assert(IntervalUtils.toIso8601String(i7) === "PT-3H-13M-0.000123S")
+ val i8 = new CalendarInterval(-34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toIso8601String(i8) === "P-2Y-10M31DT3H13M0.000123S")
+ val i9 = new CalendarInterval(0, 0, -3000 * MICROS_PER_HOUR)
+ assert(IntervalUtils.toIso8601String(i9) === "PT-3000H")
+ }
+
+ test("to multi units string") {
+ val i1 = new CalendarInterval(0, 0, 0)
+ assert(IntervalUtils.toMultiUnitsString(i1) === "0 seconds")
+ val i2 = new CalendarInterval(34, 0, 0)
+ assert(IntervalUtils.toMultiUnitsString(i2) === "2 years 10 months")
+ val i3 = new CalendarInterval(-34, 0, 0)
+ assert(IntervalUtils.toMultiUnitsString(i3) === "-2 years -10 months")
+ val i4 = new CalendarInterval(0, 31, 0)
+ assert(IntervalUtils.toMultiUnitsString(i4) === "31 days")
+ val i5 = new CalendarInterval(0, -31, 0)
+ assert(IntervalUtils.toMultiUnitsString(i5) === "-31 days")
+ val i6 = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toMultiUnitsString(i6) === "3 hours 13 minutes 0.000123 seconds")
+ val i7 = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123)
+ assert(IntervalUtils.toMultiUnitsString(i7) === "-3 hours -13 minutes -0.000123 seconds")
+ val i8 = new CalendarInterval(-34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123)
+ assert(IntervalUtils.toMultiUnitsString(i8) ===
+ "-2 years -10 months 31 days 3 hours 13 minutes 0.000123 seconds")
+ val i9 = new CalendarInterval(0, 0, -3000 * MICROS_PER_HOUR)
+ assert(IntervalUtils.toMultiUnitsString(i9) === "-3000 hours")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index 414f9d5834868..201860e5135ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -22,6 +22,8 @@ import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable
+import org.scalatest.Assertions._
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform}
@@ -122,7 +124,7 @@ class InMemoryTable(
}
private abstract class TestBatchWrite extends BatchWrite {
- override def createBatchWriterFactory(): DataWriterFactory = {
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
BufferedRowsWriterFactory
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
index 6107a15f5c428..082849c88669a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
+import org.apache.spark.unsafe.types.UTF8String
class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers {
@@ -154,4 +155,82 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers
formatter.parse("Tomorrow ") should be(tomorrow +- tolerance)
}
}
+
+ test("parsing timestamp strings with various seconds fractions") {
+ DateTimeTestUtils.outstandingZoneIds.foreach { zoneId =>
+ def check(pattern: String, input: String, reference: String): Unit = {
+ val formatter = TimestampFormatter(pattern, zoneId)
+ val expected = DateTimeUtils.stringToTimestamp(
+ UTF8String.fromString(reference), zoneId).get
+ val actual = formatter.parse(input)
+ assert(actual === expected)
+ }
+
+ check("yyyy-MM-dd'T'HH:mm:ss.SSSSSSSXXX",
+ "2019-10-14T09:39:07.3220000Z", "2019-10-14T09:39:07.322Z")
+ check("yyyy-MM-dd'T'HH:mm:ss.SSSSSS",
+ "2019-10-14T09:39:07.322000", "2019-10-14T09:39:07.322")
+ check("yyyy-MM-dd'T'HH:mm:ss.SSSSSSX",
+ "2019-10-14T09:39:07.123456Z", "2019-10-14T09:39:07.123456Z")
+ check("yyyy-MM-dd'T'HH:mm:ss.SSSSSSX",
+ "2019-10-14T09:39:07.000010Z", "2019-10-14T09:39:07.00001Z")
+ check("yyyy HH:mm:ss.SSSSS", "1970 01:02:03.00004", "1970-01-01 01:02:03.00004")
+ check("yyyy HH:mm:ss.SSSS", "2019 00:00:07.0100", "2019-01-01 00:00:07.0100")
+ check("yyyy-MM-dd'T'HH:mm:ss.SSSX",
+ "2019-10-14T09:39:07.322Z", "2019-10-14T09:39:07.322Z")
+ check("yyyy-MM-dd'T'HH:mm:ss.SS",
+ "2019-10-14T09:39:07.10", "2019-10-14T09:39:07.1")
+ check("yyyy-MM-dd'T'HH:mm:ss.S",
+ "2019-10-14T09:39:07.1", "2019-10-14T09:39:07.1")
+
+ try {
+ TimestampFormatter("yyyy/MM/dd HH_mm_ss.SSSSSS", zoneId)
+ .parse("2019/11/14 20#25#30.123456")
+ fail("Expected to throw an exception for the invalid input")
+ } catch {
+ case e: java.time.format.DateTimeParseException =>
+ assert(e.getMessage.contains("could not be parsed"))
+ }
+ }
+ }
+
+ test("formatting timestamp strings up to microsecond precision") {
+ DateTimeTestUtils.outstandingZoneIds.foreach { zoneId =>
+ def check(pattern: String, input: String, expected: String): Unit = {
+ val formatter = TimestampFormatter(pattern, zoneId)
+ val timestamp = DateTimeUtils.stringToTimestamp(
+ UTF8String.fromString(input), zoneId).get
+ val actual = formatter.format(timestamp)
+ assert(actual === expected)
+ }
+
+ check(
+ "yyyy-MM-dd HH:mm:ss.SSSSSSS", "2019-10-14T09:39:07.123456",
+ "2019-10-14 09:39:07.1234560")
+ check(
+ "yyyy-MM-dd HH:mm:ss.SSSSSS", "1960-01-01T09:39:07.123456",
+ "1960-01-01 09:39:07.123456")
+ check(
+ "yyyy-MM-dd HH:mm:ss.SSSSS", "0001-10-14T09:39:07.1",
+ "0001-10-14 09:39:07.10000")
+ check(
+ "yyyy-MM-dd HH:mm:ss.SSSS", "9999-12-31T23:59:59.999",
+ "9999-12-31 23:59:59.9990")
+ check(
+ "yyyy-MM-dd HH:mm:ss.SSS", "1970-01-01T00:00:00.0101",
+ "1970-01-01 00:00:00.010")
+ check(
+ "yyyy-MM-dd HH:mm:ss.SS", "2019-10-14T09:39:07.09",
+ "2019-10-14 09:39:07.09")
+ check(
+ "yyyy-MM-dd HH:mm:ss.S", "2019-10-14T09:39:07.2",
+ "2019-10-14 09:39:07.2")
+ check(
+ "yyyy-MM-dd HH:mm:ss.S", "2019-10-14T09:39:07",
+ "2019-10-14 09:39:07.0")
+ check(
+ "yyyy-MM-dd HH:mm:ss", "2019-10-14T09:39:07.123456",
+ "2019-10-14 09:39:07")
+ }
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 3a8b7d22397ff..5cd7c656ea725 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -177,7 +177,7 @@
org.scalatest
scalatest-maven-plugin
- -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} -Dio.netty.tryReflectionSetAccessible=true
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
index 40c2cc806e87a..1f243406c77e0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java
@@ -20,8 +20,13 @@
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import java.nio.ByteOrder;
+
public final class RecordBinaryComparator extends RecordComparator {
+ private static final boolean LITTLE_ENDIAN =
+ ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN);
+
@Override
public int compare(
Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) {
@@ -38,10 +43,10 @@ public int compare(
// check if stars align and we can get both offsets to be aligned
if ((leftOff % 8) == (rightOff % 8)) {
while ((leftOff + i) % 8 != 0 && i < leftLen) {
- final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
- final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+ final int v1 = Platform.getByte(leftObj, leftOff + i);
+ final int v2 = Platform.getByte(rightObj, rightOff + i);
if (v1 != v2) {
- return v1 > v2 ? 1 : -1;
+ return (v1 & 0xff) > (v2 & 0xff) ? 1 : -1;
}
i += 1;
}
@@ -49,10 +54,17 @@ public int compare(
// for architectures that support unaligned accesses, chew it up 8 bytes at a time
if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) {
while (i <= leftLen - 8) {
- final long v1 = Platform.getLong(leftObj, leftOff + i);
- final long v2 = Platform.getLong(rightObj, rightOff + i);
+ long v1 = Platform.getLong(leftObj, leftOff + i);
+ long v2 = Platform.getLong(rightObj, rightOff + i);
if (v1 != v2) {
- return v1 > v2 ? 1 : -1;
+ if (LITTLE_ENDIAN) {
+ // if read as little-endian, we have to reverse bytes so that the long comparison result
+ // is equivalent to byte-by-byte comparison result.
+ // See discussion in https://github.com/apache/spark/pull/26548#issuecomment-554645859
+ v1 = Long.reverseBytes(v1);
+ v2 = Long.reverseBytes(v2);
+ }
+ return Long.compareUnsigned(v1, v2);
}
i += 8;
}
@@ -60,10 +72,10 @@ public int compare(
// this will finish off the unaligned comparisons, or do the entire aligned comparison
// whichever is needed.
while (i < leftLen) {
- final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff;
- final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff;
+ final int v1 = Platform.getByte(leftObj, leftOff + i);
+ final int v2 = Platform.getByte(rightObj, rightOff + i);
if (v1 != v2) {
- return v1 > v2 ? 1 : -1;
+ return (v1 & 0xff) > (v2 & 0xff) ? 1 : -1;
}
i += 1;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 09426117a24b9..acd54fe25d62d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -29,7 +29,7 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index aaa3f9dd71594..e1bca44dfccf5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -60,7 +60,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
private[sql] object Dataset {
@@ -586,8 +586,8 @@ class Dataset[T] private[sql](
* @group basic
* @since 2.4.0
*/
- def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
- plan.executeCollect().head.getLong(0) == 0
+ def isEmpty: Boolean = withAction("isEmpty", select().queryExecution) { plan =>
+ plan.executeTake(1).isEmpty
}
/**
@@ -725,7 +725,7 @@ class Dataset[T] private[sql](
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
val parsedDelay =
try {
- IntervalUtils.fromString(delayThreshold)
+ IntervalUtils.stringToInterval(UTF8String.fromString(delayThreshold))
} catch {
case e: IllegalArgumentException =>
throw new AnalysisException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 4d4731870700c..b1ba7d4538732 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -129,6 +130,37 @@ class RelationalGroupedDataset protected[sql](
(inputExpr: Expression) => exprToFunc(inputExpr)
}
+ /**
+ * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
+ * of current `RelationalGroupedDataset`.
+ *
+ * @since 3.0.0
+ */
+ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
+ val keyEncoder = encoderFor[K]
+ val valueEncoder = encoderFor[T]
+
+ // Resolves grouping expressions.
+ val dummyPlan = Project(groupingExprs.map(alias), LocalRelation(df.logicalPlan.output))
+ val analyzedPlan = df.sparkSession.sessionState.analyzer.execute(dummyPlan)
+ .asInstanceOf[Project]
+ df.sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan)
+ val aliasedGroupings = analyzedPlan.projectList
+
+ // Adds the grouping expressions that are not in base DataFrame into outputs.
+ val addedCols = aliasedGroupings.filter(g => !df.logicalPlan.outputSet.contains(g.toAttribute))
+ val qe = Dataset.ofRows(
+ df.sparkSession,
+ Project(df.logicalPlan.output ++ addedCols, df.logicalPlan)).queryExecution
+
+ new KeyValueGroupedDataset(
+ keyEncoder,
+ valueEncoder,
+ qe,
+ df.logicalPlan.output,
+ aliasedGroupings.map(_.toAttribute))
+ }
+
/**
* (Scala-specific) Compute aggregates by specifying the column names and
* aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 340e09ae66adb..eb53e3accc3d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -158,6 +158,30 @@ class ResolveSessionCatalog(
case AlterViewUnsetPropertiesStatement(SessionCatalog(catalog, tableName), keys, ifExists) =>
AlterTableUnsetPropertiesCommand(tableName.asTableIdentifier, keys, ifExists, isView = true)
+ case d @ DescribeNamespaceStatement(SessionCatalog(_, nameParts), _) =>
+ if (nameParts.length != 1) {
+ throw new AnalysisException(
+ s"The database name is not valid: ${nameParts.quoted}")
+ }
+ DescribeDatabaseCommand(nameParts.head, d.extended)
+
+ case AlterNamespaceSetPropertiesStatement(SessionCatalog(_, nameParts), properties) =>
+ if (nameParts.length != 1) {
+ throw new AnalysisException(
+ s"The database name is not valid: ${nameParts.quoted}")
+ }
+ AlterDatabasePropertiesCommand(nameParts.head, properties)
+
+ case AlterNamespaceSetLocationStatement(SessionCatalog(_, nameParts), location) =>
+ if (nameParts.length != 1) {
+ throw new AnalysisException(
+ s"The database name is not valid: ${nameParts.quoted}")
+ }
+ AlterDatabaseSetLocationCommand(nameParts.head, location)
+
+ case RenameTableStatement(SessionCatalog(_, oldName), newNameParts, isView) =>
+ AlterTableRenameCommand(oldName.asTableIdentifier, newNameParts.asTableIdentifier, isView)
+
case DescribeTableStatement(
nameParts @ SessionCatalog(catalog, tableName), partitionSpec, isExtended) =>
loadTable(catalog, tableName.asIdentifier).collect {
@@ -301,6 +325,15 @@ class ResolveSessionCatalog(
case ShowTablesStatement(None, pattern) if isSessionCatalog(currentCatalog) =>
ShowTablesCommand(None, pattern)
+ case ShowTableStatement(namespace, pattern, partitionsSpec) =>
+ val db = namespace match {
+ case Some(namespace) if namespace.length != 1 =>
+ throw new AnalysisException(
+ s"The database name is not valid: ${namespace.quoted}")
+ case _ => namespace.map(_.head)
+ }
+ ShowTablesCommand(db, Some(pattern), true, partitionsSpec)
+
case AnalyzeTableStatement(tableName, partitionSpec, noScan) =>
val v1TableName = parseV1Table(tableName, "ANALYZE TABLE")
if (partitionSpec.isEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index b0fe4b741479f..88f5673aa9a1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -230,7 +230,7 @@ case class FileSourceScanExec(
// call the file index for the files matching all filters except dynamic partition filters
val predicate = dynamicPartitionFilters.reduce(And)
val partitionColumns = relation.partitionSchema
- val boundPredicate = newPredicate(predicate.transform {
+ val boundPredicate = Predicate.create(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
index 75abac4cfd1da..d4e10b3ffc733 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
@@ -22,9 +22,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
+import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.IntervalStyle._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* Runs a query returning the result in Hive compatible form.
@@ -97,7 +100,12 @@ object HiveResult {
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
case (decimal, DecimalType()) => decimal.toString
- case (interval, CalendarIntervalType) => interval.toString
+ case (interval: CalendarInterval, CalendarIntervalType) =>
+ SQLConf.get.intervalOutputStyle match {
+ case SQL_STANDARD => toSqlStandardString(interval)
+ case ISO_8601 => toIso8601String(interval)
+ case MULTI_UNITS => toMultiUnitsString(interval)
+ }
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
@@ -120,6 +128,12 @@ object HiveResult {
DateTimeUtils.timestampToString(timestampFormatter, DateTimeUtils.fromJavaTimestamp(t))
case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8)
case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal)
+ case (interval: CalendarInterval, CalendarIntervalType) =>
+ SQLConf.get.intervalOutputStyle match {
+ case SQL_STANDARD => toSqlStandardString(interval)
+ case ISO_8601 => toIso8601String(interval)
+ case MULTI_UNITS => toMultiUnitsString(interval)
+ }
case (interval, CalendarIntervalType) => interval.toString
case (other, _ : UserDefinedType[_]) => other.toString
case (other, tpe) if primitiveTypes.contains(tpe) => other.toString
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index f9394473d06e0..258f9cea05b82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -141,6 +141,7 @@ class QueryExecution(
ExplainUtils.processPlan(executedPlan, concat.append)
} catch {
case e: AnalysisException => concat.append(e.toString)
+ case e: IllegalArgumentException => concat.append(e.toString)
}
} else {
QueryPlan.append(executedPlan, concat.append, verbose = false, addSuffix = false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 24f664ca595c7..6b6ca531c6d3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -71,7 +71,7 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
- val ordering = newOrdering(sortOrder, output)
+ val ordering = RowOrdering.create(sortOrder, output)
// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 125f76282e3df..ef9f38b8f9927 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -21,10 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
-
-import org.codehaus.commons.compiler.CompileException
-import org.codehaus.janino.InternalCompilerException
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
@@ -33,13 +29,11 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
object SparkPlan {
@@ -73,16 +67,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val id: Int = SparkPlan.newPlanId()
- // sqlContext will be null when SparkPlan nodes are created without the active sessions.
- val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
- sqlContext.conf.subexpressionEliminationEnabled
- } else {
- false
- }
-
- // whether we should fallback when hitting compilation errors caused by codegen
- private val codeGenFallBack = (sqlContext == null) || sqlContext.conf.codegenFallback
-
/**
* Return true if this stage of the plan supports columnar execution.
*/
@@ -463,51 +447,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
buf.toArray
}
- protected def newMutableProjection(
- expressions: Seq[Expression],
- inputSchema: Seq[Attribute],
- useSubexprElimination: Boolean = false): MutableProjection = {
- log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
- MutableProjection.create(expressions, inputSchema)
- }
-
- private def genInterpretedPredicate(
- expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
- val str = expression.toString
- val logMessage = if (str.length > 256) {
- str.substring(0, 256 - 3) + "..."
- } else {
- str
- }
- logWarning(s"Codegen disabled for this expression:\n $logMessage")
- InterpretedPredicate.create(expression, inputSchema)
- }
-
- protected def newPredicate(
- expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
- try {
- GeneratePredicate.generate(expression, inputSchema)
- } catch {
- case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack =>
- genInterpretedPredicate(expression, inputSchema)
- }
- }
-
- protected def newOrdering(
- order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
- GenerateOrdering.generate(order, inputSchema)
- }
-
- /**
- * Creates a row ordering for the given schema, in natural ascending order.
- */
- protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = {
- val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
- case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending)
- }
- newOrdering(order, Seq.empty)
- }
-
/**
* Cleans up the resources used by the physical operator (if any). In general, all the resources
* should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 9351b074c6590..ac66a71fe7ec0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -62,15 +62,22 @@ private[execution] object SparkPlanInfo {
new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
}
+ val nodeName = plan match {
+ case physicalOperator: WholeStageCodegenExec =>
+ s"${plan.nodeName} (${physicalOperator.codegenStageId})"
+ case _ => plan.nodeName
+ }
+
// dump the file scan metadata (e.g file path) to event log
val metadata = plan match {
case fileScan: FileSourceScanExec => fileScan.metadata
case _ => Map[String, String]()
}
new SparkPlanInfo(
- plan.nodeName,
+ nodeName,
plan.simpleString(SQLConf.get.maxToStringFields),
children.map(fromSparkPlan),
- metadata, metrics)
+ metadata,
+ metrics)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index b1271ad870565..44e60767e6b1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -89,23 +89,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ResetCommand
}
- /**
- * Create a [[ShowTablesCommand]] logical plan.
- * Example SQL :
- * {{{
- * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards'
- * [PARTITION(partition_spec)];
- * }}}
- */
- override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) {
- val partitionSpec = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)
- ShowTablesCommand(
- Option(ctx.db).map(_.getText),
- Option(ctx.pattern).map(string),
- isExtended = true,
- partitionSpec = partitionSpec)
- }
-
/**
* Create a [[RefreshResource]] logical plan.
*/
@@ -244,49 +227,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty))
}
- /**
- * Create an [[AlterDatabasePropertiesCommand]] command.
- *
- * For example:
- * {{{
- * ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...);
- * }}}
- */
- override def visitSetDatabaseProperties(
- ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) {
- AlterDatabasePropertiesCommand(
- ctx.db.getText,
- visitPropertyKeyValues(ctx.tablePropertyList))
- }
-
- /**
- * Create an [[AlterDatabaseSetLocationCommand]] command.
- *
- * For example:
- * {{{
- * ALTER (DATABASE|SCHEMA) database SET LOCATION path;
- * }}}
- */
- override def visitSetDatabaseLocation(
- ctx: SetDatabaseLocationContext): LogicalPlan = withOrigin(ctx) {
- AlterDatabaseSetLocationCommand(
- ctx.db.getText,
- visitLocationSpec(ctx.locationSpec)
- )
- }
-
- /**
- * Create a [[DescribeDatabaseCommand]] command.
- *
- * For example:
- * {{{
- * DESCRIBE DATABASE [EXTENDED] database;
- * }}}
- */
- override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) {
- DescribeDatabaseCommand(ctx.db.getText, ctx.EXTENDED != null)
- }
-
/**
* Create a plan for a DESCRIBE FUNCTION command.
*/
@@ -376,22 +316,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx.TEMPORARY != null)
}
- /**
- * Create a [[AlterTableRenameCommand]] command.
- *
- * For example:
- * {{{
- * ALTER TABLE table1 RENAME TO table2;
- * ALTER VIEW view1 RENAME TO view2;
- * }}}
- */
- override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) {
- AlterTableRenameCommand(
- visitTableIdentifier(ctx.from),
- visitTableIdentifier(ctx.to),
- ctx.VIEW != null)
- }
-
/**
* Convert a nested constants list into a sequence of string sequences.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 8d4731f34ddd6..b4eea620b93a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -82,7 +82,6 @@ case class AdaptiveSparkPlanExec(
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
- OptimizeLocalShuffleReader(conf),
ensureRequirements
)
@@ -90,16 +89,10 @@ case class AdaptiveSparkPlanExec(
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
ReuseAdaptiveSubquery(conf, subqueryCache),
-
- // When adding local shuffle readers in 'OptimizeLocalShuffleReader`, we revert all the local
- // readers if additional shuffles are introduced. This may be too conservative: maybe there is
- // only one local reader that introduces shuffle, and we can still keep other local readers.
- // Here we re-execute this rule with the sub-plan-tree of a query stage, to make sure necessary
- // local readers are added before executing the query stage.
- // This rule must be executed before `ReduceNumShufflePartitions`, as local shuffle readers
- // can't change number of partitions.
- OptimizeLocalShuffleReader(conf),
ReduceNumShufflePartitions(conf),
+ // The rule of 'OptimizeLocalShuffleReader' need to make use of the 'partitionStartIndices'
+ // in 'ReduceNumShufflePartitions' rule. So it must be after 'ReduceNumShufflePartitions' rule.
+ OptimizeLocalShuffleReader(conf),
ApplyColumnarRulesAndInsertTransitions(session.sessionState.conf,
session.sessionState.columnarRules),
CollapseCodegenStages(conf)
@@ -133,10 +126,8 @@ case class AdaptiveSparkPlanExec(
override def doCanonicalize(): SparkPlan = initialPlan.canonicalized
- override def doExecute(): RDD[InternalRow] = lock.synchronized {
- if (isFinalPlan) {
- currentPhysicalPlan.execute()
- } else {
+ private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
+ if (!isFinalPlan) {
// Make sure we only update Spark UI if this plan's `QueryExecution` object matches the one
// retrieved by the `sparkContext`'s current execution ID. Note that sub-queries do not have
// their own execution IDs and therefore rely on the main query to update UI.
@@ -217,12 +208,21 @@ case class AdaptiveSparkPlanExec(
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(result.newPlan, queryStageOptimizerRules)
isFinalPlan = true
-
- val ret = currentPhysicalPlan.execute()
logDebug(s"Final plan: $currentPhysicalPlan")
- executionId.foreach(onUpdatePlan)
- ret
}
+ currentPhysicalPlan
+ }
+
+ override def executeCollect(): Array[InternalRow] = {
+ getFinalPhysicalPlan().executeCollect()
+ }
+
+ override def executeTake(n: Int): Array[InternalRow] = {
+ getFinalPhysicalPlan().executeTake(n)
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ getFinalPhysicalPlan().execute()
}
override def verboseString(maxFields: Int): String = simpleString(maxFields)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
index 5fccb5ce65783..6385ea67c49fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
@@ -17,20 +17,24 @@
package org.apache.spark.sql.execution.adaptive
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter}
-
/**
- * The [[Partition]] used by [[LocalShuffledRowRDD]]. A pre-shuffle partition
- * (identified by `preShufflePartitionIndex`) contains a range of post-shuffle partitions
- * (`startPostShufflePartitionIndex` to `endPostShufflePartitionIndex - 1`, inclusive).
+ * The [[Partition]] used by [[LocalShuffledRowRDD]].
+ * @param mapIndex the index of mapper.
+ * @param startPartition the start partition ID in mapIndex mapper.
+ * @param endPartition the end partition ID in mapIndex mapper.
*/
private final class LocalShuffledRowRDDPartition(
- val preShufflePartitionIndex: Int) extends Partition {
- override val index: Int = preShufflePartitionIndex
+ override val index: Int,
+ val mapIndex: Int,
+ val startPartition: Int,
+ val endPartition: Int) extends Partition {
}
/**
@@ -40,7 +44,7 @@ private final class LocalShuffledRowRDDPartition(
* data of another input table of the join that reads from shuffle. Each partition of the RDD reads
* the whole data from just one mapper output locally. So actually there is no data transferred
* from the network.
-
+ *
* This RDD takes a [[ShuffleDependency]] (`dependency`).
*
* The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle
@@ -49,10 +53,15 @@ private final class LocalShuffledRowRDDPartition(
* `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions. (i.e. the number
* of partitions of the map output). The post-shuffle partition number is the same to the parent
* RDD's partition number.
+ *
+ * `partitionStartIndicesPerMapper` specifies how to split the shuffle blocks of each mapper into
+ * one or more partitions. For a mapper `i`, the `j`th partition includes shuffle blocks from
+ * `partitionStartIndicesPerMapper[i][j]` to `partitionStartIndicesPerMapper[i][j+1]` (exclusive).
*/
class LocalShuffledRowRDD(
var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
- metrics: Map[String, SQLMetric])
+ metrics: Map[String, SQLMetric],
+ partitionStartIndicesPerMapper: Array[Array[Int]])
extends RDD[InternalRow](dependency.rdd.context, Nil) {
private[this] val numReducers = dependency.partitioner.numPartitions
@@ -61,10 +70,14 @@ class LocalShuffledRowRDD(
override def getDependencies: Seq[Dependency[_]] = List(dependency)
override def getPartitions: Array[Partition] = {
-
- Array.tabulate[Partition](numMappers) { i =>
- new LocalShuffledRowRDDPartition(i)
+ val partitions = ArrayBuffer[LocalShuffledRowRDDPartition]()
+ for (mapIndex <- 0 until numMappers) {
+ (partitionStartIndicesPerMapper(mapIndex) :+ numReducers).sliding(2, 1).foreach {
+ case Array(start, end) =>
+ partitions += new LocalShuffledRowRDDPartition(partitions.length, mapIndex, start, end)
+ }
}
+ partitions.toArray
}
override def getPreferredLocations(partition: Partition): Seq[String] = {
@@ -74,17 +87,16 @@ class LocalShuffledRowRDD(
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val localRowPartition = split.asInstanceOf[LocalShuffledRowRDDPartition]
- val mapIndex = localRowPartition.index
+ val mapIndex = localRowPartition.mapIndex
val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
-
val reader = SparkEnv.get.shuffleManager.getReaderForOneMapper(
dependency.shuffleHandle,
mapIndex,
- 0,
- numReducers,
+ localRowPartition.startPartition,
+ localRowPartition.endPartition,
context,
sqlMetricsReporter)
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index 87d745bf976ab..176e5ec8312e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -27,87 +27,139 @@ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExcha
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.internal.SQLConf
-object BroadcastJoinWithShuffleLeft {
- def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match {
- case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) =>
- Some((join.left.asInstanceOf[QueryStageExec], join.buildSide))
- case _ => None
- }
-}
-
-object BroadcastJoinWithShuffleRight {
- def unapply(plan: SparkPlan): Option[(QueryStageExec, BuildSide)] = plan match {
- case join: BroadcastHashJoinExec if ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) =>
- Some((join.right.asInstanceOf[QueryStageExec], join.buildSide))
- case _ => None
- }
-}
-
/**
- * A rule to optimize the shuffle reader to local reader as far as possible
- * when converting the 'SortMergeJoinExec' to 'BroadcastHashJoinExec' in runtime.
- *
- * This rule can be divided into two steps:
- * Step1: Add the local reader in probe side and then check whether additional
- * shuffle introduced. If introduced, we will revert all the local
- * reader in probe side.
- * Step2: Add the local reader in build side and will not check whether
- * additional shuffle introduced. Because the build side will not introduce
- * additional shuffle.
+ * A rule to optimize the shuffle reader to local reader iff no additional shuffles
+ * will be introduced:
+ * 1. if the input plan is a shuffle, add local reader directly as we can never introduce
+ * extra shuffles in this case.
+ * 2. otherwise, add local reader to the probe side of broadcast hash join and
+ * then run `EnsureRequirements` to check whether additional shuffle introduced.
+ * If introduced, we will revert all the local readers.
*/
case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
+ import OptimizeLocalShuffleReader._
- override def apply(plan: SparkPlan): SparkPlan = {
- if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
- return plan
- }
- // Add local reader in probe side.
- val withProbeSideLocalReader = plan.transformDown {
+ private val ensureRequirements = EnsureRequirements(conf)
+
+ // The build side is a broadcast query stage which should have been optimized using local reader
+ // already. So we only need to deal with probe side here.
+ private def createProbeSideLocalReader(plan: SparkPlan): SparkPlan = {
+ val optimizedPlan = plan.transformDown {
case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildRight) =>
- val localReader = LocalShuffleReaderExec(shuffleStage)
+ val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildLeft) =>
- val localReader = LocalShuffleReaderExec(shuffleStage)
+ val localReader = createLocalReader(shuffleStage)
join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
}
- def numExchanges(plan: SparkPlan): Int = {
- plan.collect {
- case e: ShuffleExchangeExec => e
- }.length
- }
+ val numShuffles = ensureRequirements.apply(optimizedPlan).collect {
+ case e: ShuffleExchangeExec => e
+ }.length
+
// Check whether additional shuffle introduced. If introduced, revert the local reader.
- val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan))
- val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(withProbeSideLocalReader))
- val optimizedPlan = if (numExchangeAfter > numExchangeBefore) {
- logDebug("OptimizeLocalShuffleReader rule is not applied in the probe side due" +
+ if (numShuffles > 0) {
+ logDebug("OptimizeLocalShuffleReader rule is not applied due" +
" to additional shuffles will be introduced.")
plan
} else {
- withProbeSideLocalReader
+ optimizedPlan
}
- // Add the local reader in build side and and do not need to check whether
- // additional shuffle introduced.
- optimizedPlan.transformDown {
- case join @ BroadcastJoinWithShuffleLeft(shuffleStage, BuildLeft) =>
- val localReader = LocalShuffleReaderExec(shuffleStage)
- join.asInstanceOf[BroadcastHashJoinExec].copy(left = localReader)
- case join @ BroadcastJoinWithShuffleRight(shuffleStage, BuildRight) =>
- val localReader = LocalShuffleReaderExec(shuffleStage)
- join.asInstanceOf[BroadcastHashJoinExec].copy(right = localReader)
+ }
+
+ private def createLocalReader(plan: SparkPlan): LocalShuffleReaderExec = {
+ plan match {
+ case c @ CoalescedShuffleReaderExec(q: QueryStageExec, _) =>
+ LocalShuffleReaderExec(
+ q, getPartitionStartIndices(q, Some(c.partitionStartIndices.length)))
+ case q: QueryStageExec =>
+ LocalShuffleReaderExec(q, getPartitionStartIndices(q, None))
+ }
+ }
+
+ // TODO: this method assumes all shuffle blocks are the same data size. We should calculate the
+ // partition start indices based on block size to avoid data skew.
+ private def getPartitionStartIndices(
+ shuffle: QueryStageExec,
+ advisoryParallelism: Option[Int]): Array[Array[Int]] = {
+ val shuffleDep = shuffle match {
+ case s: ShuffleQueryStageExec => s.plan.shuffleDependency
+ case ReusedQueryStageExec(_, s: ShuffleQueryStageExec, _) => s.plan.shuffleDependency
}
+ val numReducers = shuffleDep.partitioner.numPartitions
+ val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
+ val numMappers = shuffleDep.rdd.getNumPartitions
+ Array.fill(numMappers) {
+ equallyDivide(numReducers, math.max(1, expectedParallelism / numMappers)).toArray
+ }
+ }
+
+ /**
+ * To equally divide n elements into m buckets, basically each bucket should have n/m elements,
+ * for the remaining n%m elements, add one more element to the first n%m buckets each. Returns
+ * a sequence with length numBuckets and each value represents the start index of each bucket.
+ */
+ private def equallyDivide(numElements: Int, numBuckets: Int): Seq[Int] = {
+ val elementsPerBucket = numElements / numBuckets
+ val remaining = numElements % numBuckets
+ val splitPoint = (elementsPerBucket + 1) * remaining
+ (0 until remaining).map(_ * (elementsPerBucket + 1)) ++
+ (remaining until numBuckets).map(i => splitPoint + (i - remaining) * elementsPerBucket)
+ }
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!conf.getConf(SQLConf.LOCAL_SHUFFLE_READER_ENABLED)) {
+ return plan
+ }
+
+ plan match {
+ case s: SparkPlan if canUseLocalShuffleReader(s) =>
+ createLocalReader(s)
+ case s: SparkPlan =>
+ createProbeSideLocalReader(s)
+ }
+ }
+}
+
+object OptimizeLocalShuffleReader {
+
+ object BroadcastJoinWithShuffleLeft {
+ def unapply(plan: SparkPlan): Option[(SparkPlan, BuildSide)] = plan match {
+ case join: BroadcastHashJoinExec if canUseLocalShuffleReader(join.left) =>
+ Some((join.left, join.buildSide))
+ case _ => None
+ }
+ }
+
+ object BroadcastJoinWithShuffleRight {
+ def unapply(plan: SparkPlan): Option[(SparkPlan, BuildSide)] = plan match {
+ case join: BroadcastHashJoinExec if canUseLocalShuffleReader(join.right) =>
+ Some((join.right, join.buildSide))
+ case _ => None
+ }
+ }
+
+ def canUseLocalShuffleReader(plan: SparkPlan): Boolean = {
+ ShuffleQueryStageExec.isShuffleQueryStageExec(plan) ||
+ plan.isInstanceOf[CoalescedShuffleReaderExec]
}
}
/**
- * A wrapper of shuffle query stage, which submits one reduce task per mapper to read the shuffle
- * files written by one mapper. By doing this, it's very likely to read the shuffle files locally,
- * as the shuffle files that a reduce task needs to read are in one node.
+ * A wrapper of shuffle query stage, which submits one or more reduce tasks per mapper to read the
+ * shuffle files written by one mapper. By doing this, it's very likely to read the shuffle files
+ * locally, as the shuffle files that a reduce task needs to read are in one node.
*
* @param child It's usually `ShuffleQueryStageExec` or `ReusedQueryStageExec`, but can be the
* shuffle exchange node during canonicalization.
+ * @param partitionStartIndicesPerMapper A mapper usually writes many shuffle blocks, and it's
+ * better to launch multiple tasks to read shuffle blocks of
+ * one mapper. This array contains the partition start
+ * indices for each mapper.
*/
-case class LocalShuffleReaderExec(child: SparkPlan) extends UnaryExecNode {
+case class LocalShuffleReaderExec(
+ child: SparkPlan,
+ partitionStartIndicesPerMapper: Array[Array[Int]]) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
@@ -124,9 +176,9 @@ case class LocalShuffleReaderExec(child: SparkPlan) extends UnaryExecNode {
if (cachedShuffleRDD == null) {
cachedShuffleRDD = child match {
case stage: ShuffleQueryStageExec =>
- stage.plan.createLocalShuffleRDD()
+ stage.plan.createLocalShuffleRDD(partitionStartIndicesPerMapper)
case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) =>
- stage.plan.createLocalShuffleRDD()
+ stage.plan.createLocalShuffleRDD(partitionStartIndicesPerMapper)
}
}
cachedShuffleRDD
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 95bef308e453d..ad8976c77b16a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -126,7 +126,7 @@ case class HashAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
- newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
+ MutableProjection.create(expressions, inputSchema),
child.output,
iter,
testFallbackStartsAt,
@@ -486,10 +486,9 @@ case class HashAggregateExec(
// Create a MutableProjection to merge the rows of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
- val mergeProjection = newMutableProjection(
+ val mergeProjection = MutableProjection.create(
mergeExpr,
- aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
- subexpressionEliminationEnabled)
+ aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
val joinedRow = new JoinedRow()
var currentKey: UnsafeRow = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index b88ddba8e48d3..1f325c11c9e44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index 151da241144be..953622afebf89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -122,7 +122,7 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
- newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
+ MutableProjection.create(expressions, inputSchema),
child.output,
iter,
fallbackCountThreshold,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index 7ab6ecc08a7bc..0ddf95771d5b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -93,7 +93,7 @@ case class SortAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
- newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
+ MutableProjection.create(expressions, inputSchema),
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
// There is no input and there is no grouping expressions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 3ed42f359c0a4..e128d59dca6ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{LongType, StructType}
@@ -227,7 +226,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
- val predicate = newPredicate(condition, child.output)
+ val predicate = Predicate.create(condition, child.output)
predicate.initialize(0)
iter.filter { row =>
val r = predicate.eval(row)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 8d13cfb93d270..f03c2586048bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -310,7 +310,7 @@ case class InMemoryTableScanExec(
val buffers = relation.cacheBuilder.cachedColumnBuffers
buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
- val partitionFilter = newPredicate(
+ val partitionFilter = Predicate.create(
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
schema)
partitionFilter.initialize(index)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index 3adec2f790730..21ddeb6491155 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -171,7 +171,7 @@ abstract class PartitioningAwareFileIndex(
if (partitionPruningPredicates.nonEmpty) {
val predicate = partitionPruningPredicates.reduce(expressions.And)
- val boundPredicate = InterpretedPredicate.create(predicate.transform {
+ val boundPredicate = Predicate.createInterpreted(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index d4c7f005a16df..c1e1aed83bae5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -170,8 +170,8 @@ object JdbcUtils extends Logging {
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
- case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
- case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT))
+ case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
+ case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
@@ -235,7 +235,7 @@ object JdbcUtils extends Logging {
case java.sql.Types.REF => StringType
case java.sql.Types.REF_CURSOR => null
case java.sql.Types.ROWID => LongType
- case java.sql.Types.SMALLINT => ShortType
+ case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
@@ -244,7 +244,7 @@ object JdbcUtils extends Logging {
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE
=> null
- case java.sql.Types.TINYINT => ByteType
+ case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ =>
@@ -445,7 +445,7 @@ object JdbcUtils extends Logging {
case ByteType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
- row.update(pos, rs.getByte(pos + 1))
+ row.setByte(pos, rs.getByte(pos + 1))
case StringType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
@@ -546,11 +546,11 @@ object JdbcUtils extends Logging {
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
- stmt.setShort(pos + 1, row.getShort(pos))
+ stmt.setInt(pos + 1, row.getShort(pos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
- stmt.setByte(pos + 1, row.getByte(pos))
+ stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
index 3f4f29c3e135a..03e5f43a2a0af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider}
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
@@ -58,7 +58,8 @@ private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate
}
private[noop] object NoopBatchWrite extends BatchWrite {
- override def createBatchWriterFactory(): DataWriterFactory = NoopWriterFactory
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
+ NoopWriterFactory
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
}
@@ -74,8 +75,8 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] {
}
private[noop] object NoopStreamingWrite extends StreamingWrite {
- override def createStreamingWriterFactory(): StreamingDataWriterFactory =
- NoopStreamingDataWriterFactory
+ override def createStreamingWriterFactory(
+ info: PhysicalWriteInfo): StreamingDataWriterFactory = NoopStreamingDataWriterFactory
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala
new file mode 100644
index 0000000000000..1eebe4cdb6a86
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterNamespaceSetPropertiesExec.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespaces}
+
+/**
+ * Physical plan node for setting properties of namespace.
+ */
+case class AlterNamespaceSetPropertiesExec(
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ props: Map[String, String]) extends V2CommandExec {
+ override protected def run(): Seq[InternalRow] = {
+ val changes = props.map{ case (k, v) =>
+ NamespaceChange.setProperty(k, v)
+ }.toSeq
+ catalog.alterNamespace(namespace.toArray, changes: _*)
+ Seq.empty
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 01ff4a9303e98..a0d10f1d09e63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.{AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropNamespace, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowCurrentNamespace, ShowNamespaces, ShowTableProperties, ShowTables}
+import org.apache.spark.sql.catalyst.plans.logical.{AlterNamespaceSetProperties, AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeNamespace, DescribeTable, DropNamespace, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, RenameTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowCurrentNamespace, ShowNamespaces, ShowTableProperties, ShowTables}
import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability}
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
@@ -192,6 +192,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
Nil
}
+ case desc @ DescribeNamespace(catalog, namespace, extended) =>
+ DescribeNamespaceExec(desc.output, catalog, namespace, extended) :: Nil
+
case desc @ DescribeTable(DataSourceV2Relation(table, _, _), isExtended) =>
DescribeTableExec(desc.output, table, isExtended) :: Nil
@@ -201,6 +204,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
case AlterTable(catalog, ident, _, changes) =>
AlterTableExec(catalog, ident, changes) :: Nil
+ case RenameTable(catalog, oldIdent, newIdent) =>
+ RenameTableExec(catalog, oldIdent, newIdent) :: Nil
+
+ case AlterNamespaceSetProperties(catalog, namespace, properties) =>
+ AlterNamespaceSetPropertiesExec(catalog, namespace, properties) :: Nil
+
case CreateNamespace(catalog, namespace, ifNotExists, properties) =>
CreateNamespaceExec(catalog, namespace, ifNotExists, properties) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
new file mode 100644
index 0000000000000..7c5cfcbbc7e3c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeNamespaceExec.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRowWithSchema}
+import org.apache.spark.sql.connector.catalog.SupportsNamespaces
+import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog.COMMENT_TABLE_PROP
+import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog.LOCATION_TABLE_PROP
+import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog.RESERVED_PROPERTIES
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Physical plan node for describing a namespace.
+ */
+case class DescribeNamespaceExec(
+ output: Seq[Attribute],
+ catalog: SupportsNamespaces,
+ namespace: Seq[String],
+ isExtended: Boolean) extends V2CommandExec {
+ private val encoder = RowEncoder(StructType.fromAttributes(output)).resolveAndBind()
+
+ override protected def run(): Seq[InternalRow] = {
+ val rows = new ArrayBuffer[InternalRow]()
+ val ns = namespace.toArray
+ val metadata = catalog.loadNamespaceMetadata(ns)
+
+ rows += toCatalystRow("Namespace Name", ns.last)
+ rows += toCatalystRow("Description", metadata.get(COMMENT_TABLE_PROP))
+ rows += toCatalystRow("Location", metadata.get(LOCATION_TABLE_PROP))
+ if (isExtended) {
+ val properties = metadata.asScala.toSeq.filter(p => !RESERVED_PROPERTIES.contains(p._1))
+ if (properties.nonEmpty) {
+ rows += toCatalystRow("Properties", properties.mkString("(", ",", ")"))
+ }
+ }
+ rows
+ }
+
+ private def toCatalystRow(strs: String*): InternalRow = {
+ encoder.toRow(new GenericRowWithSchema(strs.toArray, schema)).copy()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala
index e7d9a247533c4..266c834909363 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala
@@ -20,7 +20,7 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult}
import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats
@@ -44,7 +44,7 @@ class FileBatchWrite(
committer.abortJob(job)
}
- override def createBatchWriterFactory(): DataWriterFactory = {
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
FileWriterFactory(description, committer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala
new file mode 100644
index 0000000000000..a650607d5f129
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RenameTableExec.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
+
+/**
+ * Physical plan node for renaming a table.
+ */
+case class RenameTableExec(
+ catalog: TableCatalog,
+ oldIdent: Identifier,
+ newIdent: Identifier) extends V2CommandExec {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override protected def run(): Seq[InternalRow] = {
+ catalog.invalidateTable(oldIdent)
+ catalog.renameTable(oldIdent, newIdent)
+
+ Seq.empty
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 9f4392da6ab4d..7d8a115c126eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -353,17 +353,20 @@ trait V2TableWriteExec extends UnaryExecNode {
override def output: Seq[Attribute] = Nil
protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
- val writerFactory = batchWrite.createBatchWriterFactory()
- val useCommitCoordinator = batchWrite.useCommitCoordinator
- val rdd = query.execute()
- // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
- // partition rdd to make sure we at least set up one write task to write the metadata.
- val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
- sparkContext.parallelize(Array.empty[InternalRow], 1)
- } else {
- rdd
+ val rdd: RDD[InternalRow] = {
+ val tempRdd = query.execute()
+ // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
+ // partition rdd to make sure we at least set up one write task to write the metadata.
+ if (tempRdd.partitions.length == 0) {
+ sparkContext.parallelize(Array.empty[InternalRow], 1)
+ } else {
+ tempRdd
+ }
}
- val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length)
+ val writerFactory = batchWrite.createBatchWriterFactory(
+ PhysicalWriteInfoImpl(rdd.getNumPartitions))
+ val useCommitCoordinator = batchWrite.useCommitCoordinator
+ val messages = new Array[WriterCommitMessage](rdd.partitions.length)
val totalNumRowsAccumulator = new LongAccumulator()
logInfo(s"Start processing data source write support: $batchWrite. " +
@@ -371,10 +374,10 @@ trait V2TableWriteExec extends UnaryExecNode {
try {
sparkContext.runJob(
- rddWithNonEmptyPartitions,
+ rdd,
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
- rddWithNonEmptyPartitions.partitions.indices,
+ rdd.partitions.indices,
(index, result: DataWritingSparkTaskResult) => {
val commitMessage = result.writerCommitMessage
messages(index) = commitMessage
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index c56a5c015f32d..866b382a1d808 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -83,7 +83,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
numPartitionsSet.headOption
}
- val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max)
+ // If there are non-shuffle children that satisfy the required distribution, we have
+ // some tradeoffs when picking the expected number of shuffle partitions:
+ // 1. We should avoid shuffling these children.
+ // 2. We should have a reasonable parallelism.
+ val nonShuffleChildrenNumPartitions =
+ childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
+ .map(_.outputPartitioning.numPartitions)
+ val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
+ // Here we pick the max number of partitions among these non-shuffle children as the
+ // expected number of shuffle partitions. However, if it's smaller than
+ // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the
+ // expected number of shuffle partitions.
+ math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions)
+ } else {
+ childrenNumPartitions.max
+ }
+
+ val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)
children = children.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, distribution), index) if childrenIndexes.contains(index) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 2f94c522712b1..b876183c78ec2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -83,8 +83,9 @@ case class ShuffleExchangeExec(
new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices)
}
- def createLocalShuffleRDD(): LocalShuffledRowRDD = {
- new LocalShuffledRowRDD(shuffleDependency, readMetrics)
+ def createLocalShuffleRDD(
+ partitionStartIndicesPerMapper: Array[Array[Int]]): LocalShuffledRowRDD = {
+ new LocalShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndicesPerMapper)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index f526a19876670..5517c0dcdb188 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -19,14 +19,12 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
case class BroadcastNestedLoopJoinExec(
@@ -84,7 +82,7 @@ case class BroadcastNestedLoopJoinExec(
@transient private lazy val boundCondition = {
if (condition.isDefined) {
- newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
+ Predicate.create(condition.get, streamed.output ++ broadcast.output).eval _
} else {
(r: InternalRow) => true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 88d98530991c9..29645a736548c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.CompletionIterator
@@ -93,7 +92,7 @@ case class CartesianProductExec(
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
- val boundCondition = newPredicate(condition.get, left.output ++ right.output)
+ val boundCondition = Predicate.create(condition.get, left.output ++ right.output)
boundCondition.initialize(index)
val joined = new JoinedRow
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index e8938cb22e890..137f0b87a2f3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -99,7 +99,7 @@ trait HashJoin {
UnsafeProjection.create(streamedKeys)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
+ Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
} else {
(r: InternalRow) => true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 26fb0e5ffb1af..f327e84563da9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -168,14 +168,14 @@ case class SortMergeJoinExec(
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
- newPredicate(cond, left.output ++ right.output).eval _
+ Predicate.create(cond, left.output ++ right.output).eval _
}.getOrElse {
(r: InternalRow) => true
}
}
// An ordering that can be used to compare keys from both sides.
- val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+ val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index 3554bdb5c9e0c..a0f23e925d237 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -113,7 +113,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute],
}
}.toArray
}.toArray
- val projection = newMutableProjection(allInputs, child.output)
+ val projection = MutableProjection.create(allInputs, child.output)
val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index aac5da8104a8b..59ce7c3707b27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, Processing
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.execution.streaming.GroupStateImpl._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.UTF8String
/**
@@ -160,7 +160,7 @@ private[sql] class GroupStateImpl[S] private(
def getTimeoutTimestamp: Long = timeoutTimestamp
private def parseDuration(duration: String): Long = {
- val cal = IntervalUtils.fromString(duration)
+ val cal = IntervalUtils.stringToInterval(UTF8String.fromString(duration))
if (IntervalUtils.isNegative(cal)) {
throw new IllegalArgumentException(s"Provided duration ($duration) is negative")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 6bb4dc1672900..f1bfe97610fed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._
import org.apache.spark.sql.catalyst.plans.physical._
@@ -233,8 +233,9 @@ case class StreamingSymmetricHashJoinExec(
val joinedRow = new JoinedRow
+ val inputSchema = left.output ++ right.output
val postJoinFilter =
- newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _
+ Predicate.create(condition.bothSides.getOrElse(Literal(true)), inputSchema).eval _
val leftSideJoiner = new OneSideHashJoiner(
LeftSide, left.output, leftKeys, leftInputIter,
condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left)
@@ -417,7 +418,7 @@ case class StreamingSymmetricHashJoinExec(
// Filter the joined rows based on the given condition.
val preJoinFilter =
- newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _
+ Predicate.create(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _
private val joinStateManager = new SymmetricHashJoinStateManager(
joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value,
@@ -428,16 +429,16 @@ case class StreamingSymmetricHashJoinExec(
case Some(JoinStateKeyWatermarkPredicate(expr)) =>
// inputSchema can be empty as expr should only have BoundReferences and does not require
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
- newPredicate(expr, Seq.empty).eval _
+ Predicate.create(expr, Seq.empty).eval _
case _ =>
- newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
+ Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
}
private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateValueWatermarkPredicate(expr)) =>
- newPredicate(expr, inputAttributes).eval _
+ Predicate.create(expr, inputAttributes).eval _
case _ =>
- newPredicate(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
+ Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
}
private[this] var updatedStateRowsCount = 0
@@ -457,7 +458,7 @@ case class StreamingSymmetricHashJoinExec(
val nonLateRows =
WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match {
case Some(watermarkExpr) =>
- val predicate = newPredicate(watermarkExpr, inputAttributes)
+ val predicate = Predicate.create(watermarkExpr, inputAttributes)
inputIter.filter { row => !predicate.eval(row) }
case None =>
inputIter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
index 2dd287cb734bf..1a27fe61d9602 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
@@ -24,6 +24,7 @@ import scala.concurrent.duration.Duration
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.unsafe.types.UTF8String
private object Triggers {
def validate(intervalMs: Long): Unit = {
@@ -31,7 +32,7 @@ private object Triggers {
}
def convert(interval: String): Long = {
- val cal = IntervalUtils.fromString(interval)
+ val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
if (cal.months != 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index d4e522562e914..f1898ad3f27ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.write.PhysicalWriteInfoImpl
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.streaming.StreamExecution
@@ -38,8 +39,10 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl
override def output: Seq[Attribute] = Nil
override protected def doExecute(): RDD[InternalRow] = {
- val writerFactory = write.createStreamingWriterFactory()
- val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
+ val queryRdd = query.execute()
+ val writerFactory = write.createStreamingWriterFactory(
+ PhysicalWriteInfoImpl(queryRdd.getNumPartitions))
+ val rdd = new ContinuousWriteRDD(queryRdd, writerFactory)
logInfo(s"Start processing data source write support: $write. " +
s"The input RDD has ${rdd.partitions.length} partitions.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
index 6afb811a4d998..ad5c7cf24caf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.connector.write.WriterCommitMessage
+import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -38,7 +38,8 @@ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap)
assert(SparkSession.getActiveSession.isDefined)
protected val spark = SparkSession.getActiveSession.get
- def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory
+ def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
+ PackedRowWriterFactory
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
index bae7fa7d07356..53d4bca1a5f7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
-import org.apache.spark.sql.connector.write.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.execution.python.PythonForeachWriter
import org.apache.spark.sql.types.StructType
@@ -72,7 +72,8 @@ case class ForeachWriterTable[T](
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
- override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
+ override def createStreamingWriterFactory(
+ info: PhysicalWriteInfo): StreamingDataWriterFactory = {
val rowConverter: InternalRow => T = converter match {
case Left(enc) =>
val boundEnc = enc.resolveAndBind(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
index 5f12832cd2550..c2adc1dd6742a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
/**
@@ -36,8 +36,8 @@ class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends B
writeSupport.abort(eppchId, messages)
}
- override def createBatchWriterFactory(): DataWriterFactory = {
- new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory())
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
+ new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory(info))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
index 51ab5ce3578af..a976876b4d8e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability}
-import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, PhysicalWriteInfo, SupportsTruncate, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.types.StructType
@@ -140,7 +140,7 @@ class MemoryStreamingWrite(
val sink: MemorySink, schema: StructType, needTruncate: Boolean)
extends StreamingWrite {
- override def createStreamingWriterFactory: MemoryWriterFactory = {
+ override def createStreamingWriterFactory(info: PhysicalWriteInfo): MemoryWriterFactory = {
MemoryWriterFactory(schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index d689a6f3c9819..01b309c3cf345 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
@@ -156,17 +156,17 @@ trait WatermarkSupport extends UnaryExecNode {
}
/** Predicate based on keys that matches data older than the watermark */
- lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.flatMap { e =>
+ lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e =>
if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) {
- Some(newPredicate(e, keyExpressions))
+ Some(Predicate.create(e, keyExpressions))
} else {
None
}
}
/** Predicate based on the child output that matches data older than the watermark. */
- lazy val watermarkPredicateForData: Option[Predicate] =
- watermarkExpression.map(newPredicate(_, child.output))
+ lazy val watermarkPredicateForData: Option[BasePredicate] =
+ watermarkExpression.map(Predicate.create(_, child.output))
protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index f898236c537a8..bd14be702a407 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -78,7 +78,7 @@ object SparkPlanGraph {
subgraph: SparkPlanGraphCluster,
exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
planInfo.nodeName match {
- case "WholeStageCodegen" =>
+ case name if name.startsWith("WholeStageCodegen") =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
index dcb86f48bdf32..e8248b7028757 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -73,7 +73,7 @@ abstract class WindowExecBase(
RowBoundOrdering(offset)
case (RangeFrame, CurrentRow) =>
- val ordering = newOrdering(orderSpec, child.output)
+ val ordering = RowOrdering.create(orderSpec, child.output)
RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
@@ -82,7 +82,7 @@ abstract class WindowExecBase(
val expr = sortExpr.child
// Create the projection which returns the current 'value'.
- val current = newMutableProjection(expr :: Nil, child.output)
+ val current = MutableProjection.create(expr :: Nil, child.output)
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
@@ -97,13 +97,13 @@ abstract class WindowExecBase(
TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a == b => Add(expr, boundOffset)
}
- val bound = newMutableProjection(boundExpr :: Nil, child.output)
+ val bound = MutableProjection.create(boundExpr :: Nil, child.output)
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
- val ordering = newOrdering(boundSortExprs, Nil)
+ val ordering = RowOrdering.create(boundSortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case (RangeFrame, _) =>
@@ -167,7 +167,7 @@ abstract class WindowExecBase(
ordinal,
child.output,
(expressions, schema) =>
- newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
+ MutableProjection.create(expressions, schema))
}
// Create the factory
@@ -182,7 +182,7 @@ abstract class WindowExecBase(
functions.map(_.asInstanceOf[OffsetWindowFunction]),
child.output,
(expressions, schema) =>
- newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
+ MutableProjection.create(expressions, schema),
offset)
// Entire Partition Frame.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2ba34647dbca8..72e9e337c4258 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3335,7 +3335,7 @@ object functions {
* @group collection_funcs
* @since 2.4.0
*/
- def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) }
+ def array_sort(e: Column): Column = withExpr { new ArraySort(e.expr) }
/**
* Remove all elements that equal to element from the given array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index b810bedac471d..de3805e105802 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab}
import org.apache.spark.sql.internal.StaticSQLConf._
-import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.status.ElementTrackingStore
import org.apache.spark.util.Utils
@@ -52,6 +51,8 @@ private[sql] class SharedState(
initialConfigs: scala.collection.Map[String, String])
extends Logging {
+ SharedState.setFsUrlStreamHandlerFactory(sparkContext.conf)
+
// Load hive-site.xml into hadoopConf and determine the warehouse path we want to use, based on
// the config from both hive and Spark SQL. Finally set the warehouse config value to sparkConf.
val warehousePath: String = {
@@ -191,11 +192,23 @@ private[sql] class SharedState(
}
object SharedState extends Logging {
- try {
- URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory())
- } catch {
- case e: Error =>
- logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory")
+ @volatile private var fsUrlStreamHandlerFactoryInitialized = false
+
+ private def setFsUrlStreamHandlerFactory(conf: SparkConf): Unit = {
+ if (!fsUrlStreamHandlerFactoryInitialized &&
+ conf.get(DEFAULT_URL_STREAM_HANDLER_FACTORY_ENABLED)) {
+ synchronized {
+ if (!fsUrlStreamHandlerFactoryInitialized) {
+ try {
+ URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory())
+ fsUrlStreamHandlerFactoryInitialized = true
+ } catch {
+ case NonFatal(_) =>
+ logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory")
+ }
+ }
+ }
+ }
}
private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog"
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
index 92dabc79d2bff..4b23615275871 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
@@ -33,6 +33,7 @@
import org.apache.spark.util.collection.unsafe.sort.*;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -81,14 +82,14 @@ private void insertRow(UnsafeRow row) {
int recordLength = row.getSizeInBytes();
Object baseObject = dataPage.getBaseObject();
- assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size());
+ Assert.assertTrue(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size());
long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor);
UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
pageCursor += uaoSize;
Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength);
pageCursor += recordLength;
- assert(pos < 2);
+ Assert.assertTrue(pos < 2);
array.set(pos, recordAddress);
pos++;
}
@@ -141,8 +142,8 @@ public void testBinaryComparatorForSingleColumnRow() throws Exception {
insertRow(row1);
insertRow(row2);
- assert(compare(0, 0) == 0);
- assert(compare(0, 1) < 0);
+ Assert.assertEquals(0, compare(0, 0));
+ Assert.assertTrue(compare(0, 1) < 0);
}
@Test
@@ -166,8 +167,8 @@ public void testBinaryComparatorForMultipleColumnRow() throws Exception {
insertRow(row1);
insertRow(row2);
- assert(compare(0, 0) == 0);
- assert(compare(0, 1) < 0);
+ Assert.assertEquals(0, compare(0, 0));
+ Assert.assertTrue(compare(0, 1) < 0);
}
@Test
@@ -193,8 +194,8 @@ public void testBinaryComparatorForArrayColumn() throws Exception {
insertRow(row1);
insertRow(row2);
- assert(compare(0, 0) == 0);
- assert(compare(0, 1) > 0);
+ Assert.assertEquals(0, compare(0, 0));
+ Assert.assertTrue(compare(0, 1) > 0);
}
@Test
@@ -226,8 +227,8 @@ public void testBinaryComparatorForMixedColumns() throws Exception {
insertRow(row1);
insertRow(row2);
- assert(compare(0, 0) == 0);
- assert(compare(0, 1) > 0);
+ Assert.assertEquals(0, compare(0, 0));
+ Assert.assertTrue(compare(0, 1) > 0);
}
@Test
@@ -252,8 +253,8 @@ public void testBinaryComparatorForNullColumns() throws Exception {
insertRow(row1);
insertRow(row2);
- assert(compare(0, 0) == 0);
- assert(compare(0, 1) > 0);
+ Assert.assertEquals(0, compare(0, 0));
+ Assert.assertTrue(compare(0, 1) > 0);
}
@Test
@@ -273,7 +274,7 @@ public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws
insertRow(row1);
insertRow(row2);
- assert(compare(0, 1) < 0);
+ Assert.assertTrue(compare(0, 1) > 0);
}
@Test
@@ -293,7 +294,7 @@ public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exc
insertRow(row1);
insertRow(row2);
- assert(compare(0, 1) < 0);
+ Assert.assertTrue(compare(0, 1) < 0);
}
@Test
@@ -319,6 +320,50 @@ public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception
insertRow(row1);
insertRow(row2);
- assert(compare(0, 1) < 0);
+ Assert.assertTrue(compare(0, 1) < 0);
+ }
+
+ @Test
+ public void testCompareLongsAsLittleEndian() {
+ long arrayOffset = 12;
+
+ long[] arr1 = new long[2];
+ Platform.putLong(arr1, arrayOffset, 0x0100000000000000L);
+ long[] arr2 = new long[2];
+ Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000001L);
+ // leftBaseOffset is not aligned while rightBaseOffset is aligned,
+ // it will start by comparing long
+ int result1 = binaryComparator.compare(arr1, arrayOffset, 8, arr2, arrayOffset + 4, 8);
+
+ long[] arr3 = new long[2];
+ Platform.putLong(arr3, arrayOffset, 0x0100000000000000L);
+ long[] arr4 = new long[2];
+ Platform.putLong(arr4, arrayOffset, 0x0000000000000001L);
+ // both left and right offset is not aligned, it will start with byte-by-byte comparison
+ int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
+
+ Assert.assertEquals(result1, result2);
+ }
+
+ @Test
+ public void testCompareLongsAsUnsigned() {
+ long arrayOffset = 12;
+
+ long[] arr1 = new long[2];
+ Platform.putLong(arr1, arrayOffset + 4, 0xa000000000000000L);
+ long[] arr2 = new long[2];
+ Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000000L);
+ // both leftBaseOffset and rightBaseOffset are aligned, so it will start by comparing long
+ int result1 = binaryComparator.compare(arr1, arrayOffset + 4, 8, arr2, arrayOffset + 4, 8);
+
+ long[] arr3 = new long[2];
+ Platform.putLong(arr3, arrayOffset, 0xa000000000000000L);
+ long[] arr4 = new long[2];
+ Platform.putLong(arr4, arrayOffset, 0x0000000000000000L);
+ // both leftBaseOffset and rightBaseOffset are not aligned,
+ // so it will start with byte-by-byte comparison
+ int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8);
+
+ Assert.assertEquals(result1, result2);
}
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/higher-order-functions.sql
index 4068a27fcb2a7..1e2424fe47cad 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/ansi/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/higher-order-functions.sql
@@ -1 +1 @@
---import higher-order-functions.sql
+--IMPORT higher-order-functions.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/interval.sql
index 215ee7c074fa6..087914eebb077 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/ansi/interval.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/interval.sql
@@ -1,4 +1,4 @@
---import interval.sql
+--IMPORT interval.sql
-- the `interval` keyword can be omitted with ansi mode
select 1 year 2 days;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/literals.sql
index 170690ea699c0..698e8fa886307 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/ansi/literals.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/literals.sql
@@ -1,2 +1,2 @@
--- malformed interval literal with ansi mode
---import literals.sql
+--IMPORT literals.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
index 8a035f594be54..3c1702e6f837e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -60,3 +60,13 @@ DESC FUNCTION EXTENDED boolean;
-- cast string to interval and interval to string
SELECT CAST('interval 3 month 1 hour' AS interval);
SELECT CAST(interval 3 month 1 hour AS string);
+
+-- trim string before cast to numeric
+select cast(' 1' as tinyint);
+select cast(' 1\t' as tinyint);
+select cast(' 1' as smallint);
+select cast(' 1' as INT);
+select cast(' 1' as bigint);
+select cast(' 1' as float);
+select cast(' 1 ' as DOUBLE);
+select cast('1.0 ' as DEC);
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql
index 3e2447723e576..70af4f75ac431 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql
@@ -1,3 +1,13 @@
-- binary type
select x'00' < x'0f';
select x'00' < x'ff';
+
+-- trim string to numeric
+select '1 ' = 1Y;
+select '\t1 ' = 1Y;
+select '1 ' = 1S;
+select '1 ' = 1;
+select ' 1' = 1L;
+select ' 1' = cast(1.0 as float);
+select ' 1.0 ' = 1.0D;
+select ' 1.0 ' = 1.0BD;
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index d602f63e529d1..fedf03d774e42 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -1,3 +1,8 @@
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null)
diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index 7665346f86ba8..cfa06aea82b04 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -1,3 +1,8 @@
+-- Test higher order functions with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
create or replace temporary view nested as values
(1, array(32, 97), array(array(12, 99), array(123, 42), array(1))),
(2, array(77, -76), array(array(6, 96, 65), array(-1, -2))),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql
index e87c660cb1fe6..5623161839331 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql
@@ -1,7 +1,14 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+-- There are 2 dimensions we want to test
+-- 1. run with broadcast hash join, sort merge join or shuffle hash join.
+-- 2. run with whole-stage-codegen, operator codegen or no codegen.
+
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a);
CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval-display-iso_8601.sql b/sql/core/src/test/resources/sql-tests/inputs/interval-display-iso_8601.sql
new file mode 100644
index 0000000000000..3b63c715a6aa1
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval-display-iso_8601.sql
@@ -0,0 +1,3 @@
+-- tests for interval output style with iso_8601 format
+--SET spark.sql.intervalOutputStyle = ISO_8601
+--IMPORT interval-display.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval-display-sql_standard.sql b/sql/core/src/test/resources/sql-tests/inputs/interval-display-sql_standard.sql
new file mode 100644
index 0000000000000..d96865b160bb6
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval-display-sql_standard.sql
@@ -0,0 +1,3 @@
+-- tests for interval output style with sql standard format
+--SET spark.sql.intervalOutputStyle = SQL_STANDARD
+--IMPORT interval-display.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval-display.sql b/sql/core/src/test/resources/sql-tests/inputs/interval-display.sql
new file mode 100644
index 0000000000000..ae19f1b6374ba
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval-display.sql
@@ -0,0 +1,14 @@
+-- tests for interval output style
+
+SELECT
+ cast(null as interval), -- null
+ interval '0 day', -- 0
+ interval '1 year', -- year only
+ interval '1 month', -- month only
+ interval '1 year 2 month', -- year month only
+ interval '1 day -1 hours',
+ interval '-1 day -1 hours',
+ interval '-1 day 1 hours',
+ interval '-1 days +1 hours',
+ interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds',
+ - interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql
index 2e6a5f362a8fa..8afa3270f4de4 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-empty-relation.sql
@@ -1,8 +1,3 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
-
CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a);
CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
index d1dff7bc94686..61b02d86bb51b 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/literals.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql
@@ -107,7 +107,10 @@ select integer '2147483648';
-- awareness of the negative/positive sign before type
select -integer '7';
+select +integer '7';
+select +date '1999-01-01';
+select +timestamp '1999-01-01';
+-- can't negate date/timestamp/binary
select -date '1999-01-01';
select -timestamp '1999-01-01';
select -x'2379ACFe';
-select +integer '7';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql
new file mode 100644
index 0000000000000..95f71925e9294
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/misc-functions.sql
@@ -0,0 +1,10 @@
+-- test for misc functions
+
+-- typeof
+select typeof(null);
+select typeof(true);
+select typeof(1Y), typeof(1S), typeof(1), typeof(1L);
+select typeof(cast(1.0 as float)), typeof(1.0D), typeof(1.2);
+select typeof(date '1986-05-23'), typeof(timestamp '1986-05-23'), typeof(interval '23 days');
+select typeof(x'ABCD'), typeof('SPARK');
+select typeof(array(1, 2)), typeof(map(1, 2)), typeof(named_struct('a', 1, 'b', 'spark'));
diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql
index e0abeda3eb44f..71a50157b766c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql
@@ -1,8 +1,3 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
-
create temporary view nt1 as select * from values
("one", 1),
("two", 2),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql
index f7637b444b9fe..ad3977465c835 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql
@@ -1,3 +1,8 @@
+-- Test sort operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- Q1. testing window functions with order by
create table spark_10747(col1 int, col2 int, col3 int) using parquet;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
index ce09c21568f13..ceb438ec34b2d 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql
@@ -1,7 +1,14 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+-- There are 2 dimensions we want to test
+-- 1. run with broadcast hash join, sort merge join or shuffle hash join.
+-- 2. run with whole-stage-codegen, operator codegen or no codegen.
+
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
-- SPARK-17099: Incorrect result when HAVING clause is added to group by query
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
@@ -29,9 +36,6 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1)
CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1);
--- Set the cross join enabled flag for the LEFT JOIN test since there's no join condition.
--- Ultimately the join should be optimized away.
-set spark.sql.crossJoin.enabled = true;
SELECT *
FROM (
SELECT
@@ -39,6 +43,3 @@ SELECT
FROM t1
LEFT JOIN t2 ON false
) t where (t.int_col) is not null;
-set spark.sql.crossJoin.enabled = false;
-
-
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
index 5d54be9341148..63f80bd2efa73 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part1.sql
@@ -8,6 +8,11 @@
-- avoid bit-exact output here because operations may not be bit-exact.
-- SET extra_float_digits = 0;
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
SELECT avg(four) AS avg_1 FROM onek;
SELECT avg(a) AS avg_32 FROM aggtest WHERE a < 100;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part2.sql
index ba91366014e16..a8af1db77563c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part2.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part2.sql
@@ -5,6 +5,11 @@
-- AGGREGATES [Part 2]
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L145-L350
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
create temporary view int4_tbl as select * from values
(0),
(123456),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
index 78fdbf6ae6cd2..6f5e549644bbf 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -5,6 +5,11 @@
-- AGGREGATES [Part 3]
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L352-L605
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- [SPARK-28865] Table inheritance
-- try it on an inheritance tree
-- create table minmaxtest(f1 int);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part4.sql
index 6fa2306cf1475..0d255bed24e9c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part4.sql
@@ -5,6 +5,11 @@
-- AGGREGATES [Part 4]
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/aggregates.sql#L607-L997
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- [SPARK-27980] Ordered-Set Aggregate Functions
-- ordered-set aggregates
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/date.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/date.sql
index d3cd46e4e6b89..0bab2f884d976 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/date.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/date.sql
@@ -7,23 +7,25 @@
CREATE TABLE DATE_TBL (f1 date) USING parquet;
-INSERT INTO DATE_TBL VALUES ('1957-04-09');
-INSERT INTO DATE_TBL VALUES ('1957-06-13');
-INSERT INTO DATE_TBL VALUES ('1996-02-28');
-INSERT INTO DATE_TBL VALUES ('1996-02-29');
-INSERT INTO DATE_TBL VALUES ('1996-03-01');
-INSERT INTO DATE_TBL VALUES ('1996-03-02');
-INSERT INTO DATE_TBL VALUES ('1997-02-28');
+-- PostgreSQL implicitly casts string literals to data with date types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO DATE_TBL VALUES (date('1957-04-09'));
+INSERT INTO DATE_TBL VALUES (date('1957-06-13'));
+INSERT INTO DATE_TBL VALUES (date('1996-02-28'));
+INSERT INTO DATE_TBL VALUES (date('1996-02-29'));
+INSERT INTO DATE_TBL VALUES (date('1996-03-01'));
+INSERT INTO DATE_TBL VALUES (date('1996-03-02'));
+INSERT INTO DATE_TBL VALUES (date('1997-02-28'));
-- [SPARK-27923] Skip invalid date: 1997-02-29
--- INSERT INTO DATE_TBL VALUES ('1997-02-29');
-INSERT INTO DATE_TBL VALUES ('1997-03-01');
-INSERT INTO DATE_TBL VALUES ('1997-03-02');
-INSERT INTO DATE_TBL VALUES ('2000-04-01');
-INSERT INTO DATE_TBL VALUES ('2000-04-02');
-INSERT INTO DATE_TBL VALUES ('2000-04-03');
-INSERT INTO DATE_TBL VALUES ('2038-04-08');
-INSERT INTO DATE_TBL VALUES ('2039-04-09');
-INSERT INTO DATE_TBL VALUES ('2040-04-10');
+-- INSERT INTO DATE_TBL VALUES ('1997-02-29'));
+INSERT INTO DATE_TBL VALUES (date('1997-03-01'));
+INSERT INTO DATE_TBL VALUES (date('1997-03-02'));
+INSERT INTO DATE_TBL VALUES (date('2000-04-01'));
+INSERT INTO DATE_TBL VALUES (date('2000-04-02'));
+INSERT INTO DATE_TBL VALUES (date('2000-04-03'));
+INSERT INTO DATE_TBL VALUES (date('2038-04-08'));
+INSERT INTO DATE_TBL VALUES (date('2039-04-09'));
+INSERT INTO DATE_TBL VALUES (date('2040-04-10'));
SELECT f1 AS `Fifteen` FROM DATE_TBL;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql
index 058467695a608..2989569e219ff 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float4.sql
@@ -7,11 +7,13 @@
CREATE TABLE FLOAT4_TBL (f1 float) USING parquet;
-INSERT INTO FLOAT4_TBL VALUES (' 0.0');
-INSERT INTO FLOAT4_TBL VALUES ('1004.30 ');
-INSERT INTO FLOAT4_TBL VALUES (' -34.84 ');
-INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e+20');
-INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e-20');
+-- PostgreSQL implicitly casts string literals to data with floating point types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO FLOAT4_TBL VALUES (float(' 0.0'));
+INSERT INTO FLOAT4_TBL VALUES (float('1004.30 '));
+INSERT INTO FLOAT4_TBL VALUES (float(' -34.84 '));
+INSERT INTO FLOAT4_TBL VALUES (float('1.2345678901234e+20'));
+INSERT INTO FLOAT4_TBL VALUES (float('1.2345678901234e-20'));
-- [SPARK-28024] Incorrect numeric values when out of range
-- test for over and under flow
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql
index 957dabdebab4e..932cdb95fcf3a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/float8.sql
@@ -7,11 +7,13 @@
CREATE TABLE FLOAT8_TBL(f1 double) USING parquet;
-INSERT INTO FLOAT8_TBL VALUES (' 0.0 ');
-INSERT INTO FLOAT8_TBL VALUES ('1004.30 ');
-INSERT INTO FLOAT8_TBL VALUES (' -34.84');
-INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e+200');
-INSERT INTO FLOAT8_TBL VALUES ('1.2345678901234e-200');
+-- PostgreSQL implicitly casts string literals to data with floating point types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO FLOAT8_TBL VALUES (double(' 0.0 '));
+INSERT INTO FLOAT8_TBL VALUES (double('1004.30 '));
+INSERT INTO FLOAT8_TBL VALUES (double(' -34.84'));
+INSERT INTO FLOAT8_TBL VALUES (double('1.2345678901234e+200'));
+INSERT INTO FLOAT8_TBL VALUES (double('1.2345678901234e-200'));
-- [SPARK-28024] Incorrect numeric values when out of range
-- test for underflow and overflow handling
@@ -227,15 +229,17 @@ SELECT atanh(double('NaN'));
TRUNCATE TABLE FLOAT8_TBL;
-INSERT INTO FLOAT8_TBL VALUES ('0.0');
+-- PostgreSQL implicitly casts string literals to data with floating point types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO FLOAT8_TBL VALUES (double('0.0'));
-INSERT INTO FLOAT8_TBL VALUES ('-34.84');
+INSERT INTO FLOAT8_TBL VALUES (double('-34.84'));
-INSERT INTO FLOAT8_TBL VALUES ('-1004.30');
+INSERT INTO FLOAT8_TBL VALUES (double('-1004.30'));
-INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e+200');
+INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e+200'));
-INSERT INTO FLOAT8_TBL VALUES ('-1.2345678901234e-200');
+INSERT INTO FLOAT8_TBL VALUES (double('-1.2345678901234e-200'));
SELECT '' AS five, * FROM FLOAT8_TBL;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int2.sql
index f64ec5d75afcf..07f5976ca6d2f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int2.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int2.sql
@@ -8,19 +8,23 @@
CREATE TABLE INT2_TBL(f1 smallint) USING parquet;
-- [SPARK-28023] Trim the string when cast string type to other types
-INSERT INTO INT2_TBL VALUES (trim('0 '));
+-- PostgreSQL implicitly casts string literals to data with integral types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO INT2_TBL VALUES (smallint(trim('0 ')));
-INSERT INTO INT2_TBL VALUES (trim(' 1234 '));
+INSERT INTO INT2_TBL VALUES (smallint(trim(' 1234 ')));
-INSERT INTO INT2_TBL VALUES (trim(' -1234'));
+INSERT INTO INT2_TBL VALUES (smallint(trim(' -1234')));
-- [SPARK-27923] Invalid input syntax for type short throws exception at PostgreSQL
-- INSERT INTO INT2_TBL VALUES ('34.5');
-- largest and smallest values
-INSERT INTO INT2_TBL VALUES ('32767');
+-- PostgreSQL implicitly casts string literals to data with integral types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO INT2_TBL VALUES (smallint('32767'));
-INSERT INTO INT2_TBL VALUES ('-32767');
+INSERT INTO INT2_TBL VALUES (smallint('-32767'));
-- bad input values -- should give errors
-- INSERT INTO INT2_TBL VALUES ('100000');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
index 1c2320ff7fad6..3a409eea34837 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
@@ -9,19 +9,23 @@
CREATE TABLE INT4_TBL(f1 int) USING parquet;
-- [SPARK-28023] Trim the string when cast string type to other types
-INSERT INTO INT4_TBL VALUES (trim(' 0 '));
+-- PostgreSQL implicitly casts string literals to data with integral types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO INT4_TBL VALUES (int(trim(' 0 ')));
-INSERT INTO INT4_TBL VALUES (trim('123456 '));
+INSERT INTO INT4_TBL VALUES (int(trim('123456 ')));
-INSERT INTO INT4_TBL VALUES (trim(' -123456'));
+INSERT INTO INT4_TBL VALUES (int(trim(' -123456')));
-- [SPARK-27923] Invalid input syntax for integer: "34.5" at PostgreSQL
-- INSERT INTO INT4_TBL(f1) VALUES ('34.5');
-- largest and smallest values
-INSERT INTO INT4_TBL VALUES ('2147483647');
+-- PostgreSQL implicitly casts string literals to data with integral types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO INT4_TBL VALUES (int('2147483647'));
-INSERT INTO INT4_TBL VALUES ('-2147483647');
+INSERT INTO INT4_TBL VALUES (int('-2147483647'));
-- [SPARK-27923] Spark SQL insert these bad inputs to NULL
-- bad input values
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
index d29bf3bfad4ca..5fea758e73084 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -8,11 +8,13 @@
--
CREATE TABLE INT8_TBL(q1 bigint, q2 bigint) USING parquet;
-INSERT INTO INT8_TBL VALUES(trim(' 123 '),trim(' 456'));
-INSERT INTO INT8_TBL VALUES(trim('123 '),'4567890123456789');
-INSERT INTO INT8_TBL VALUES('4567890123456789','123');
-INSERT INTO INT8_TBL VALUES(+4567890123456789,'4567890123456789');
-INSERT INTO INT8_TBL VALUES('+4567890123456789','-4567890123456789');
+-- PostgreSQL implicitly casts string literals to data with integral types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO INT8_TBL VALUES(bigint(trim(' 123 ')),bigint(trim(' 456')));
+INSERT INTO INT8_TBL VALUES(bigint(trim('123 ')),bigint('4567890123456789'));
+INSERT INTO INT8_TBL VALUES(bigint('4567890123456789'),bigint('123'));
+INSERT INTO INT8_TBL VALUES(+4567890123456789,bigint('4567890123456789'));
+INSERT INTO INT8_TBL VALUES(bigint('+4567890123456789'),bigint('-4567890123456789'));
-- [SPARK-27923] Spark SQL insert there bad inputs to NULL
-- bad inputs
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/interval.sql
index 01df2a3fd1b21..3b25ef7334c0a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/interval.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/interval.sql
@@ -270,10 +270,12 @@ SELECT interval '1 2:03:04' minute to second;
-- test output of couple non-standard interval values in the sql style
-- [SPARK-29406] Interval output styles
-- SET IntervalStyle TO sql_standard;
--- SELECT interval '1 day -1 hours',
--- interval '-1 days +1 hours',
--- interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds',
--- - interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds';
+set spark.sql.intervalOutputStyle=SQL_STANDARD;
+SELECT interval '1 day -1 hours',
+ interval '-1 days +1 hours',
+ interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds',
+ - interval '1 years 2 months -3 days 4 hours 5 minutes 6.789 seconds';
+set spark.sql.intervalOutputStyle=MULTI_UNITS;
-- test outputting iso8601 intervals
-- [SPARK-29406] Interval output styles
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql
index 1ada723d6ae22..cc07b00cc3670 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/join.sql
@@ -7,10 +7,17 @@
-- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/join.sql
--
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+-- There are 2 dimensions we want to test
+-- 1. run with broadcast hash join, sort merge join or shuffle hash join.
+-- 2. run with whole-stage-codegen, operator codegen or no codegen.
+
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
CREATE OR REPLACE TEMPORARY VIEW INT4_TBL AS SELECT * FROM
(VALUES (0), (123456), (-123456), (2147483647), (-2147483647))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
index c447a0dc2c7f2..dbdb2cace0e0c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/numeric.sql
@@ -26,466 +26,508 @@ CREATE TABLE num_result (id1 int, id2 int, result decimal(38,10)) USING parquet;
-- ******************************
-- BEGIN TRANSACTION;
-INSERT INTO num_exp_add VALUES (0,0,'0');
-INSERT INTO num_exp_sub VALUES (0,0,'0');
-INSERT INTO num_exp_mul VALUES (0,0,'0');
-INSERT INTO num_exp_div VALUES (0,0,'NaN');
-INSERT INTO num_exp_add VALUES (0,1,'0');
-INSERT INTO num_exp_sub VALUES (0,1,'0');
-INSERT INTO num_exp_mul VALUES (0,1,'0');
-INSERT INTO num_exp_div VALUES (0,1,'NaN');
-INSERT INTO num_exp_add VALUES (0,2,'-34338492.215397047');
-INSERT INTO num_exp_sub VALUES (0,2,'34338492.215397047');
-INSERT INTO num_exp_mul VALUES (0,2,'0');
-INSERT INTO num_exp_div VALUES (0,2,'0');
-INSERT INTO num_exp_add VALUES (0,3,'4.31');
-INSERT INTO num_exp_sub VALUES (0,3,'-4.31');
-INSERT INTO num_exp_mul VALUES (0,3,'0');
-INSERT INTO num_exp_div VALUES (0,3,'0');
-INSERT INTO num_exp_add VALUES (0,4,'7799461.4119');
-INSERT INTO num_exp_sub VALUES (0,4,'-7799461.4119');
-INSERT INTO num_exp_mul VALUES (0,4,'0');
-INSERT INTO num_exp_div VALUES (0,4,'0');
-INSERT INTO num_exp_add VALUES (0,5,'16397.038491');
-INSERT INTO num_exp_sub VALUES (0,5,'-16397.038491');
-INSERT INTO num_exp_mul VALUES (0,5,'0');
-INSERT INTO num_exp_div VALUES (0,5,'0');
-INSERT INTO num_exp_add VALUES (0,6,'93901.57763026');
-INSERT INTO num_exp_sub VALUES (0,6,'-93901.57763026');
-INSERT INTO num_exp_mul VALUES (0,6,'0');
-INSERT INTO num_exp_div VALUES (0,6,'0');
-INSERT INTO num_exp_add VALUES (0,7,'-83028485');
-INSERT INTO num_exp_sub VALUES (0,7,'83028485');
-INSERT INTO num_exp_mul VALUES (0,7,'0');
-INSERT INTO num_exp_div VALUES (0,7,'0');
-INSERT INTO num_exp_add VALUES (0,8,'74881');
-INSERT INTO num_exp_sub VALUES (0,8,'-74881');
-INSERT INTO num_exp_mul VALUES (0,8,'0');
-INSERT INTO num_exp_div VALUES (0,8,'0');
-INSERT INTO num_exp_add VALUES (0,9,'-24926804.045047420');
-INSERT INTO num_exp_sub VALUES (0,9,'24926804.045047420');
-INSERT INTO num_exp_mul VALUES (0,9,'0');
-INSERT INTO num_exp_div VALUES (0,9,'0');
-INSERT INTO num_exp_add VALUES (1,0,'0');
-INSERT INTO num_exp_sub VALUES (1,0,'0');
-INSERT INTO num_exp_mul VALUES (1,0,'0');
-INSERT INTO num_exp_div VALUES (1,0,'NaN');
-INSERT INTO num_exp_add VALUES (1,1,'0');
-INSERT INTO num_exp_sub VALUES (1,1,'0');
-INSERT INTO num_exp_mul VALUES (1,1,'0');
-INSERT INTO num_exp_div VALUES (1,1,'NaN');
-INSERT INTO num_exp_add VALUES (1,2,'-34338492.215397047');
-INSERT INTO num_exp_sub VALUES (1,2,'34338492.215397047');
-INSERT INTO num_exp_mul VALUES (1,2,'0');
-INSERT INTO num_exp_div VALUES (1,2,'0');
-INSERT INTO num_exp_add VALUES (1,3,'4.31');
-INSERT INTO num_exp_sub VALUES (1,3,'-4.31');
-INSERT INTO num_exp_mul VALUES (1,3,'0');
-INSERT INTO num_exp_div VALUES (1,3,'0');
-INSERT INTO num_exp_add VALUES (1,4,'7799461.4119');
-INSERT INTO num_exp_sub VALUES (1,4,'-7799461.4119');
-INSERT INTO num_exp_mul VALUES (1,4,'0');
-INSERT INTO num_exp_div VALUES (1,4,'0');
-INSERT INTO num_exp_add VALUES (1,5,'16397.038491');
-INSERT INTO num_exp_sub VALUES (1,5,'-16397.038491');
-INSERT INTO num_exp_mul VALUES (1,5,'0');
-INSERT INTO num_exp_div VALUES (1,5,'0');
-INSERT INTO num_exp_add VALUES (1,6,'93901.57763026');
-INSERT INTO num_exp_sub VALUES (1,6,'-93901.57763026');
-INSERT INTO num_exp_mul VALUES (1,6,'0');
-INSERT INTO num_exp_div VALUES (1,6,'0');
-INSERT INTO num_exp_add VALUES (1,7,'-83028485');
-INSERT INTO num_exp_sub VALUES (1,7,'83028485');
-INSERT INTO num_exp_mul VALUES (1,7,'0');
-INSERT INTO num_exp_div VALUES (1,7,'0');
-INSERT INTO num_exp_add VALUES (1,8,'74881');
-INSERT INTO num_exp_sub VALUES (1,8,'-74881');
-INSERT INTO num_exp_mul VALUES (1,8,'0');
-INSERT INTO num_exp_div VALUES (1,8,'0');
-INSERT INTO num_exp_add VALUES (1,9,'-24926804.045047420');
-INSERT INTO num_exp_sub VALUES (1,9,'24926804.045047420');
-INSERT INTO num_exp_mul VALUES (1,9,'0');
-INSERT INTO num_exp_div VALUES (1,9,'0');
-INSERT INTO num_exp_add VALUES (2,0,'-34338492.215397047');
-INSERT INTO num_exp_sub VALUES (2,0,'-34338492.215397047');
-INSERT INTO num_exp_mul VALUES (2,0,'0');
-INSERT INTO num_exp_div VALUES (2,0,'NaN');
-INSERT INTO num_exp_add VALUES (2,1,'-34338492.215397047');
-INSERT INTO num_exp_sub VALUES (2,1,'-34338492.215397047');
-INSERT INTO num_exp_mul VALUES (2,1,'0');
-INSERT INTO num_exp_div VALUES (2,1,'NaN');
-INSERT INTO num_exp_add VALUES (2,2,'-68676984.430794094');
-INSERT INTO num_exp_sub VALUES (2,2,'0');
-INSERT INTO num_exp_mul VALUES (2,2,'1179132047626883.596862135856320209');
-INSERT INTO num_exp_div VALUES (2,2,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (2,3,'-34338487.905397047');
-INSERT INTO num_exp_sub VALUES (2,3,'-34338496.525397047');
-INSERT INTO num_exp_mul VALUES (2,3,'-147998901.44836127257');
-INSERT INTO num_exp_div VALUES (2,3,'-7967167.56737750510440835266');
-INSERT INTO num_exp_add VALUES (2,4,'-26539030.803497047');
-INSERT INTO num_exp_sub VALUES (2,4,'-42137953.627297047');
-INSERT INTO num_exp_mul VALUES (2,4,'-267821744976817.8111137106593');
-INSERT INTO num_exp_div VALUES (2,4,'-4.40267480046830116685');
-INSERT INTO num_exp_add VALUES (2,5,'-34322095.176906047');
-INSERT INTO num_exp_sub VALUES (2,5,'-34354889.253888047');
-INSERT INTO num_exp_mul VALUES (2,5,'-563049578578.769242506736077');
-INSERT INTO num_exp_div VALUES (2,5,'-2094.18866914563535496429');
-INSERT INTO num_exp_add VALUES (2,6,'-34244590.637766787');
-INSERT INTO num_exp_sub VALUES (2,6,'-34432393.793027307');
-INSERT INTO num_exp_mul VALUES (2,6,'-3224438592470.18449811926184222');
-INSERT INTO num_exp_div VALUES (2,6,'-365.68599891479766440940');
-INSERT INTO num_exp_add VALUES (2,7,'-117366977.215397047');
-INSERT INTO num_exp_sub VALUES (2,7,'48689992.784602953');
-INSERT INTO num_exp_mul VALUES (2,7,'2851072985828710.485883795');
-INSERT INTO num_exp_div VALUES (2,7,'.41357483778485235518');
-INSERT INTO num_exp_add VALUES (2,8,'-34263611.215397047');
-INSERT INTO num_exp_sub VALUES (2,8,'-34413373.215397047');
-INSERT INTO num_exp_mul VALUES (2,8,'-2571300635581.146276407');
-INSERT INTO num_exp_div VALUES (2,8,'-458.57416721727870888476');
-INSERT INTO num_exp_add VALUES (2,9,'-59265296.260444467');
-INSERT INTO num_exp_sub VALUES (2,9,'-9411688.170349627');
-INSERT INTO num_exp_mul VALUES (2,9,'855948866655588.453741509242968740');
-INSERT INTO num_exp_div VALUES (2,9,'1.37757299946438931811');
-INSERT INTO num_exp_add VALUES (3,0,'4.31');
-INSERT INTO num_exp_sub VALUES (3,0,'4.31');
-INSERT INTO num_exp_mul VALUES (3,0,'0');
-INSERT INTO num_exp_div VALUES (3,0,'NaN');
-INSERT INTO num_exp_add VALUES (3,1,'4.31');
-INSERT INTO num_exp_sub VALUES (3,1,'4.31');
-INSERT INTO num_exp_mul VALUES (3,1,'0');
-INSERT INTO num_exp_div VALUES (3,1,'NaN');
-INSERT INTO num_exp_add VALUES (3,2,'-34338487.905397047');
-INSERT INTO num_exp_sub VALUES (3,2,'34338496.525397047');
-INSERT INTO num_exp_mul VALUES (3,2,'-147998901.44836127257');
-INSERT INTO num_exp_div VALUES (3,2,'-.00000012551512084352');
-INSERT INTO num_exp_add VALUES (3,3,'8.62');
-INSERT INTO num_exp_sub VALUES (3,3,'0');
-INSERT INTO num_exp_mul VALUES (3,3,'18.5761');
-INSERT INTO num_exp_div VALUES (3,3,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (3,4,'7799465.7219');
-INSERT INTO num_exp_sub VALUES (3,4,'-7799457.1019');
-INSERT INTO num_exp_mul VALUES (3,4,'33615678.685289');
-INSERT INTO num_exp_div VALUES (3,4,'.00000055260225961552');
-INSERT INTO num_exp_add VALUES (3,5,'16401.348491');
-INSERT INTO num_exp_sub VALUES (3,5,'-16392.728491');
-INSERT INTO num_exp_mul VALUES (3,5,'70671.23589621');
-INSERT INTO num_exp_div VALUES (3,5,'.00026285234387695504');
-INSERT INTO num_exp_add VALUES (3,6,'93905.88763026');
-INSERT INTO num_exp_sub VALUES (3,6,'-93897.26763026');
-INSERT INTO num_exp_mul VALUES (3,6,'404715.7995864206');
-INSERT INTO num_exp_div VALUES (3,6,'.00004589912234457595');
-INSERT INTO num_exp_add VALUES (3,7,'-83028480.69');
-INSERT INTO num_exp_sub VALUES (3,7,'83028489.31');
-INSERT INTO num_exp_mul VALUES (3,7,'-357852770.35');
-INSERT INTO num_exp_div VALUES (3,7,'-.00000005190989574240');
-INSERT INTO num_exp_add VALUES (3,8,'74885.31');
-INSERT INTO num_exp_sub VALUES (3,8,'-74876.69');
-INSERT INTO num_exp_mul VALUES (3,8,'322737.11');
-INSERT INTO num_exp_div VALUES (3,8,'.00005755799201399553');
-INSERT INTO num_exp_add VALUES (3,9,'-24926799.735047420');
-INSERT INTO num_exp_sub VALUES (3,9,'24926808.355047420');
-INSERT INTO num_exp_mul VALUES (3,9,'-107434525.43415438020');
-INSERT INTO num_exp_div VALUES (3,9,'-.00000017290624149854');
-INSERT INTO num_exp_add VALUES (4,0,'7799461.4119');
-INSERT INTO num_exp_sub VALUES (4,0,'7799461.4119');
-INSERT INTO num_exp_mul VALUES (4,0,'0');
-INSERT INTO num_exp_div VALUES (4,0,'NaN');
-INSERT INTO num_exp_add VALUES (4,1,'7799461.4119');
-INSERT INTO num_exp_sub VALUES (4,1,'7799461.4119');
-INSERT INTO num_exp_mul VALUES (4,1,'0');
-INSERT INTO num_exp_div VALUES (4,1,'NaN');
-INSERT INTO num_exp_add VALUES (4,2,'-26539030.803497047');
-INSERT INTO num_exp_sub VALUES (4,2,'42137953.627297047');
-INSERT INTO num_exp_mul VALUES (4,2,'-267821744976817.8111137106593');
-INSERT INTO num_exp_div VALUES (4,2,'-.22713465002993920385');
-INSERT INTO num_exp_add VALUES (4,3,'7799465.7219');
-INSERT INTO num_exp_sub VALUES (4,3,'7799457.1019');
-INSERT INTO num_exp_mul VALUES (4,3,'33615678.685289');
-INSERT INTO num_exp_div VALUES (4,3,'1809619.81714617169373549883');
-INSERT INTO num_exp_add VALUES (4,4,'15598922.8238');
-INSERT INTO num_exp_sub VALUES (4,4,'0');
-INSERT INTO num_exp_mul VALUES (4,4,'60831598315717.14146161');
-INSERT INTO num_exp_div VALUES (4,4,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (4,5,'7815858.450391');
-INSERT INTO num_exp_sub VALUES (4,5,'7783064.373409');
-INSERT INTO num_exp_mul VALUES (4,5,'127888068979.9935054429');
-INSERT INTO num_exp_div VALUES (4,5,'475.66281046305802686061');
-INSERT INTO num_exp_add VALUES (4,6,'7893362.98953026');
-INSERT INTO num_exp_sub VALUES (4,6,'7705559.83426974');
-INSERT INTO num_exp_mul VALUES (4,6,'732381731243.745115764094');
-INSERT INTO num_exp_div VALUES (4,6,'83.05996138436129499606');
-INSERT INTO num_exp_add VALUES (4,7,'-75229023.5881');
-INSERT INTO num_exp_sub VALUES (4,7,'90827946.4119');
-INSERT INTO num_exp_mul VALUES (4,7,'-647577464846017.9715');
-INSERT INTO num_exp_div VALUES (4,7,'-.09393717604145131637');
-INSERT INTO num_exp_add VALUES (4,8,'7874342.4119');
-INSERT INTO num_exp_sub VALUES (4,8,'7724580.4119');
-INSERT INTO num_exp_mul VALUES (4,8,'584031469984.4839');
-INSERT INTO num_exp_div VALUES (4,8,'104.15808298366741897143');
-INSERT INTO num_exp_add VALUES (4,9,'-17127342.633147420');
-INSERT INTO num_exp_sub VALUES (4,9,'32726265.456947420');
-INSERT INTO num_exp_mul VALUES (4,9,'-194415646271340.1815956522980');
-INSERT INTO num_exp_div VALUES (4,9,'-.31289456112403769409');
-INSERT INTO num_exp_add VALUES (5,0,'16397.038491');
-INSERT INTO num_exp_sub VALUES (5,0,'16397.038491');
-INSERT INTO num_exp_mul VALUES (5,0,'0');
-INSERT INTO num_exp_div VALUES (5,0,'NaN');
-INSERT INTO num_exp_add VALUES (5,1,'16397.038491');
-INSERT INTO num_exp_sub VALUES (5,1,'16397.038491');
-INSERT INTO num_exp_mul VALUES (5,1,'0');
-INSERT INTO num_exp_div VALUES (5,1,'NaN');
-INSERT INTO num_exp_add VALUES (5,2,'-34322095.176906047');
-INSERT INTO num_exp_sub VALUES (5,2,'34354889.253888047');
-INSERT INTO num_exp_mul VALUES (5,2,'-563049578578.769242506736077');
-INSERT INTO num_exp_div VALUES (5,2,'-.00047751189505192446');
-INSERT INTO num_exp_add VALUES (5,3,'16401.348491');
-INSERT INTO num_exp_sub VALUES (5,3,'16392.728491');
-INSERT INTO num_exp_mul VALUES (5,3,'70671.23589621');
-INSERT INTO num_exp_div VALUES (5,3,'3804.41728329466357308584');
-INSERT INTO num_exp_add VALUES (5,4,'7815858.450391');
-INSERT INTO num_exp_sub VALUES (5,4,'-7783064.373409');
-INSERT INTO num_exp_mul VALUES (5,4,'127888068979.9935054429');
-INSERT INTO num_exp_div VALUES (5,4,'.00210232958726897192');
-INSERT INTO num_exp_add VALUES (5,5,'32794.076982');
-INSERT INTO num_exp_sub VALUES (5,5,'0');
-INSERT INTO num_exp_mul VALUES (5,5,'268862871.275335557081');
-INSERT INTO num_exp_div VALUES (5,5,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (5,6,'110298.61612126');
-INSERT INTO num_exp_sub VALUES (5,6,'-77504.53913926');
-INSERT INTO num_exp_mul VALUES (5,6,'1539707782.76899778633766');
-INSERT INTO num_exp_div VALUES (5,6,'.17461941433576102689');
-INSERT INTO num_exp_add VALUES (5,7,'-83012087.961509');
-INSERT INTO num_exp_sub VALUES (5,7,'83044882.038491');
-INSERT INTO num_exp_mul VALUES (5,7,'-1361421264394.416135');
-INSERT INTO num_exp_div VALUES (5,7,'-.00019748690453643710');
-INSERT INTO num_exp_add VALUES (5,8,'91278.038491');
-INSERT INTO num_exp_sub VALUES (5,8,'-58483.961509');
-INSERT INTO num_exp_mul VALUES (5,8,'1227826639.244571');
-INSERT INTO num_exp_div VALUES (5,8,'.21897461960978085228');
-INSERT INTO num_exp_add VALUES (5,9,'-24910407.006556420');
-INSERT INTO num_exp_sub VALUES (5,9,'24943201.083538420');
-INSERT INTO num_exp_mul VALUES (5,9,'-408725765384.257043660243220');
-INSERT INTO num_exp_div VALUES (5,9,'-.00065780749354660427');
-INSERT INTO num_exp_add VALUES (6,0,'93901.57763026');
-INSERT INTO num_exp_sub VALUES (6,0,'93901.57763026');
-INSERT INTO num_exp_mul VALUES (6,0,'0');
-INSERT INTO num_exp_div VALUES (6,0,'NaN');
-INSERT INTO num_exp_add VALUES (6,1,'93901.57763026');
-INSERT INTO num_exp_sub VALUES (6,1,'93901.57763026');
-INSERT INTO num_exp_mul VALUES (6,1,'0');
-INSERT INTO num_exp_div VALUES (6,1,'NaN');
-INSERT INTO num_exp_add VALUES (6,2,'-34244590.637766787');
-INSERT INTO num_exp_sub VALUES (6,2,'34432393.793027307');
-INSERT INTO num_exp_mul VALUES (6,2,'-3224438592470.18449811926184222');
-INSERT INTO num_exp_div VALUES (6,2,'-.00273458651128995823');
-INSERT INTO num_exp_add VALUES (6,3,'93905.88763026');
-INSERT INTO num_exp_sub VALUES (6,3,'93897.26763026');
-INSERT INTO num_exp_mul VALUES (6,3,'404715.7995864206');
-INSERT INTO num_exp_div VALUES (6,3,'21786.90896293735498839907');
-INSERT INTO num_exp_add VALUES (6,4,'7893362.98953026');
-INSERT INTO num_exp_sub VALUES (6,4,'-7705559.83426974');
-INSERT INTO num_exp_mul VALUES (6,4,'732381731243.745115764094');
-INSERT INTO num_exp_div VALUES (6,4,'.01203949512295682469');
-INSERT INTO num_exp_add VALUES (6,5,'110298.61612126');
-INSERT INTO num_exp_sub VALUES (6,5,'77504.53913926');
-INSERT INTO num_exp_mul VALUES (6,5,'1539707782.76899778633766');
-INSERT INTO num_exp_div VALUES (6,5,'5.72674008674192359679');
-INSERT INTO num_exp_add VALUES (6,6,'187803.15526052');
-INSERT INTO num_exp_sub VALUES (6,6,'0');
-INSERT INTO num_exp_mul VALUES (6,6,'8817506281.4517452372676676');
-INSERT INTO num_exp_div VALUES (6,6,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (6,7,'-82934583.42236974');
-INSERT INTO num_exp_sub VALUES (6,7,'83122386.57763026');
-INSERT INTO num_exp_mul VALUES (6,7,'-7796505729750.37795610');
-INSERT INTO num_exp_div VALUES (6,7,'-.00113095617281538980');
-INSERT INTO num_exp_add VALUES (6,8,'168782.57763026');
-INSERT INTO num_exp_sub VALUES (6,8,'19020.57763026');
-INSERT INTO num_exp_mul VALUES (6,8,'7031444034.53149906');
-INSERT INTO num_exp_div VALUES (6,8,'1.25401073209839612184');
-INSERT INTO num_exp_add VALUES (6,9,'-24832902.467417160');
-INSERT INTO num_exp_sub VALUES (6,9,'25020705.622677680');
-INSERT INTO num_exp_mul VALUES (6,9,'-2340666225110.29929521292692920');
-INSERT INTO num_exp_div VALUES (6,9,'-.00376709254265256789');
-INSERT INTO num_exp_add VALUES (7,0,'-83028485');
-INSERT INTO num_exp_sub VALUES (7,0,'-83028485');
-INSERT INTO num_exp_mul VALUES (7,0,'0');
-INSERT INTO num_exp_div VALUES (7,0,'NaN');
-INSERT INTO num_exp_add VALUES (7,1,'-83028485');
-INSERT INTO num_exp_sub VALUES (7,1,'-83028485');
-INSERT INTO num_exp_mul VALUES (7,1,'0');
-INSERT INTO num_exp_div VALUES (7,1,'NaN');
-INSERT INTO num_exp_add VALUES (7,2,'-117366977.215397047');
-INSERT INTO num_exp_sub VALUES (7,2,'-48689992.784602953');
-INSERT INTO num_exp_mul VALUES (7,2,'2851072985828710.485883795');
-INSERT INTO num_exp_div VALUES (7,2,'2.41794207151503385700');
-INSERT INTO num_exp_add VALUES (7,3,'-83028480.69');
-INSERT INTO num_exp_sub VALUES (7,3,'-83028489.31');
-INSERT INTO num_exp_mul VALUES (7,3,'-357852770.35');
-INSERT INTO num_exp_div VALUES (7,3,'-19264149.65197215777262180974');
-INSERT INTO num_exp_add VALUES (7,4,'-75229023.5881');
-INSERT INTO num_exp_sub VALUES (7,4,'-90827946.4119');
-INSERT INTO num_exp_mul VALUES (7,4,'-647577464846017.9715');
-INSERT INTO num_exp_div VALUES (7,4,'-10.64541262725136247686');
-INSERT INTO num_exp_add VALUES (7,5,'-83012087.961509');
-INSERT INTO num_exp_sub VALUES (7,5,'-83044882.038491');
-INSERT INTO num_exp_mul VALUES (7,5,'-1361421264394.416135');
-INSERT INTO num_exp_div VALUES (7,5,'-5063.62688881730941836574');
-INSERT INTO num_exp_add VALUES (7,6,'-82934583.42236974');
-INSERT INTO num_exp_sub VALUES (7,6,'-83122386.57763026');
-INSERT INTO num_exp_mul VALUES (7,6,'-7796505729750.37795610');
-INSERT INTO num_exp_div VALUES (7,6,'-884.20756174009028770294');
-INSERT INTO num_exp_add VALUES (7,7,'-166056970');
-INSERT INTO num_exp_sub VALUES (7,7,'0');
-INSERT INTO num_exp_mul VALUES (7,7,'6893729321395225');
-INSERT INTO num_exp_div VALUES (7,7,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (7,8,'-82953604');
-INSERT INTO num_exp_sub VALUES (7,8,'-83103366');
-INSERT INTO num_exp_mul VALUES (7,8,'-6217255985285');
-INSERT INTO num_exp_div VALUES (7,8,'-1108.80577182462841041118');
-INSERT INTO num_exp_add VALUES (7,9,'-107955289.045047420');
-INSERT INTO num_exp_sub VALUES (7,9,'-58101680.954952580');
-INSERT INTO num_exp_mul VALUES (7,9,'2069634775752159.035758700');
-INSERT INTO num_exp_div VALUES (7,9,'3.33089171198810413382');
-INSERT INTO num_exp_add VALUES (8,0,'74881');
-INSERT INTO num_exp_sub VALUES (8,0,'74881');
-INSERT INTO num_exp_mul VALUES (8,0,'0');
-INSERT INTO num_exp_div VALUES (8,0,'NaN');
-INSERT INTO num_exp_add VALUES (8,1,'74881');
-INSERT INTO num_exp_sub VALUES (8,1,'74881');
-INSERT INTO num_exp_mul VALUES (8,1,'0');
-INSERT INTO num_exp_div VALUES (8,1,'NaN');
-INSERT INTO num_exp_add VALUES (8,2,'-34263611.215397047');
-INSERT INTO num_exp_sub VALUES (8,2,'34413373.215397047');
-INSERT INTO num_exp_mul VALUES (8,2,'-2571300635581.146276407');
-INSERT INTO num_exp_div VALUES (8,2,'-.00218067233500788615');
-INSERT INTO num_exp_add VALUES (8,3,'74885.31');
-INSERT INTO num_exp_sub VALUES (8,3,'74876.69');
-INSERT INTO num_exp_mul VALUES (8,3,'322737.11');
-INSERT INTO num_exp_div VALUES (8,3,'17373.78190255220417633410');
-INSERT INTO num_exp_add VALUES (8,4,'7874342.4119');
-INSERT INTO num_exp_sub VALUES (8,4,'-7724580.4119');
-INSERT INTO num_exp_mul VALUES (8,4,'584031469984.4839');
-INSERT INTO num_exp_div VALUES (8,4,'.00960079113741758956');
-INSERT INTO num_exp_add VALUES (8,5,'91278.038491');
-INSERT INTO num_exp_sub VALUES (8,5,'58483.961509');
-INSERT INTO num_exp_mul VALUES (8,5,'1227826639.244571');
-INSERT INTO num_exp_div VALUES (8,5,'4.56673929509287019456');
-INSERT INTO num_exp_add VALUES (8,6,'168782.57763026');
-INSERT INTO num_exp_sub VALUES (8,6,'-19020.57763026');
-INSERT INTO num_exp_mul VALUES (8,6,'7031444034.53149906');
-INSERT INTO num_exp_div VALUES (8,6,'.79744134113322314424');
-INSERT INTO num_exp_add VALUES (8,7,'-82953604');
-INSERT INTO num_exp_sub VALUES (8,7,'83103366');
-INSERT INTO num_exp_mul VALUES (8,7,'-6217255985285');
-INSERT INTO num_exp_div VALUES (8,7,'-.00090187120721280172');
-INSERT INTO num_exp_add VALUES (8,8,'149762');
-INSERT INTO num_exp_sub VALUES (8,8,'0');
-INSERT INTO num_exp_mul VALUES (8,8,'5607164161');
-INSERT INTO num_exp_div VALUES (8,8,'1.00000000000000000000');
-INSERT INTO num_exp_add VALUES (8,9,'-24851923.045047420');
-INSERT INTO num_exp_sub VALUES (8,9,'25001685.045047420');
-INSERT INTO num_exp_mul VALUES (8,9,'-1866544013697.195857020');
-INSERT INTO num_exp_div VALUES (8,9,'-.00300403532938582735');
-INSERT INTO num_exp_add VALUES (9,0,'-24926804.045047420');
-INSERT INTO num_exp_sub VALUES (9,0,'-24926804.045047420');
-INSERT INTO num_exp_mul VALUES (9,0,'0');
-INSERT INTO num_exp_div VALUES (9,0,'NaN');
-INSERT INTO num_exp_add VALUES (9,1,'-24926804.045047420');
-INSERT INTO num_exp_sub VALUES (9,1,'-24926804.045047420');
-INSERT INTO num_exp_mul VALUES (9,1,'0');
-INSERT INTO num_exp_div VALUES (9,1,'NaN');
-INSERT INTO num_exp_add VALUES (9,2,'-59265296.260444467');
-INSERT INTO num_exp_sub VALUES (9,2,'9411688.170349627');
-INSERT INTO num_exp_mul VALUES (9,2,'855948866655588.453741509242968740');
-INSERT INTO num_exp_div VALUES (9,2,'.72591434384152961526');
-INSERT INTO num_exp_add VALUES (9,3,'-24926799.735047420');
-INSERT INTO num_exp_sub VALUES (9,3,'-24926808.355047420');
-INSERT INTO num_exp_mul VALUES (9,3,'-107434525.43415438020');
-INSERT INTO num_exp_div VALUES (9,3,'-5783481.21694835730858468677');
-INSERT INTO num_exp_add VALUES (9,4,'-17127342.633147420');
-INSERT INTO num_exp_sub VALUES (9,4,'-32726265.456947420');
-INSERT INTO num_exp_mul VALUES (9,4,'-194415646271340.1815956522980');
-INSERT INTO num_exp_div VALUES (9,4,'-3.19596478892958416484');
-INSERT INTO num_exp_add VALUES (9,5,'-24910407.006556420');
-INSERT INTO num_exp_sub VALUES (9,5,'-24943201.083538420');
-INSERT INTO num_exp_mul VALUES (9,5,'-408725765384.257043660243220');
-INSERT INTO num_exp_div VALUES (9,5,'-1520.20159364322004505807');
-INSERT INTO num_exp_add VALUES (9,6,'-24832902.467417160');
-INSERT INTO num_exp_sub VALUES (9,6,'-25020705.622677680');
-INSERT INTO num_exp_mul VALUES (9,6,'-2340666225110.29929521292692920');
-INSERT INTO num_exp_div VALUES (9,6,'-265.45671195426965751280');
-INSERT INTO num_exp_add VALUES (9,7,'-107955289.045047420');
-INSERT INTO num_exp_sub VALUES (9,7,'58101680.954952580');
-INSERT INTO num_exp_mul VALUES (9,7,'2069634775752159.035758700');
-INSERT INTO num_exp_div VALUES (9,7,'.30021990699995814689');
-INSERT INTO num_exp_add VALUES (9,8,'-24851923.045047420');
-INSERT INTO num_exp_sub VALUES (9,8,'-25001685.045047420');
-INSERT INTO num_exp_mul VALUES (9,8,'-1866544013697.195857020');
-INSERT INTO num_exp_div VALUES (9,8,'-332.88556569820675471748');
-INSERT INTO num_exp_add VALUES (9,9,'-49853608.090094840');
-INSERT INTO num_exp_sub VALUES (9,9,'0');
-INSERT INTO num_exp_mul VALUES (9,9,'621345559900192.420120630048656400');
-INSERT INTO num_exp_div VALUES (9,9,'1.00000000000000000000');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO num_exp_add VALUES (0,0,0);
+INSERT INTO num_exp_sub VALUES (0,0,0);
+INSERT INTO num_exp_mul VALUES (0,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (0,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (0,1,0);
+INSERT INTO num_exp_sub VALUES (0,1,0);
+INSERT INTO num_exp_mul VALUES (0,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (0,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (0,2,-34338492.215397047);
+INSERT INTO num_exp_sub VALUES (0,2,34338492.215397047);
+INSERT INTO num_exp_mul VALUES (0,2,0);
+INSERT INTO num_exp_div VALUES (0,2,0);
+INSERT INTO num_exp_add VALUES (0,3,4.31);
+INSERT INTO num_exp_sub VALUES (0,3,-4.31);
+INSERT INTO num_exp_mul VALUES (0,3,0);
+INSERT INTO num_exp_div VALUES (0,3,0);
+INSERT INTO num_exp_add VALUES (0,4,7799461.4119);
+INSERT INTO num_exp_sub VALUES (0,4,-7799461.4119);
+INSERT INTO num_exp_mul VALUES (0,4,0);
+INSERT INTO num_exp_div VALUES (0,4,0);
+INSERT INTO num_exp_add VALUES (0,5,16397.038491);
+INSERT INTO num_exp_sub VALUES (0,5,-16397.038491);
+INSERT INTO num_exp_mul VALUES (0,5,0);
+INSERT INTO num_exp_div VALUES (0,5,0);
+INSERT INTO num_exp_add VALUES (0,6,93901.57763026);
+INSERT INTO num_exp_sub VALUES (0,6,-93901.57763026);
+INSERT INTO num_exp_mul VALUES (0,6,0);
+INSERT INTO num_exp_div VALUES (0,6,0);
+INSERT INTO num_exp_add VALUES (0,7,-83028485);
+INSERT INTO num_exp_sub VALUES (0,7,83028485);
+INSERT INTO num_exp_mul VALUES (0,7,0);
+INSERT INTO num_exp_div VALUES (0,7,0);
+INSERT INTO num_exp_add VALUES (0,8,74881);
+INSERT INTO num_exp_sub VALUES (0,8,-74881);
+INSERT INTO num_exp_mul VALUES (0,8,0);
+INSERT INTO num_exp_div VALUES (0,8,0);
+INSERT INTO num_exp_add VALUES (0,9,-24926804.045047420);
+INSERT INTO num_exp_sub VALUES (0,9,24926804.045047420);
+INSERT INTO num_exp_mul VALUES (0,9,0);
+INSERT INTO num_exp_div VALUES (0,9,0);
+INSERT INTO num_exp_add VALUES (1,0,0);
+INSERT INTO num_exp_sub VALUES (1,0,0);
+INSERT INTO num_exp_mul VALUES (1,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (1,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (1,1,0);
+INSERT INTO num_exp_sub VALUES (1,1,0);
+INSERT INTO num_exp_mul VALUES (1,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (1,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (1,2,-34338492.215397047);
+INSERT INTO num_exp_sub VALUES (1,2,34338492.215397047);
+INSERT INTO num_exp_mul VALUES (1,2,0);
+INSERT INTO num_exp_div VALUES (1,2,0);
+INSERT INTO num_exp_add VALUES (1,3,4.31);
+INSERT INTO num_exp_sub VALUES (1,3,-4.31);
+INSERT INTO num_exp_mul VALUES (1,3,0);
+INSERT INTO num_exp_div VALUES (1,3,0);
+INSERT INTO num_exp_add VALUES (1,4,7799461.4119);
+INSERT INTO num_exp_sub VALUES (1,4,-7799461.4119);
+INSERT INTO num_exp_mul VALUES (1,4,0);
+INSERT INTO num_exp_div VALUES (1,4,0);
+INSERT INTO num_exp_add VALUES (1,5,16397.038491);
+INSERT INTO num_exp_sub VALUES (1,5,-16397.038491);
+INSERT INTO num_exp_mul VALUES (1,5,0);
+INSERT INTO num_exp_div VALUES (1,5,0);
+INSERT INTO num_exp_add VALUES (1,6,93901.57763026);
+INSERT INTO num_exp_sub VALUES (1,6,-93901.57763026);
+INSERT INTO num_exp_mul VALUES (1,6,0);
+INSERT INTO num_exp_div VALUES (1,6,0);
+INSERT INTO num_exp_add VALUES (1,7,-83028485);
+INSERT INTO num_exp_sub VALUES (1,7,83028485);
+INSERT INTO num_exp_mul VALUES (1,7,0);
+INSERT INTO num_exp_div VALUES (1,7,0);
+INSERT INTO num_exp_add VALUES (1,8,74881);
+INSERT INTO num_exp_sub VALUES (1,8,-74881);
+INSERT INTO num_exp_mul VALUES (1,8,0);
+INSERT INTO num_exp_div VALUES (1,8,0);
+INSERT INTO num_exp_add VALUES (1,9,-24926804.045047420);
+INSERT INTO num_exp_sub VALUES (1,9,24926804.045047420);
+INSERT INTO num_exp_mul VALUES (1,9,0);
+INSERT INTO num_exp_div VALUES (1,9,0);
+INSERT INTO num_exp_add VALUES (2,0,-34338492.215397047);
+INSERT INTO num_exp_sub VALUES (2,0,-34338492.215397047);
+INSERT INTO num_exp_mul VALUES (2,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (2,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (2,1,-34338492.215397047);
+INSERT INTO num_exp_sub VALUES (2,1,-34338492.215397047);
+INSERT INTO num_exp_mul VALUES (2,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (2,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (2,2,-68676984.430794094);
+INSERT INTO num_exp_sub VALUES (2,2,0);
+INSERT INTO num_exp_mul VALUES (2,2,1179132047626883.596862135856320209);
+INSERT INTO num_exp_div VALUES (2,2,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (2,3,-34338487.905397047);
+INSERT INTO num_exp_sub VALUES (2,3,-34338496.525397047);
+INSERT INTO num_exp_mul VALUES (2,3,-147998901.44836127257);
+INSERT INTO num_exp_div VALUES (2,3,-7967167.56737750510440835266);
+INSERT INTO num_exp_add VALUES (2,4,-26539030.803497047);
+INSERT INTO num_exp_sub VALUES (2,4,-42137953.627297047);
+INSERT INTO num_exp_mul VALUES (2,4,-267821744976817.8111137106593);
+INSERT INTO num_exp_div VALUES (2,4,-4.40267480046830116685);
+INSERT INTO num_exp_add VALUES (2,5,-34322095.176906047);
+INSERT INTO num_exp_sub VALUES (2,5,-34354889.253888047);
+INSERT INTO num_exp_mul VALUES (2,5,-563049578578.769242506736077);
+INSERT INTO num_exp_div VALUES (2,5,-2094.18866914563535496429);
+INSERT INTO num_exp_add VALUES (2,6,-34244590.637766787);
+INSERT INTO num_exp_sub VALUES (2,6,-34432393.793027307);
+INSERT INTO num_exp_mul VALUES (2,6,-3224438592470.18449811926184222);
+INSERT INTO num_exp_div VALUES (2,6,-365.68599891479766440940);
+INSERT INTO num_exp_add VALUES (2,7,-117366977.215397047);
+INSERT INTO num_exp_sub VALUES (2,7,48689992.784602953);
+INSERT INTO num_exp_mul VALUES (2,7,2851072985828710.485883795);
+INSERT INTO num_exp_div VALUES (2,7,.41357483778485235518);
+INSERT INTO num_exp_add VALUES (2,8,-34263611.215397047);
+INSERT INTO num_exp_sub VALUES (2,8,-34413373.215397047);
+INSERT INTO num_exp_mul VALUES (2,8,-2571300635581.146276407);
+INSERT INTO num_exp_div VALUES (2,8,-458.57416721727870888476);
+INSERT INTO num_exp_add VALUES (2,9,-59265296.260444467);
+INSERT INTO num_exp_sub VALUES (2,9,-9411688.170349627);
+INSERT INTO num_exp_mul VALUES (2,9,855948866655588.453741509242968740);
+INSERT INTO num_exp_div VALUES (2,9,1.37757299946438931811);
+INSERT INTO num_exp_add VALUES (3,0,4.31);
+INSERT INTO num_exp_sub VALUES (3,0,4.31);
+INSERT INTO num_exp_mul VALUES (3,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (3,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (3,1,4.31);
+INSERT INTO num_exp_sub VALUES (3,1,4.31);
+INSERT INTO num_exp_mul VALUES (3,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (3,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (3,2,-34338487.905397047);
+INSERT INTO num_exp_sub VALUES (3,2,34338496.525397047);
+INSERT INTO num_exp_mul VALUES (3,2,-147998901.44836127257);
+INSERT INTO num_exp_div VALUES (3,2,-.00000012551512084352);
+INSERT INTO num_exp_add VALUES (3,3,8.62);
+INSERT INTO num_exp_sub VALUES (3,3,0);
+INSERT INTO num_exp_mul VALUES (3,3,18.5761);
+INSERT INTO num_exp_div VALUES (3,3,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (3,4,7799465.7219);
+INSERT INTO num_exp_sub VALUES (3,4,-7799457.1019);
+INSERT INTO num_exp_mul VALUES (3,4,33615678.685289);
+INSERT INTO num_exp_div VALUES (3,4,.00000055260225961552);
+INSERT INTO num_exp_add VALUES (3,5,16401.348491);
+INSERT INTO num_exp_sub VALUES (3,5,-16392.728491);
+INSERT INTO num_exp_mul VALUES (3,5,70671.23589621);
+INSERT INTO num_exp_div VALUES (3,5,.00026285234387695504);
+INSERT INTO num_exp_add VALUES (3,6,93905.88763026);
+INSERT INTO num_exp_sub VALUES (3,6,-93897.26763026);
+INSERT INTO num_exp_mul VALUES (3,6,404715.7995864206);
+INSERT INTO num_exp_div VALUES (3,6,.00004589912234457595);
+INSERT INTO num_exp_add VALUES (3,7,-83028480.69);
+INSERT INTO num_exp_sub VALUES (3,7,83028489.31);
+INSERT INTO num_exp_mul VALUES (3,7,-357852770.35);
+INSERT INTO num_exp_div VALUES (3,7,-.00000005190989574240);
+INSERT INTO num_exp_add VALUES (3,8,74885.31);
+INSERT INTO num_exp_sub VALUES (3,8,-74876.69);
+INSERT INTO num_exp_mul VALUES (3,8,322737.11);
+INSERT INTO num_exp_div VALUES (3,8,.00005755799201399553);
+INSERT INTO num_exp_add VALUES (3,9,-24926799.735047420);
+INSERT INTO num_exp_sub VALUES (3,9,24926808.355047420);
+INSERT INTO num_exp_mul VALUES (3,9,-107434525.43415438020);
+INSERT INTO num_exp_div VALUES (3,9,-.00000017290624149854);
+INSERT INTO num_exp_add VALUES (4,0,7799461.4119);
+INSERT INTO num_exp_sub VALUES (4,0,7799461.4119);
+INSERT INTO num_exp_mul VALUES (4,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (4,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (4,1,7799461.4119);
+INSERT INTO num_exp_sub VALUES (4,1,7799461.4119);
+INSERT INTO num_exp_mul VALUES (4,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (4,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (4,2,-26539030.803497047);
+INSERT INTO num_exp_sub VALUES (4,2,42137953.627297047);
+INSERT INTO num_exp_mul VALUES (4,2,-267821744976817.8111137106593);
+INSERT INTO num_exp_div VALUES (4,2,-.22713465002993920385);
+INSERT INTO num_exp_add VALUES (4,3,7799465.7219);
+INSERT INTO num_exp_sub VALUES (4,3,7799457.1019);
+INSERT INTO num_exp_mul VALUES (4,3,33615678.685289);
+INSERT INTO num_exp_div VALUES (4,3,1809619.81714617169373549883);
+INSERT INTO num_exp_add VALUES (4,4,15598922.8238);
+INSERT INTO num_exp_sub VALUES (4,4,0);
+INSERT INTO num_exp_mul VALUES (4,4,60831598315717.14146161);
+INSERT INTO num_exp_div VALUES (4,4,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (4,5,7815858.450391);
+INSERT INTO num_exp_sub VALUES (4,5,7783064.373409);
+INSERT INTO num_exp_mul VALUES (4,5,127888068979.9935054429);
+INSERT INTO num_exp_div VALUES (4,5,475.66281046305802686061);
+INSERT INTO num_exp_add VALUES (4,6,7893362.98953026);
+INSERT INTO num_exp_sub VALUES (4,6,7705559.83426974);
+INSERT INTO num_exp_mul VALUES (4,6,732381731243.745115764094);
+INSERT INTO num_exp_div VALUES (4,6,83.05996138436129499606);
+INSERT INTO num_exp_add VALUES (4,7,-75229023.5881);
+INSERT INTO num_exp_sub VALUES (4,7,90827946.4119);
+INSERT INTO num_exp_mul VALUES (4,7,-647577464846017.9715);
+INSERT INTO num_exp_div VALUES (4,7,-.09393717604145131637);
+INSERT INTO num_exp_add VALUES (4,8,7874342.4119);
+INSERT INTO num_exp_sub VALUES (4,8,7724580.4119);
+INSERT INTO num_exp_mul VALUES (4,8,584031469984.4839);
+INSERT INTO num_exp_div VALUES (4,8,104.15808298366741897143);
+INSERT INTO num_exp_add VALUES (4,9,-17127342.633147420);
+INSERT INTO num_exp_sub VALUES (4,9,32726265.456947420);
+INSERT INTO num_exp_mul VALUES (4,9,-194415646271340.1815956522980);
+INSERT INTO num_exp_div VALUES (4,9,-.31289456112403769409);
+INSERT INTO num_exp_add VALUES (5,0,16397.038491);
+INSERT INTO num_exp_sub VALUES (5,0,16397.038491);
+INSERT INTO num_exp_mul VALUES (5,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (5,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (5,1,16397.038491);
+INSERT INTO num_exp_sub VALUES (5,1,16397.038491);
+INSERT INTO num_exp_mul VALUES (5,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (5,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (5,2,-34322095.176906047);
+INSERT INTO num_exp_sub VALUES (5,2,34354889.253888047);
+INSERT INTO num_exp_mul VALUES (5,2,-563049578578.769242506736077);
+INSERT INTO num_exp_div VALUES (5,2,-.00047751189505192446);
+INSERT INTO num_exp_add VALUES (5,3,16401.348491);
+INSERT INTO num_exp_sub VALUES (5,3,16392.728491);
+INSERT INTO num_exp_mul VALUES (5,3,70671.23589621);
+INSERT INTO num_exp_div VALUES (5,3,3804.41728329466357308584);
+INSERT INTO num_exp_add VALUES (5,4,7815858.450391);
+INSERT INTO num_exp_sub VALUES (5,4,-7783064.373409);
+INSERT INTO num_exp_mul VALUES (5,4,127888068979.9935054429);
+INSERT INTO num_exp_div VALUES (5,4,.00210232958726897192);
+INSERT INTO num_exp_add VALUES (5,5,32794.076982);
+INSERT INTO num_exp_sub VALUES (5,5,0);
+INSERT INTO num_exp_mul VALUES (5,5,268862871.275335557081);
+INSERT INTO num_exp_div VALUES (5,5,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (5,6,110298.61612126);
+INSERT INTO num_exp_sub VALUES (5,6,-77504.53913926);
+INSERT INTO num_exp_mul VALUES (5,6,1539707782.76899778633766);
+INSERT INTO num_exp_div VALUES (5,6,.17461941433576102689);
+INSERT INTO num_exp_add VALUES (5,7,-83012087.961509);
+INSERT INTO num_exp_sub VALUES (5,7,83044882.038491);
+INSERT INTO num_exp_mul VALUES (5,7,-1361421264394.416135);
+INSERT INTO num_exp_div VALUES (5,7,-.00019748690453643710);
+INSERT INTO num_exp_add VALUES (5,8,91278.038491);
+INSERT INTO num_exp_sub VALUES (5,8,-58483.961509);
+INSERT INTO num_exp_mul VALUES (5,8,1227826639.244571);
+INSERT INTO num_exp_div VALUES (5,8,.21897461960978085228);
+INSERT INTO num_exp_add VALUES (5,9,-24910407.006556420);
+INSERT INTO num_exp_sub VALUES (5,9,24943201.083538420);
+INSERT INTO num_exp_mul VALUES (5,9,-408725765384.257043660243220);
+INSERT INTO num_exp_div VALUES (5,9,-.00065780749354660427);
+INSERT INTO num_exp_add VALUES (6,0,93901.57763026);
+INSERT INTO num_exp_sub VALUES (6,0,93901.57763026);
+INSERT INTO num_exp_mul VALUES (6,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (6,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (6,1,93901.57763026);
+INSERT INTO num_exp_sub VALUES (6,1,93901.57763026);
+INSERT INTO num_exp_mul VALUES (6,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (6,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (6,2,-34244590.637766787);
+INSERT INTO num_exp_sub VALUES (6,2,34432393.793027307);
+INSERT INTO num_exp_mul VALUES (6,2,-3224438592470.18449811926184222);
+INSERT INTO num_exp_div VALUES (6,2,-.00273458651128995823);
+INSERT INTO num_exp_add VALUES (6,3,93905.88763026);
+INSERT INTO num_exp_sub VALUES (6,3,93897.26763026);
+INSERT INTO num_exp_mul VALUES (6,3,404715.7995864206);
+INSERT INTO num_exp_div VALUES (6,3,21786.90896293735498839907);
+INSERT INTO num_exp_add VALUES (6,4,7893362.98953026);
+INSERT INTO num_exp_sub VALUES (6,4,-7705559.83426974);
+INSERT INTO num_exp_mul VALUES (6,4,732381731243.745115764094);
+INSERT INTO num_exp_div VALUES (6,4,.01203949512295682469);
+INSERT INTO num_exp_add VALUES (6,5,110298.61612126);
+INSERT INTO num_exp_sub VALUES (6,5,77504.53913926);
+INSERT INTO num_exp_mul VALUES (6,5,1539707782.76899778633766);
+INSERT INTO num_exp_div VALUES (6,5,5.72674008674192359679);
+INSERT INTO num_exp_add VALUES (6,6,187803.15526052);
+INSERT INTO num_exp_sub VALUES (6,6,0);
+INSERT INTO num_exp_mul VALUES (6,6,8817506281.4517452372676676);
+INSERT INTO num_exp_div VALUES (6,6,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (6,7,-82934583.42236974);
+INSERT INTO num_exp_sub VALUES (6,7,83122386.57763026);
+INSERT INTO num_exp_mul VALUES (6,7,-7796505729750.37795610);
+INSERT INTO num_exp_div VALUES (6,7,-.00113095617281538980);
+INSERT INTO num_exp_add VALUES (6,8,168782.57763026);
+INSERT INTO num_exp_sub VALUES (6,8,19020.57763026);
+INSERT INTO num_exp_mul VALUES (6,8,7031444034.53149906);
+INSERT INTO num_exp_div VALUES (6,8,1.25401073209839612184);
+INSERT INTO num_exp_add VALUES (6,9,-24832902.467417160);
+INSERT INTO num_exp_sub VALUES (6,9,25020705.622677680);
+INSERT INTO num_exp_mul VALUES (6,9,-2340666225110.29929521292692920);
+INSERT INTO num_exp_div VALUES (6,9,-.00376709254265256789);
+INSERT INTO num_exp_add VALUES (7,0,-83028485);
+INSERT INTO num_exp_sub VALUES (7,0,-83028485);
+INSERT INTO num_exp_mul VALUES (7,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (7,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (7,1,-83028485);
+INSERT INTO num_exp_sub VALUES (7,1,-83028485);
+INSERT INTO num_exp_mul VALUES (7,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (7,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (7,2,-117366977.215397047);
+INSERT INTO num_exp_sub VALUES (7,2,-48689992.784602953);
+INSERT INTO num_exp_mul VALUES (7,2,2851072985828710.485883795);
+INSERT INTO num_exp_div VALUES (7,2,2.41794207151503385700);
+INSERT INTO num_exp_add VALUES (7,3,-83028480.69);
+INSERT INTO num_exp_sub VALUES (7,3,-83028489.31);
+INSERT INTO num_exp_mul VALUES (7,3,-357852770.35);
+INSERT INTO num_exp_div VALUES (7,3,-19264149.65197215777262180974);
+INSERT INTO num_exp_add VALUES (7,4,-75229023.5881);
+INSERT INTO num_exp_sub VALUES (7,4,-90827946.4119);
+INSERT INTO num_exp_mul VALUES (7,4,-647577464846017.9715);
+INSERT INTO num_exp_div VALUES (7,4,-10.64541262725136247686);
+INSERT INTO num_exp_add VALUES (7,5,-83012087.961509);
+INSERT INTO num_exp_sub VALUES (7,5,-83044882.038491);
+INSERT INTO num_exp_mul VALUES (7,5,-1361421264394.416135);
+INSERT INTO num_exp_div VALUES (7,5,-5063.62688881730941836574);
+INSERT INTO num_exp_add VALUES (7,6,-82934583.42236974);
+INSERT INTO num_exp_sub VALUES (7,6,-83122386.57763026);
+INSERT INTO num_exp_mul VALUES (7,6,-7796505729750.37795610);
+INSERT INTO num_exp_div VALUES (7,6,-884.20756174009028770294);
+INSERT INTO num_exp_add VALUES (7,7,-166056970);
+INSERT INTO num_exp_sub VALUES (7,7,0);
+INSERT INTO num_exp_mul VALUES (7,7,6893729321395225);
+INSERT INTO num_exp_div VALUES (7,7,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (7,8,-82953604);
+INSERT INTO num_exp_sub VALUES (7,8,-83103366);
+INSERT INTO num_exp_mul VALUES (7,8,-6217255985285);
+INSERT INTO num_exp_div VALUES (7,8,-1108.80577182462841041118);
+INSERT INTO num_exp_add VALUES (7,9,-107955289.045047420);
+INSERT INTO num_exp_sub VALUES (7,9,-58101680.954952580);
+INSERT INTO num_exp_mul VALUES (7,9,2069634775752159.035758700);
+INSERT INTO num_exp_div VALUES (7,9,3.33089171198810413382);
+INSERT INTO num_exp_add VALUES (8,0,74881);
+INSERT INTO num_exp_sub VALUES (8,0,74881);
+INSERT INTO num_exp_mul VALUES (8,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (8,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (8,1,74881);
+INSERT INTO num_exp_sub VALUES (8,1,74881);
+INSERT INTO num_exp_mul VALUES (8,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (8,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (8,2,-34263611.215397047);
+INSERT INTO num_exp_sub VALUES (8,2,34413373.215397047);
+INSERT INTO num_exp_mul VALUES (8,2,-2571300635581.146276407);
+INSERT INTO num_exp_div VALUES (8,2,-.00218067233500788615);
+INSERT INTO num_exp_add VALUES (8,3,74885.31);
+INSERT INTO num_exp_sub VALUES (8,3,74876.69);
+INSERT INTO num_exp_mul VALUES (8,3,322737.11);
+INSERT INTO num_exp_div VALUES (8,3,17373.78190255220417633410);
+INSERT INTO num_exp_add VALUES (8,4,7874342.4119);
+INSERT INTO num_exp_sub VALUES (8,4,-7724580.4119);
+INSERT INTO num_exp_mul VALUES (8,4,584031469984.4839);
+INSERT INTO num_exp_div VALUES (8,4,.00960079113741758956);
+INSERT INTO num_exp_add VALUES (8,5,91278.038491);
+INSERT INTO num_exp_sub VALUES (8,5,58483.961509);
+INSERT INTO num_exp_mul VALUES (8,5,1227826639.244571);
+INSERT INTO num_exp_div VALUES (8,5,4.56673929509287019456);
+INSERT INTO num_exp_add VALUES (8,6,168782.57763026);
+INSERT INTO num_exp_sub VALUES (8,6,-19020.57763026);
+INSERT INTO num_exp_mul VALUES (8,6,7031444034.53149906);
+INSERT INTO num_exp_div VALUES (8,6,.79744134113322314424);
+INSERT INTO num_exp_add VALUES (8,7,-82953604);
+INSERT INTO num_exp_sub VALUES (8,7,83103366);
+INSERT INTO num_exp_mul VALUES (8,7,-6217255985285);
+INSERT INTO num_exp_div VALUES (8,7,-.00090187120721280172);
+INSERT INTO num_exp_add VALUES (8,8,149762);
+INSERT INTO num_exp_sub VALUES (8,8,0);
+INSERT INTO num_exp_mul VALUES (8,8,5607164161);
+INSERT INTO num_exp_div VALUES (8,8,1.00000000000000000000);
+INSERT INTO num_exp_add VALUES (8,9,-24851923.045047420);
+INSERT INTO num_exp_sub VALUES (8,9,25001685.045047420);
+INSERT INTO num_exp_mul VALUES (8,9,-1866544013697.195857020);
+INSERT INTO num_exp_div VALUES (8,9,-.00300403532938582735);
+INSERT INTO num_exp_add VALUES (9,0,-24926804.045047420);
+INSERT INTO num_exp_sub VALUES (9,0,-24926804.045047420);
+INSERT INTO num_exp_mul VALUES (9,0,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (9,0,double('NaN'));
+INSERT INTO num_exp_add VALUES (9,1,-24926804.045047420);
+INSERT INTO num_exp_sub VALUES (9,1,-24926804.045047420);
+INSERT INTO num_exp_mul VALUES (9,1,0);
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_div VALUES (9,1,double('NaN'));
+INSERT INTO num_exp_add VALUES (9,2,-59265296.260444467);
+INSERT INTO num_exp_sub VALUES (9,2,9411688.170349627);
+INSERT INTO num_exp_mul VALUES (9,2,855948866655588.453741509242968740);
+INSERT INTO num_exp_div VALUES (9,2,.72591434384152961526);
+INSERT INTO num_exp_add VALUES (9,3,-24926799.735047420);
+INSERT INTO num_exp_sub VALUES (9,3,-24926808.355047420);
+INSERT INTO num_exp_mul VALUES (9,3,-107434525.43415438020);
+INSERT INTO num_exp_div VALUES (9,3,-5783481.21694835730858468677);
+INSERT INTO num_exp_add VALUES (9,4,-17127342.633147420);
+INSERT INTO num_exp_sub VALUES (9,4,-32726265.456947420);
+INSERT INTO num_exp_mul VALUES (9,4,-194415646271340.1815956522980);
+INSERT INTO num_exp_div VALUES (9,4,-3.19596478892958416484);
+INSERT INTO num_exp_add VALUES (9,5,-24910407.006556420);
+INSERT INTO num_exp_sub VALUES (9,5,-24943201.083538420);
+INSERT INTO num_exp_mul VALUES (9,5,-408725765384.257043660243220);
+INSERT INTO num_exp_div VALUES (9,5,-1520.20159364322004505807);
+INSERT INTO num_exp_add VALUES (9,6,-24832902.467417160);
+INSERT INTO num_exp_sub VALUES (9,6,-25020705.622677680);
+INSERT INTO num_exp_mul VALUES (9,6,-2340666225110.29929521292692920);
+INSERT INTO num_exp_div VALUES (9,6,-265.45671195426965751280);
+INSERT INTO num_exp_add VALUES (9,7,-107955289.045047420);
+INSERT INTO num_exp_sub VALUES (9,7,58101680.954952580);
+INSERT INTO num_exp_mul VALUES (9,7,2069634775752159.035758700);
+INSERT INTO num_exp_div VALUES (9,7,.30021990699995814689);
+INSERT INTO num_exp_add VALUES (9,8,-24851923.045047420);
+INSERT INTO num_exp_sub VALUES (9,8,-25001685.045047420);
+INSERT INTO num_exp_mul VALUES (9,8,-1866544013697.195857020);
+INSERT INTO num_exp_div VALUES (9,8,-332.88556569820675471748);
+INSERT INTO num_exp_add VALUES (9,9,-49853608.090094840);
+INSERT INTO num_exp_sub VALUES (9,9,0);
+INSERT INTO num_exp_mul VALUES (9,9,621345559900192.420120630048656400);
+INSERT INTO num_exp_div VALUES (9,9,1.00000000000000000000);
-- COMMIT TRANSACTION;
-- BEGIN TRANSACTION;
-INSERT INTO num_exp_sqrt VALUES (0,'0');
-INSERT INTO num_exp_sqrt VALUES (1,'0');
-INSERT INTO num_exp_sqrt VALUES (2,'5859.90547836712524903505');
-INSERT INTO num_exp_sqrt VALUES (3,'2.07605394920266944396');
-INSERT INTO num_exp_sqrt VALUES (4,'2792.75158435189147418923');
-INSERT INTO num_exp_sqrt VALUES (5,'128.05092147657509145473');
-INSERT INTO num_exp_sqrt VALUES (6,'306.43364311096782703406');
-INSERT INTO num_exp_sqrt VALUES (7,'9111.99676251039939975230');
-INSERT INTO num_exp_sqrt VALUES (8,'273.64392922189960397542');
-INSERT INTO num_exp_sqrt VALUES (9,'4992.67503899937593364766');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO num_exp_sqrt VALUES (0,0);
+INSERT INTO num_exp_sqrt VALUES (1,0);
+INSERT INTO num_exp_sqrt VALUES (2,5859.90547836712524903505);
+INSERT INTO num_exp_sqrt VALUES (3,2.07605394920266944396);
+INSERT INTO num_exp_sqrt VALUES (4,2792.75158435189147418923);
+INSERT INTO num_exp_sqrt VALUES (5,128.05092147657509145473);
+INSERT INTO num_exp_sqrt VALUES (6,306.43364311096782703406);
+INSERT INTO num_exp_sqrt VALUES (7,9111.99676251039939975230);
+INSERT INTO num_exp_sqrt VALUES (8,273.64392922189960397542);
+INSERT INTO num_exp_sqrt VALUES (9,4992.67503899937593364766);
-- COMMIT TRANSACTION;
-- BEGIN TRANSACTION;
-INSERT INTO num_exp_ln VALUES (0,'NaN');
-INSERT INTO num_exp_ln VALUES (1,'NaN');
-INSERT INTO num_exp_ln VALUES (2,'17.35177750493897715514');
-INSERT INTO num_exp_ln VALUES (3,'1.46093790411565641971');
-INSERT INTO num_exp_ln VALUES (4,'15.86956523951936572464');
-INSERT INTO num_exp_ln VALUES (5,'9.70485601768871834038');
-INSERT INTO num_exp_ln VALUES (6,'11.45000246622944403127');
-INSERT INTO num_exp_ln VALUES (7,'18.23469429965478772991');
-INSERT INTO num_exp_ln VALUES (8,'11.22365546576315513668');
-INSERT INTO num_exp_ln VALUES (9,'17.03145425013166006962');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_ln VALUES (0,double('NaN'));
+INSERT INTO num_exp_ln VALUES (1,double('NaN'));
+INSERT INTO num_exp_ln VALUES (2,17.35177750493897715514);
+INSERT INTO num_exp_ln VALUES (3,1.46093790411565641971);
+INSERT INTO num_exp_ln VALUES (4,15.86956523951936572464);
+INSERT INTO num_exp_ln VALUES (5,9.70485601768871834038);
+INSERT INTO num_exp_ln VALUES (6,11.45000246622944403127);
+INSERT INTO num_exp_ln VALUES (7,18.23469429965478772991);
+INSERT INTO num_exp_ln VALUES (8,11.22365546576315513668);
+INSERT INTO num_exp_ln VALUES (9,17.03145425013166006962);
-- COMMIT TRANSACTION;
-- BEGIN TRANSACTION;
-INSERT INTO num_exp_log10 VALUES (0,'NaN');
-INSERT INTO num_exp_log10 VALUES (1,'NaN');
-INSERT INTO num_exp_log10 VALUES (2,'7.53578122160797276459');
-INSERT INTO num_exp_log10 VALUES (3,'.63447727016073160075');
-INSERT INTO num_exp_log10 VALUES (4,'6.89206461372691743345');
-INSERT INTO num_exp_log10 VALUES (5,'4.21476541614777768626');
-INSERT INTO num_exp_log10 VALUES (6,'4.97267288886207207671');
-INSERT INTO num_exp_log10 VALUES (7,'7.91922711353275546914');
-INSERT INTO num_exp_log10 VALUES (8,'4.87437163556421004138');
-INSERT INTO num_exp_log10 VALUES (9,'7.39666659961986567059');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_log10 VALUES (0,double('NaN'));
+INSERT INTO num_exp_log10 VALUES (1,double('NaN'));
+INSERT INTO num_exp_log10 VALUES (2,7.53578122160797276459);
+INSERT INTO num_exp_log10 VALUES (3,.63447727016073160075);
+INSERT INTO num_exp_log10 VALUES (4,6.89206461372691743345);
+INSERT INTO num_exp_log10 VALUES (5,4.21476541614777768626);
+INSERT INTO num_exp_log10 VALUES (6,4.97267288886207207671);
+INSERT INTO num_exp_log10 VALUES (7,7.91922711353275546914);
+INSERT INTO num_exp_log10 VALUES (8,4.87437163556421004138);
+INSERT INTO num_exp_log10 VALUES (9,7.39666659961986567059);
-- COMMIT TRANSACTION;
-- BEGIN TRANSACTION;
-INSERT INTO num_exp_power_10_ln VALUES (0,'NaN');
-INSERT INTO num_exp_power_10_ln VALUES (1,'NaN');
-INSERT INTO num_exp_power_10_ln VALUES (2,'224790267919917955.13261618583642653184');
-INSERT INTO num_exp_power_10_ln VALUES (3,'28.90266599445155957393');
-INSERT INTO num_exp_power_10_ln VALUES (4,'7405685069594999.07733999469386277636');
-INSERT INTO num_exp_power_10_ln VALUES (5,'5068226527.32127265408584640098');
-INSERT INTO num_exp_power_10_ln VALUES (6,'281839893606.99372343357047819067');
-INSERT INTO num_exp_power_10_ln VALUES (7,'1716699575118597095.42330819910640247627');
-INSERT INTO num_exp_power_10_ln VALUES (8,'167361463828.07491320069016125952');
-INSERT INTO num_exp_power_10_ln VALUES (9,'107511333880052007.04141124673540337457');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+-- [SPARK-28315] Decimal can not accept NaN as input
+INSERT INTO num_exp_power_10_ln VALUES (0,double('NaN'));
+INSERT INTO num_exp_power_10_ln VALUES (1,double('NaN'));
+INSERT INTO num_exp_power_10_ln VALUES (2,224790267919917955.13261618583642653184);
+INSERT INTO num_exp_power_10_ln VALUES (3,28.90266599445155957393);
+INSERT INTO num_exp_power_10_ln VALUES (4,7405685069594999.07733999469386277636);
+INSERT INTO num_exp_power_10_ln VALUES (5,5068226527.32127265408584640098);
+INSERT INTO num_exp_power_10_ln VALUES (6,281839893606.99372343357047819067);
+-- In Spark, decimal can only support precision up to 38
+INSERT INTO num_exp_power_10_ln VALUES (7,1716699575118597095.42330819910640247627);
+INSERT INTO num_exp_power_10_ln VALUES (8,167361463828.07491320069016125952);
+INSERT INTO num_exp_power_10_ln VALUES (9,107511333880052007.04141124673540337457);
-- COMMIT TRANSACTION;
-- BEGIN TRANSACTION;
-INSERT INTO num_data VALUES (0, '0');
-INSERT INTO num_data VALUES (1, '0');
-INSERT INTO num_data VALUES (2, '-34338492.215397047');
-INSERT INTO num_data VALUES (3, '4.31');
-INSERT INTO num_data VALUES (4, '7799461.4119');
-INSERT INTO num_data VALUES (5, '16397.038491');
-INSERT INTO num_data VALUES (6, '93901.57763026');
-INSERT INTO num_data VALUES (7, '-83028485');
-INSERT INTO num_data VALUES (8, '74881');
-INSERT INTO num_data VALUES (9, '-24926804.045047420');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO num_data VALUES (0, 0);
+INSERT INTO num_data VALUES (1, 0);
+INSERT INTO num_data VALUES (2, -34338492.215397047);
+INSERT INTO num_data VALUES (3, 4.31);
+INSERT INTO num_data VALUES (4, 7799461.4119);
+INSERT INTO num_data VALUES (5, 16397.038491);
+INSERT INTO num_data VALUES (6, 93901.57763026);
+INSERT INTO num_data VALUES (7, -83028485);
+INSERT INTO num_data VALUES (8, 74881);
+INSERT INTO num_data VALUES (9, -24926804.045047420);
-- COMMIT TRANSACTION;
SELECT * FROM num_data;
@@ -657,16 +699,22 @@ SELECT AVG(val) FROM num_data;
-- Check for appropriate rounding and overflow
CREATE TABLE fract_only (id int, val decimal(4,4)) USING parquet;
-INSERT INTO fract_only VALUES (1, '0.0');
-INSERT INTO fract_only VALUES (2, '0.1');
+INSERT INTO fract_only VALUES (1, 0.0);
+INSERT INTO fract_only VALUES (2, 0.1);
-- [SPARK-27923] PostgreSQL throws an exception but Spark SQL is NULL
-- INSERT INTO fract_only VALUES (3, '1.0'); -- should fail
-INSERT INTO fract_only VALUES (4, '-0.9999');
-INSERT INTO fract_only VALUES (5, '0.99994');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO fract_only VALUES (4, -0.9999);
+INSERT INTO fract_only VALUES (5, 0.99994);
-- [SPARK-27923] PostgreSQL throws an exception but Spark SQL is NULL
-- INSERT INTO fract_only VALUES (6, '0.99995'); -- should fail
-INSERT INTO fract_only VALUES (7, '0.00001');
-INSERT INTO fract_only VALUES (8, '0.00017');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO fract_only VALUES (7, 0.00001);
+INSERT INTO fract_only VALUES (8, 0.00017);
SELECT * FROM fract_only;
DROP TABLE fract_only;
@@ -682,13 +730,16 @@ SELECT decimal(float('-Infinity'));
-- Simple check that ceil(), floor(), and round() work correctly
CREATE TABLE ceil_floor_round (a decimal(38, 18)) USING parquet;
-INSERT INTO ceil_floor_round VALUES ('-5.5');
-INSERT INTO ceil_floor_round VALUES ('-5.499999');
-INSERT INTO ceil_floor_round VALUES ('9.5');
-INSERT INTO ceil_floor_round VALUES ('9.4999999');
-INSERT INTO ceil_floor_round VALUES ('0.0');
-INSERT INTO ceil_floor_round VALUES ('0.0000001');
-INSERT INTO ceil_floor_round VALUES ('-0.000001');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO ceil_floor_round VALUES (-5.5);
+INSERT INTO ceil_floor_round VALUES (-5.499999);
+INSERT INTO ceil_floor_round VALUES (9.5);
+INSERT INTO ceil_floor_round VALUES (9.4999999);
+INSERT INTO ceil_floor_round VALUES (0.0);
+INSERT INTO ceil_floor_round VALUES (0.0000001);
+INSERT INTO ceil_floor_round VALUES (-0.000001);
SELECT a, ceil(a), ceiling(a), floor(a), round(a) FROM ceil_floor_round;
DROP TABLE ceil_floor_round;
@@ -853,11 +904,14 @@ DROP TABLE ceil_floor_round;
CREATE TABLE num_input_test (n1 decimal(38, 18)) USING parquet;
-- good inputs
-INSERT INTO num_input_test VALUES (trim(' 123'));
-INSERT INTO num_input_test VALUES (trim(' 3245874 '));
-INSERT INTO num_input_test VALUES (trim(' -93853'));
-INSERT INTO num_input_test VALUES ('555.50');
-INSERT INTO num_input_test VALUES ('-555.50');
+-- PostgreSQL implicitly casts string literals to data with decimal types, but
+-- Spark does not support that kind of implicit casts. To test all the INSERT queries below,
+-- we rewrote them into the other typed literals.
+INSERT INTO num_input_test VALUES (double(trim(' 123')));
+INSERT INTO num_input_test VALUES (double(trim(' 3245874 ')));
+INSERT INTO num_input_test VALUES (double(trim(' -93853')));
+INSERT INTO num_input_test VALUES (555.50);
+INSERT INTO num_input_test VALUES (-555.50);
-- [SPARK-28315] Decimal can not accept NaN as input
-- INSERT INTO num_input_test VALUES (trim('NaN '));
-- INSERT INTO num_input_test VALUES (trim(' nan'));
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql
index 7abf903bc6bee..05953123da86f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/text.sql
@@ -44,11 +44,7 @@ select concat_ws(',',10,20,null,30);
select concat_ws('',10,20,null,30);
select concat_ws(NULL,10,20,null,30) is null;
select reverse('abcde');
--- [SPARK-28036] Built-in udf left/right has inconsistent behavior
--- [SPARK-28479][SPARK-28989] Parser error when enabling ANSI mode
-set spark.sql.ansi.enabled=false;
select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i;
-set spark.sql.ansi.enabled=true;
-- [SPARK-28037] Add built-in String Functions: quote_literal
-- select quote_literal('');
-- select quote_literal('abc''');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/timestamp.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/timestamp.sql
index 260e8ea93d22d..bf69da295a960 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/timestamp.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/timestamp.sql
@@ -16,19 +16,23 @@ CREATE TABLE TIMESTAMP_TBL (d1 timestamp) USING parquet;
-- block is entered exactly at local midnight; then 'now' and 'today' have
-- the same values and the counts will come out different.
-INSERT INTO TIMESTAMP_TBL VALUES ('now');
+-- PostgreSQL implicitly casts string literals to data with timestamp types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('now'));
-- SELECT pg_sleep(0.1);
-- BEGIN;
-INSERT INTO TIMESTAMP_TBL VALUES ('now');
-INSERT INTO TIMESTAMP_TBL VALUES ('today');
-INSERT INTO TIMESTAMP_TBL VALUES ('yesterday');
-INSERT INTO TIMESTAMP_TBL VALUES ('tomorrow');
+-- PostgreSQL implicitly casts string literals to data with timestamp types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('now'));
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('today'));
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('yesterday'));
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('tomorrow'));
-- time zone should be ignored by this data type
-INSERT INTO TIMESTAMP_TBL VALUES ('tomorrow EST');
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('tomorrow EST'));
-- [SPARK-29024] Ignore case while resolving time zones
-INSERT INTO TIMESTAMP_TBL VALUES ('tomorrow Zulu');
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('tomorrow Zulu'));
SELECT count(*) AS One FROM TIMESTAMP_TBL WHERE d1 = timestamp 'today';
SELECT count(*) AS Three FROM TIMESTAMP_TBL WHERE d1 = timestamp 'tomorrow';
@@ -54,7 +58,9 @@ TRUNCATE TABLE TIMESTAMP_TBL;
-- Special values
-- INSERT INTO TIMESTAMP_TBL VALUES ('-infinity');
-- INSERT INTO TIMESTAMP_TBL VALUES ('infinity');
-INSERT INTO TIMESTAMP_TBL VALUES ('epoch');
+-- PostgreSQL implicitly casts string literals to data with timestamp types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('epoch'));
-- [SPARK-27923] Spark SQL insert there obsolete special values to NULL
-- Obsolete special values
-- INSERT INTO TIMESTAMP_TBL VALUES ('invalid');
@@ -73,14 +79,16 @@ INSERT INTO TIMESTAMP_TBL VALUES ('epoch');
-- INSERT INTO TIMESTAMP_TBL VALUES ('Mon Feb 10 17:32:01.6 1997 PST');
-- ISO 8601 format
-INSERT INTO TIMESTAMP_TBL VALUES ('1997-01-02');
-INSERT INTO TIMESTAMP_TBL VALUES ('1997-01-02 03:04:05');
-INSERT INTO TIMESTAMP_TBL VALUES ('1997-02-10 17:32:01-08');
+-- PostgreSQL implicitly casts string literals to data with timestamp types, but
+-- Spark does not support that kind of implicit casts.
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('1997-01-02'));
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('1997-01-02 03:04:05'));
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('1997-02-10 17:32:01-08'));
-- INSERT INTO TIMESTAMP_TBL VALUES ('1997-02-10 17:32:01-0800');
-- INSERT INTO TIMESTAMP_TBL VALUES ('1997-02-10 17:32:01 -08:00');
-- INSERT INTO TIMESTAMP_TBL VALUES ('19970210 173201 -0800');
-- INSERT INTO TIMESTAMP_TBL VALUES ('1997-06-10 17:32:01 -07:00');
-INSERT INTO TIMESTAMP_TBL VALUES ('2001-09-22T18:19:20');
+INSERT INTO TIMESTAMP_TBL VALUES (timestamp('2001-09-22T18:19:20'));
-- POSIX format (note that the timezone abbrev is just decoration here)
-- INSERT INTO TIMESTAMP_TBL VALUES ('2000-03-15 08:14:01 GMT+8');
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql
index ae2a015ada245..087d7a5befd19 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part1.sql
@@ -3,6 +3,11 @@
-- Window Functions Testing
-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L1-L319
+-- Test window operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE TEMPORARY VIEW tenk2 AS SELECT * FROM tenk1;
-- [SPARK-29540] Thrift in some cases can't parse string to date
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql
index 728e8cab0c3ba..395149e48d5c8 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part2.sql
@@ -3,6 +3,11 @@
-- Window Functions Testing
-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L320-562
+-- Test window operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE TABLE empsalary (
depname string,
empno integer,
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
index 205c7d391a973..8187f8a2773ff 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part3.sql
@@ -3,6 +3,11 @@
-- Window Functions Testing
-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L564-L911
+-- Test window operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE TEMPORARY VIEW tenk2 AS SELECT * FROM tenk1;
CREATE TABLE empsalary (
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql
index 456b390fca6c3..64ba8e3b7a5ad 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/window_part4.sql
@@ -3,6 +3,11 @@
-- Window Functions Testing
-- https://github.com/postgres/postgres/blob/REL_12_STABLE/src/test/regress/sql/window.sql#L913-L1278
+-- Test window operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- Spark doesn't handle UDFs in SQL
-- test user-defined window function with named args and default args
-- CREATE FUNCTION nth_value_def(val anyelement, n integer = 1) RETURNS anyelement
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
index b5f458f2cb184..ae6a9641aae66 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-aggregate.sql
@@ -1,5 +1,10 @@
-- Tests aggregate expressions in outer query and EXISTS subquery.
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES
(100, "emp 1", date "2005-01-01", 100.00D, 10),
(100, "emp 1", date "2005-01-01", 100.00D, 10),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql
index cefc3fe6272ab..667573b30d265 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-joins-and-set-ops.sql
@@ -1,9 +1,17 @@
-- Tests EXISTS subquery support. Tests Exists subquery
-- used in Joins (Both when joins occurs in outer and suquery blocks)
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+-- There are 2 dimensions we want to test
+-- 1. run with broadcast hash join, sort merge join or shuffle hash join.
+-- 2. run with whole-stage-codegen, operator codegen or no codegen.
+
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES
(100, "emp 1", date "2005-01-01", 100.00D, 10),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-orderby-limit.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-orderby-limit.sql
index 19fc18833760c..580fc1d4162eb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-orderby-limit.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/exists-subquery/exists-orderby-limit.sql
@@ -1,5 +1,10 @@
-- Tests EXISTS subquery support with ORDER BY and LIMIT clauses.
+-- Test sort operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE TEMPORARY VIEW EMP AS SELECT * FROM VALUES
(100, "emp 1", date "2005-01-01", 100.00D, 10),
(100, "emp 1", date "2005-01-01", 100.00D, 10),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql
index b1d96b32c2478..b06e1cccca5ab 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-group-by.sql
@@ -1,6 +1,11 @@
-- A test suite for GROUP BY in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
create temporary view t1 as select * from values
("t1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
("t1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql
index cd350a98e130b..200a71ebbb622 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql
@@ -1,9 +1,17 @@
-- A test suite for IN JOINS in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+-- There are 2 dimensions we want to test
+-- 1. run with broadcast hash join, sort merge join or shuffle hash join.
+-- 2. run with whole-stage-codegen, operator codegen or no codegen.
+
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=10485760
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
+--CONFIG_DIM1 spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
+
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM2 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
create temporary view t1 as select * from values
("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-order-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-order-by.sql
index 892e39ff47c1f..042966b0a4e26 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-order-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-order-by.sql
@@ -1,6 +1,11 @@
-- A test suite for ORDER BY in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
+-- Test sort operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
create temporary view t1 as select * from values
("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-group-by.sql
index 58cf109e136c5..54b74534c1162 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-group-by.sql
@@ -1,6 +1,11 @@
-- A test suite for NOT IN GROUP BY in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
+-- Test aggregate operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
create temporary view t1 as select * from values
("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
("val1b", 8S, 16, 19L, float(17.0), 25D, 26E2, timestamp '2014-05-04 01:01:00.000', date '2014-05-04'),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql
index bebc18a61894b..fcdb667ad4523 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql
@@ -1,9 +1,5 @@
-- A test suite for not-in-joins in parent side, subquery, and both predicate subquery
-- It includes correlated cases.
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
create temporary view t1 as select * from values
("val1a", 6S, 8, 10L, float(15.0), 20D, 20E2, timestamp '2014-04-04 01:00:00.000', date '2014-04-04'),
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
index 58613a1325dfa..0374d98feb6e6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
@@ -1,3 +1,8 @@
+-- Test aggregate operator and UDAF with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
(1), (2), (3), (4)
as t1(int_col1);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-join-empty-relation.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-join-empty-relation.sql
index 47fb70d02394b..b46206d4530ed 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-join-empty-relation.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-join-empty-relation.sql
@@ -1,8 +1,3 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
-
-- This test file was converted from join-empty-relation.sql.
CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql
index e5eb812d69a1c..7cf080ea1b4eb 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-natural-join.sql
@@ -1,8 +1,3 @@
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
-
-- This test file was converted from natural-join.sql.
create temporary view nt1 as select * from values
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-outer-join.sql
index 4eb0805c9cc67..4b09bcb988d25 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udf/udf-outer-join.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udf/udf-outer-join.sql
@@ -1,8 +1,4 @@
-- This test file was converted from outer-join.sql.
--- List of configuration the test suite is run against:
---SET spark.sql.autoBroadcastJoinThreshold=10485760
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=true
---SET spark.sql.autoBroadcastJoinThreshold=-1,spark.sql.join.preferSortMergeJoin=false
-- SPARK-17099: Incorrect result when HAVING clause is added to group by query
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index faab4c61c8640..e25a252418301 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -1,3 +1,8 @@
+-- Test window operator with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
+
-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
(null, 1L, 1.0D, date("2017-08-01"), timestamp(1501545600), "a"),
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
index 73bf299c509cf..bceb6bd1d2ea9 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
@@ -5,7 +5,7 @@
-- !query 0
select interval '1 day' > interval '23 hour'
-- !query 0 schema
-struct<(1 days > 23 hours):boolean>
+struct<(INTERVAL '1 days' > INTERVAL '23 hours'):boolean>
-- !query 0 output
true
@@ -13,7 +13,7 @@ true
-- !query 1
select interval '-1 day' >= interval '-23 hour'
-- !query 1 schema
-struct<(-1 days >= -23 hours):boolean>
+struct<(INTERVAL '-1 days' >= INTERVAL '-23 hours'):boolean>
-- !query 1 output
false
@@ -21,7 +21,7 @@ false
-- !query 2
select interval '-1 day' > null
-- !query 2 schema
-struct<(-1 days > CAST(NULL AS INTERVAL)):boolean>
+struct<(INTERVAL '-1 days' > CAST(NULL AS INTERVAL)):boolean>
-- !query 2 output
NULL
@@ -29,7 +29,7 @@ NULL
-- !query 3
select null > interval '-1 day'
-- !query 3 schema
-struct<(CAST(NULL AS INTERVAL) > -1 days):boolean>
+struct<(CAST(NULL AS INTERVAL) > INTERVAL '-1 days'):boolean>
-- !query 3 output
NULL
@@ -37,7 +37,7 @@ NULL
-- !query 4
select interval '1 minutes' < interval '1 hour'
-- !query 4 schema
-struct<(1 minutes < 1 hours):boolean>
+struct<(INTERVAL '1 minutes' < INTERVAL '1 hours'):boolean>
-- !query 4 output
true
@@ -45,7 +45,7 @@ true
-- !query 5
select interval '-1 day' <= interval '-23 hour'
-- !query 5 schema
-struct<(-1 days <= -23 hours):boolean>
+struct<(INTERVAL '-1 days' <= INTERVAL '-23 hours'):boolean>
-- !query 5 output
true
@@ -53,7 +53,7 @@ true
-- !query 6
select interval '1 year' = interval '360 days'
-- !query 6 schema
-struct<(1 years = 360 days):boolean>
+struct<(INTERVAL '1 years' = INTERVAL '360 days'):boolean>
-- !query 6 output
true
@@ -61,7 +61,7 @@ true
-- !query 7
select interval '1 year 2 month' = interval '420 days'
-- !query 7 schema
-struct<(1 years 2 months = 420 days):boolean>
+struct<(INTERVAL '1 years 2 months' = INTERVAL '420 days'):boolean>
-- !query 7 output
true
@@ -69,7 +69,7 @@ true
-- !query 8
select interval '1 year' = interval '365 days'
-- !query 8 schema
-struct<(1 years = 365 days):boolean>
+struct<(INTERVAL '1 years' = INTERVAL '365 days'):boolean>
-- !query 8 output
false
@@ -77,7 +77,7 @@ false
-- !query 9
select interval '1 month' = interval '30 days'
-- !query 9 schema
-struct<(1 months = 30 days):boolean>
+struct<(INTERVAL '1 months' = INTERVAL '30 days'):boolean>
-- !query 9 output
true
@@ -85,7 +85,7 @@ true
-- !query 10
select interval '1 minutes' = interval '1 hour'
-- !query 10 schema
-struct<(1 minutes = 1 hours):boolean>
+struct<(INTERVAL '1 minutes' = INTERVAL '1 hours'):boolean>
-- !query 10 output
false
@@ -93,7 +93,7 @@ false
-- !query 11
select interval '1 minutes' = null
-- !query 11 schema
-struct<(1 minutes = CAST(NULL AS INTERVAL)):boolean>
+struct<(INTERVAL '1 minutes' = CAST(NULL AS INTERVAL)):boolean>
-- !query 11 output
NULL
@@ -101,7 +101,7 @@ NULL
-- !query 12
select null = interval '-1 day'
-- !query 12 schema
-struct<(CAST(NULL AS INTERVAL) = -1 days):boolean>
+struct<(CAST(NULL AS INTERVAL) = INTERVAL '-1 days'):boolean>
-- !query 12 output
NULL
@@ -109,7 +109,7 @@ NULL
-- !query 13
select interval '1 minutes' <=> null
-- !query 13 schema
-struct<(1 minutes <=> CAST(NULL AS INTERVAL)):boolean>
+struct<(INTERVAL '1 minutes' <=> CAST(NULL AS INTERVAL)):boolean>
-- !query 13 output
false
@@ -117,7 +117,7 @@ false
-- !query 14
select null <=> interval '1 minutes'
-- !query 14 schema
-struct<(CAST(NULL AS INTERVAL) <=> 1 minutes):boolean>
+struct<(CAST(NULL AS INTERVAL) <=> INTERVAL '1 minutes'):boolean>
-- !query 14 output
false
@@ -125,7 +125,7 @@ false
-- !query 15
select INTERVAL '9 years 1 months -1 weeks -4 days -10 hours -46 minutes' > interval '1 minutes'
-- !query 15 schema
-struct<(9 years 1 months -11 days -10 hours -46 minutes > 1 minutes):boolean>
+struct<(INTERVAL '9 years 1 months -11 days -10 hours -46 minutes' > INTERVAL '1 minutes'):boolean>
-- !query 15 output
true
@@ -143,7 +143,7 @@ struct
-- !query 17
select interval '1 month 120 days' > interval '2 month'
-- !query 17 schema
-struct<(1 months 120 days > 2 months):boolean>
+struct<(INTERVAL '1 months 120 days' > INTERVAL '2 months'):boolean>
-- !query 17 output
true
@@ -151,7 +151,7 @@ true
-- !query 18
select interval '1 month 30 days' = interval '2 month'
-- !query 18 schema
-struct<(1 months 30 days = 2 months):boolean>
+struct<(INTERVAL '1 months 30 days' = INTERVAL '2 months'):boolean>
-- !query 18 output
true
@@ -159,7 +159,7 @@ true
-- !query 19
select interval '1 month 29 days 40 hours' > interval '2 month'
-- !query 19 schema
-struct<(1 months 29 days 40 hours > 2 months):boolean>
+struct<(INTERVAL '1 months 29 days 40 hours' > INTERVAL '2 months'):boolean>
-- !query 19 output
true
@@ -183,7 +183,7 @@ struct
-- !query 22
select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15')
-- !query 22 schema
-struct
+struct
-- !query 22 output
30 hours 33 minutes 36.003006 seconds
@@ -191,7 +191,7 @@ struct
+struct
-- !query 23 output
6 months 21 days 0.000005 seconds
@@ -199,7 +199,7 @@ struct
+struct
-- !query 24 output
16 hours
@@ -207,7 +207,7 @@ struct
+struct
-- !query 25 output
NULL
@@ -215,7 +215,7 @@ NULL
-- !query 26
select interval '2 seconds' / null
-- !query 26 schema
-struct
+struct
-- !query 26 output
NULL
@@ -223,7 +223,7 @@ NULL
-- !query 27
select interval '2 seconds' * null
-- !query 27 schema
-struct
+struct
-- !query 27 output
NULL
@@ -231,7 +231,7 @@ NULL
-- !query 28
select null * interval '2 seconds'
-- !query 28 schema
-struct
+struct
-- !query 28 output
NULL
@@ -239,7 +239,7 @@ NULL
-- !query 29
select -interval '-1 month 1 day -1 second'
-- !query 29 schema
-struct<1 months -1 days 1 seconds:interval>
+struct<(- INTERVAL '-1 months 1 days -1 seconds'):interval>
-- !query 29 output
1 months -1 days 1 seconds
@@ -247,7 +247,7 @@ struct<1 months -1 days 1 seconds:interval>
-- !query 30
select -interval -1 month 1 day -1 second
-- !query 30 schema
-struct<1 months -1 days 1 seconds:interval>
+struct<(- INTERVAL '-1 months 1 days -1 seconds'):interval>
-- !query 30 output
1 months -1 days 1 seconds
@@ -255,7 +255,7 @@ struct<1 months -1 days 1 seconds:interval>
-- !query 31
select +interval '-1 month 1 day -1 second'
-- !query 31 schema
-struct<-1 months 1 days -1 seconds:interval>
+struct
-- !query 31 output
-1 months 1 days -1 seconds
@@ -263,7 +263,7 @@ struct<-1 months 1 days -1 seconds:interval>
-- !query 32
select +interval -1 month 1 day -1 second
-- !query 32 schema
-struct<-1 months 1 days -1 seconds:interval>
+struct
-- !query 32 output
-1 months 1 days -1 seconds
@@ -407,7 +407,7 @@ NULL
-- !query 50
select justify_days(interval '1 month 59 day 25 hour')
-- !query 50 schema
-struct
+struct
-- !query 50 output
2 months 29 days 25 hours
@@ -415,7 +415,7 @@ struct
-- !query 51
select justify_hours(interval '1 month 59 day 25 hour')
-- !query 51 schema
-struct
+struct
-- !query 51 output
1 months 60 days 1 hours
@@ -423,7 +423,7 @@ struct
-- !query 52
select justify_interval(interval '1 month 59 day 25 hour')
-- !query 52 schema
-struct
+struct
-- !query 52 output
3 months 1 hours
@@ -431,7 +431,7 @@ struct
-- !query 53
select justify_days(interval '1 month -59 day 25 hour')
-- !query 53 schema
-struct
+struct
-- !query 53 output
-29 days 25 hours
@@ -439,7 +439,7 @@ struct
-- !query 54
select justify_hours(interval '1 month -59 day 25 hour')
-- !query 54 schema
-struct
+struct
-- !query 54 output
1 months -57 days -23 hours
@@ -447,7 +447,7 @@ struct
-- !query 55
select justify_interval(interval '1 month -59 day 25 hour')
-- !query 55 schema
-struct
+struct
-- !query 55 output
-27 days -23 hours
@@ -455,7 +455,7 @@ struct
-- !query 56
select justify_days(interval '1 month 59 day -25 hour')
-- !query 56 schema
-struct
+struct
-- !query 56 output
2 months 29 days -25 hours
@@ -463,7 +463,7 @@ struct
-- !query 57
select justify_hours(interval '1 month 59 day -25 hour')
-- !query 57 schema
-struct
+struct
-- !query 57 output
1 months 57 days 23 hours
@@ -471,7 +471,7 @@ struct
-- !query 58
select justify_interval(interval '1 month 59 day -25 hour')
-- !query 58 schema
-struct
+struct
-- !query 58 output
2 months 27 days 23 hours
@@ -479,7 +479,7 @@ struct
-- !query 59
select interval 13.123456789 seconds, interval -13.123456789 second
-- !query 59 schema
-struct<13.123456 seconds:interval,-13.123456 seconds:interval>
+struct
-- !query 59 output
13.123456 seconds -13.123456 seconds
@@ -487,7 +487,7 @@ struct<13.123456 seconds:interval,-13.123456 seconds:interval>
-- !query 60
select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond 9 microsecond
-- !query 60 schema
-struct<1 years 2 months 25 days 5 hours 6 minutes 7.008009 seconds:interval>
+struct
-- !query 60 output
1 years 2 months 25 days 5 hours 6 minutes 7.008009 seconds
@@ -495,7 +495,7 @@ struct<1 years 2 months 25 days 5 hours 6 minutes 7.008009 seconds:interval>
-- !query 61
select interval '30' year '25' month '-100' day '40' hour '80' minute '299.889987299' second
-- !query 61 schema
-struct<32 years 1 months -100 days 41 hours 24 minutes 59.889987 seconds:interval>
+struct
-- !query 61 output
32 years 1 months -100 days 41 hours 24 minutes 59.889987 seconds
@@ -503,7 +503,7 @@ struct<32 years 1 months -100 days 41 hours 24 minutes 59.889987 seconds:interva
-- !query 62
select interval '0 0:0:0.1' day to second
-- !query 62 schema
-struct<0.1 seconds:interval>
+struct
-- !query 62 output
0.1 seconds
@@ -511,7 +511,7 @@ struct<0.1 seconds:interval>
-- !query 63
select interval '10-9' year to month
-- !query 63 schema
-struct<10 years 9 months:interval>
+struct
-- !query 63 output
10 years 9 months
@@ -519,7 +519,7 @@ struct<10 years 9 months:interval>
-- !query 64
select interval '20 15:40:32.99899999' day to hour
-- !query 64 schema
-struct<20 days 15 hours:interval>
+struct
-- !query 64 output
20 days 15 hours
@@ -527,7 +527,7 @@ struct<20 days 15 hours:interval>
-- !query 65
select interval '20 15:40:32.99899999' day to minute
-- !query 65 schema
-struct<20 days 15 hours 40 minutes:interval>
+struct
-- !query 65 output
20 days 15 hours 40 minutes
@@ -535,7 +535,7 @@ struct<20 days 15 hours 40 minutes:interval>
-- !query 66
select interval '20 15:40:32.99899999' day to second
-- !query 66 schema
-struct<20 days 15 hours 40 minutes 32.998999 seconds:interval>
+struct
-- !query 66 output
20 days 15 hours 40 minutes 32.998999 seconds
@@ -543,7 +543,7 @@ struct<20 days 15 hours 40 minutes 32.998999 seconds:interval>
-- !query 67
select interval '15:40:32.99899999' hour to minute
-- !query 67 schema
-struct<15 hours 40 minutes:interval>
+struct
-- !query 67 output
15 hours 40 minutes
@@ -551,7 +551,7 @@ struct<15 hours 40 minutes:interval>
-- !query 68
select interval '15:40.99899999' hour to second
-- !query 68 schema
-struct<15 minutes 40.998999 seconds:interval>
+struct
-- !query 68 output
15 minutes 40.998999 seconds
@@ -559,7 +559,7 @@ struct<15 minutes 40.998999 seconds:interval>
-- !query 69
select interval '15:40' hour to second
-- !query 69 schema
-struct<15 hours 40 minutes:interval>
+struct
-- !query 69 output
15 hours 40 minutes
@@ -567,7 +567,7 @@ struct<15 hours 40 minutes:interval>
-- !query 70
select interval '15:40:32.99899999' hour to second
-- !query 70 schema
-struct<15 hours 40 minutes 32.998999 seconds:interval>
+struct
-- !query 70 output
15 hours 40 minutes 32.998999 seconds
@@ -575,7 +575,7 @@ struct<15 hours 40 minutes 32.998999 seconds:interval>
-- !query 71
select interval '20 40:32.99899999' minute to second
-- !query 71 schema
-struct<20 days 40 minutes 32.998999 seconds:interval>
+struct
-- !query 71 output
20 days 40 minutes 32.998999 seconds
@@ -583,7 +583,7 @@ struct<20 days 40 minutes 32.998999 seconds:interval>
-- !query 72
select interval '40:32.99899999' minute to second
-- !query 72 schema
-struct<40 minutes 32.998999 seconds:interval>
+struct
-- !query 72 output
40 minutes 32.998999 seconds
@@ -591,7 +591,7 @@ struct<40 minutes 32.998999 seconds:interval>
-- !query 73
select interval '40:32' minute to second
-- !query 73 schema
-struct<40 minutes 32 seconds:interval>
+struct
-- !query 73 output
40 minutes 32 seconds
@@ -627,7 +627,7 @@ select interval 10 nanoseconds
-- !query 76
select map(1, interval 1 day, 2, interval 3 week)
-- !query 76 schema
-struct