Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into SPARK-1…
Browse files Browse the repository at this point in the history
…712_new
  • Loading branch information
witgo committed May 8, 2014
2 parents 2a89adc + 19c8fb0 commit 86e2048
Show file tree
Hide file tree
Showing 34 changed files with 1,052 additions and 318 deletions.
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))

private val secretKey = generateSecretKey()
logInfo("SecurityManager, is authentication enabled: " + authOn +
" are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") +
"; ui acls " + (if (uiAclsOn) "enabled" else "disabled") +
"; users with view permissions: " + viewAcls.toString())

// Set our own authenticator to properly negotiate user/password for HTTP connections.
// This is needed by the HTTP client fetching from the HttpServer. Put here so its
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@ class TaskContext(
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]

// Set to true when the task is completed, before the onCompleteCallbacks are executed.
@volatile var completed: Boolean = false

/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}

def executeOnCompleteCallbacks() {
completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach{_()}
}
Expand Down
217 changes: 114 additions & 103 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
val env = SparkEnv.get
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Ensure worker socket is closed on task completion. Closing sockets is idempotent.
context.addOnCompleteCallback(() =>
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)

context.addOnCompleteCallback { () =>
writerThread.shutdownOnTaskCompletion()

// Cleanup the worker socket. This will also cause the Python worker to exit.
try {
worker.close()
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
}
)

@volatile var readerException: Exception = null

// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
worker.shutdownOutput()
} catch {

case e: java.io.FileNotFoundException =>
readerException = e
Try(worker.shutdownOutput()) // kill Python worker process

case e: IOException =>
// This can happen for legitimate reasons if the Python code stops returning data
// before we are done passing elements through, e.g., for take(). Just log a message to
// say it happened (as it could also be hiding a real IOException from a data source).
logInfo("stdin writer to Python finished early (may not be an error)", e)

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see Executor).
readerException = e
Try(worker.shutdownOutput()) // kill Python worker process
}
}
}.start()

// Necessary to distinguish between a task that has failed and a task that is finished
@volatile var complete: Boolean = false

// It is necessary to have a monitor thread for python workers if the user cancels with
// interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
// threads can block indefinitely.
new Thread(s"Worker Monitor for $pythonExec") {
override def run() {
// Kill the worker if it is interrupted or completed
// When a python task completes, the context is always set to interupted
while (!context.interrupted) {
Thread.sleep(2000)
}
if (!complete) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}.start()

/*
* Partial fix for SPARK-1019: Attempts to stop reading the input stream since
* other completion callbacks might invalidate the input. Because interruption
* is not synchronous this still leaves a potential race where the interruption is
* processed only after the stream becomes invalid.
*/
context.addOnCompleteCallback{ () =>
complete = true // Indicate that the task has completed successfully
context.interrupted = true
}

writerThread.start()
new MonitorThread(env, worker, context).start()

// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
// FIXME: can deadlock if worker is waiting for us to
// respond to current message (currently irrelevant because
// output is shutdown before we read any input)
_nextObj = read()
}
obj
}

private def read(): Array[Byte] = {
if (readerException != null) {
throw readerException
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
stream.readInt() match {
Expand All @@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
read
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, "utf-8"), readerException)
throw new PythonException(new String(obj, "utf-8"),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
Expand All @@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
Array.empty[Byte]
}
} catch {
case e: Exception if readerException != null =>

case e: Exception if context.interrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("Python crash may have been caused by prior exception:", readerException)
throw readerException
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
Expand All @@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](

def hasNext = _nextObj.length != 0
}
stdoutIterator
new InterruptibleIterator(context, stdoutIterator)
}

val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {

@volatile private var _exception: Exception = null

setDaemon(true)

/** Contains the exception thrown while writing the parent iterator to the Python process. */
def exception: Option[Exception] = Option(_exception)

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.completed)
this.interrupt()
}

override def run() {
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
} catch {
case e: Exception if context.completed || context.interrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
} finally {
Try(worker.shutdownOutput()) // kill Python worker process
}
}
}

/**
* It is necessary to have a monitor thread for python workers if the user cancels with
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
* threads can block indefinitely.
*/
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
extends Thread(s"Worker Monitor for $pythonExec") {

setDaemon(true)

override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.interrupted && !context.completed) {
Thread.sleep(2000)
}
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
}

/** Thrown for exceptions in user Python code. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.api.python

import java.io.File
import java.io.{File, InputStream, IOException, OutputStream}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -40,3 +40,28 @@ private[spark] object PythonUtils {
paths.filter(_ != "").mkString(File.pathSeparator)
}
}


/**
* A utility class to redirect the child process's stdout or stderr.
*/
private[spark] class RedirectThread(
in: InputStream,
out: OutputStream,
name: String)
extends Thread(name) {

setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
out.write(buf, 0, len)
out.flush()
len = in.read(buf)
}
}
}
}
Loading

0 comments on commit 86e2048

Please sign in to comment.