From a86c4f053e6fbf30ca0b2dad3a3c7a49a70e97c5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 11 Feb 2019 17:25:41 +0800 Subject: [PATCH 1/5] Arrow optimization for conversion from Spark DataFrame to R DataFrame --- R/pkg/R/DataFrame.R | 48 +++++++++++++ R/pkg/tests/fulltests/test_sparkSQL.R | 4 +- .../apache/spark/api/python/PythonRDD.scala | 6 ++ .../scala/org/apache/spark/api/r/RRDD.scala | 9 ++- .../scala/org/apache/spark/sql/Dataset.scala | 72 +++++++++++++++++-- 5 files changed, 128 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 24ed449f2a7d1..c0b5f3416b02a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1177,11 +1177,59 @@ setMethod("dim", setMethod("collect", signature(x = "SparkDataFrame"), function(x, stringsAsFactors = FALSE) { + useArrow <- FALSE + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" + if (arrowEnabled) { + useArrow <- tryCatch({ + # Currenty Arrow optimization does not support raw for now. + # Also, it does not support explicit float type set by users. + if (inherits(schema(x), "structType")) { + if (any(sapply(schema(x)$fields(), function(x) x$dataType.toString() == "FloatType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support FloatType yet.")) + } + if (any(sapply(schema(x)$fields(), function(x) x$dataType.toString() == "BinaryType"))) { + stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", + "DataFrame does not support BinaryType yet.")) + } + } + TRUE + }, + error = function(e) { + warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ", + "with Arrow optimization because 'spark.sql.execution.arrow.enabled' ", + "is set to true; however, ", + "failed, attempting non-optimization. Reason: ", + e)) + FALSE + }) + } + dtypes <- dtypes(x) ncol <- length(dtypes) if (ncol <= 0) { # empty data.frame with 0 columns and 0 rows data.frame() + } else if (useArrow) { + # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. + requireNamespace1 <- requireNamespace + requireNamespace1("arrow", quietly = TRUE) + read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) + as_tibble <- get("as_tibble", envir = asNamespace("arrow")) + + portAuth <- callJMethod(x@sdf, "collectAsArrowToR") + port <- portAuth[[1]] + authSecret <- portAuth[[2]] + conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) + output <- tryCatch({ + doServerAuth(conn, authSecret) + arrowTable <- read_arrow(readRaw(conn)) + as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors) + }, + finally = { + close(conn) + }) + return(output) } else { # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 9dc699c09a1e4..0802ac83fb631 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", { unsetHiveContext() }) -test_that("createDataFrame Arrow optimization", { +test_that("createDataFrame/collect Arrow optimization", { skip_if_not_installed("arrow") conf <- callJMethod(sparkSession, "conf") @@ -332,7 +332,7 @@ test_that("createDataFrame Arrow optimization", { }) }) -test_that("createDataFrame Arrow optimization - type specification", { +test_that("createDataFrame/collect Arrow optimization - type specification", { skip_if_not_installed("arrow") rdf <- data.frame(list(list(a = 1, b = "a", diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 14ea289e5f908..0937a63dad19b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging { */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + serveToStream(threadName, authHelper)(writeFunc) + } + + private[spark] def serveToStream( + threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit) + : Array[Any] = { val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s => val out = new BufferedOutputStream(s.getOutputStream()) Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1dc61c7eef33c..04fc6e18c1e5c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.r -import java.io.{DataInputStream, File} +import java.io.{DataInputStream, File, OutputStream} import java.net.Socket import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Map => JMap} @@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag]( lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } -private[r] object RRDD { +private[spark] object RRDD { def createSparkContext( master: String, appName: String, @@ -165,6 +165,11 @@ private[r] object RRDD { JavaRDD[Array[Byte]] = { PythonRDD.readRDDFromFile(jsc, fileName, parallelism) } + + private[spark] def serveToStream( + threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { + PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8a26152271a83..bd1ae509cf54b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.{CharArrayWriter, DataOutputStream} +import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.r.RRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.QueryPlanningTracker @@ -3198,9 +3199,66 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as Arrow batches and serve stream to PySpark. + * Collect a Dataset as Arrow batches and serve stream to SparkR. It sends + * arrow batches in an ordered manner with buffering. This is inevitable + * due to missing R API that reads batches from socket directly. See ARROW-4512. + * Eventually, this code should be deduplicated by `collectAsArrowToPython`. */ - private[sql] def collectAsArrowToPython(): Array[Any] = { + private[sql] def collectAsArrowToR(): Array[Any] = { + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + + withAction("collectAsArrowToR", queryExecution) { plan => + RRDD.serveToStream("serve-Arrow") { outputStream => + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(outputStream) + val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId) + val arrowBatchRdd = toArrowBatchRdd(plan) + val numPartitions = arrowBatchRdd.partitions.length + + // Store collection results for worst case of 1 to N-1 partitions + val results = new Array[Array[Array[Byte]]](numPartitions - 1) + var lastIndex = -1 // index of last partition written + + // Handler to eagerly write partitions to Python in order + def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { + batchWriter.writeBatches(arrowBatches.iterator) + lastIndex += 1 + // Write stored partitions that come next in order + while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 + } + // After last batch, end the stream + if (lastIndex == results.length) { + batchWriter.end() + val batches = buffer.toByteArray + out.writeInt(batches.length) + out.write(batches) + } + } else { + // Store partitions received out of order + results(index - 1) = arrowBatches + } + } + + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) + } + } + } + + /** + * Collect a Dataset as Arrow batches and serve stream to PySpark. It sends + * arrow batches in an un-ordered manner without buffering, and then batch order + * information at the end. The batches should be reordered at Python side. + */ + private[sql] def collectAsArrowToPython: Array[Any] = { val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone withAction("collectAsArrowToPython", queryExecution) { plan => @@ -3211,7 +3269,7 @@ class Dataset[T] private[sql]( val numPartitions = arrowBatchRdd.partitions.length // Batches ordered by (index of partition, batch index in that partition) tuple - val batchOrder = new ArrayBuffer[(Int, Int)]() + val batchOrder = ArrayBuffer.empty[(Int, Int)] var partitionCount = 0 // Handler to eagerly write batches to Python as they arrive, un-ordered @@ -3220,7 +3278,7 @@ class Dataset[T] private[sql]( // Write all batches (can be more than 1) in the partition, store the batch order tuple batchWriter.writeBatches(arrowBatches.iterator) arrowBatches.indices.foreach { - partition_batch_index => batchOrder.append((index, partition_batch_index)) + partitionBatchIndex => batchOrder.append((index, partitionBatchIndex)) } } partitionCount += 1 @@ -3232,8 +3290,8 @@ class Dataset[T] private[sql]( // Sort by (index of partition, batch index in that partition) tuple to get the // overall_batch_index from 0 to N-1 batches, which can be used to put the // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) => - out.writeInt(overall_batch_index) + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) } out.flush() } From 7f327f0db705d433692fbaf9582aebdd5120aee1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Feb 2019 12:19:25 +0800 Subject: [PATCH 2/5] Fix style --- R/pkg/R/DataFrame.R | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index c0b5f3416b02a..0fc3eb7f5b21d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1184,11 +1184,13 @@ setMethod("collect", # Currenty Arrow optimization does not support raw for now. # Also, it does not support explicit float type set by users. if (inherits(schema(x), "structType")) { - if (any(sapply(schema(x)$fields(), function(x) x$dataType.toString() == "FloatType"))) { + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "FloatType"))) { stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", "DataFrame does not support FloatType yet.")) } - if (any(sapply(schema(x)$fields(), function(x) x$dataType.toString() == "BinaryType"))) { + if (any(sapply(schema(x)$fields(), + function(x) x$dataType.toString() == "BinaryType"))) { stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ", "DataFrame does not support BinaryType yet.")) } @@ -1197,8 +1199,8 @@ setMethod("collect", }, error = function(e) { warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ", - "with Arrow optimization because 'spark.sql.execution.arrow.enabled' ", - "is set to true; however, ", + "with Arrow optimization because ", + "'spark.sql.execution.arrow.enabled' is set to true; however, ", "failed, attempting non-optimization. Reason: ", e)) FALSE @@ -1211,7 +1213,8 @@ setMethod("collect", # empty data.frame with 0 columns and 0 rows data.frame() } else if (useArrow) { - # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. + # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. + # See ARROW-3204. requireNamespace1 <- requireNamespace requireNamespace1("arrow", quietly = TRUE) read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) From e75ee3f2dd940a352e2be0789386b93832f5b719 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Feb 2019 19:51:04 +0800 Subject: [PATCH 3/5] Show proper messages if Arrow is not installed --- R/pkg/R/DataFrame.R | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0fc3eb7f5b21d..5d6212231618e 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1181,6 +1181,10 @@ setMethod("collect", arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" if (arrowEnabled) { useArrow <- tryCatch({ + requireNamespace1 <- requireNamespace + if (!requireNamespace1("arrow", quietly = TRUE)) { + stop("'arrow' package should be installed.") + } # Currenty Arrow optimization does not support raw for now. # Also, it does not support explicit float type set by users. if (inherits(schema(x), "structType")) { @@ -1196,8 +1200,7 @@ setMethod("collect", } } TRUE - }, - error = function(e) { + }, error = function(e) { warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ", "with Arrow optimization because ", "'spark.sql.execution.arrow.enabled' is set to true; however, ", @@ -1213,26 +1216,26 @@ setMethod("collect", # empty data.frame with 0 columns and 0 rows data.frame() } else if (useArrow) { - # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. - # See ARROW-3204. requireNamespace1 <- requireNamespace - requireNamespace1("arrow", quietly = TRUE) - read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) - as_tibble <- get("as_tibble", envir = asNamespace("arrow")) - - portAuth <- callJMethod(x@sdf, "collectAsArrowToR") - port <- portAuth[[1]] - authSecret <- portAuth[[2]] - conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) - output <- tryCatch({ - doServerAuth(conn, authSecret) - arrowTable <- read_arrow(readRaw(conn)) - as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors) - }, - finally = { - close(conn) - }) - return(output) + if (requireNamespace1("arrow", quietly = TRUE)) { + read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE) + as_tibble <- get("as_tibble", envir = asNamespace("arrow")) + + portAuth <- callJMethod(x@sdf, "collectAsArrowToR") + port <- portAuth[[1]] + authSecret <- portAuth[[2]] + conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) + output <- tryCatch({ + doServerAuth(conn, authSecret) + arrowTable <- read_arrow(readRaw(conn)) + as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors) + }, finally = { + close(conn) + }) + return(output) + } else { + stop("'arrow' package should be installed.") + } } else { # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) From bba8d9477223d1a0e7c0be23af08e480b335bdb4 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Feb 2019 20:15:57 +0800 Subject: [PATCH 4/5] Address Felix's comments --- R/pkg/R/DataFrame.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 5d6212231618e..fe836bf4c1b3c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1177,6 +1177,7 @@ setMethod("dim", setMethod("collect", signature(x = "SparkDataFrame"), function(x, stringsAsFactors = FALSE) { + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) useArrow <- FALSE arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" if (arrowEnabled) { @@ -1224,7 +1225,8 @@ setMethod("collect", portAuth <- callJMethod(x@sdf, "collectAsArrowToR") port <- portAuth[[1]] authSecret <- portAuth[[2]] - conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500) + conn <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) output <- tryCatch({ doServerAuth(conn, authSecret) arrowTable <- read_arrow(readRaw(conn)) From cfe947cc8fe6d8c9a5c262a5d2b0e6bb28955e33 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 19 Feb 2019 10:10:12 +0800 Subject: [PATCH 5/5] Add tests for many partitions --- R/pkg/tests/fulltests/test_sparkSQL.R | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 0802ac83fb631..21eaa32f0011c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -332,6 +332,23 @@ test_that("createDataFrame/collect Arrow optimization", { }) }) +test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", { + skip_if_not_installed("arrow") + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)), + collect(createDataFrame(mtcars, numPartitions = 1))) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + test_that("createDataFrame/collect Arrow optimization - type specification", { skip_if_not_installed("arrow") rdf <- data.frame(list(list(a = 1,