Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26762][SQL][R] Arrow optimization for conversion from Spark DataFrame to R DataFrame #23760

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1177,11 +1177,67 @@ setMethod("dim",
setMethod("collect",
signature(x = "SparkDataFrame"),
function(x, stringsAsFactors = FALSE) {
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
useArrow <- FALSE
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true"
if (arrowEnabled) {
useArrow <- tryCatch({
requireNamespace1 <- requireNamespace
if (!requireNamespace1("arrow", quietly = TRUE)) {
stop("'arrow' package should be installed.")
}
# Currenty Arrow optimization does not support raw for now.
# Also, it does not support explicit float type set by users.
if (inherits(schema(x), "structType")) {
if (any(sapply(schema(x)$fields(),
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
function(x) x$dataType.toString() == "FloatType"))) {
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
"DataFrame does not support FloatType yet."))
}
if (any(sapply(schema(x)$fields(),
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
function(x) x$dataType.toString() == "BinaryType"))) {
stop(paste0("Arrow optimization in the conversion from Spark DataFrame to R ",
"DataFrame does not support BinaryType yet."))
}
}
TRUE
}, error = function(e) {
warning(paste0("The conversion from Spark DataFrame to R DataFrame was attempted ",
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
"with Arrow optimization because ",
"'spark.sql.execution.arrow.enabled' is set to true; however, ",
"failed, attempting non-optimization. Reason: ",
e))
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
FALSE
})
}

dtypes <- dtypes(x)
ncol <- length(dtypes)
if (ncol <= 0) {
# empty data.frame with 0 columns and 0 rows
data.frame()
} else if (useArrow) {
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
requireNamespace1 <- requireNamespace
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
if (requireNamespace1("arrow", quietly = TRUE)) {
read_arrow <- get("read_arrow", envir = asNamespace("arrow"), inherits = FALSE)
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))

portAuth <- callJMethod(x@sdf, "collectAsArrowToR")
port <- portAuth[[1]]
authSecret <- portAuth[[2]]
conn <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
output <- tryCatch({
doServerAuth(conn, authSecret)
arrowTable <- read_arrow(readRaw(conn))
as.data.frame(as_tibble(arrowTable), stringsAsFactors = stringsAsFactors)
}, finally = {
close(conn)
})
return(output)
} else {
stop("'arrow' package should be installed.")
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
# listCols is a list of columns
listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
Expand Down
21 changes: 19 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ test_that("create DataFrame from RDD", {
unsetHiveContext()
})

test_that("createDataFrame Arrow optimization", {
test_that("createDataFrame/collect Arrow optimization", {
skip_if_not_installed("arrow")

conf <- callJMethod(sparkSession, "conf")
Expand All @@ -332,7 +332,24 @@ test_that("createDataFrame Arrow optimization", {
})
})

test_that("createDataFrame Arrow optimization - type specification", {
test_that("createDataFrame/collect Arrow optimization - many partitions (partition order test)", {
skip_if_not_installed("arrow")

conf <- callJMethod(sparkSession, "conf")
arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]]

callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true")
tryCatch({
expect_equal(collect(createDataFrame(mtcars, numPartitions = 32)),
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
collect(createDataFrame(mtcars, numPartitions = 1)))
},
finally = {
# Resetting the conf back to default value
callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled)
})
})

test_that("createDataFrame/collect Arrow optimization - type specification", {
skip_if_not_installed("arrow")
rdf <- data.frame(list(list(a = 1,
b = "a",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,12 @@ private[spark] object PythonRDD extends Logging {
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
serveToStream(threadName, authHelper)(writeFunc)
}

private[spark] def serveToStream(
threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
: Array[Any] = {
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.api.r

import java.io.{DataInputStream, File}
import java.io.{DataInputStream, File, OutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Map => JMap}
Expand Down Expand Up @@ -104,7 +104,7 @@ private class StringRRDD[T: ClassTag](
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}

private[r] object RRDD {
private[spark] object RRDD {
def createSparkContext(
master: String,
appName: String,
Expand Down Expand Up @@ -165,6 +165,11 @@ private[r] object RRDD {
JavaRDD[Array[Byte]] = {
PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
}

private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
Expand Down
72 changes: 65 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.io.{CharArrayWriter, DataOutputStream}
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand All @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.QueryPlanningTracker
Expand Down Expand Up @@ -3198,9 +3199,66 @@ class Dataset[T] private[sql](
}

/**
* Collect a Dataset as Arrow batches and serve stream to PySpark.
* Collect a Dataset as Arrow batches and serve stream to SparkR. It sends
* arrow batches in an ordered manner with buffering. This is inevitable
* due to missing R API that reads batches from socket directly. See ARROW-4512.
* Eventually, this code should be deduplicated by `collectAsArrowToPython`.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
private[sql] def collectAsArrowToR(): Array[Any] = {
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToR", queryExecution) { plan =>
RRDD.serveToStream("serve-Arrow") { outputStream =>
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, buffer, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length

// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written

// Handler to eagerly write partitions to Python in order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
val batches = buffer.toByteArray
out.writeInt(batches.length)
out.write(batches)
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
}
}

sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
0 until numPartitions,
handlePartitionBatches)
}
}
}

/**
* Collect a Dataset as Arrow batches and serve stream to PySpark. It sends
* arrow batches in an un-ordered manner without buffering, and then batch order
* information at the end. The batches should be reordered at Python side.
*/
private[sql] def collectAsArrowToPython: Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
Expand All @@ -3211,7 +3269,7 @@ class Dataset[T] private[sql](
val numPartitions = arrowBatchRdd.partitions.length

// Batches ordered by (index of partition, batch index in that partition) tuple
val batchOrder = new ArrayBuffer[(Int, Int)]()
val batchOrder = ArrayBuffer.empty[(Int, Int)]
var partitionCount = 0

// Handler to eagerly write batches to Python as they arrive, un-ordered
Expand All @@ -3220,7 +3278,7 @@ class Dataset[T] private[sql](
// Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
arrowBatches.indices.foreach {
partition_batch_index => batchOrder.append((index, partition_batch_index))
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
}
}
partitionCount += 1
Expand All @@ -3232,8 +3290,8 @@ class Dataset[T] private[sql](
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
out.writeInt(overall_batch_index)
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
out.flush()
}
Expand Down