Skip to content

Commit

Permalink
[PYSPARK] Update py4j to version 0.10.7.
Browse files Browse the repository at this point in the history
(cherry picked from commit cc613b5)
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
Marcelo Vanzin committed May 10, 2018
1 parent eab10f9 commit 323dc3a
Show file tree
Hide file tree
Showing 32 changed files with 418 additions and 116 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/)
(The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
Expand Down
6 changes: 3 additions & 3 deletions bin/pyspark
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"

# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
# and executor Python executables.

# Fail noisily if removed options are set
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
echo "Error in pyspark startup:"
echo "Error in pyspark startup:"
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
exit 1
fi
Expand All @@ -57,7 +57,7 @@ export PYSPARK_PYTHON

# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH"
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"

# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
Expand Down
2 changes: 1 addition & 1 deletion bin/pyspark2.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)

set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH%
set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%

set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
Expand Down
2 changes: 1 addition & 1 deletion core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
<dependency>
<groupId>net.sf.py4j</groupId>
<artifactId>py4j</artifactId>
<version>0.10.6</version>
<version>0.10.7</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
11 changes: 2 additions & 9 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

package org.apache.spark

import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.nio.charset.StandardCharsets.UTF_8
import java.security.{KeyStore, SecureRandom}
import java.security.KeyStore
import java.security.cert.X509Certificate
import javax.net.ssl._

import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
Expand Down Expand Up @@ -542,13 +540,8 @@ private[spark] class SecurityManager(
return
}

val rnd = new SecureRandom()
val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
val secretBytes = new Array[Byte](length)
rnd.nextBytes(secretBytes)

secretKey = Utils.createSecret(sparkConf)
val creds = new Credentials()
secretKey = HashCodes.fromBytes(secretBytes).toString()
creds.addSecretKey(SECRET_LOOKUP_KEY, secretKey.getBytes(UTF_8))
UserGroupInformation.getCurrentUser().addCredentials(creds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,39 @@

package org.apache.spark.api.python

import java.io.DataOutputStream
import java.net.Socket
import java.io.{DataOutputStream, File, FileOutputStream}
import java.net.InetAddress
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files

import py4j.GatewayServer

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
* back to its caller via a callback port specified by the caller.
* Process that starts a Py4J GatewayServer on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)

def main(args: Array[String]): Unit = Utils.tryOrExit {
// Start a GatewayServer on an ephemeral port
val gatewayServer: GatewayServer = new GatewayServer(null, 0)
def main(args: Array[String]): Unit = {
val secret = Utils.createSecret(new SparkConf())

// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
Expand All @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}

// Communicate the bound port back to the caller via the caller-specified callback port
val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
val callbackSocket = new Socket(callbackHost, callbackPort)
val dos = new DataOutputStream(callbackSocket.getOutputStream)
// Communicate the connection information back to the python process by writing the
// information in the requested file. This needs to match the read side in java_gateway.py.
val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
"connection", ".info").toFile()

val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)

val secretBytes = secret.getBytes(UTF_8)
dos.writeInt(secretBytes.length)
dos.write(secretBytes, 0, secretBytes.length)
dos.close()
callbackSocket.close()

if (!tmpPath.renameTo(connectionInfoPath)) {
logError(s"Unable to write connection information to $connectionInfoPath.")
System.exit(1)
}

// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
Expand Down
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._


Expand Down Expand Up @@ -107,6 +108,12 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()

// Authentication helper used when serving iterator data.
private lazy val authHelper = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
new SocketAuthHelper(conf)
}

def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
Expand All @@ -129,12 +136,13 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
partitions: JArrayList[Int]): Int = {
partitions: JArrayList[Int]): Array[Any] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
Expand All @@ -147,13 +155,14 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}

Expand Down Expand Up @@ -384,8 +393,11 @@ private[spark] object PythonRDD extends Logging {
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
*/
def serveIterator[T](items: Iterator[T], threadName: String): Int = {
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
Expand All @@ -395,11 +407,14 @@ private[spark] object PythonRDD extends Logging {
override def run() {
try {
val sock = serverSocket.accept()
authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
} {
out.close()
sock.close()
}
} catch {
case NonFatal(e) =>
Expand All @@ -410,7 +425,7 @@ private[spark] object PythonRDD extends Logging {
}
}.start()

serverSocket.getLocalPort
Array(serverSocket.getLocalPort, authHelper.secret)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator)
pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.mutable

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}

private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
Expand All @@ -45,6 +46,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
!System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled
}


private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)

var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
Expand Down Expand Up @@ -85,6 +89,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
}

authHelper.authToServer(socket)
daemonWorkers.put(socket, pid)
socket
}
Expand Down Expand Up @@ -122,25 +128,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
val worker = pb.start()

// Redirect worker stdout and stderr
redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)

// Tell the worker our port
val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
out.write(serverSocket.getLocalPort + "\n")
out.flush()

// Wait for it to connect to our socket
// Wait for it to connect to our socket, and validate the auth secret.
serverSocket.setSoTimeout(10000)

try {
val socket = serverSocket.accept()
authHelper.authClient(socket)
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
throw new SparkException("Python worker failed to connect back.", e)
}
} finally {
if (serverSocket != null) {
Expand All @@ -163,6 +168,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
workerEnv.put("PYTHONPATH", pythonPath)
workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
Expand All @@ -172,7 +178,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

// Redirect daemon stdout and stderr
redirectStreamsToStderr(in, daemon.getErrorStream)

} catch {
case e: Exception =>

Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.deploy

import java.io.File
import java.net.URI
import java.net.{InetAddress, URI}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
Expand All @@ -39,6 +39,7 @@ object PythonRunner {
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
val sparkConf = new SparkConf()
val secret = Utils.createSecret(sparkConf)
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
.orElse(sparkConf.get(PYSPARK_PYTHON))
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
Expand All @@ -51,7 +52,13 @@ object PythonRunner {

// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
val gatewayServer = new py4j.GatewayServer(null, 0)
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
val thread = new Thread(new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
Expand Down Expand Up @@ -82,6 +89,7 @@ object PythonRunner {
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
env.put("PYSPARK_GATEWAY_SECRET", secret)
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ package object config {
.regexConf
.createOptional

private[spark] val AUTH_SECRET_BIT_LENGTH =
ConfigBuilder("spark.authenticate.secretBitLength")
.intConf
.createWithDefault(256)

private[spark] val NETWORK_AUTH_ENABLED =
ConfigBuilder("spark.authenticate")
.booleanConf
Expand Down
Loading

0 comments on commit 323dc3a

Please sign in to comment.