Skip to content

Commit

Permalink
Merge pull request apache#147 from shivaram/sparkr-ec2-fixes
Browse files Browse the repository at this point in the history
Bunch of fixes for longer running jobs
  • Loading branch information
concretevitamin committed Feb 3, 2015
2 parents c662f29 + f34bb88 commit 554bda0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 58 deletions.
2 changes: 1 addition & 1 deletion pkg/R/sparkRClient.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout = 60) {
connectBackend <- function(hostname, port, timeout = 6000) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
cat("SparkRBackend client connection already exists\n")
return(get(".sparkRcon", envir = .sparkREnv))
Expand Down
123 changes: 66 additions & 57 deletions pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for R") {
override def run() {
SparkEnv.set(env)
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
val printOutStd = new PrintStream(streamStd)
printOutStd.println(tempFileName)
printOutStd.println(rLibDir)
printOutStd.println(tempFileIn.getAbsolutePath())
printOutStd.flush()

streamStd.close()

val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
val printOut = new PrintStream(stream)
val dataOut = new DataOutputStream(stream)

dataOut.writeInt(splitIndex)

dataOut.writeInt(func.length)
dataOut.write(func, 0, func.length)

// R worker process input serialization flag
dataOut.writeInt(if (parentSerialized) 1 else 0)
// R worker process output serialization flag
dataOut.writeInt(if (dataSerialized) 1 else 0)

dataOut.writeInt(packageNames.length)
dataOut.write(packageNames, 0, packageNames.length)

dataOut.writeInt(functionDependencies.length)
dataOut.write(functionDependencies, 0, functionDependencies.length)

dataOut.writeInt(broadcastVars.length)
broadcastVars.foreach { broadcast =>
// TODO(shivaram): Read a Long in R to avoid this cast
dataOut.writeInt(broadcast.id.toInt)
// TODO: Pass a byte array from R to avoid this cast ?
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
dataOut.writeInt(broadcastByteArr.length)
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
}

dataOut.writeInt(numPartitions)
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
val printOut = new PrintStream(stream)
val dataOut = new DataOutputStream(stream)

dataOut.writeInt(splitIndex)

dataOut.writeInt(func.length)
dataOut.write(func, 0, func.length)

// R worker process input serialization flag
dataOut.writeInt(if (parentSerialized) 1 else 0)
// R worker process output serialization flag
dataOut.writeInt(if (dataSerialized) 1 else 0)

dataOut.writeInt(packageNames.length)
dataOut.write(packageNames, 0, packageNames.length)

dataOut.writeInt(functionDependencies.length)
dataOut.write(functionDependencies, 0, functionDependencies.length)

dataOut.writeInt(broadcastVars.length)
broadcastVars.foreach { broadcast =>
// TODO(shivaram): Read a Long in R to avoid this cast
dataOut.writeInt(broadcast.id.toInt)
// TODO: Pass a byte array from R to avoid this cast ?
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
dataOut.writeInt(broadcastByteArr.length)
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
}

if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
dataOut.writeInt(1)
}
dataOut.writeInt(numPartitions)

for (elem <- iter) {
if (parentSerialized) {
val elemArr = elem.asInstanceOf[Array[Byte]]
dataOut.writeInt(elemArr.length)
dataOut.write(elemArr, 0, elemArr.length)
if (!iter.hasNext) {
dataOut.writeInt(0)
} else {
printOut.println(elem)
dataOut.writeInt(1)
}

for (elem <- iter) {
if (parentSerialized) {
val elemArr = elem.asInstanceOf[Array[Byte]]
dataOut.writeInt(elemArr.length)
dataOut.write(elemArr, 0, elemArr.length)
} else {
printOut.println(elem)
}
}
}

printOut.flush()
dataOut.flush()
stream.flush()
stream.close()
printOut.flush()
dataOut.flush()
stream.flush()
stream.close()

// NOTE: We need to write out the temp file before writing out the
// file name to stdin. Otherwise the R process could read partial state
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
val printOutStd = new PrintStream(streamStd)
printOutStd.println(tempFileName)
printOutStd.println(rLibDir)
printOutStd.println(tempFileIn.getAbsolutePath())
printOutStd.flush()

streamStd.close()
} catch {
// TODO: We should propogate this error to the task thread
case e: Exception =>
System.err.println("R Writer thread got an exception " + e)
e.printStackTrace()
}
}
}.start()

Expand Down

0 comments on commit 554bda0

Please sign in to comment.