diff --git a/.rat-excludes b/.rat-excludes index c20226f233a93..3c2451103e4b6 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,7 @@ log4j-defaults.properties bootstrap-tooltip.js jquery-1.11.1.min.js sorttable.js +.*avsc .*txt .*json .*data diff --git a/README.md b/README.md index f87e07aa5cc90..a1a48f5bd0819 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,15 @@ If your project is built with Maven, add this to your POM file's ` +## A Note About Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. + + $ sbt/sbt -Phive-thriftserver assembly + + ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) diff --git a/bin/pyspark b/bin/pyspark index 39a20e2a24a3c..01d42025c978e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -23,12 +23,18 @@ FWDIR="$(cd `dirname $0`/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +source $FWDIR/bin/utils.sh + SCALA_VERSION=2.10 -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then +function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage fi # Exit if the user hasn't compiled Spark @@ -66,10 +72,11 @@ fi # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. - +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" PYSPARK_SUBMIT_ARGS="" whitespace="[[:space:]]" -for i in "$@"; do +for i in "${SUBMISSION_OPTS[@]}"; do if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" @@ -90,7 +97,10 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - exec $FWDIR/bin/spark-submit "$@" + primary=$1 + shift + gatherSparkSubmitOpts "$@" + exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" else # Only use ipython if no command line arguments were provided [SPARK-1134] if [[ "$IPYTHON" = "1" ]]; then diff --git a/bin/spark-shell b/bin/spark-shell index 756c8179d12b6..8b7ccd7439551 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -31,13 +31,21 @@ set -o posix ## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" +function usage() { + echo "Usage: ./bin/spark-shell [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +} + if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + usage fi -function main(){ +source $FWDIR/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" + +function main() { if $cygwin; then # Workaround for issue involving JLine and Cygwin # (see http://sourceforge.net/p/jline/bugs/40/). @@ -46,11 +54,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index b56d69801171c..2ee60b4e2a2b3 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-sql b/bin/spark-sql index 7813ccc361415..564f1f419060f 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -65,30 +65,30 @@ while (($#)); do case $1 in -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -e) ensure_arg_number $# 2 - CLI_ARGS+=($1); shift - CLI_ARGS+=(\"$1\"); shift + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift ;; -s | --silent) - CLI_ARGS+=($1); shift + CLI_ARGS+=("$1"); shift ;; -v | --verbose) # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose - CLI_ARGS+=($1) - SUBMISSION_ARGS+=($1); shift + CLI_ARGS+=("$1") + SUBMISSION_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${CLI_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${CLI_ARGS[@]}" diff --git a/bin/utils.sh b/bin/utils.sh index 4b3c09fbe2537..1eaff0e503651 100755 --- a/bin/utils.sh +++ b/bin/utils.sh @@ -24,7 +24,7 @@ # http://docs.oracle.com/javase/7/docs/api/java/util/Properties.html#load(java.io.Reader). # This accepts the name of the config as an argument, and expects the path of the property # file to be found in PROPERTIES_FILE. The value is returned through JAVA_PROPERTY_VALUE. -parse_java_property() { +function parse_java_property() { JAVA_PROPERTY_VALUE="" # return value config_buffer="" # buffer for collecting parts of a config value multi_line=0 # whether this config is spanning multiple lines @@ -58,7 +58,7 @@ parse_java_property() { # Properly split java options, dealing with whitespace, double quotes and backslashes. # This accepts a string and returns the resulting list through SPLIT_JAVA_OPTS. -split_java_options() { +function split_java_options() { SPLIT_JAVA_OPTS=() # return value option_buffer="" # buffer for collecting parts of an option opened_quotes=0 # whether we are expecting a closing double quotes @@ -93,7 +93,7 @@ split_java_options() { # Put double quotes around each of the given java options that is a system property. # This accepts a list and returns the quoted list through QUOTED_JAVA_OPTS -quote_java_property() { +function quote_java_property() { QUOTED_JAVA_OPTS=() for opt in "$@"; do is_system_property=$(echo "$opt" | grep -e "^-D") @@ -106,3 +106,43 @@ quote_java_property() { export QUOTED_JAVA_OPTS } +# Gather all all spark-submit options into SUBMISSION_OPTS +function gatherSparkSubmitOpts() { + + if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then + echo "Function for printing usage of $0 is not set." 1>&2 + echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 + exit 1 + fi + + # NOTE: If you add or remove spark-sumbmit options, + # modify NOT ONLY this script but also SparkSubmitArgument.scala + SUBMISSION_OPTS=() + APPLICATION_OPTS=() + while (($#)); do + case "$1" in + --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ + --conf | --properties-file | --driver-memory | --driver-java-options | \ + --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + if [[ $# -lt 2 ]]; then + "$SUBMIT_USAGE_FUNCTION" + exit 1; + fi + SUBMISSION_OPTS+=("$1"); shift + SUBMISSION_OPTS+=("$1"); shift + ;; + + --verbose | -v | --supervise) + SUBMISSION_OPTS+=("$1"); shift + ;; + + *) + APPLICATION_OPTS+=("$1"); shift + ;; + esac + done + + export SUBMISSION_OPTS + export APPLICATION_OPTS +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java deleted file mode 100644 index 0d31894d6ec7a..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.util.concurrent.TimeUnit; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private static final Logger LOG = LoggerFactory.getLogger(FileClient.class.getName()); - - private final FileClientHandler handler; - private Channel channel = null; - private Bootstrap bootstrap = null; - private EventLoopGroup group = null; - private final int connectTimeout; - private final int sendTimeout = 60; // 1 min - - FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - group = new OioEventLoopGroup(); - bootstrap = new Bootstrap(); - bootstrap.group(group) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted while trying to connect", e); - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - try { - // Should be able to send the message to network link channel. - boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); - if (!bSent) { - throw new RuntimeException("Failed to send"); - } - } catch (InterruptedException e) { - LOG.error("Error", e); - } - } - - public void close() { - if (group != null) { - group.shutdownGracefully(); - group = null; - bootstrap = null; - } - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java deleted file mode 100644 index c93425e2787dc..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private static final Logger LOG = LoggerFactory.getLogger(FileServer.class.getName()); - - private EventLoopGroup bossGroup = null; - private EventLoopGroup workerGroup = null; - private ChannelFuture channelFuture = null; - private int port = 0; - - FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bossGroup = new OioEventLoopGroup(); - workerGroup = new OioEventLoopGroup(); - - ServerBootstrap bootstrap = new ServerBootstrap(); - bootstrap.group(bossGroup, workerGroup) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - Thread blockingThread = new Thread() { - @Override - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); - channelFuture = null; - } - - // Shutdown event groups - if (bossGroup != null) { - bossGroup.shutdownGracefully(); - bossGroup = null; - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully(); - workerGroup = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java deleted file mode 100644 index c0133e19c7f79..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.channel.DefaultFileRegion; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; - -class FileServerHandler extends SimpleChannelInboundHandler { - - private static final Logger LOG = LoggerFactory.getLogger(FileServerHandler.class.getName()); - - private final PathResolver pResolver; - - FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { - BlockId blockId = BlockId.apply(blockIdString); - FileSegment fileSegment = pResolver.getBlockLocation(blockId); - // if getBlockLocation returns null, close the channel - if (fileSegment == null) { - //ctx.close(); - return; - } - File file = fileSegment.file(); - if (file.exists()) { - if (!file.isFile()) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = fileSegment.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = (int) length; - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.write(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), fileSegment.offset(), fileSegment.length())); - } catch (Exception e) { - LOG.error("Exception: ", e); - } - } else { - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - LOG.error("Exception: ", cause); - ctx.close(); - } -} diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 24ccce21b62ca..83ae57b7f1516 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -21,6 +21,7 @@ import akka.actor.Actor import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.ActorLogReceive /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -36,8 +37,10 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { - override def receive = { +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index f40baa8e43592..5c262bcbddf76 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.interrupted) { + if (context.isInterrupted) { throw new TaskKilledException } else { delegate.hasNext diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 894091761485d..51705c895a55c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,10 +38,10 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - def receive = { + override def receiveWithLogging = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 51f40c339d13c..2b99b8a5af250 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task */ @DeveloperApi class TaskContext( @@ -39,13 +47,45 @@ class TaskContext( def splitId = partitionId // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] // Whether the corresponding task has been killed. - @volatile var interrupted: Boolean = false + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + /** Checks whether the task has completed. */ + def isCompleted: Boolean = completed - // Whether the task has completed, before the onCompleteCallbacks are executed. - @volatile var completed: Boolean = false + /** Checks whether the task has been killed. */ + def isInterrupted: Boolean = interrupted + + // TODO: Also track whether the task has completed successfully or with exception. + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } /** * Add a callback function to be executed on task completion. An example use @@ -53,13 +93,22 @@ class TaskContext( * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ + @deprecated("use addTaskCompletionListener", "1.1.0") def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } } - def executeOnCompleteCallbacks() { + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _() } + onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 76d4193e96aea..feeb6c02caa78 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], - exact: Boolean, seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], - exact: Boolean): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * Return a subset of this RDD sampled by key (via stratified sampling). - * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. */ - def sampleByKey(withReplacement: Boolean, + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, seed) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) /** - * Return a subset of this RDD sampled by key (via stratified sampling). + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** * Return the union of this RDD and another one. Any identical elements will appear multiple diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index f3b05e1243045..49dc95f349eac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ @@ -42,7 +43,7 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0b5322c6fb965..9f5c5bd30f0c9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -68,7 +68,7 @@ private[spark] class PythonRDD( // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - context.addOnCompleteCallback { () => + context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() // Cleanup the worker socket. This will also cause the Python worker to exit. @@ -137,7 +137,7 @@ private[spark] class PythonRDD( } } catch { - case e: Exception if context.interrupted => + case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException @@ -176,7 +176,7 @@ private[spark] class PythonRDD( /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { - assert(context.completed) + assert(context.isCompleted) this.interrupt() } @@ -209,7 +209,7 @@ private[spark] class PythonRDD( PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() } catch { - case e: Exception if context.completed || context.interrupted => + case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) case e: Exception => @@ -235,10 +235,10 @@ private[spark] class PythonRDD( override def run() { // Kill the worker if it is interrupted, checking until task completion. // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.interrupted && !context.completed) { + while (!context.isInterrupted && !context.isCompleted) { Thread.sleep(2000) } - if (!context.completed) { + if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") env.destroyPythonWorker(pythonExec, envVars.toMap, worker) @@ -372,8 +372,8 @@ private[spark] object PythonRDD extends Logging { batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, @@ -440,9 +440,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { @@ -509,9 +509,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { @@ -558,7 +558,7 @@ private[spark] object PythonRDD extends Logging { for { k <- Option(keyClass) v <- Option(valueClass) - } yield (Class.forName(k), Class.forName(v)) + } yield (Utils.classForName(k), Utils.classForName(v)) } private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, @@ -621,10 +621,10 @@ private[spark] object PythonRDD extends Logging { val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) - val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]]) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) } @@ -653,7 +653,7 @@ private[spark] object PythonRDD extends Logging { val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7af260d0b7f26..bf716a8ab025b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -68,7 +68,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val socket = new Socket(daemonHost, daemonPort) val pid = new DataInputStream(socket.getInputStream).readInt() if (pid < 0) { - throw new IllegalStateException("Python daemon failed to launch worker") + throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } daemonWorkers.put(socket, pid) socket diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c07003784e8ac..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,12 +27,14 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { +private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) + extends Actor with ActorLogReceive with Logging { + var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) @@ -114,7 +116,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends } } - override def receive = { + override def receiveWithLogging = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 5cdee4e4f9d0a..28d4e0f65a560 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -233,6 +233,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { // Delineates parsing of Spark options from parsing of user options. parse(opts) + /** + * NOTE: If you add or remove spark-submit options, + * modify NOT ONLY this file but also utils.sh + */ def parse(opts: Seq[String]): Unit = opts match { case ("--name") :: value :: tail => name = value diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index d38e9e79204c2..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -56,7 +56,7 @@ private[spark] class AppClient( var registered = false var activeMasterUrl: String = null - class ClientActor extends Actor with Logging { + class ClientActor extends Actor with ActorLogReceive with Logging { var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times @@ -119,7 +119,7 @@ private[spark] class AppClient( .contains(remoteUrl.hostPort) } - override def receive = { + override def receiveWithLogging = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 72d0589689e71..d3674427b1271 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -46,6 +46,11 @@ private[spark] class ApplicationInfo( init() + private def readObject(in: java.io.ObjectInputStream): Unit = { + in.defaultReadObject() + init() + } + private def init() { state = ApplicationState.WAITING executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index c87b66f047dc8..38db02cd2421b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a70ecdb375373..cfa2c028a807b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,14 +42,14 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class Master( host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -167,7 +167,7 @@ private[spark] class Master( context.stop(leaderElectionAgent) } - override def receive = { + override def receiveWithLogging = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 36c1b87b7f684..9c3f79f1244b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" + override val metricRegistry = new MetricRegistry() + override val sourceName = "master" // Gauge for worker numbers in cluster metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 458d9947bd873..80fde7e4b2624 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -51,7 +51,7 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher Utils.checkHost(host, "Expected hostname") @@ -136,7 +136,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) + webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -187,7 +187,7 @@ private[spark] class Worker( } } - override def receive = { + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true @@ -373,7 +373,8 @@ private[spark] class Worker( private[spark] object Worker extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) - val args = new WorkerArguments(argStrings) + val conf = new SparkConf + val args = new WorkerArguments(argStrings, conf) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index dc5158102054e..1e295aaa48c30 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker import java.lang.management.ManagementFactory import org.apache.spark.util.{IntParam, MemoryParam, Utils} +import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String]) { +private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -46,6 +47,9 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index b7ddd8c816cbc..df1e01b23b932 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() + override val sourceName = "worker" + override val metricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] { override def getValue: Int = worker.executors.size diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 530c147000904..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -22,13 +22,15 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, Di import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat +import org.apache.spark.util.ActorLogReceive /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) extends Actor - with Logging { +private[spark] class WorkerWatcher(workerUrl: String) + extends Actor with ActorLogReceive with Logging { + override def preStart() { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -48,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receive = { + override def receiveWithLogging = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 47fbda600bea7..b07942a9ca729 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -34,8 +34,8 @@ private[spark] class WorkerWebUI( val worker: Worker, val workDir: File, - port: Option[Int] = None) - extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI") + requestedPort: Int) + extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) @@ -55,10 +55,5 @@ class WorkerWebUI( } private[spark] object WorkerWebUI { - val DEFAULT_PORT = 8081 val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR - - def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) - } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1f46a0f176490..13af5b6f5812d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -31,14 +31,15 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, cores: Int, - sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { + sparkProperties: Seq[(String, String)]) + extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eac1f2326a29d..fb3f7bd54bbfa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -99,6 +99,9 @@ private[spark] class Executor( private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 0ed52cfe9df61..d6721586566c2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -35,9 +35,10 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) }) } - val metricRegistry = new MetricRegistry() + override val metricRegistry = new MetricRegistry() + // TODO: It would be nice to pass the application name here - val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor.%s".format(executorId) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 1b66218d86dd9..ef9c43ecf14f6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -46,17 +46,24 @@ trait CompressionCodec { private[spark] object CompressionCodec { + + private val shortCompressionCodecNames = Map( + "lz4" -> classOf[LZ4CompressionCodec].getName, + "lzf" -> classOf[LZFCompressionCodec].getName, + "snappy" -> classOf[SnappyCompressionCodec].getName) + def createCodec(conf: SparkConf): CompressionCodec = { createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val ctor = Class.forName(codecName, true, Utils.getContextOrSparkClassLoader) + val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) .getConstructor(classOf[SparkConf]) ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = "snappy" } diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index f865f9648a91e..635bff2cd7ec8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -21,12 +21,9 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() + override val sourceName = "jvm" + override val metricRegistry = new MetricRegistry() - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) + metricRegistry.registerAll(new GarbageCollectorMetricSet) + metricRegistry.registerAll(new MemoryUsageGaugeSet) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala new file mode 100644 index 0000000000000..c6d35f73db545 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClient.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.TimeUnit + +import io.netty.bootstrap.Bootstrap +import io.netty.channel.{Channel, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioSocketChannel + +import org.apache.spark.Logging + +class FileClient(handler: FileClientHandler, connectTimeout: Int) extends Logging { + + private var channel: Channel = _ + private var bootstrap: Bootstrap = _ + private var group: EventLoopGroup = _ + private val sendTimeout = 60 + + def init(): Unit = { + group = new OioEventLoopGroup + bootstrap = new Bootstrap + bootstrap.group(group) + .channel(classOf[OioSocketChannel]) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Integer.valueOf(connectTimeout)) + .handler(new FileClientChannelInitializer(handler)) + } + + def connect(host: String, port: Int) { + try { + channel = bootstrap.connect(host, port).sync().channel() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted while trying to connect", e) + close() + } + } + + def waitForClose(): Unit = { + try { + channel.closeFuture.sync() + } catch { + case e: InterruptedException => + logWarning("FileClient interrupted", e) + } + } + + def sendRequest(file: String): Unit = { + try { + val bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS) + if (!bSent) { + throw new RuntimeException("Failed to send") + } + } catch { + case e: InterruptedException => + logError("Error", e) + } + } + + def close(): Unit = { + if (group != null) { + group.shutdownGracefully() + group = null + bootstrap = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala new file mode 100644 index 0000000000000..f4261c13f70a8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientChannelInitializer.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.string.StringEncoder + + +class FileClientChannelInitializer(handler: FileClientHandler) + extends ChannelInitializer[SocketChannel] { + + def initChannel(channel: SocketChannel) { + channel.pipeline.addLast("encoder", new StringEncoder).addLast("handler", handler) + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala similarity index 51% rename from core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java rename to core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala index 63d3d927255f9..017302ec7d33d 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileClientHandler.scala @@ -15,41 +15,36 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} -import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockId -abstract class FileClientHandler extends SimpleChannelInboundHandler { - private FileHeader currentHeader = null; +abstract class FileClientHandler extends SimpleChannelInboundHandler[ByteBuf] { - private volatile boolean handlerCalled = false; + private var currentHeader: FileHeader = null - public boolean isComplete() { - return handlerCalled; - } + @volatile + private var handlerCalled: Boolean = false + + def isComplete: Boolean = handlerCalled + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(BlockId blockId); + def handleError(blockId: BlockId) - @Override - public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + if (currentHeader == null && in.readableBytes >= FileHeader.HEADER_SIZE) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE)) } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); + if (in.readableBytes >= currentHeader.fileLen) { + handle(ctx, in, currentHeader) + handlerCalled = true + currentHeader = null + ctx.close() } } - } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 136c1912045aa..607e560ff277f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -26,7 +26,7 @@ private[spark] class FileHeader ( val fileLen: Int, val blockId: BlockId) extends Logging { - lazy val buffer = { + lazy val buffer: ByteBuf = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) @@ -62,11 +62,10 @@ private[spark] object FileHeader { new FileHeader(length, blockId) } - def main (args:Array[String]) { + def main(args:Array[String]) { val header = new FileHeader(25, TestBlockId("my_block")) val buf = header.buffer val newHeader = FileHeader.create(buf) System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } - diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala new file mode 100644 index 0000000000000..dff77950659af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServer.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.{ChannelFuture, ChannelOption, EventLoopGroup} +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.oio.OioServerSocketChannel + +import org.apache.spark.Logging + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer(pResolver: PathResolver, private var port: Int) extends Logging { + + private val addr: InetSocketAddress = new InetSocketAddress(port) + private var bossGroup: EventLoopGroup = new OioEventLoopGroup + private var workerGroup: EventLoopGroup = new OioEventLoopGroup + + private var channelFuture: ChannelFuture = { + val bootstrap = new ServerBootstrap + bootstrap.group(bossGroup, workerGroup) + .channel(classOf[OioServerSocketChannel]) + .option(ChannelOption.SO_BACKLOG, java.lang.Integer.valueOf(100)) + .option(ChannelOption.SO_RCVBUF, java.lang.Integer.valueOf(1500)) + .childHandler(new FileServerChannelInitializer(pResolver)) + bootstrap.bind(addr) + } + + try { + val boundAddress = channelFuture.sync.channel.localAddress.asInstanceOf[InetSocketAddress] + port = boundAddress.getPort + } catch { + case ie: InterruptedException => + port = 0 + } + + /** Start the file server asynchronously in a new thread. */ + def start(): Unit = { + val blockingThread: Thread = new Thread { + override def run(): Unit = { + try { + channelFuture.channel.closeFuture.sync + logInfo("FileServer exiting") + } catch { + case e: InterruptedException => + logError("File server start got interrupted", e) + } + // NOTE: bootstrap is shutdown in stop() + } + } + blockingThread.setDaemon(true) + blockingThread.start() + } + + def getPort: Int = port + + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bossGroup != null) { + bossGroup.shutdownGracefully() + bossGroup = null + } + if (workerGroup != null) { + workerGroup.shutdownGracefully() + workerGroup = null + } + } +} + diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala similarity index 54% rename from core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala index 46efec8f8d963..aaa2f913d0269 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala @@ -15,27 +15,20 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters} +import io.netty.handler.codec.string.StringDecoder -class FileServerChannelInitializer extends ChannelInitializer { +class FileServerChannelInitializer(pResolver: PathResolver) + extends ChannelInitializer[SocketChannel] { - private final PathResolver pResolver; - - FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; - } - - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("stringDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); + override def initChannel(channel: SocketChannel): Unit = { + channel.pipeline + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*)) + .addLast("stringDecoder", new StringDecoder) + .addLast("handler", new FileServerHandler(pResolver)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala new file mode 100644 index 0000000000000..96f60b2883ad9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/FileServerHandler.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.FileInputStream + +import io.netty.channel.{DefaultFileRegion, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, FileSegment} + + +class FileServerHandler(pResolver: PathResolver) + extends SimpleChannelInboundHandler[String] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, blockIdString: String): Unit = { + val blockId: BlockId = BlockId(blockIdString) + val fileSegment: FileSegment = pResolver.getBlockLocation(blockId) + if (fileSegment == null) { + return + } + val file = fileSegment.file + if (file.exists) { + if (!file.isFile) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + val length: Long = fileSegment.length + if (length > Integer.MAX_VALUE || length <= 0) { + ctx.write(new FileHeader(0, blockId).buffer) + ctx.flush() + return + } + ctx.write(new FileHeader(length.toInt, blockId).buffer) + try { + val channel = new FileInputStream(file).getChannel + ctx.write(new DefaultFileRegion(channel, fileSegment.offset, fileSegment.length)) + } catch { + case e: Exception => + logError("Exception: ", e) + } + } else { + ctx.write(new FileHeader(0, blockId).buffer) + } + ctx.flush() + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError("Exception: ", cause) + ctx.close() + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala old mode 100755 new mode 100644 similarity index 80% rename from core/src/main/java/org/apache/spark/network/netty/PathResolver.java rename to core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala index 7ad8d03efbadc..0d7695072a7b1 --- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java +++ b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; +import org.apache.spark.storage.{BlockId, FileSegment} -public interface PathResolver { +trait PathResolver { /** Get the file segment in which the given block resides. */ - FileSegment getBlockLocation(BlockId blockId); + def getBlockLocation(blockId: BlockId): FileSegment } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 7ef7aecc6a9fb..95958e30f7eeb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -32,7 +32,7 @@ private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) ext server.stop() } - def port: Int = server.getPort() + def port: Int = server.getPort } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 34c51b833025e..20938781ac694 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) + context.addTaskCompletionListener(context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 9ca971c8a4c27..f233544d128f5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -119,11 +119,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the left except for the last which is closed + * to the right except for the last which is closed * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8d92ea01d9a3f..c8623314c98eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -197,7 +197,7 @@ class HadoopRDD[K, V]( reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val key: K = reader.createKey() val value: V = reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 8947e66f4577c..0e38f224ac81d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag]( } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7dfec9a18ec67..58f707b9b4634 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -129,7 +129,7 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) + context.addTaskCompletionListener(context => close()) var havePair = false var finished = false diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93af50c0a9cd1..f6d9d12fe9006 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use - * additional passes over the RDD to create a sample size that's exactly equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling - * without replacement, we need one additional pass over the RDD to guarantee sample size; - * when sampling with replacement, we need two additional passes. + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. * * @param withReplacement whether to sample with or without replacement * @param fractions map of specific keys to sampling rates * @param seed seed for the random number generator - * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key * @return RDD containing the sampled subset */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - exact: Boolean = false, - seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") val samplingFunc = if (withReplacement) { - StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). + * + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @return RDD containing the sampled subset + */ + @Experimental + def sampleByKeyExact(withReplacement: Boolean, + fractions: Map[K, Double], + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } @@ -237,6 +260,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey[V]((v: V) => v, func, func, partitioner) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + /** * Merge the values for each key using an associative reduce function, but return the results * immediately to the master as a Map. This will also perform the merging locally on each mapper @@ -374,15 +416,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - /** * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. @@ -482,16 +515,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 19e10bd04681b..daea2617e62ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1299,6 +1299,19 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { + // Get a debug description of an rdd without its children + def debugSelf (rdd: RDD[_]): Seq[String] = { + import Utils.bytesToString + + val persistence = storageLevel.description + val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( + info.numCachedPartitions, bytesToString(info.memSize), + bytesToString(info.tachyonSize), bytesToString(info.diskSize))) + + s"$rdd [$persistence]" +: storageInfo + } + // Apply a different rule to the last child def debugChildren(rdd: RDD[_], prefix: String): Seq[String] = { val len = rdd.dependencies.length @@ -1324,7 +1337,11 @@ abstract class RDD[T: ClassTag]( val partitionStr = "(" + rdd.partitions.size + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) - Seq(partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix $desc" + } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { val partitionStr = "(" + rdd.partitions.size + ")" @@ -1334,7 +1351,11 @@ abstract class RDD[T: ClassTag]( thisPrefix + (if (isLastChild) " " else "| ") + (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset))) - Seq(thisPrefix + "+-" + partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$thisPrefix+-$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix$desc" + } ++ debugChildren(rdd, nextPrefix) } def debugString(rdd: RDD[_], prefix: String = "", @@ -1342,9 +1363,8 @@ abstract class RDD[T: ClassTag]( isLastChild: Boolean = false): Seq[String] = { if (isShuffle) { shuffleDebugString(rdd, prefix, isLastChild) - } - else { - Seq(prefix + rdd) ++ debugChildren(rdd, prefix) + } else { + debugSelf(rdd).map(prefix + _) ++ debugChildren(rdd, prefix) } } firstDebugString(this).mkString("\n") diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 197167ecad0bd..0c97eb0aaa51f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -83,8 +83,7 @@ class UnionRDD[T: ClassTag]( override def compute(s: Partition, context: TaskContext): Iterator[T] = { val part = s.asInstanceOf[UnionPartition[T]] - val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] - parentRdd.iterator(part.parentPartition, context) + parent[T](part.parentRddIndex).iterator(part.parentPartition, context) } override def getPreferredLocations(s: Partition): Seq[String] = diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 430e45ada5808..b86cfbfa48fbe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -121,6 +121,9 @@ class DAGScheduler( private[scheduler] var eventProcessActor: ActorRef = _ + /** If enabled, we may run certain actions like take() and first() locally. */ + private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -631,7 +634,7 @@ class DAGScheduler( val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { - taskContext.executeOnCompleteCallbacks() + taskContext.markTaskCompleted() } } catch { case e: Exception => @@ -732,7 +735,9 @@ class DAGScheduler( logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + val shouldRunLocally = + localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 5878e733908f5..94944399b134a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.DAGScheduler".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.DAGScheduler".format(sc.appName) metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 406147f167bf3..7378ce923f0ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -127,6 +127,8 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + // No-op because logging every update would be overkill + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } /** * Stop logging events. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 47dd112f68325..4d6b5c81883b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -162,6 +162,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime val inputMetrics = taskMetrics.inputMetrics match { case Some(metrics) => " READ_METHOD=" + metrics.readMethod.toString + @@ -179,11 +180,13 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case Some(metrics) => + " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + + " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + inputMetrics + shuffleReadMetrics + - writeMetrics) + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + shuffleReadMetrics + writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index d09fd7aa57642..2ccbd8edeb028 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U]( try { func(context, rdd.iterator(partition, context)) } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 11255c07469d4..381eff2147e95 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask( } throw e } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c5e421404a21..6aa0cca06878d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -46,7 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) - context.taskMetrics.hostname = Utils.localHostName(); + context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) @@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex def kill(interruptThread: Boolean) { _killed = true if (context != null) { - context.interrupted = true + context.markInterrupted() } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 20a4bd12f93f6..d9d53faf843ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -690,8 +690,7 @@ private[spark] class TaskSetManager( handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } // recalculate valid locality levels and waits when executor is lost - myLocalityLevels = computeValidLocalityLevels() - localityWaits = myLocalityLevels.map(getLocalityWait) + recomputeLocality() } /** @@ -775,9 +774,15 @@ private[spark] class TaskSetManager( levels.toArray } - def executorAdded() { + def recomputeLocality() { + val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) + currentLocalityIndex = getLocalityIndex(previousLocalityLevel) + } + + def executorAdded() { + recomputeLocality() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9f085eef46720..2a3711ae2a78c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,7 +30,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} import org.apache.spark.ui.JettyUtils /** @@ -47,21 +47,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - var totalExpectedExecutors = new AtomicInteger(0) + var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered executors / total expected executors) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = conf.getDouble("spark.scheduler.minRegisteredExecutorsRatio", 0) - if (minRegisteredRatio > 1) minRegisteredRatio = 1 - // Whatever minRegisteredExecutorsRatio is arrived, submit tasks after the time(milliseconds). + var minRegisteredRatio = + math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) + // Submit tasks after maxRegisteredWaitingTime milliseconds + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredExecutorsWaitingTime", 30000) + conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - var ready = if (minRegisteredRatio <= 0) true else false - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + + override protected def log = CoarseGrainedSchedulerBackend.this.log + private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] @@ -79,7 +82,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receive = { + def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { @@ -94,12 +97,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) - if (executorActor.size >= totalExpectedExecutors.get() * minRegisteredRatio && !ready) { - ready = true - logInfo("SchedulerBackend is ready for scheduling beginning, registered executors: " + - executorActor.size + ", total expected executors: " + totalExpectedExecutors.get() + - ", minRegisteredExecutorsRatio: " + minRegisteredRatio) - } + totalRegisteredExecutors.addAndGet(1) makeOffers() } @@ -268,14 +266,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } + def sufficientResourcesRegistered(): Boolean = true + override def isReady(): Boolean = { - if (ready) { + if (sufficientResourcesRegistered) { + logInfo("SchedulerBackend is ready for scheduling beginning after " + + s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { - ready = true logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - "maxRegisteredExecutorsWaitingTime: " + maxRegisteredWaitingTime) + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") return true } false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a28446f6c8a6b..589dba2e40d20 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -36,6 +36,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() @@ -97,7 +98,6 @@ private[spark] class SparkDeploySchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - totalExpectedExecutors.addAndGet(1) logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } @@ -110,4 +110,8 @@ private[spark] class SparkDeploySchedulerBackend( logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason.toString) } + + override def sufficientResourcesRegistered(): Boolean = { + totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3d1cf312ccc97..bec9502f20466 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,9 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -43,7 +43,7 @@ private case class StopExecutor() private[spark] class LocalActor( scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, - private val totalCores: Int) extends Actor with Logging { + private val totalCores: Int) extends Actor with ActorLogReceive with Logging { private var freeCores = totalCores @@ -53,7 +53,7 @@ private[spark] class LocalActor( val executor = new Executor( localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) - def receive = { + override def receiveWithLogging = { case ReviveOffers => reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 34bc3124097bb..af33a2f2ca3e1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,7 +63,9 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { +private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) + extends SerializerInstance { + def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) - def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + override def newInstance(): SerializerInstance = { + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) + new JavaSerializerInstance(counterReset, classLoader) + } override def writeExternal(out: ObjectOutput) { out.writeInt(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 407cb9db6ee9a..99682220b4ab5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf) val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() kryo.setRegistrationRequired(registrationRequired) - val classLoader = Thread.currentThread.getContextClassLoader + + val oldClassLoader = Thread.currentThread.getContextClassLoader + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. @@ -79,15 +81,21 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Allow the user to register their own classes by setting spark.kryo.registrator - try { - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) + for (regCls <- registrator) { + logDebug("Running user registrator: " + regCls) + try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] + + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) + } catch { + case e: Exception => + throw new SparkException(s"Failed to invoke $regCls", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } - } catch { - case e: Exception => logError("Failed to run spark.kryo.registrator", e) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2f5cea469c61..e674438c8176c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} */ @DeveloperApi trait Serializer { + + /** + * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should + * make sure it is using this when set. + */ + @volatile protected var defaultClassLoader: Option[ClassLoader] = None + + /** + * Sets a class loader for the serializer to use in deserialization. + * + * @return this Serializer object + */ + def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + defaultClassLoader = Some(classLoader) + this + } + def newInstance(): SerializerInstance } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index bd31e3c5a187f..3ab07703b6f85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -31,7 +31,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -39,7 +39,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -55,8 +55,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", - 60000) + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) var timeoutCheckingTask: Cancellable = null @@ -67,9 +66,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus super.preStart() } - def receive = { + override def receiveWithLogging = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -118,7 +116,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! true case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 10b65286fb7db..2ba16b8476600 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -53,7 +53,7 @@ private[spark] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - class UpdateBlockInfo( + case class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, var storageLevel: StorageLevel, @@ -84,24 +84,6 @@ private[spark] object BlockManagerMessages { } } - object UpdateBlockInfo { - def apply( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo) - : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize)) - } - } - case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 6d4db064dff58..c194e0fed3367 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -23,6 +23,7 @@ import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ActorLogReceive /** * An actor to take commands from the master to execute options. For example, @@ -32,12 +33,12 @@ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receive = { + override def receiveWithLogging = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 3f14c40ec61cb..49fea6d9e2a76 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.BlockManager".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.BlockManager".format(sc.appName) metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 28f675c2bbb1e..0a09c24d61879 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -238,7 +238,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // If our vector's size has exceeded the threshold, request more memory val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * (memoryGrowthFactor - 1)).toLong + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { @@ -254,7 +254,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } // New threshold is currentSize * memoryGrowthFactor - memoryThreshold = currentSize + amountToRequest + memoryThreshold += amountToRequest } } elementsUnrolled += 1 diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 29e9cf947856f..6b4689291097f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -93,7 +93,7 @@ private[spark] object JettyUtils extends Logging { def createServletHandler( path: String, servlet: HttpServlet, - basePath: String = ""): ServletContextHandler = { + basePath: String): ServletContextHandler = { val prefixedPath = attachPrefix(basePath, path) val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a57a354620163..a3e9566832d06 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,6 +153,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => + stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3dcfaf76e4aba..15998404ed612 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -168,7 +168,7 @@ private[ui] class StageTableBase( {submissionTime} {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} {inputReadWithUnit} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 85db15472a00c..a336bf7e1ed02 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.util.collection.OpenHashSet import scala.collection.mutable.HashMap @@ -38,6 +39,7 @@ private[jobs] object UIData { class StageUIData { var numActiveTasks: Int = _ var numCompleteTasks: Int = _ + var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ var executorRunTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala new file mode 100644 index 0000000000000..332d0cbb2dc0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.Actor +import org.slf4j.Logger + +/** + * A trait to enable logging all Akka actor messages. Here's an example of using this: + * + * {{{ + * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { + * ... + * override def receiveWithLogging = { + * case GetLocations(blockId) => + * sender ! getLocations(blockId) + * ... + * } + * ... + * } + * }}} + * + */ +private[spark] trait ActorLogReceive { + self: Actor => + + override def receive: Actor.Receive = new Actor.Receive { + + private val _receiveWithLogging = receiveWithLogging + + override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + + override def apply(o: Any): Unit = { + if (log.isDebugEnabled) { + log.debug(s"[actor] received message $o from ${self.sender}") + } + val start = System.nanoTime + _receiveWithLogging.apply(o) + val timeTaken = (System.nanoTime - start).toDouble / 1000000 + if (log.isDebugEnabled) { + log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") + } + } + } + + def receiveWithLogging: Actor.Receive + + protected def log: Logger +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6f8eb1ee12634..1e18ec688c40d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -72,8 +72,9 @@ private[spark] object JsonProtocol { case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - // Not used, but keeps compiler happy + // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing + case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala new file mode 100644 index 0000000000000..c1b8bf052c0ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. + */ +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext) +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c60be4f8a11d2..019f68b160894 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -146,6 +146,9 @@ private[spark] object Utils extends Logging { Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess } + /** Preferred alternative to Class.forName(className) */ + def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + /** * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. */ @@ -284,17 +287,32 @@ private[spark] object Utils extends Logging { /** Copy all data from an InputStream to an OutputStream */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false) + closeStreams: Boolean = false): Long = { + var count = 0L try { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = in.asInstanceOf[FileInputStream].getChannel() + val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val size = inChannel.size() + + // In case transferTo method transferred less data than we have required. + while (count < size) { + count += inChannel.transferTo(count, size - count, outChannel) + } + } else { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + count += n + } } } + count } finally { if (closeStreams) { try { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b73d5e0cf1714..5d8a648d9551e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -745,12 +745,11 @@ private[spark] class ExternalSorter[K, V, C]( try { out = new FileOutputStream(outputFile) for (i <- 0 until numPartitions) { - val file = partitionWriters(i).fileSegment().file - in = new FileInputStream(file) - org.apache.spark.util.Utils.copyStream(in, out) + in = new FileInputStream(partitionWriters(i).fileSegment().file) + val size = org.apache.spark.util.Utils.copyStream(in, out, false) in.close() in = null - lengths(i) = file.length() + lengths(i) = size offsets(i + 1) = offsets(i) + lengths(i) } } finally { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 56150caa5d6ba..e1c13de04a0be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1239,12 +1239,28 @@ public Tuple2 call(Integer i) { Assert.assertTrue(worCounts.size() == 2); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); - JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKeyExact() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); Assert.assertTrue(wrExactCounts.size() == 2); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); - JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); Assert.assertTrue(worExactCounts.size() == 2); Assert.assertTrue(worExactCounts.get(0) == 2); diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java similarity index 58% rename from core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java rename to core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 264cf97d0209f..af34cdb03e4d1 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -15,25 +15,25 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.util; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; +import org.apache.spark.TaskContext; -class FileClientChannelInitializer extends ChannelInitializer { - private final FileClientHandler fhandler; - - FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } +/** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ +public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder()) - .addLast("handler", fhandler); + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.runningLocally(); + context.taskMetrics(); + context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 3f882a724b047..25be7f25c21bb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -56,15 +56,33 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lz4") + assert(codec.getClass === classOf[LZ4CompressionCodec]) + testCodec(codec) + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } + test("lzf compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lzf") + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } + + test("snappy compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "snappy") + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4f49d4a1d4d34..63d3ddb4af98a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("sampleByKey") { - def stratifier (fractionPositive: Double) = { - (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" - } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { - if (exact) { - return expected == actual - } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // Without replacement validation - def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } - assert(takeSample.size === takeSample.toSet.size) - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // With replacement validation - def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey().mapValues(count => - math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } - val groupedByKey = takeSample.groupBy(_._1) - for ((key, v) <- groupedByKey) { - if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { - // sample large enough for there to be repeats with high likelihood - assert(v.toSet.size < expectedSampleSize(key)) - } else { - if (exact) { - assert(v.toSet.size <= expectedSampleSize(key)) - } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) - } - } - } - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n) } - def checkAllCombos(stratifiedData: RDD[(String, Int)], - samplingRate: Double, - seed: Long, - n: Long) = { - takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } + } + test("sampleByKeyExact") { val defaultSeed = 1L // vary RDD size for (n <- List(100, 1000, 1000000)) { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // vary fractionPositive for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // Use the same data for the rest of the tests val fractionPositive = 0.3 val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) // vary seed for (seed <- defaultSeed to defaultSeed + 5L) { val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, seed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, n) } // vary sampling rate for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } } @@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] {shuffled.lookup(-1)} } + private object StratifiedAuxiliary { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + def testSampleExact(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, true, samplingRate, seed, n) + testPoisson(stratifiedData, true, samplingRate, seed, n) + } + + def testSample(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, false, samplingRate, seed, n) + testPoisson(stratifiedData, false, samplingRate, seed, n) + } + + // Without replacement validation + def testBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(false, fractions, seed) + } else { + stratifiedData.sampleByKey(false, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def testPoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(true, fractions, seed) + } else { + stratifiedData.sampleByKey(true, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case (k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]")) + } + } + } /* diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8c1b0fed11f72..bd829752eb401 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -141,7 +141,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } before { - sc = new SparkContext("local", "DAGSchedulerSuite") + // Enable local execution for this test + val conf = new SparkConf().set("spark.localExecution.enabled", "true") + sc = new SparkContext("local", "DAGSchedulerSuite", conf) sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 270f7e661045a..db2ad829a48f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => TaskContextSuite.completed = true) + context.addTaskCompletionListener(context => TaskContextSuite.completed = true) sys.error("failed") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ffd23380a886f..93e8ddacf8865 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -154,6 +154,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 + override def beforeEach() { + super.beforeEach() + FakeRackUtil.cleanUp() + } + test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -471,7 +476,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("new executors get added and lost") { // Assign host2 to rack2 - FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) @@ -504,7 +508,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("test RACK_LOCAL tasks") { - FakeRackUtil.cleanUp() // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") // Assign host2 to rack1 @@ -607,6 +610,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("Ensure TaskSetManager is usable after addition of levels") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2", "execB.1"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execA", "host1") + sched.addExecutor("execB.2", "host2") + manager.executorAdded() + assert(manager.pendingTasksWithNoPrefs.size === 0) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + assert(manager.resourceOffer("execA", "host1", ANY) !== None) + clock.advance(LOCALITY_WAIT * 4) + assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) + sched.removeExecutor("execA") + sched.removeExecutor("execB.2") + manager.executorLost("execA", "host1") + manager.executorLost("execB.2", "host2") + clock.advance(LOCALITY_WAIT * 4) + sched.addExecutor("execC", "host3") + manager.executorAdded() + // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: + assert(manager.resourceOffer("execC", "host3", ANY) !== None) + } + + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala new file mode 100644 index 0000000000000..11e8c9c4cb37f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.Kryo +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoDistributedTest._ + +class KryoSerializerDistributedSuite extends FunSuite { + + test("kryo objects are serialised consistently in different processes") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + conf.set("spark.task.maxFailures", "1") + + val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) + conf.setJars(List(jar.getPath)) + + val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val original = Thread.currentThread.getContextClassLoader + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + SparkEnv.get.serializer.setDefaultClassLoader(loader) + + val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache() + + // Randomly mix the keys so that the join below will require a shuffle with each partition + // sending data to multiple other partitions. + val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)} + + // Join the two RDDs, and force evaluation + assert(shuffledRDD.join(cachedRDD).collect().size == 1) + + LocalSparkContext.stop(sc) + } +} + +object KryoDistributedTest { + class MyCustomClass + + class AppJarRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + val classLoader = Thread.currentThread.getContextClassLoader + k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + } + } + + object AppJarRegistrator { + val customClassName = "KryoSerializerDistributedSuiteCustomClass" + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 789b773bae316..a579fd50bd9e4 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ class KryoSerializerSuite extends FunSuite with SharedSparkContext { @@ -207,8 +207,39 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } + + test("kryo with nonexistent custom registrator should fail") { + import org.apache.spark.{SparkConf, SparkException} + + val conf = new SparkConf(false) + conf.set("spark.kryo.registrator", "this.class.does.not.exist") + + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) + } + + test("default class loader can be set by a different thread") { + val ser = new KryoSerializer(new SparkConf) + + // First serialize the object + val serInstance = ser.newInstance() + val bytes = serInstance.serialize(new ClassLoaderTestingObject) + + // Deserialize the object to make sure normal deserialization works + serInstance.deserialize[ClassLoaderTestingObject](bytes) + + // Set a special, broken ClassLoader and make sure we get an exception on deserialization + ser.setDefaultClassLoader(new ClassLoader() { + override def loadClass(name: String) = throw new UnsupportedOperationException + }) + intercept[UnsupportedOperationException] { + ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) + } + } } +class ClassLoaderTestingObject + class KryoSerializerResizableOutputSuite extends FunSuite { import org.apache.spark.SparkConf import org.apache.spark.SparkContext diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 53df9b5a3f1d5..d48c8bde12905 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -74,8 +74,10 @@ def fail(msg): def run_cmd(cmd): if isinstance(cmd, list): + print " ".join(cmd) return subprocess.check_output(cmd) else: + print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/dev/mima b/dev/mima index 4c3e65039b160..09e4482af5f3d 100755 --- a/dev/mima +++ b/dev/mima @@ -26,7 +26,9 @@ cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3076eb847b420..721f09be5be6d 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -19,67 +19,156 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. +# Environment variables are populated by the code here: +#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +function get_jq () { + # Get jq so we can parse some JSON, man. + # Essential if we want to do anything with the GitHub API responses. + local JQ_EXECUTABLE_URL="http://stedolan.github.io/jq/download/linux64/jq" + + echo "Fetching jq from ${JQ_EXECUTABLE_URL}" + + curl --silent --output "$FWDIR/dev/jq" "$JQ_EXECUTABLE_URL" + local curl_status=$? + + if [ $curl_status -ne 0 ]; then + echo "Failed to get jq." >&2 + return $curl_status + fi + + chmod u+x "$FWDIR/dev/jq" +} + COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" +PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" + +function post_message () { + local message=$1 + local data="{\"body\": \"$message\"}" + local HTTP_CODE_HEADER="HTTP Response Code: " + + echo "Attempting to post to Github..." + + local curl_output=$( + curl `#--dump-header -` \ + --silent \ + --user x-oauth-basic:$GITHUB_OAUTH_KEY \ + --request POST \ + --data "$data" \ + --write-out "${HTTP_CODE_HEADER}%{http_code}\n" \ + --header "Content-Type: application/json" \ + "$COMMENTS_URL" #> /dev/null #| "$FWDIR/dev/jq" .id #| head -n 8 + ) + local curl_status=${PIPESTATUS[0]} + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to post message to GitHub." >&2 + echo " > curl_status: ${curl_status}" >&2 + echo " > curl_output: ${curl_output}" >&2 + echo " > data: ${data}" >&2 + # exit $curl_status + fi + + local api_response=$( + echo "${curl_output}" \ + | grep -v -e "^${HTTP_CODE_HEADER}" + ) + + local http_code=$( + echo "${curl_output}" \ + | grep -e "^${HTTP_CODE_HEADER}" \ + | sed -r -e "s/^${HTTP_CODE_HEADER}//g" + ) + + if [ -n "$http_code" ] && [ "$http_code" -ne "201" ]; then + echo " > http_code: ${http_code}." >&2 + echo " > api_response: ${api_response}" >&2 + echo " > data: ${data}" >&2 + fi + + if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then + echo " > Post successful." + fi +} + +COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" +# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( +short_commit_hash=${ghprbActualCommit:0:7} + +# check PR merge-ability and check for new public classes +{ + if [ "$sha1" == "$ghprbActualCommit" ]; then + merge_note=" * This patch **does not** merge cleanly!" + else + merge_note=" * This patch merges cleanly." + + non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") + new_public_classes=$( + git diff master ${non_test_files} `# diff this patch against master and...` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) -function post_message { - message=$1 - data="{\"body\": \"$message\"}" - echo "Attempting to post to Github:" - echo "$data" + if [ "$new_public_classes" == "" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" + fi + fi +} - curl -D- -u x-oauth-basic:$GITHUB_OAUTH_KEY -X POST --data "$data" -H \ - "Content-Type: application/json" \ - $COMMENTS_URL | head -n 8 +# post start message +{ + start_message="\ + [QA tests have started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${short_commit_hash}\`](${COMMIT_URL})." + + start_message="${start_message}\n${merge_note}" + # start_message="${start_message}\n${public_classes_note}" + + post_message "$start_message" } -start_message="QA tests have started for PR $ghprbPullId." -if [ "$sha1" == "$ghprbActualCommit" ]; then - start_message="$start_message This patch DID NOT merge cleanly! " -else - start_message="$start_message This patch merges cleanly. " -fi -start_message="$start_message
View progress: " -start_message="$start_message${BUILD_URL}consoleFull" - -post_message "$start_message" - -./dev/run-tests -test_result="$?" - -result_message="QA results for PR $ghprbPullId:
" - -if [ "$test_result" -eq "0" ]; then - result_message="$result_message- This patch PASSES unit tests.
" -else - result_message="$result_message- This patch FAILED unit tests.
" -fi - -if [ "$sha1" != "$ghprbActualCommit" ]; then - result_message="$result_message- This patch merges cleanly
" - non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") - new_public_classes=$(git diff master $non_test_files \ - | grep -e "trait " -e "class " \ - | grep -e "{" -e "(" \ - | grep -v -e \@\@ -e private \ - | grep \+ \ - | sed "s/\+ *//" \ - | tr "\n" "~" \ - | sed "s/~/
/g") - if [ "$new_public_classes" == "" ]; then - result_message="$result_message- This patch adds no public classes
" +# run tests +{ + ./dev/run-tests + test_result="$?" + + if [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes** unit tests." else - result_message="$result_message- This patch adds the following public classes (experimental):
" - result_message="$result_message$new_public_classes" + test_result_note=" * This patch **fails** unit tests." fi -fi -result_message="${result_message}
For more information see test ouptut:" -result_message="${result_message}
${BUILD_URL}consoleFull" +} -post_message "$result_message" +# post end message +{ + result_message="\ + [QA tests have finished](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${short_commit_hash}\`](${COMMIT_URL})." + + result_message="${result_message}\n${test_result_note}" + result_message="${result_message}\n${merge_note}" + result_message="${result_message}\n${public_classes_note}" + + post_message "$result_message" +} exit $test_result diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 672d0ef114f6d..4d87ab92cec5b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,6 +96,15 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} +# Building Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. +{% highlight bash %} +mvn -Phive-thriftserver assembly +{% endhighlight %} + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). diff --git a/docs/configuration.md b/docs/configuration.md index 4d27c5a918fe0..c408c468dcd94 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -373,10 +373,12 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - org.apache.spark.io.
SnappyCompressionCodec + snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, + Spark provides three codecs: lz4, lzf, and snappy. You + can also use fully qualified class names to specify the codec, e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec. @@ -560,7 +562,7 @@ Apart from these, the following properties are also available, and may be useful - spark.hadoop.validateOutputSpecs + spark.hadoop.validateOutputSpecs true If set to true, validates the output specification (e.g. checking if the output directory already exists) used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing @@ -568,7 +570,7 @@ Apart from these, the following properties are also available, and may be useful previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. - spark.executor.heartbeatInterval + spark.executor.heartbeatInterval 10000 Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress @@ -825,24 +827,34 @@ Apart from these, the following properties are also available, and may be useful - spark.scheduler.minRegisteredExecutorsRatio + spark.scheduler.minRegisteredResourcesRatio 0 - The minimum ratio of registered executors (registered executors / total expected executors) + The minimum ratio of registered resources (registered resources / total expected resources) + (resources are executors in yarn mode, CPU cores in standalone mode) to wait for before scheduling begins. Specified as a double between 0 and 1. - Regardless of whether the minimum ratio of executors has been reached, + Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime 30000 - Maximum amount of time to wait for executors to register before scheduling begins + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds). + + spark.localExecution.enabled + false + + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver. + + #### Security diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index f9585251fafac..8752df412950a 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -9,17 +9,17 @@ displayTitle: MLlib - Basics MLlib supports local vectors and matrices stored on a single machine, as well as distributed matrices backed by one or more RDDs. -In the current implementation, local vectors and matrices are simple data models -to serve public interfaces. The underlying linear algebra operations are provided by +Local vectors and local matrices are simple data models +that serve as public interfaces. The underlying linear algebra operations are provided by [Breeze](http://www.scalanlp.org/) and [jblas](http://jblas.org/). -A training example used in supervised learning is called "labeled point" in MLlib. +A training example used in supervised learning is called a "labeled point" in MLlib. ## Local vector A local vector has integer-typed and 0-based indices and double-typed values, stored on a single machine. MLlib supports two types of local vectors: dense and sparse. A dense vector is backed by a double array representing its entry values, while a sparse vector is backed by two parallel -arrays: indices and values. For example, a vector $(1.0, 0.0, 3.0)$ can be represented in dense +arrays: indices and values. For example, a vector `(1.0, 0.0, 3.0)` can be represented in dense format as `[1.0, 0.0, 3.0]` or in sparse format as `(3, [0, 2], [1.0, 3.0])`, where `3` is the size of the vector. @@ -44,8 +44,7 @@ val sv1: Vector = Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0)) val sv2: Vector = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0))) {% endhighlight %} -***Note*** - +***Note:*** Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. @@ -110,8 +109,8 @@ sv2 = sps.csc_matrix((np.array([1.0, 3.0]), np.array([0, 2]), np.array([0, 2])), A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. -For binary classification, label should be either $0$ (negative) or $1$ (positive). -For multiclass classification, labels should be class indices staring from zero: $0, 1, 2, \ldots$. +For binary classification, a label should be either `0` (negative) or `1` (positive). +For multiclass classification, labels should be class indices starting from zero: `0, 1, 2, ...`.
@@ -172,7 +171,7 @@ neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) It is very common in practice to have sparse training data. MLlib supports reading training examples stored in `LIBSVM` format, which is the default format used by [`LIBSVM`](http://www.csie.ntu.edu.tw/~cjlin/libsvm/) and -[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format. Each line +[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format in which each line represents a labeled sparse feature vector using the following format: ~~~ @@ -226,7 +225,7 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") ## Local matrix A local matrix has integer-typed row and column indices and double-typed values, stored on a single -machine. MLlib supports dense matrix, whose entry values are stored in a single double array in +machine. MLlib supports dense matrices, whose entry values are stored in a single double array in column major. For example, the following matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ @@ -234,7 +233,6 @@ column major. For example, the following matrix `\[ \begin{pmatrix} \end{pmatrix} \]` is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the matrix size `(3, 2)`. -We are going to add sparse matrix in the next release.
@@ -242,7 +240,7 @@ We are going to add sparse matrix in the next release. The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local matrices. @@ -259,7 +257,7 @@ val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) The base class of local matrices is [`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local matrices. @@ -279,28 +277,30 @@ Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); A distributed matrix has long-typed row and column indices and double-typed values, stored distributively in one or more RDDs. It is very important to choose the right format to store large and distributed matrices. Converting a distributed matrix to a different format may require a -global shuffle, which is quite expensive. We implemented three types of distributed matrices in -this release and will add more types in the future. +global shuffle, which is quite expensive. Three types of distributed matrices have been implemented +so far. The basic type is called `RowMatrix`. A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, e.g., a collection of feature vectors. It is backed by an RDD of its rows, where each row is a local vector. -We assume that the number of columns is not huge for a `RowMatrix`. +We assume that the number of columns is not huge for a `RowMatrix` so that a single +local vector can be reasonably communicated to the driver and can also be stored / +operated on using a single node. An `IndexedRowMatrix` is similar to a `RowMatrix` but with row indices, -which can be used for identifying rows and joins. -A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix) format, +which can be used for identifying rows and executing joins. +A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_.28COO.29) format, backed by an RDD of its entries. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -It is always error-prone to have non-deterministic RDDs. +In general the use of non-deterministic RDDs can lead to errors. ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD -of its rows, where each row is a local vector. This is similar to `data matrix` in the context of -multivariate statistics. Since each row is represented by a local vector, the number of columns is +of its rows, where each row is a local vector. +Since each row is represented by a local vector, the number of columns is limited by the integer range but it should be much smaller in practice.
@@ -344,70 +344,10 @@ long n = mat.numCols();
-#### Multivariate summary statistics - -We provide column summary statistics for `RowMatrix`. -If the number of columns is not large, say, smaller than 3000, you can also compute -the covariance matrix as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the -number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, -which could be faster if the rows are sparse. - -
-
- -[`RowMatrix#computeColumnSummaryStatistics`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of -[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -val mat: RowMatrix = ... // a RowMatrix - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -// Compute the covariance matrix. -val cov: Matrix = mat.computeCovariance() -{% endhighlight %} -
- -
- -[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of -[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight java %} -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; - -RowMatrix mat = ... // a RowMatrix - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -// Compute the covariance matrix. -Matrix cov = mat.computeCovariance(); -{% endhighlight %} -
-
- ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, which each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector.
@@ -467,7 +407,7 @@ RowMatrix rowMat = mat.toRowMatrix(); A `CoordinateMatrix` is a distributed matrix backed by an RDD of its entries. Each entry is a tuple of `(i: Long, j: Long, value: Double)`, where `i` is the row index, `j` is the column index, and -`value` is the entry value. A `CoordinateMatrix` should be used only in the case when both +`value` is the entry value. A `CoordinateMatrix` should be used only when both dimensions of the matrix are huge and the matrix is very sparse.
@@ -477,9 +417,9 @@ A [`CoordinateMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.CoordinateMatrix) can be created from an `RDD[MatrixEntry]` instance, where [`MatrixEntry`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.MatrixEntry) is a -wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. In this release, we do not provide other -computation for `CoordinateMatrix`. +wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} @@ -503,8 +443,9 @@ A [`CoordinateMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.html) can be created from a `JavaRDD` instance, where [`MatrixEntry`](api/java/org/apache/spark/mllib/linalg/distributed/MatrixEntry.html) is a -wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. +wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight java %} import org.apache.spark.api.java.JavaRDD; diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md new file mode 100644 index 0000000000000..719cc95767b00 --- /dev/null +++ b/docs/mllib-classification-regression.md @@ -0,0 +1,37 @@ +--- +layout: global +title: Classification and Regression - MLlib +displayTitle: MLlib - Classification and Regression +--- + +MLlib supports various methods for +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), +[multiclass +classification](http://en.wikipedia.org/wiki/Multiclass_classification), and +[regression analysis](http://en.wikipedia.org/wiki/Regression_analysis). The table below outlines +the supported algorithms for each type of problem. + + + + + + + + + + + + + + + + +
Problem TypeSupported Methods
Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes
Multiclass Classificationdecision trees, naive Bayes
Regressionlinear least squares, Lasso, ridge regression, decision trees
+ +More details for these methods can be found here: + +* [Linear models](mllib-linear-methods.html) + * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) + * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) +* [Decision trees](mllib-decision-tree.html) +* [Naive Bayes](mllib-naive-bayes.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 561de48910132..dfd9cd572888c 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -38,7 +38,7 @@ a given dataset, the algorithm returns the best clustering result).
-Following code snippets can be executed in `spark-shell`. +The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data @@ -70,7 +70,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A standalone application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: {% highlight java %} import org.apache.spark.api.java.*; @@ -113,14 +113,15 @@ public class KMeansExample { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
-Following examples can be tested in the PySpark shell. +The following examples can be tested in the PySpark shell. In the following example after loading and parsing data, we use the KMeans object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0d28b5f7c89b3..ab10b2f01f87b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -14,13 +14,13 @@ is commonly used for recommender systems. These techniques aim to fill in the missing entries of a user-item association matrix. MLlib currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -In particular, we implement the [alternating least squares +MLlib uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) algorithm to learn these latent factors. The implementation in MLlib has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in our model. +* *rank* is the number of latent factors in the model. * *iterations* is the number of iterations to run. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for @@ -86,8 +86,8 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => println("Mean Squared Error = " + MSE) {% endhighlight %} -If the rating matrix is derived from other source of information (i.e., it is inferred from -other signals), you can use the trainImplicit method to get better results. +If the rating matrix is derived from another source of information (e.g., it is inferred from +other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 @@ -174,10 +174,11 @@ public class CollaborativeFiltering { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
@@ -219,5 +220,5 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) ## Tutorial -[AMP Camp](http://ampcamp.berkeley.edu/) provides a hands-on tutorial for -[personalized movie recommendation with MLlib](http://ampcamp.berkeley.edu/big-data-mini-course/movie-recommendation-with-mllib.html). +The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for +[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 8e434998c15ea..065d646496131 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -9,9 +9,9 @@ displayTitle: MLlib - Dimensionality Reduction [Dimensionality reduction](http://en.wikipedia.org/wiki/Dimensionality_reduction) is the process of reducing the number of variables under consideration. -It is used to extract latent features from raw and noisy features, +It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -In this release, we provide preliminary support for dimensionality reduction on tall-and-skinny matrices. +MLlib provides support for dimensionality reduction on tall-and-skinny matrices. ## Singular value decomposition (SVD) @@ -30,17 +30,17 @@ where * $V$ is an orthonormal matrix, whose columns are called right singular vectors. For large matrices, usually we don't need the complete factorization but only the top singular -values and its associated singular vectors. This can save storage, and more importantly, de-noise +values and its associated singular vectors. This can save storage, de-noise and recover the low-rank structure of the matrix. -If we keep the top $k$ singular values, then the dimensions of the return will be: +If we keep the top $k$ singular values, then the dimensions of the resulting low-rank matrix will be: * `$U$`: `$m \times k$`, * `$\Sigma$`: `$k \times k$`, * `$V$`: `$n \times k$`. -In this release, we provide SVD computation to row-oriented matrices that have only a few columns, -say, less than $1000$, but many rows, which we call *tall-and-skinny*. +MLlib provides SVD functionality to row-oriented matrices that have only a few columns, +say, less than $1000$, but many rows, i.e., *tall-and-skinny* matrices.
@@ -58,15 +58,10 @@ val s: Vector = svd.s // The singular values are stored in a local dense vector. val V: Matrix = svd.V // The V factor is a local dense matrix. {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
-In order to run the following standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. - {% highlight java %} import java.util.LinkedList; @@ -104,8 +99,16 @@ public class SVD { } } {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. + +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`. + +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. +
@@ -116,7 +119,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -In this release, we implement PCA for tall-and-skinny matrices stored in row-oriented format. +MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format.
@@ -180,9 +183,10 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md new file mode 100644 index 0000000000000..21453cb9cd8c9 --- /dev/null +++ b/docs/mllib-feature-extraction.md @@ -0,0 +1,12 @@ +--- +layout: global +title: Feature Extraction - MLlib +displayTitle: MLlib - Feature Extraction +--- + +* Table of contents +{:toc} + +## Word2Vec + +## TFIDF diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 95ee6bc96801f..23d5a0c4607af 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -3,18 +3,19 @@ layout: global title: Machine Learning Library (MLlib) --- -MLlib is a Spark implementation of some common machine learning algorithms and utilities, +MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives: +filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: -* [Basics](mllib-basics.html) - * data types +* [Data types](mllib-basics.html) +* [Basic statistics](mllib-stats.html) + * data generators + * stratified sampling * summary statistics -* Classification and regression - * [linear support vector machine (SVM)](mllib-linear-methods.html#linear-support-vector-machine-svm) - * [logistic regression](mllib-linear-methods.html#logistic-regression) - * [linear least squares, Lasso, and ridge regression](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) - * [decision tree](mllib-decision-tree.html) + * hypothesis testing +* [Classification and regression](mllib-classification-regression.html) + * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) + * [decision trees](mllib-decision-tree.html) * [naive Bayes](mllib-naive-bayes.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) @@ -23,17 +24,18 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) -* [Optimization](mllib-optimization.html) +* [Feature extraction and transformation](mllib-feature-extraction.html) +* [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) -MLlib is a new component under active development. +MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and we will provide migration guide between releases. +and the migration guide below will explain all changes between releases. # Dependencies -MLlib uses linear algebra packages [Breeze](http://www.scalanlp.org/), which depends on +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on [netlib-java](https://github.com/fommil/netlib-java), and [jblas](https://github.com/mikiobraun/jblas). `netlib-java` and `jblas` depend on native Fortran routines. @@ -56,7 +58,7 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. +take advantage of sparsity in both storage and computation. Details are described below.
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 254201147edc1..e504cd7f0f578 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -33,24 +33,24 @@ the task of finding a minimizer of a convex function `$f$` that depends on a var Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} - f(\wv) := - \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) + - \lambda\, R(\wv_i) + f(\wv) := \lambda\, R(\wv) + + \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) \label{eq:regPrimal} \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several MLlib's classification and regression algorithms fall into this category, +Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: -the loss that measures the error of the model on the training data, -and the regularizer that measures the complexity of the model. -The loss function `$L(\wv;.)$` must be a convex function in `$\wv$`. -The fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) defines the trade-off -between the two goals of small loss and small model complexity. +the regularizer that controls the complexity of the model, +and the loss that measures the error of the model on the training data. +The loss function `$L(\wv;.)$` is typically a convex function in `$\wv$`. The +fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) +defines the trade-off between the two goals of minimizing the loss (i.e., +training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions @@ -80,10 +80,10 @@ methods MLlib supports: ### Regularizers -The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to -encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid -over-fitting. -We support the following regularizers in MLlib: +The purpose of the +[regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to +encourage simple models and avoid overfitting. We support the following +regularizers in MLlib: @@ -106,27 +106,28 @@ Here `$\mathrm{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. -However, L1 regularization can help promote sparsity in weights, leading to simpler models, which is -also used for feature selection. It is not recommended to train models without any regularization, +However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. +It is not recommended to train models without any regularization, especially when the number of training examples is small. ## Binary classification -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) is to divide items into -two categories: positive and negative. MLlib supports two linear methods for binary classification: -linear support vector machine (SVM) and logistic regression. The training data set is represented -by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the mathematical -formulation, a training label $y$ is either $+1$ (positive) or $-1$ (negative), which is convenient -for the formulation. *However*, the negative label is represented by $0$ in MLlib instead of $-1$, -to be consistent with multiclass labeling. +[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) +aims to divide items into two categories: positive and negative. MLlib +supports two linear methods for binary classification: linear support vector +machines (SVMs) and logistic regression. For both methods, MLlib supports +L1 and L2 regularized variants. The training data set is represented by an RDD +of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the +mathematical formulation in this guide, a training label $y$ is denoted as +either $+1$ (positive) or $-1$ (negative), which is convenient for the +formulation. *However*, the negative label is represented by $0$ in MLlib +instead of $-1$, to be consistent with multiclass labeling. -### Linear support vector machine (SVM) +### Linear support vector machines (SVMs) The [linear SVM](http://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) -has become a standard choice for large-scale classification tasks. -The name "linear SVM" is actually ambiguous. -By "linear SVM", we mean specifically the linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the hinge loss +is a standard method for large-scale classification tasks. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the hinge loss: + `\[ L(\wv;\x,y) := \max \{0, 1-y \wv^T \x \}. \]` @@ -134,39 +135,44 @@ By default, linear SVMs are trained with an L2 regularization. We also support alternative L1 regularization. In this case, the problem becomes a [linear program](http://en.wikipedia.org/wiki/Linear_programming). -Linear SVM algorithm outputs a SVM model, which makes predictions based on the value of $\wv^T \x$. -By the default, if $\wv^T \x \geq 0$, the outcome is positive, or negative otherwise. -However, quite often in practice, the default threshold $0$ is not a good choice. -The threshold should be determined via model evaluation. +The linear SVMs algorithm outputs an SVM model. Given a new data point, +denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. +By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative +otherwise. ### Logistic regression [Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. It is a linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the logistic loss +binary response. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the logistic loss: `\[ L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). \]` -Logistic regression algorithm outputs a logistic regression model, which makes predictions by +The logistic regression algorithm outputs a logistic regression model. Given a +new data point, denoted by $\x$, the model makes predictions by applying the logistic function `\[ \mathrm{f}(z) = \frac{1}{1 + e^{-z}} \]` where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. -For the same reason mentioned above, quite often in practice, this default threshold is not a good choice. -The threshold should be determined via model evaluation. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). ### Evaluation metrics -MLlib supports common evaluation metrics for binary classification (not available in Python). This +MLlib supports common evaluation metrics for binary classification (not available in PySpark). +This includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), [receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), precision-recall curve, and [area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -Among the metrics, area under ROC is commonly used to compare models and precision/recall/F-measure -can help determine the threshold to use. +AUC is commonly used to compare the performance of various models while +precision/recall/F-measure can help determine the appropriate threshold to use +for prediction purposes. ### Examples @@ -233,8 +239,7 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -Similarly, you can use replace `SVMWithSGD` by -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD). +[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. @@ -318,10 +323,11 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
@@ -354,24 +360,22 @@ print("Training Error = " + str(trainErr)) ## Linear least squares, Lasso, and ridge regression -Linear least squares is a family of linear methods with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the squared loss +Linear least squares is the most common formulation for regression problems. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the squared loss: `\[ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` -Depending on the regularization type, we call the method -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or simply -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) if there -is no regularization, [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) if L2 -regularization is used, and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) if L1 -regularization is used. This average loss $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$ is also +Various related regression methods are derived by using different types of regularization: +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses + no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 +regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 +regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -Note that the squared loss is sensitive to outliers. -Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually necessary in practice. - ### Examples
@@ -379,7 +383,7 @@ Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually ne
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} @@ -407,9 +411,8 @@ val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) {% endhighlight %} -Similarly you can use [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) -and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD). +and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`.
@@ -479,16 +482,17 @@ public class LinearRegression { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight python %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index b1650c83c98b9..86d94aebd9442 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -4,23 +4,23 @@ title: Naive Bayes - MLlib displayTitle: MLlib - Naive Bayes --- -Naive Bayes is a simple multiclass classification algorithm with the assumption of independence -between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to -the training data, it computes the conditional probability distribution of each feature given label, -and then it applies Bayes' theorem to compute the conditional probability distribution of label -given an observation and use it for prediction. For more details, please visit the Wikipedia page -[Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier). - -In MLlib, we implemented multinomial naive Bayes, which is typically used for document -classification. Within that context, each observation is a document, each feature represents a term, -whose value is the frequency of the term. For its formulation, please visit the Wikipedia page -[Multinomial Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -or the section -[Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) -from the book Introduction to Information -Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple +multiclass classification algorithm with the assumption of independence between +every pair of features. Naive Bayes can be trained very efficiently. Within a +single pass to the training data, it computes the conditional probability +distribution of each feature given label, and then it applies Bayes' theorem to +compute the conditional probability distribution of label given an observation +and use it for prediction. + +MLlib supports [multinomial naive +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), +which is typically used for [document +classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each +feature represents a term whose value is the frequency of the term. +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature -vectors are usually sparse. Please supply sparse vectors as input to take advantage of +vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of sparsity. Since the training data is only used once, it is not necessary to cache it. ## Examples diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md new file mode 100644 index 0000000000000..ca9ef46c15186 --- /dev/null +++ b/docs/mllib-stats.md @@ -0,0 +1,95 @@ +--- +layout: global +title: Statistics Functionality - MLlib +displayTitle: MLlib - Statistics Functionality +--- + +* Table of contents +{:toc} + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +## Data Generators + +## Stratified Sampling + +## Summary Statistics + +### Multivariate summary statistics + +We provide column summary statistics for `RowMatrix` (note: this functionality is not currently supported in `IndexedRowMatrix` or `CoordinateMatrix`). +If the number of columns is not large, e.g., on the order of thousands, then the +covariance matrix can also be computed as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the +number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, +and is faster if the rows are sparse. + +
+
+ +[`computeColumnSummaryStatistics()`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of +[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +val mat: RowMatrix = ... // a RowMatrix + +// Compute column summary statistics. +val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() +println(summary.mean) // a dense vector containing the mean value for each column +println(summary.variance) // column-wise variance +println(summary.numNonzeros) // number of nonzeros in each column + +// Compute the covariance matrix. +val cov: Matrix = mat.computeCovariance() +{% endhighlight %} +
+ +
+ +[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of +[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight java %} +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; + +RowMatrix mat = ... // a RowMatrix + +// Compute column summary statistics. +MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); +System.out.println(summary.mean()); // a dense vector containing the mean value for each column +System.out.println(summary.variance()); // column-wise variance +System.out.println(summary.numNonzeros()); // number of nonzeros in each column + +// Compute the covariance matrix. +Matrix cov = mat.computeCovariance(); +{% endhighlight %} +
+
+ + +## Hypothesis Testing diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py new file mode 100644 index 0000000000000..e902ae29753c0 --- /dev/null +++ b/examples/src/main/python/avro_inputformat.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.avro in local Spark distro: + +$ cd $SPARK_HOME +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} + +To read name and favorite_color fields only, specify the following reader schema: + +$ cat examples/src/main/resources/user.avsc +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} + +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro examples/src/main/resources/user.avsc +{u'favorite_color': None, u'name': u'Alyssa'} +{u'favorite_color': u'red', u'name': u'Ben'} +""" +if __name__ == "__main__": + if len(sys.argv) != 2 and len(sys.argv) != 3: + print >> sys.stderr, """ + Usage: avro_inputformat [reader_schema_file] + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="AvroKeyInputFormat") + + conf = None + if len(sys.argv) == 3: + schema_rdd = sc.textFile(sys.argv[2], 1).collect() + conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + + avro_rdd = sc.newAPIHadoopFile(path, + "org.apache.avro.mapreduce.AvroKeyInputFormat", + "org.apache.avro.mapred.AvroKey", + "org.apache.hadoop.io.NullWritable", + keyConverter="org.apache.spark.examples.pythonconverters.AvroWrapperToJavaConverter", + conf=conf) + output = avro_rdd.map(lambda x: x[0]).collect() + for k in output: + print k diff --git a/examples/src/main/resources/user.avsc b/examples/src/main/resources/user.avsc new file mode 100644 index 0000000000000..4995357ab3736 --- /dev/null +++ b/examples/src/main/resources/user.avsc @@ -0,0 +1,8 @@ +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/examples/src/main/resources/users.avro b/examples/src/main/resources/users.avro new file mode 100644 index 0000000000000..27c526ab114b2 Binary files /dev/null and b/examples/src/main/resources/users.avro differ diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala new file mode 100644 index 0000000000000..1b25983a38453 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import java.util.{Collection => JCollection, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.avro.generic.{GenericFixed, IndexedRecord} +import org.apache.avro.mapred.AvroWrapper +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ + +import org.apache.spark.api.python.Converter +import org.apache.spark.SparkException + + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { + if (obj == null) { + return null + } + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") + } + } + + def unpackRecord(obj: Any): JMap[String, Any] = { + val map = new java.util.HashMap[String, Any] + obj match { + case record: IndexedRecord => + record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + map.put(f.name, fromAvro(record.get(i), f.schema)) + } + case other => throw new SparkException( + s"Unsupported RECORD type ${other.getClass.getName}") + } + map + } + + def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { + obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + (key.toString, fromAvro(value, schema.getValueType)) + } + } + + def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { + unpackBytes(obj.asInstanceOf[GenericFixed].bytes()) + } + + def unpackBytes(obj: Any): Array[Byte] = { + val bytes: Array[Byte] = obj match { + case buf: java.nio.ByteBuffer => buf.array() + case arr: Array[Byte] => arr + case other => throw new SparkException( + s"Unknown BYTES type ${other.getClass.getName}") + } + val bytearray = new Array[Byte](bytes.length) + System.arraycopy(bytes, 0, bytearray, 0, bytes.length) + bytearray + } + + def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { + case c: JCollection[_] => + c.map(fromAvro(_, schema.getElementType)) + case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => + arr.toSeq + case arr: Array[_] => + arr.map(fromAvro(_, schema.getElementType)).toSeq + case other => throw new SparkException( + s"Unknown ARRAY type ${other.getClass.getName}") + } + + def unpackUnion(obj: Any, schema: Schema): Any = { + schema.getTypes.toList match { + case List(s) => fromAvro(obj, s) + case List(n, s) if n.getType == NULL => fromAvro(obj, s) + case List(s, n) if n.getType == NULL => fromAvro(obj, s) + case _ => throw new SparkException( + "Unions may only consist of a concrete type and null") + } + } + + def fromAvro(obj: Any, schema: Schema): Any = { + if (obj == null) { + return null + } + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ba7ccd8ce4b8b..18dc087856785 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -48,182 +48,7 @@ import org.apache.spark.util.Utils */ @DeveloperApi class PythonMLLibAPI extends Serializable { - private val DENSE_VECTOR_MAGIC: Byte = 1 - private val SPARSE_VECTOR_MAGIC: Byte = 2 - private val DENSE_MATRIX_MAGIC: Byte = 3 - private val LABELED_POINT_MAGIC: Byte = 4 - - private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { - require(bytes.length - offset >= 5, "Byte array too short") - val magic = bytes(offset) - if (magic == DENSE_VECTOR_MAGIC) { - deserializeDenseVector(bytes, offset) - } else if (magic == SPARSE_VECTOR_MAGIC) { - deserializeSparseVector(bytes, offset) - } else { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - } - - private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { - require(bytes.length - offset == 8, "Wrong size byte array for Double") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - bb.getDouble - } - private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 5, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val length = bb.getInt() - require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - Vectors.dense(ans) - } - - private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 9, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val size = bb.getInt() - val nonZeros = bb.getInt() - require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) - val ib = bb.asIntBuffer() - val indices = new Array[Int](nonZeros) - ib.get(indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - val values = new Array[Double](nonZeros) - db.get(values) - Vectors.sparse(size, indices, values) - } - - /** - * Returns an 8-byte array for the input Double. - * - * Note: we currently do not use a magic byte for double for storage efficiency. - * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. - * The corresponding deserializer, deserializeDouble, needs to be modified as well if the - * serialization scheme changes. - */ - private[python] def serializeDouble(double: Double): Array[Byte] = { - val bytes = new Array[Byte](8) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putDouble(double) - bytes - } - - private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](5 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_VECTOR_MAGIC) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - bytes - } - - private def serializeSparseVector(vector: SparseVector): Array[Byte] = { - val nonZeros = vector.indices.length - val bytes = new Array[Byte](9 + 12 * nonZeros) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(SPARSE_VECTOR_MAGIC) - bb.putInt(vector.size) - bb.putInt(nonZeros) - val ib = bb.asIntBuffer() - ib.put(vector.indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - db.put(vector.values) - bytes - } - - private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { - case s: SparseVector => - serializeSparseVector(s) - case _ => - serializeDenseVector(vector.toArray) - } - - private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 9) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - if (magic != DENSE_MATRIX_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val rows = bb.getInt() - val cols = bb.getInt() - if (packetLength != 9 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) - } - ans - } - - private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length - } - val bytes = new Array[Byte](9 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_MATRIX_MAGIC) - bb.putInt(rows) - bb.putInt(cols) - val db = bb.asDoubleBuffer() - for (i <- 0 until rows) { - db.put(doubles(i)) - } - bytes - } - - private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { - val fb = serializeDoubleVector(p.features) - val bytes = new Array[Byte](1 + 8 + fb.length) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(LABELED_POINT_MAGIC) - bb.putDouble(p.label) - bb.put(fb) - bytes - } - - private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { - require(bytes.length >= 9, "Byte array too short") - val magic = bytes(0) - if (magic != LABELED_POINT_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val labelBytes = ByteBuffer.wrap(bytes, 1, 8) - labelBytes.order(ByteOrder.nativeOrder()) - val label = labelBytes.asDoubleBuffer().get(0) - LabeledPoint(label, deserializeDoubleVector(bytes, 9)) - } /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -236,17 +61,17 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val initialWeights = deserializeDoubleVector(initialWeightsBA) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) + val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(SerDe.serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -405,12 +230,12 @@ class PythonMLLibAPI extends Serializable { def trainNaiveBayes( dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(Vectors.dense(model.labels))) - ret.add(serializeDoubleVector(Vectors.dense(model.pi))) - ret.add(serializeDoubleMatrix(model.theta)) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels))) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi))) + ret.add(SerDe.serializeDoubleMatrix(model.theta)) ret } @@ -423,52 +248,13 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes)) + val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) + ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) ret } - /** Unpack a Rating object from an array of bytes */ - private def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - new Rating(user, product, rating) - } - - /** Unpack a tuple of Ints from an array of bytes */ - private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { - val bb = ByteBuffer.wrap(tupleBytes) - bb.order(ByteOrder.nativeOrder()) - val v1 = bb.getInt() - val v2 = bb.getInt() - (v1, v2) - } - - /** - * Serialize a Rating object into an array of bytes. - * It can be deserialized using RatingDeserializer(). - * - * @param rate the Rating object to serialize - * @return - */ - private[spark] def serializeRating(rate: Rating): Array[Byte] = { - val len = 3 - val bytes = new Array[Byte](4 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(rate.user.toDouble) - db.put(rate.product.toDouble) - db.put(rate.rating) - bytes - } - /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -481,7 +267,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -498,7 +284,7 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } @@ -519,7 +305,7 @@ class PythonMLLibAPI extends Serializable { maxDepth: Int, maxBins: Int): DecisionTreeModel = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val algo = Algo.fromString(algoStr) val impurity = Impurities.fromString(impurityStr) @@ -545,7 +331,7 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, featuresBytes: Array[Byte]): Double = { - val features: Vector = deserializeDoubleVector(featuresBytes) + val features: Vector = SerDe.deserializeDoubleVector(featuresBytes) model.predict(features) } @@ -559,8 +345,17 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - model.predict(data).map(serializeDouble) + val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes)) + model.predict(data).map(SerDe.serializeDouble) + } + + /** + * Java stub for mllib Statistics.colStats(X: RDD[Vector]). + * TODO figure out return type. + */ + def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = { + val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_))) + new MultivariateStatisticalSummarySerialized(cStats) } /** @@ -569,17 +364,17 @@ class PythonMLLibAPI extends Serializable { * pyspark. */ def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { - val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_)) val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) - serializeDoubleMatrix(to2dArray(result)) + SerDe.serializeDoubleMatrix(SerDe.to2dArray(result)) } /** * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). */ def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { - val xDeser = x.rdd.map(deserializeDouble(_)) - val yDeser = y.rdd.map(deserializeDouble(_)) + val xDeser = x.rdd.map(SerDe.deserializeDouble(_)) + val yDeser = y.rdd.map(SerDe.deserializeDouble(_)) Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) } @@ -588,12 +383,6 @@ class PythonMLLibAPI extends Serializable { if (method == null) CorrelationNames.defaultCorrName else method } - // Reformat a Matrix into Array[Array[Double]] for serialization - private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { - val values = matrix.toArray - Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) - } - // Used by the *RDD methods to get default seed if not passed in from pyspark private def getSeedOrDefault(seed: java.lang.Long): Long = { if (seed == null) Utils.random.nextLong else seed @@ -621,7 +410,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -633,7 +422,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -646,7 +435,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble) } /** @@ -659,7 +448,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -672,7 +461,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -686,7 +475,256 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + } + +} + +/** + * :: DeveloperApi :: + * MultivariateStatisticalSummary with Vector fields serialized. + */ +@DeveloperApi +class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary) + extends Serializable { + + def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean) + + def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance) + + def count: Long = summary.count + + def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros) + + def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max) + + def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min) +} + +/** + * SerDe utility functions for PythonMLLibAPI. + */ +private[spark] object SerDe extends Serializable { + private val DENSE_VECTOR_MAGIC: Byte = 1 + private val SPARSE_VECTOR_MAGIC: Byte = 2 + private val DENSE_MATRIX_MAGIC: Byte = 3 + private val LABELED_POINT_MAGIC: Byte = 4 + + private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + require(bytes.length - offset >= 5, "Byte array too short") + val magic = bytes(offset) + if (magic == DENSE_VECTOR_MAGIC) { + deserializeDenseVector(bytes, offset) + } else if (magic == SPARSE_VECTOR_MAGIC) { + deserializeSparseVector(bytes, offset) + } else { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } } + private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { + require(bytes.length - offset == 8, "Wrong size byte array for Double") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + bb.getDouble + } + + private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 5, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val length = bb.getInt() + require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + Vectors.dense(ans) + } + + private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 9, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val size = bb.getInt() + val nonZeros = bb.getInt() + require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) + val ib = bb.asIntBuffer() + val indices = new Array[Int](nonZeros) + ib.get(indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + val values = new Array[Double](nonZeros) + db.get(values) + Vectors.sparse(size, indices, values) + } + + /** + * Returns an 8-byte array for the input Double. + * + * Note: we currently do not use a magic byte for double for storage efficiency. + * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. + * The corresponding deserializer, deserializeDouble, needs to be modified as well if the + * serialization scheme changes. + */ + private[python] def serializeDouble(double: Double): Array[Byte] = { + val bytes = new Array[Byte](8) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(double) + bytes + } + + private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](5 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_VECTOR_MAGIC) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + bytes + } + + private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = { + val nonZeros = vector.indices.length + val bytes = new Array[Byte](9 + 12 * nonZeros) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(SPARSE_VECTOR_MAGIC) + bb.putInt(vector.size) + bb.putInt(nonZeros) + val ib = bb.asIntBuffer() + ib.put(vector.indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + db.put(vector.values) + bytes + } + + private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + case s: SparseVector => + serializeSparseVector(s) + case _ => + serializeDenseVector(vector.toArray) + } + + private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 9) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + if (magic != DENSE_MATRIX_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getInt() + val cols = bb.getInt() + if (packetLength != 9 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + ans + } + + private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](9 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_MATRIX_MAGIC) + bb.putInt(rows) + bb.putInt(cols) + val db = bb.asDoubleBuffer() + for (i <- 0 until rows) { + db.put(doubles(i)) + } + bytes + } + + private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { + val fb = serializeDoubleVector(p.features) + val bytes = new Array[Byte](1 + 8 + fb.length) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(LABELED_POINT_MAGIC) + bb.putDouble(p.label) + bb.put(fb) + bytes + } + + private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + require(bytes.length >= 9, "Byte array too short") + val magic = bytes(0) + if (magic != LABELED_POINT_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val labelBytes = ByteBuffer.wrap(bytes, 1, 8) + labelBytes.order(ByteOrder.nativeOrder()) + val label = labelBytes.asDoubleBuffer().get(0) + LabeledPoint(label, deserializeDoubleVector(bytes, 9)) + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + + /** Unpack a Rating object from an array of bytes */ + private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + new Rating(user, product, rating) + } + + /** Unpack a tuple of Ints from an array of bytes */ + def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + /** + * Serialize a Rating object into an array of bytes. + * It can be deserialized using RatingDeserializer(). + * + * @param rate the Rating object to serialize + * @return + */ + def serializeRating(rate: Rating): Array[Byte] = { + val len = 3 + val bytes = new Array[Byte](4 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(rate.user.toDouble) + db.put(rate.product.toDouble) + db.put(rate.rating) + bytes + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2242329b7918e..6790c86f651b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,7 +62,7 @@ class LogisticRegressionModel ( override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0/ (1.0 + math.exp(-margin)) + val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score < t) 0.0 else 1.0 case None => score @@ -101,7 +101,7 @@ class LogisticRegressionWithSGD private ( } /** - * Top-level methods for calling Logistic Regression. + * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ object LogisticRegressionWithSGD { @@ -188,3 +188,54 @@ object LogisticRegressionWithSGD { train(input, numIterations, 1.0, 1.0) } } + +/** + * Train a classification model for Logistic Regression using Limited-memory BFGS. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithLBFGS private ( + private var convergenceTol: Double, + private var maxNumIterations: Int, + private var regParam: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { + + /** + * Construct a LogisticRegression object with default parameters + */ + def this() = this(1E-4, 100, 0.0) + + this.setFeatureScaling(true) + + private val gradient = new LogisticGradient() + private val updater = new SimpleUpdater() + // Have to return new LBFGS object every time since users can reset the parameters anytime. + override def optimizer = new LBFGS(gradient, updater) + .setNumCorrections(10) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + override protected val validators = List(DataValidators.binaryLabelValidator) + + /** + * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. + * Smaller value will lead to higher accuracy with the cost of more iterations. + */ + def setConvergenceTol(convergenceTol: Double): this.type = { + this.convergenceTol = convergenceTol + this + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(numIterations: Int): this.type = { + this.maxNumIterations = numIterations + this + } + + override protected def createModel(weights: Vector, intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 395037e1ec47c..ecd49ea2ff533 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging { private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging { private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] - private var alpha = startingAlpha private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) @@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) - var syn1Global = new Array[Float](vocabSize * layer1Size) + Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + var syn1Global = new Array[Float](vocabSize * vectorSize) + var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging { val c = pos - window + a if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Float](layer1Size) + val l1 = lastWord * vectorSize + val neu1e = new Array[Float](vectorSize) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * layer1Size + val l2 = bcVocab.value(word).point(d) * vectorSize // Propagate hidden -> output - var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) } d += 1 } - blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) } } a += 1 @@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging { var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Float](layer1Size) - Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + val vector = new Array[Float](vectorSize) + Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala new file mode 100644 index 0000000000000..70e23033c8754 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} + +/** + * BLAS routines for MLlib's vectors and matrices. + */ +private[mllib] object BLAS extends Serializable { + + @transient private var _f2jBLAS: NetlibBLAS = _ + + // For level-1 routines, we use Java implementation. + private def f2jBLAS: NetlibBLAS = { + if (_f2jBLAS == null) { + _f2jBLAS = new F2jBLAS + } + _f2jBLAS + } + + /** + * y += a * x + */ + def axpy(a: Double, x: Vector, y: Vector): Unit = { + require(x.size == y.size) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + axpy(a, sx, dy) + case dx: DenseVector => + axpy(a, dx, dy) + case _ => + throw new UnsupportedOperationException( + s"axpy doesn't support x type ${x.getClass}.") + } + case _ => + throw new IllegalArgumentException( + s"axpy only supports adding to a dense vector but got type ${y.getClass}.") + } + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { + val n = x.size + f2jBLAS.daxpy(n, a, x.values, 1, y.values, 1) + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { + val nnz = x.indices.size + if (a == 1.0) { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += x.values(k) + k += 1 + } + } else { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += a * x.values(k) + k += 1 + } + } + } + + /** + * dot(x, y) + */ + def dot(x: Vector, y: Vector): Double = { + require(x.size == y.size) + (x, y) match { + case (dx: DenseVector, dy: DenseVector) => + dot(dx, dy) + case (sx: SparseVector, dy: DenseVector) => + dot(sx, dy) + case (dx: DenseVector, sy: SparseVector) => + dot(sy, dx) + case (sx: SparseVector, sy: SparseVector) => + dot(sx, sy) + case _ => + throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") + } + } + + /** + * dot(x, y) + */ + private def dot(x: DenseVector, y: DenseVector): Double = { + val n = x.size + f2jBLAS.ddot(n, x.values, 1, y.values, 1) + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: DenseVector): Double = { + val nnz = x.indices.size + var sum = 0.0 + var k = 0 + while (k < nnz) { + sum += x.values(k) * y.values(x.indices(k)) + k += 1 + } + sum + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: SparseVector): Double = { + var kx = 0 + val nnzx = x.indices.size + var ky = 0 + val nnzy = y.indices.size + var sum = 0.0 + // y catching x + while (kx < nnzx && ky < nnzy) { + val ix = x.indices(kx) + while (ky < nnzy && y.indices(ky) < ix) { + ky += 1 + } + if (ky < nnzy && y.indices(ky) == ix) { + sum += x.values(kx) * y.values(ky) + ky += 1 + } + kx += 1 + } + sum + } + + /** + * y = x + */ + def copy(x: Vector, y: Vector): Unit = { + val n = y.size + require(x.size == n) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + var i = 0 + var k = 0 + val nnz = sx.indices.size + while (k < nnz) { + val j = sx.indices(k) + while (i < j) { + dy.values(i) = 0.0 + i += 1 + } + dy.values(i) = sx.values(k) + i += 1 + k += 1 + } + while (i < n) { + dy.values(i) = 0.0 + i += 1 + } + case dx: DenseVector => + Array.copy(dx.values, 0, dy.values, 0, n) + } + case _ => + throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") + } + } + + /** + * x = a * x + */ + def scal(a: Double, x: Vector): Unit = { + x match { + case sx: SparseVector => + f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + case dx: DenseVector => + f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + case _ => + throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 77b3e8c714997..a45781d12e41e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.linalg import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} -import java.util.Arrays +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -30,6 +30,8 @@ import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. + * + * Note: Users should not implement this interface. */ trait Vector extends Serializable { @@ -46,12 +48,12 @@ trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { case v: Vector => - Arrays.equals(this.toArray, v.toArray) + util.Arrays.equals(this.toArray, v.toArray) case _ => false } } - override def hashCode(): Int = Arrays.hashCode(this.toArray) + override def hashCode(): Int = util.Arrays.hashCode(this.toArray) /** * Converts the instance to a breeze vector. @@ -63,6 +65,13 @@ trait Vector extends Serializable { * @param i index */ def apply(i: Int): Double = toBreeze(i) + + /** + * Makes a deep copy of this vector. + */ + def copy: Vector = { + throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + } } /** @@ -127,6 +136,16 @@ object Vectors { }.toSeq) } + /** + * Creates a dense vector of all zeros. + * + * @param size vector size + * @return a zero vector + */ + def zeros(size: Int): Vector = { + new DenseVector(new Array[Double](size)) + } + /** * Parses a string resulted from `Vector#toString` into * an [[org.apache.spark.mllib.linalg.Vector]]. @@ -142,7 +161,7 @@ object Vectors { case Seq(size: Double, indices: Array[Double], values: Array[Double]) => Vectors.sparse(size.toInt, indices.map(_.toInt), values) case other => - throw new SparkException(s"Cannot parse $other.") + throw new SparkException(s"Cannot parse $other.") } } @@ -183,6 +202,10 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int) = values(i) + + override def copy: DenseVector = { + new DenseVector(values.clone()) + } } /** @@ -213,5 +236,9 @@ class SparseVector( data } + override def copy: SparseVector = { + new SparseVector(size, indices.clone(), values.clone()) + } + private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 9d82f011e674a..fdd67160114ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import breeze.linalg.{axpy => brzAxpy} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} /** * :: DeveloperApi :: @@ -61,11 +60,10 @@ abstract class Gradient extends Serializable { @DeveloperApi class LogisticGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = brzData * gradientMultiplier + val gradient = data.copy + scal(gradientMultiplier, gradient) val loss = if (label > 0) { math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p @@ -73,7 +71,7 @@ class LogisticGradient extends Gradient { math.log1p(math.exp(margin)) - margin } - (Vectors.fromBreeze(gradient), loss) + (gradient, loss) } override def compute( @@ -81,13 +79,9 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) - + axpy(gradientMultiplier, data, cumGradient) if (label > 0) { math.log1p(math.exp(margin)) } else { @@ -106,13 +100,11 @@ class LogisticGradient extends Gradient { @DeveloperApi class LeastSquaresGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label + val diff = dot(data, weights) - label val loss = diff * diff - val gradient = brzData * (2.0 * diff) - - (Vectors.fromBreeze(gradient), loss) + val gradient = data.copy + scal(2.0 * diff, gradient) + (gradient, loss) } override def compute( @@ -120,12 +112,8 @@ class LeastSquaresGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label - - brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) - + val diff = dot(data, weights) - label + axpy(2.0 * diff, data, cumGradient) diff * diff } } @@ -139,18 +127,16 @@ class LeastSquaresGradient extends Gradient { @DeveloperApi class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + val gradient = data.copy + scal(-labelScaled, gradient) + (gradient, 1.0 - labelScaled * dotProduct) } else { - (Vectors.dense(new Array[Double](weights.size)), 0.0) + (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0) } } @@ -159,16 +145,12 @@ class HingeGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + axpy(-labelScaled, data, cumGradient) 1.0 - labelScaled * dotProduct } else { 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 26a2b62e76ed0..033fe44f34f3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -192,31 +193,29 @@ object LBFGS extends Logging { regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { - private var i = 0 - - override def calculate(weights: BDV[Double]) = { + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. + val w = Vectors.fromBreeze(weights) + val n = w.size + val bcW = data.context.broadcast(w) val localGradient = gradient - val n = weights.length - val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( - features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) + features, label, bcW.value, grad) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) }) /** * regVal is sum of weight squares if it's L2 updater; * for other updater, the same logic is followed. */ - val regVal = updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2 val loss = lossSum / numExamples + regVal /** @@ -236,17 +235,13 @@ object LBFGS extends Logging { */ // The following gradientTotal is actually the regularization part of gradient. // Will add the gradientSum computed from the data with weights in the next step. - val gradientTotal = weights - updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze + val gradientTotal = w.copy + axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal) // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) - i += 1 - - (loss, gradientTotal) + (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]]) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index c8db3910c6eab..910eff9540a47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -105,16 +105,16 @@ private[mllib] object RandomRDD { def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(generator.nextValue()).toIterator + Iterator.fill(partition.size)(generator.nextValue()) } // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition[Double], - vectorSize: Int): Iterator[Vector] = { + def getVectorIterator( + partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(new DenseVector( - (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator + Iterator.fill(partition.size)(new DenseVector(Array.fill(vectorSize)(generator.nextValue()))) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8ebc7e27ed4dd..84d192db53e26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -111,11 +111,17 @@ class ALS private ( */ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) + /** If true, do alternating nonnegative least squares. */ + private var nonnegative = false + + /** storage level for user/product in/out links */ + private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ - def setBlocks(numBlocks: Int): ALS = { + def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -124,7 +130,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ - def setUserBlocks(numUserBlocks: Int): ALS = { + def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this } @@ -132,31 +138,31 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ - def setProductBlocks(numProductBlocks: Int): ALS = { + def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { + def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { + def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { + def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ - def setImplicitPrefs(implicitPrefs: Boolean): ALS = { + def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this } @@ -166,29 +172,38 @@ class ALS private ( * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ @Experimental - def setAlpha(alpha: Double): ALS = { + def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ - def setSeed(seed: Long): ALS = { + def setSeed(seed: Long): this.type = { this.seed = seed this } - /** If true, do alternating nonnegative least squares. */ - private var nonnegative = false - /** * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ - def setNonnegative(b: Boolean): ALS = { + def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this } + /** + * :: DeveloperApi :: + * Sets storage level for intermediate RDDs (user/product in/out links). The default value is + * `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and + * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. + */ + @DeveloperApi + def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + this.intermediateRDDStorageLevel = storageLevel + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -441,8 +456,8 @@ class ALS private ( }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) - inLinks.persist(StorageLevel.MEMORY_AND_DISK) - outLinks.persist(StorageLevel.MEMORY_AND_DISK) + inLinks.persist(intermediateRDDStorageLevel) + outLinks.persist(intermediateRDDStorageLevel) (inLinks, outLinks) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index a1a76fcbe9f9c..478c6485052b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.PythonMLLibAPI +import org.apache.spark.mllib.api.python.SerDe /** * Model representing the result of matrix factorization. @@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] ( */ @DeveloperApi def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val pythonAPI = new PythonMLLibAPI() - val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) - predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes)) + predict(usersProducts).map(rate => SerDe.serializeRating(rate)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 54854252d7477..20c1fdd2269ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ @@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * Whether to perform feature scaling before model training to reduce the condition numbers + * which can significantly help the optimizer converging faster. The scaling correction will be + * translated back to resulting model weights, so it's transparent to users. + * Note: This technique is used in both libsvm and glmnet packages. Default false. + */ + private var useFeatureScaling = false + + /** + * Set if the algorithm should use feature scaling to improve the convergence during optimization. + */ + private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = { + this.useFeatureScaling = useFeatureScaling + this + } + /** * Create a model given the weights and intercept */ @@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } + /** + * Scaling columns to unit variance as a heuristic to reduce the condition number: + * + * During the optimization process, the convergence (rate) depends on the condition number of + * the training dataset. Scaling the variables often reduces this condition number + * heuristically, thus improving the convergence rate. Without reducing the condition number, + * some training datasets mixing the columns with different scales may not be able to converge. + * + * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return + * the weights in the original scale. + * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing + * the variance of each column (without subtracting the mean), and train the model in the + * scaled space. Then we transform the coefficients from the scaled space to the original scale + * as GLMNET and LIBSVM do. + * + * Currently, it's only enabled in LogisticRegressionWithLBFGS + */ + val scaler = if (useFeatureScaling) { + (new StandardScaler).fit(input.map(x => x.features)) + } else { + null + } + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + if(useFeatureScaling) { + input.map(labeledPoint => + (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) + } else { + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } } val initialWeightsWithIntercept = if (addIntercept) { @@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - val weights = + var weights = if (addIntercept) { Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } + /** + * The weights and intercept are trained in the scaled space; we're converting them back to + * the original scale. + * + * Math shows that if we only perform standardization without subtracting means, the intercept + * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i + * is the coefficient in the original space, and v_i is the variance of the column i. + */ + if (useFeatureScaling) { + weights = scaler.transform(weights) + } + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f416a9fbb323d..3cf1028fbc725 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,8 +18,11 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** @@ -28,6 +31,18 @@ import org.apache.spark.rdd.RDD @Experimental object Statistics { + /** + * :: Experimental :: + * Computes column-wise summary statistics for the input RDD[Vector]. + * + * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. + * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + */ + @Experimental + def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { + new RowMatrix(X).computeColumnSummaryStatistics() + } + /** * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. @@ -89,4 +104,66 @@ object Statistics { */ @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the + * expected distribution. + * + * Note: the two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @param expected Vector containing the expected categorical counts/relative frequencies. + * `expected` is rescaled if the `expected` sum differs from the `observed` sum. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + ChiSqTest.chiSquared(observed, expected) + } + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform + * distribution, with each category having an expected frequency of `1 / observed.size`. + * + * Note: `observed` cannot contain negative values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test on the input contingency matrix, which cannot contain + * negative entries or columns or rows that sum up to 0. + * + * @param observed The contingency matrix (containing either counts or relative frequencies). + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the chi-squared statistic is computed. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @return an array containing the ChiSquaredTestResult for every feature against the label. + * The order of the elements in the returned array reflects the order of input features. + */ + @Experimental + def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { + ChiSqTest.chiSquaredFeatures(data) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala new file mode 100644 index 0000000000000..8f6752737402e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import breeze.linalg.{DenseMatrix => BDM} +import cern.jet.stat.Probability.chiSquareComplemented + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +/** + * Conduct the chi-squared test for the input RDDs using the specified method. + * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted + * on an input of type `Matrix` in which independence between columns is assessed. + * We also provide a method for computing the chi-squared statistic between each feature and the + * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = + * number of features in the inpuy RDD. + * + * Supported methods for goodness of fit: `pearson` (default) + * Supported methods for independence: `pearson` (default) + * + * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test + */ +private[stat] object ChiSqTest extends Logging { + + /** + * @param name String name for the method. + * @param chiSqFunc Function for computing the statistic given the observed and expected counts. + */ + case class Method(name: String, chiSqFunc: (Double, Double) => Double) + + // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test + val PEARSON = new Method("pearson", (observed: Double, expected: Double) => { + val dev = observed - expected + dev * dev / expected + }) + + // Null hypothesis for the two different types of chi-squared tests to be included in the result. + object NullHypothesis extends Enumeration { + type NullHypothesis = Value + val goodnessOfFit = Value("observed follows the same distribution as expected.") + val independence = Value("observations in each column are statistically independent.") + } + + // Method identification based on input methodName string + private def methodFromString(methodName: String): Method = { + methodName match { + case PEARSON.name => PEARSON + case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") + } + } + + /** + * Conduct Pearson's independence test for each feature against the label across the input RDD. + * The contingency table is constructed from the raw (feature, label) pairs and used to conduct + * the independence test. + * Returns an array containing the ChiSquaredTestResult for every feature against the label. + */ + def chiSquaredFeatures(data: RDD[LabeledPoint], + methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val numCols = data.first().features.size + val results = new Array[ChiSqTestResult](numCols) + var labels: Map[Double, Int] = null + // At most 100 columns at a time + val batchSize = 100 + var batch = 0 + while (batch * batchSize < numCols) { + // The following block of code can be cleaned up and made public as + // chiSquared(data: RDD[(V1, V2)]) + val startCol = batch * batchSize + val endCol = startCol + math.min(batchSize, numCols - startCol) + val pairCounts = data.flatMap { p => + // assume dense vectors + p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => + (col, feature, p.label) + } + }.countByValue() + + if (labels == null) { + // Do this only once for the first column since labels are invariant across features. + labels = + pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap + } + val numLabels = labels.size + pairCounts.keys.groupBy(_._1).map { case (col, keys) => + val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap + val numRows = features.size + val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) + keys.foreach { case (_, feature, label) => + val i = features(feature) + val j = labels(label) + contingency(i, j) += pairCounts((col, feature, label)) + } + results(col) = chiSquaredMatrix(Matrices.fromBreeze(contingency), methodName) + } + batch += 1 + } + results + } + + /* + * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Uniform distribution is assumed when `expected` is not passed in. + */ + def chiSquared(observed: Vector, + expected: Vector = Vectors.dense(Array[Double]()), + methodName: String = PEARSON.name): ChiSqTestResult = { + + // Validate input arguments + val method = methodFromString(methodName) + if (expected.size != 0 && observed.size != expected.size) { + throw new IllegalArgumentException("observed and expected must be of the same size.") + } + val size = observed.size + if (size > 1000) { + logWarning("Chi-squared approximation may not be accurate due to low expected frequencies " + + s" as a result of a large number of categories: $size.") + } + val obsArr = observed.toArray + val expArr = if (expected.size == 0) Array.tabulate(size)(_ => 1.0 / size) else expected.toArray + if (!obsArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the observed vector.") + } + if (expected.size != 0 && ! expArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the expected vector.") + } + + // Determine the scaling factor for expected + val obsSum = obsArr.sum + val expSum = if (expected.size == 0.0) 1.0 else expArr.sum + val scale = if (math.abs(obsSum - expSum) < 1e-7) 1.0 else obsSum / expSum + + // compute chi-squared statistic + val statistic = obsArr.zip(expArr).foldLeft(0.0) { case (stat, (obs, exp)) => + if (exp == 0.0) { + if (obs == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input vectors due" + + " to 0.0 values in both observed and expected.") + } else { + return new ChiSqTestResult(0.0, size - 1, Double.PositiveInfinity, PEARSON.name, + NullHypothesis.goodnessOfFit.toString) + } + } + if (scale == 1.0) { + stat + method.chiSqFunc(obs, exp) + } else { + stat + method.chiSqFunc(obs, exp * scale) + } + } + val df = size - 1 + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) + } + + /* + * Pearon's independence test on the input contingency matrix. + * TODO: optimize for SparseMatrix when it becomes supported. + */ + def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + val method = methodFromString(methodName) + val numRows = counts.numRows + val numCols = counts.numCols + + // get row and column sums + val colSums = new Array[Double](numCols) + val rowSums = new Array[Double](numRows) + val colMajorArr = counts.toArray + var i = 0 + while (i < colMajorArr.size) { + val elem = colMajorArr(i) + if (elem < 0.0) { + throw new IllegalArgumentException("Contingency table cannot contain negative entries.") + } + colSums(i / numRows) += elem + rowSums(i % numRows) += elem + i += 1 + } + val total = colSums.sum + + // second pass to collect statistic + var statistic = 0.0 + var j = 0 + while (j < colMajorArr.size) { + val col = j / numRows + val colSum = colSums(col) + if (colSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in column [$col].") + } + val row = j % numRows + val rowSum = rowSums(row) + if (rowSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in row [$row].") + } + val expected = colSum * rowSum / total + statistic += method.chiSqFunc(colMajorArr(j), expected) + j += 1 + } + val df = (numCols - 1) * (numRows - 1) + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala new file mode 100644 index 0000000000000..2f278621335e1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Trait for hypothesis test results. + * @tparam DF Return type of `degreesOfFreedom`. + */ +@Experimental +trait TestResult[DF] { + + /** + * The probability of obtaining a test statistic result at least as extreme as the one that was + * actually observed, assuming that the null hypothesis is true. + */ + def pValue: Double + + /** + * Returns the degree(s) of freedom of the hypothesis test. + * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. + */ + def degreesOfFreedom: DF + + /** + * Test statistic. + */ + def statistic: Double + + /** + * String explaining the hypothesis test result. + * Specific classes implementing this trait should override this method to output test-specific + * information. + */ + override def toString: String = { + + // String explaining what the p-value indicates. + val pValueExplain = if (pValue <= 0.01) { + "Very strong presumption against null hypothesis." + } else if (0.01 < pValue && pValue <= 0.05) { + "Strong presumption against null hypothesis." + } else if (0.05 < pValue && pValue <= 0.01) { + "Low presumption against null hypothesis." + } else { + "No presumption against null hypothesis." + } + + s"degrees of freedom = ${degreesOfFreedom.toString} \n" + + s"statistic = $statistic \n" + + s"pValue = $pValue \n" + pValueExplain + } +} + +/** + * :: Experimental :: + * Object containing the test results for the chi squared hypothesis test. + */ +@Experimental +class ChiSqTestResult(override val pValue: Double, + override val degreesOfFreedom: Int, + override val statistic: Double, + val method: String, + val nullHypothesis: String) extends TestResult[Int] { + + override def toString: String = { + "Chi squared test summary: \n" + + s"method: $method \n" + + s"null hypothesis: $nullHypothesis \n" + + super.toString + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bb50f07be5d7b..2a3107a13e916 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,22 +17,24 @@ package org.apache.spark.mllib.tree -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -53,16 +55,27 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. + timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + .persist(StorageLevel.MEMORY_AND_DISK) + // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -76,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = retaggedInput.take(1)(0).features.size + val numFeatures = treeInput.take(1)(0).binnedFeatures.size // Calculate level for single group construction @@ -96,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) + timer.stop("init") + /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. @@ -113,15 +128,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + timer.start("findBestSplits") + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + timer.stop("findBestSplits") for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + timer.start("extractNodeInfo") // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.stop("extractNodeInfo") + timer.start("extractInfoForLowerLevels") // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) @@ -144,6 +165,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + new DecisionTreeModel(topNode, strategy.algo) } @@ -406,7 +432,7 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -415,44 +441,45 @@ object DecisionTree extends Serializable with Logging { * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + filters, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) } } /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -465,13 +492,14 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -507,7 +535,7 @@ object DecisionTree extends Serializable with Logging { logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size + val numFeatures = input.first().binnedFeatures.size logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits @@ -542,33 +570,43 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false } // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val features = labeledPoint.features + parentFilters.foreach { filter => val featureIndex = filter.split.feature - val threshold = filter.split.threshold val comparison = filter.comparison - val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) if (isFeatureContinuous) { + val binId = treePoint.binnedFeatures(featureIndex) + val bin = bins(featureIndex)(binId) + val featureValue = bin.highSplit.threshold + val threshold = filter.split.threshold comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + case -1 => if (featureValue > threshold) return false + case 1 => if (featureValue <= threshold) return false } } else { - val containsFeature = categories.contains(feature) + val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + val featureValue = if (isUnorderedFeature) { + treePoint.binnedFeatures(featureIndex) + } else { + val binId = treePoint.binnedFeatures(featureIndex) + bins(featureIndex)(binId).category + } + val containsFeature = filter.split.categories.contains(featureValue) comparison match { case -1 => if (!containsFeature) return false case 1 => if (containsFeature) return false } - } } @@ -576,103 +614,6 @@ object DecisionTree extends Serializable with Logging { true } - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 - } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex - } - } - /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: @@ -689,17 +630,17 @@ object DecisionTree extends Serializable with Logging { * bin index for this labeledPoint * (or InvalidBinIndex if labeledPoint is not handled by this node) */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(treePoint: TreePoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. - arr(0) = labeledPoint.label + arr(0) = treePoint.label // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) + val sampleValid = isSampleValid(parentFilters, treePoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. @@ -707,19 +648,7 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } + arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex) featureIndex += 1 } } @@ -728,7 +657,8 @@ object DecisionTree extends Serializable with Logging { arr } - // Find feature bins for all nodes at a level. + // Find feature bins for all nodes at a level. + timer.start("aggregation") val binMappedRDD = input.map(x => findBinsForLevel(x)) /** @@ -830,6 +760,8 @@ object DecisionTree extends Serializable with Logging { } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * @@ -853,7 +785,6 @@ object DecisionTree extends Serializable with Logging { val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes // actual class label val label = arr(0) // Iterate over all features. @@ -912,7 +843,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + agg(aggIndex + 2) = agg(aggIndex + 2) + label * label featureIndex += 1 } } @@ -977,6 +908,7 @@ object DecisionTree extends Serializable with Logging { val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } + timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** @@ -1031,10 +963,17 @@ object DecisionTree extends Serializable with Logging { def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } - if (result._1 < 0) 0 else result._1 + result._1 } val predict = indexOfLargestArrayElement(leftRightCounts) @@ -1057,6 +996,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) @@ -1280,15 +1220,41 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) + splitIndex += 1 } + featureIndex += 1 } gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { + // Ordered features + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1307,7 +1273,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1317,22 +1283,8 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1383,6 +1335,7 @@ object DecisionTree extends Serializable with Logging { } // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 @@ -1395,6 +1348,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.stop("chooseSplits") + bestSplits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f31a503608b22..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,6 +85,10 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ private[tree] def assertValid(): Unit = { algo match { case Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..d215d68c4279e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..ccac1031fd9d9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[tree] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param strategy DecisionTree training info, used for dataset metadata. + * @param bins Bins for features, of size (numFeatures, numBins). + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + strategy: Strategy, + bins: Array[Array[Bin]]): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, + strategy.categoricalFeaturesInfo) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + isMulticlassClassification: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + val featureInfo = categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, + bins, categoricalFeaturesInfo) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isUnorderedFeature, bins, categoricalFeaturesInfo) + } + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find bin for one (labeledPoint, feature). + * + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } else if (lowThreshold >= feature) { + right = mid - 1 + } else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new RuntimeException("No bin was found for continuous feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new RuntimeException("No bin was found for categorical feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index bd413a80f5107..092d67bbc5238 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { - val py = new PythonMLLibAPI test("vector serialization") { val vectors = Seq( @@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite { Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => - val bytes = py.serializeDoubleVector(v) - val u = py.deserializeDoubleVector(bytes) + val bytes = SerDe.serializeDoubleVector(v) + val u = SerDe.deserializeDoubleVector(bytes) assert(u.getClass === v.getClass) assert(u === v) } @@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite { LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => - val bytes = py.serializeLabeledPoint(p) - val q = py.deserializeLabeledPoint(bytes) + val bytes = SerDe.serializeLabeledPoint(p) + val q = SerDe.deserializeLabeledPoint(bytes) assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) @@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite { test("double serialization") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { - val bytes = py.serializeDouble(x) - val deser = py.deserializeDouble(bytes) + val bytes = SerDe.serializeDouble(x) + val deser = SerDe.deserializeDouble(bytes) // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } @@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite { test("matrix to 2D array") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) - val arr = py.to2dArray(matrix) + val arr = SerDe.to2dArray(matrix) val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) assert(arr === expected) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) - val empty2D = py.to2dArray(emptyMatrix) + val empty2D = SerDe.to2dArray(emptyMatrix) assert(empty2D === Array[Array[Double]]()) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index da7c633bbd2af..bc05b2046878f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match } // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { + test("logistic regression with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } - test("logistic regression with initial weights") { + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD) + + // Test the weights + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== model.weights(0) relTol 0.01) + assert(model.intercept ~== model.intercept relTol 0.01) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -125,11 +154,99 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("logistic regression with initial weights with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.02) + assert(model.intercept ~== 1.97 relTol 0.02) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("numerical stability of scaling features using logistic regression with LBFGS") { + /** + * If we rescale the features, the condition number will be changed so the convergence rate + * and the solution will not equal to the original solution multiple by the scaling factor + * which it should be. + * + * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first, + * no matter how we multiple a scaling factor into the dataset, the convergence rate should be + * the same, and the solution should equal to the original solution multiple by the scaling + * factor. + */ + + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialWeights = Vectors.dense(0.0) + + val testRDD1 = sc.parallelize(testData, 2) + + val testRDD2 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2) + + val testRDD3 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2) + + testRDD1.cache() + testRDD2.cache() + testRDD3.cache() + + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val modelA1 = lrA.run(testRDD1, initialWeights) + val modelA2 = lrA.run(testRDD2, initialWeights) + val modelA3 = lrA.run(testRDD3, initialWeights) + + val modelB1 = lrB.run(testRDD1, initialWeights) + val modelB2 = lrB.run(testRDD2, initialWeights) + val modelB3 = lrB.run(testRDD3, initialWeights) + + // For model trained with feature standardization, the weights should + // be the same in the scaled space. Note that the weights here are already + // in the original space, we transform back to scaled space to compare. + assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01) + assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) + + // Training data with different scales without feature standardization + // will not yield the same result in the scaled space due to poor + // convergence rate. + assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { - test("task size should be small in both training and prediction") { + test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 val n = 200000 val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => @@ -139,6 +256,29 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LogisticRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() } + + test("task size should be small in both training and prediction using LBFGS optimizer") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val model = + (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala new file mode 100644 index 0000000000000..1952e6734ecf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.linalg.BLAS._ + +class BLASSuite extends FunSuite { + + test("copy") { + val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0) + val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0)) + val dy = Array(2.0, 1.0, 0.0, 1.0) + + val dy1 = Vectors.dense(dy.clone()) + copy(sx, dy1) + assert(dy1 ~== dx absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + copy(dx, dy2) + assert(dy2 ~== dx absTol 1e-15) + + intercept[IllegalArgumentException] { + copy(sx, sy) + } + + intercept[IllegalArgumentException] { + copy(dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + copy(sx, Vectors.dense(0.0, 1.0, 2.0)) + } + } + } + + test("scal") { + val a = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + + scal(a, sx) + assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15) + + scal(a, dx) + assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15) + } + + test("axpy") { + val alpha = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val dy = Array(2.0, 1.0, 0.0) + val expected = Vectors.dense(2.1, 1.0, -0.2) + + val dy1 = Vectors.dense(dy.clone()) + axpy(alpha, sx, dy1) + assert(dy1 ~== expected absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + axpy(alpha, dx, dy2) + assert(dy2 ~== expected absTol 1e-15) + + val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0)) + + intercept[IllegalArgumentException] { + axpy(alpha, sx, sy) + } + + intercept[IllegalArgumentException] { + axpy(alpha, dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + axpy(alpha, sx, Vectors.dense(1.0, 2.0)) + } + } + } + + test("dot") { + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0)) + val dy = Vectors.dense(2.0, 1.0, 0.0) + + assert(dot(sx, sy) ~== 2.0 absTol 1e-15) + assert(dot(sy, sx) ~== 2.0 absTol 1e-15) + assert(dot(sx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, sx) ~== 2.0 absTol 1e-15) + assert(dot(dx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, dx) ~== 2.0 absTol 1e-15) + + assert(dot(sx, sx) ~== 5.0 absTol 1e-15) + assert(dot(dx, dx) ~== 5.0 absTol 1e-15) + assert(dot(sx, dx) ~== 5.0 absTol 1e-15) + assert(dot(dx, sx) ~== 5.0 absTol 1e-15) + + val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15) + assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15) + + withClue("vector sizes must match") { + intercept[Exception] { + dot(sx, Vectors.dense(2.0, 1.0)) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 7972ceea1fe8a..cd651fe2d2ddf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite { } } } + + test("zeros") { + assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0)) + } + + test("Vector.copy") { + val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0)) + val svCopy = sv.copy + (sv, svCopy) match { + case (sv: SparseVector, svCopy: SparseVector) => + assert(sv.size === svCopy.size) + assert(sv.indices === svCopy.indices) + assert(sv.values === svCopy.values) + assert(!sv.indices.eq(svCopy.indices)) + assert(!sv.values.eq(svCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.") + } + + val dv = Vectors.dense(1.0, 0.0, 2.0) + val dvCopy = dv.copy + (dv, dvCopy) match { + case (dv: DenseVector, dvCopy: DenseVector) => + assert(dv.size === dvCopy.size) + assert(dv.values === dvCopy.values) + assert(!dv.values.eq(dvCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala new file mode 100644 index 0000000000000..5bd0521298c14 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class HypothesisTestSuite extends FunSuite with LocalSparkContext { + + test("chi squared pearson goodness of fit") { + + val observed = new DenseVector(Array[Double](4, 6, 5)) + val pearson = Statistics.chiSqTest(observed) + + // Results validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + assert(pearson.statistic === 0.4) + assert(pearson.degreesOfFreedom === 2) + assert(pearson.pValue ~== 0.8187 relTol 1e-4) + assert(pearson.method === ChiSqTest.PEARSON.name) + assert(pearson.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // different expected and observed sum + val observed1 = new DenseVector(Array[Double](21, 38, 43, 80)) + val expected1 = new DenseVector(Array[Double](3, 5, 7, 20)) + val pearson1 = Statistics.chiSqTest(observed1, expected1) + + // Results validated against the R command + // `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + assert(pearson1.statistic ~== 14.1429 relTol 1e-4) + assert(pearson1.degreesOfFreedom === 3) + assert(pearson1.pValue ~== 0.002717 relTol 1e-4) + assert(pearson1.method === ChiSqTest.PEARSON.name) + assert(pearson1.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // Vectors with different sizes + val observed3 = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected3 = new DenseVector(Array(1.0, 2.0, 3.0, 4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(observed3, expected3)) + + // negative counts in observed + val negObs = new DenseVector(Array(1.0, 2.0, 3.0, -4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(negObs, expected1)) + + // count = 0.0 in expected but not observed + val zeroExpected = new DenseVector(Array(1.0, 0.0, 3.0)) + val inf = Statistics.chiSqTest(observed, zeroExpected) + assert(inf.statistic === Double.PositiveInfinity) + assert(inf.degreesOfFreedom === 2) + assert(inf.pValue === 0.0) + assert(inf.method === ChiSqTest.PEARSON.name) + assert(inf.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // 0.0 in expected and observed simultaneously + val zeroObserved = new DenseVector(Array(2.0, 0.0, 1.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(zeroObserved, zeroExpected)) + } + + test("chi squared pearson matrix independence") { + val data = Array(40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0) + // [[40.0, 56.0, 31.0, 30.0], + // [24.0, 32.0, 10.0, 15.0], + // [29.0, 42.0, 0.0, 12.0]] + val chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + // Results validated against R command + // `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + assert(chi.statistic ~== 21.9958 relTol 1e-4) + assert(chi.degreesOfFreedom === 6) + assert(chi.pValue ~== 0.001213 relTol 1e-4) + assert(chi.method === ChiSqTest.PEARSON.name) + assert(chi.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + + // Negative counts + val negCounts = Array(4.0, 5.0, 3.0, -3.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, negCounts))) + + // Row sum = 0.0 + val rowZero = Array(0.0, 1.0, 0.0, 2.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, rowZero))) + + // Column sum = 0.0 + val colZero = Array(0.0, 0.0, 2.0, 2.0) + // IllegalArgumentException thrown here since it's thrown on driver, not inside a task + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, colZero))) + } + + test("chi squared pearson RDD[LabeledPoint]") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) + val feature1 = chi(0) + assert(feature1.statistic === 0.75) + assert(feature1.degreesOfFreedom === 2) + assert(feature1.pValue ~== 0.6873 relTol 1e-4) + assert(feature1.method === ChiSqTest.PEARSON.name) + assert(feature1.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + val feature2 = chi(1) + assert(feature2.statistic === 1.5) + assert(feature2.degreesOfFreedom === 3) + assert(feature2.pValue ~== 0.6823 relTol 1e-4) + assert(feature2.method === ChiSqTest.PEARSON.name) + assert(feature2.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + } + + // Test that the right number of results is returned + val numCols = 321 + val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) + assert(chi.size === numCols) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..a5c49a38dc08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} +import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy) + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } def validateRegressor( @@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { err * err }.sum val mse = squaredError / input.length - assert(mse <= requiredMSE) + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } test("split and bin calculation") { @@ -427,7 +429,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -454,7 +457,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -499,7 +503,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -544,7 +550,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -567,7 +574,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,7 +604,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,7 +613,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -630,7 +639,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -689,7 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -714,7 +725,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -738,7 +750,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -757,7 +770,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b4653c72c10b5..1e3c760b845de 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -111,9 +111,15 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") ) ++ - Seq ( // package-private classes removed in MLlib + Seq( // package-private classes removed in MLlib ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) case v if v.startsWith("1.0") => Seq( diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index e73538baf0b93..22ab8d30c0ae3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -22,7 +22,8 @@ import socket import sys import traceback -from errno import EINTR, ECHILD +import time +from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main @@ -80,6 +81,17 @@ def waitSocketClose(sock): os._exit(compute_real_exit_code(exit_code)) +# Cleanup zombie children +def cleanup_dead_children(): + try: + while True: + pid, _ = os.waitpid(0, os.WNOHANG) + if not pid: + break + except: + pass + + def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -102,29 +114,21 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP - # Cleanup zombie children - def handle_sigchld(*args): - try: - pid, status = os.waitpid(0, os.WNOHANG) - if status != 0: - msg = "worker %s crashed abruptly with exit status %s" % (pid, status) - print >> sys.stderr, msg - except EnvironmentError as err: - if err.errno not in (ECHILD, EINTR): - raise - signal.signal(SIGCHLD, handle_sigchld) - # Initialization complete sys.stdout.close() try: while True: try: - ready_fds = select.select([0, listen_sock], [], [])[0] + ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise + + # cleanup in signal handler will cause deadlock + cleanup_dead_children() + if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) @@ -137,29 +141,41 @@ def handle_sigchld(*args): pass # process already died if listen_sock in ready_fds: - sock, addr = listen_sock.accept() + try: + sock, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + # Launch a worker process try: pid = os.fork() - if pid == 0: - listen_sock.close() - try: - worker(sock) - except: - traceback.print_exc() - os._exit(1) - else: - os._exit(0) + except OSError as e: + if e.errno in (EAGAIN, EINTR): + time.sleep(1) + pid = os.fork() # error here will shutdown daemon else: + outfile = sock.makefile('w') + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() sock.close() - - except OSError as e: - print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - write_int(-1, outfile) # Signal that the fork failed - outfile.flush() - outfile.close() + continue + + if pid == 0: + # in child process + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: sock.close() + finally: shutdown(1) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 37386ab0d7d49..c7f7c1fe591b0 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -39,7 +39,7 @@ def launch_gateway(): submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script), "pyspark-shell"] + submit_args + command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index db341da85f865..bb60d3d0c8463 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -16,6 +16,7 @@ # import struct +import sys import numpy from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD @@ -78,6 +79,14 @@ LABELED_POINT_MAGIC = 4 +# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s. +if sys.version_info[:2] <= (2, 6): + def _unpack(fmt, string): + return struct.unpack(fmt, buffer(string)) +else: + _unpack = struct.unpack + + def _deserialize_numpy_array(shape, ba, offset, dtype=float64): """ Deserialize a numpy array of the given type from an offset in @@ -191,7 +200,7 @@ def _deserialize_double(ba, offset=0): raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) if len(ba) - offset != 8: raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) - return struct.unpack("d", ba[offset:])[0] + return _unpack("d", ba[offset:])[0] def _deserialize_double_vector(ba, offset=0): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 982906b9d09f0..a73abc5ff90df 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,11 +22,75 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix + _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +class MultivariateStatisticalSummary(object): + + """ + Trait for multivariate statistical summary of a data matrix. + """ + + def __init__(self, sc, java_summary): + """ + :param sc: Spark context + :param java_summary: Handle to Java summary object + """ + self._sc = sc + self._java_summary = java_summary + + def __del__(self): + self._sc._gateway.detach(self._java_summary) + + def mean(self): + return _deserialize_double_vector(self._java_summary.mean()) + + def variance(self): + return _deserialize_double_vector(self._java_summary.variance()) + + def count(self): + return self._java_summary.count() + + def numNonzeros(self): + return _deserialize_double_vector(self._java_summary.numNonzeros()) + + def max(self): + return _deserialize_double_vector(self._java_summary.max()) + + def min(self): + return _deserialize_double_vector(self._java_summary.min()) class Statistics(object): + @staticmethod + def colStats(X): + """ + Computes column-wise summary statistics for the input RDD[Vector]. + + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), + ... Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8])]) + >>> cStats = Statistics.colStats(rdd) + >>> cStats.mean() + array([ 4., 4., 0., 3.]) + >>> cStats.variance() + array([ 4., 13., 0., 25.]) + >>> cStats.count() + 3L + >>> cStats.numNonzeros() + array([ 3., 2., 0., 3.]) + >>> cStats.max() + array([ 6., 7., 0., 8.]) + >>> cStats.min() + array([ 2., 0., 0., -2.]) + """ + sc = X.ctx + Xser = _get_unmangled_double_vector_rdd(X) + cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + return MultivariateStatisticalSummary(sc, cStats) + @staticmethod def corr(x, y=None, method=None): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 6f3ec8ac94bac..8a851bd35c0e8 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,8 +19,13 @@ Fuller unit tests for Python MLlib. """ +import sys from numpy import array, array_equal -import unittest + +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest from pyspark.mllib._common import _convert_vector, _serialize_double_vector, \ _deserialize_double_vector, _dot, _squared_distance diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 756e8f35fb03d..3934bdda0a466 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,6 +30,7 @@ from threading import Thread import warnings import heapq +import bisect from random import Random from math import sqrt, log @@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 1).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] @@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - bounds = list() + if numPartitions == 1: + if self.getNumPartitions() > 1: + self = self.coalesce(1) + + def sort(iterator): + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return self.mapPartitions(sort) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them - if numPartitions > 1: - rddSize = self.count() - # constant from Spark's RangePartitioner - maxSampleSize = numPartitions * 20.0 - fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - - samples = self.sample(False, fraction, 1).map( - lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) - - # we have numPartitions many parts but one of the them has - # an implicit boundary - for i in range(0, numPartitions - 1): - index = (len(samples) - 1) * (i + 1) / numPartitions - bounds.append(samples[index]) + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + bounds = [samples[len(samples) * (i + 1) / numPartitions] + for i in range(0, numPartitions - 1)] def rangePartitionFunc(k): - p = 0 - while p < len(bounds) and keyfunc(k) > bounds[p]: - p += 1 + p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p def mapFunc(iterator): - yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) - .mapPartitions(mapFunc, preservesPartitioning=True) - .flatMap(lambda x: x, preservesPartitioning=True)) + return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b35558db3e007..df90cafb245bf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -314,8 +314,8 @@ def _copy_func(f): _old_namedtuple = _copy_func(collections.namedtuple) - def namedtuple(name, fields, verbose=False, rename=False): - cls = _old_namedtuple(name, fields, verbose, rename) + def namedtuple(*args, **kwargs): + cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with new one diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 950e275adbf01..95086a2258222 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -498,10 +498,7 @@ def _infer_schema(row): def _create_converter(obj, dataType): """Create an converter to drop the names of fields in obj """ - if not _has_struct(dataType): - return lambda x: x - - elif isinstance(dataType, ArrayType): + if isinstance(dataType, ArrayType): conv = _create_converter(obj[0], dataType.elementType) return lambda row: map(conv, row) @@ -510,6 +507,9 @@ def _create_converter(obj, dataType): conv = _create_converter(value, dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif not isinstance(dataType, StructType): + return lambda x: x + # dataType must be StructType names = [f.name for f in dataType.fields] @@ -529,8 +529,7 @@ def _create_converter(obj, dataType): elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] - nested = any(_has_struct(f.dataType) for f in dataType.fields) - if not nested: + if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): return conv row = conv(obj) @@ -912,6 +911,8 @@ def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. + @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -1035,7 +1036,8 @@ def inferSchema(self, rdd): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated") + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -1092,7 +1094,7 @@ def applySchema(self, rdd, schema): ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), @@ -1265,7 +1267,9 @@ def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) @@ -1315,6 +1319,18 @@ class HiveContext(SQLContext): It supports running both SQL and HiveQL commands. """ + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + @param sparkContext: The SparkContext to wrap. + @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + @property def _ssql_ctx(self): try: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 88a61176e51ab..22b51110ed671 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -29,9 +29,14 @@ import sys import tempfile import time -import unittest import zipfile +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest + + from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int @@ -605,6 +610,7 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) + @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -905,8 +911,9 @@ def createFileInZip(self, name, content): pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) path = os.path.join(self.programDir, name + ".zip") - with zipfile.ZipFile(path, 'w') as zip: - zip.writestr(name, content) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() return path def test_single_script(self): diff --git a/python/run-tests b/python/run-tests index 48feba2f5bd63..1218edcbd7e08 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,6 +48,14 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +# Try to test with Python 2.6, since that's the minimum version that we support: +if [ $(which python2.6) ]; then + export PYSPARK_PYTHON="python2.6" +fi + +echo "Testing with Python version:" +$PYSPARK_PYTHON --version + run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 603f50ae13240..2c4452473ccbc 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -65,14 +65,14 @@ while (($#)); do case $1 in --hiveconf) ensure_arg_number $# 2 - THRIFT_SERVER_ARGS+=($1); shift - THRIFT_SERVER_ARGS+=($1); shift + THRIFT_SERVER_ARGS+=("$1"); shift + THRIFT_SERVER_ARGS+=("$1"); shift ;; *) - SUBMISSION_ARGS+=($1); shift + SUBMISSION_ARGS+=("$1"); shift ;; esac done -eval exec "$FWDIR"/bin/spark-submit --class $CLASS ${SUBMISSION_ARGS[*]} spark-internal ${THRIFT_SERVER_ARGS[*]} +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${THRIFT_SERVER_ARGS[@]}" diff --git a/sql/README.md b/sql/README.md index 14d5555f0c713..31f9152344086 100644 --- a/sql/README.md +++ b/sql/README.md @@ -3,10 +3,11 @@ Spark SQL This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. -Spark SQL is broken up into three subprojects: +Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Other dependencies for developers diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..830711a46a35b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -77,28 +77,28 @@ org.apache.maven.plugins maven-jar-plugin - - - test-jar - - - - test-jar-on-compile - compile - - test-jar - - + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3d41acb79e5fd..e99c5b452d183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,19 +86,19 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et, _) => et :: Nil - case MapType(kt,vt, _) => kt :: vt :: Nil + case ArrayType(et, containsNull) => (et, containsNull) :: Nil + case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } // TODO: Move this pattern into Generator. protected def makeOutput() = if (attributeNames.size == elementTypes.size) { attributeNames.zip(elementTypes).map { - case (n, t) => AttributeReference(n, t, nullable = true)() + case (n, (t, nullable)) => AttributeReference(n, t, nullable)() } } else { elementTypes.zipWithIndex.map { - case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 0fd7aaaa36eb8..90de11182e605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,11 +25,13 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" + val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -71,6 +73,9 @@ trait SQLConf { /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** The number of rows that will be */ + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -83,8 +88,7 @@ trait SQLConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def codegenEnabled: Boolean = - if (getConf(CODEGEN_ENABLED, "false") == "true") true else false + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -104,6 +108,12 @@ trait SQLConf { private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + /** + * When set to true, we always treat byte arrays in Parquet files as strings. + */ + private[spark] def isParquetBinaryAsString: Boolean = + getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 71d338d21d0f2..af9f7c62a1d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -273,7 +273,7 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) } catalog.registerTable(None, tableName, asInMemoryRelation) @@ -284,7 +284,7 @@ class SQLContext(@transient val sparkContext: SparkContext) table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => + case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 88901debbb4e9..e63b4903041f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -20,21 +20,21 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{SparkPlan, LeafNode} -import org.apache.spark.sql.Row -import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} object InMemoryRelation { - def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, child)() + def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, child)() } private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, + batchSize: Int, child: SparkPlan) (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) extends LogicalPlan with MultiInstanceRelation { @@ -43,22 +43,33 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) - }.toArray - - var row: Row = null - while (iterator.hasNext) { - row = iterator.next() - var i = 0 - while (i < row.length) { - columnBuilders(i).appendFrom(row, i) - i += 1 + val cached = child.execute().mapPartitions { baseIterator => + new Iterator[Array[ByteBuffer]] { + def next() = { + val columnBuilders = output.map { attribute => + val columnType = ColumnType(attribute.dataType) + val initialBufferSize = columnType.defaultSize * batchSize + ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) + }.toArray + + var row: Row = null + var rowCount = 0 + + while (baseIterator.hasNext && rowCount < batchSize) { + row = baseIterator.next() + var i = 0 + while (i < row.length) { + columnBuilders(i).appendFrom(row, i) + i += 1 + } + rowCount += 1 + } + + columnBuilders.map(_.build()) } - } - Iterator.single(columnBuilders.map(_.build())) + def hasNext = baseIterator.hasNext + } }.cache() cached.setName(child.toString) @@ -74,6 +85,7 @@ private[sql] case class InMemoryRelation( new InMemoryRelation( output.map(_.newInstance), useCompression, + batchSize, child)( _cachedColumnBuffers).asInstanceOf[this.type] } @@ -90,22 +102,31 @@ private[sql] case class InMemoryColumnarTableScan( override def execute() = { relation.cachedColumnBuffers.mapPartitions { iterator => - val columnBuffers = iterator.next() - assert(!iterator.hasNext) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } new Iterator[Row] { - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } + private[this] var columnBuffers: Array[ByteBuffer] = null + private[this] var columnAccessors: Seq[ColumnAccessor] = null + nextBatch() + + private[this] val nextRow = new GenericMutableRow(columnAccessors.length) - val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - val nextRow = new GenericMutableRow(columnAccessors.length) + def nextBatch() = { + columnBuffers = iterator.next() + columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) + } override def next() = { + if (!columnAccessors.head.hasNext) { + nextBatch() + } + var i = 0 while (i < nextRow.length) { columnAccessors(i).extractTo(nextRow, i) @@ -114,7 +135,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors.head.hasNext || iterator.hasNext } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 51bb61530744c..c86811e838bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.{HashMap => JavaHashMap} + import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ @@ -136,14 +138,6 @@ trait HashJoin { } } -/** - * Constant Value for Binary Join Node - */ -object HashOuterJoin { - val DUMMY_LIST = Seq[Row](null) - val EMPTY_LIST = Seq[Row]() -} - /** * :: DeveloperApi :: * Performs a hash based outer join for two child relations by shuffling the data using @@ -168,7 +162,21 @@ case class HashOuterJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. @@ -188,8 +196,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right @@ -213,8 +221,8 @@ case class HashOuterJoin( joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. @@ -248,10 +256,10 @@ case class HashOuterJoin( rightMatchedSet.add(idx) joinedRow.copy } - } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. @@ -276,18 +284,22 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { - // TODO: Use Spark's HashMap implementation. - val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) - val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + existingMatchList += currentRow.copy() } - - hashTable.toMap[Row, ArrayBuffer[Row]] + + hashTable } def execute() = { @@ -298,21 +310,22 @@ case class HashOuterJoin( // Build HashMap for current partition in right relation val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + import scala.collection.JavaConversions._ val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, - leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index cc575bedd8fcb..2298a9b933df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -201,8 +201,9 @@ object ParquetFilters { (leftFilter, rightFilter) match { case (None, Some(filter)) => Some(filter) case (Some(filter), None) => Some(filter) - case (_, _) => - Some(new AndFilter(leftFilter.get, rightFilter.get)) + case (Some(leftF), Some(rightF)) => + Some(new AndFilter(leftF, rightF)) + case _ => None } } case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b3bae5db0edbc..053b2a154389c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -60,7 +60,11 @@ private[sql] case class ParquetRelation( .getSchema /** Attributes */ - override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) + override val output = + ParquetTypesConverter.readSchemaFromFile( + new Path(path), + conf, + sqlContext.isParquetBinaryAsString) override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 6d4ce32ac5bfa..6a657c20fe46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -80,9 +80,10 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { } } // if both unavailable, fall back to deducing the schema from the given Parquet schema + // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 2867dc0a8b1f9..c79a9ac2dad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -43,10 +43,13 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + def toPrimitiveDataType( + parquetType: ParquetPrimitiveType, + binayAsString: Boolean): DataType = parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || + binayAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -85,7 +88,7 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType): DataType = { + def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -107,7 +110,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -116,7 +119,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field), containsNull = false) + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -126,9 +129,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert( keyValueGroup.getFieldCount == 2, "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. @@ -138,22 +141,22 @@ private[parquet] object ParquetTypesConverter extends Logging { // Note: the order of these checks is important! if (correspondsToMap(groupType)) { // MapType val keyValueGroup = groupType.getFields.apply(0).asGroupType() - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType - val elementType = toDataType(groupType.getFields.apply(0)) + val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString) ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype), + toDataType(ptype, isBinaryAsString), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -276,7 +279,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -284,7 +287,7 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field), + toDataType(field, isBinaryAsString), field.getRepetition != Repetition.REQUIRED)()) } @@ -403,7 +406,10 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param conf The Hadoop configuration to use. * @return A list of attributes that make up the schema. */ - def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + def readSchemaFromFile( + origPath: Path, + conf: Option[Configuration], + isBinaryAsString: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -412,7 +418,7 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema) + readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index dffd15a61838b..c7e0ff1cf6494 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,6 +36,9 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fbf9bd9dbcdea..befef46d93973 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,9 +22,19 @@ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableSca import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +case class BigData(s: String) + class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("too big for memory") { + val data = "*" * 10000 + sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") + cacheTable("bigData") + assert(table("bigData").count() === 1000000L) + uncacheTable("bigData") + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) @@ -37,7 +47,7 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index b561b44ad7ee2..736c0f8571e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9933575038bd3..172dcd6aa0ee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -21,8 +21,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import parquet.schema.MessageTypeParser - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -138,6 +135,57 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } + test("Treat binary as string") { + val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString + + // Create the test file. + val file = getTempFilePath("parquet") + val path = file.toString + val range = (0 to 255) + val rowRDD = TestSQLContext.sparkContext.parallelize(range) + .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) + // We need to ask Parquet to store the String column as a Binary column. + val schema = StructType( + StructField("c1", IntegerType, false) :: + StructField("c2", BinaryType, false) :: Nil) + val schemaRDD1 = applySchema(rowRDD, schema) + schemaRDD1.saveAsParquetFile(path) + val resultWithBinary = parquetFile(path).collect + range.foreach { + i => + assert(resultWithBinary(i).getInt(0) === i) + assert(resultWithBinary(i)(1) === s"val_$i".getBytes) + } + + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + // This ParquetRelation always use Parquet types to derive output. + val parquetRelation = new ParquetRelation( + path.toString, + Some(TestSQLContext.sparkContext.hadoopConfiguration), + TestSQLContext) { + override val output = + ParquetTypesConverter.convertToAttributes( + ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, + TestSQLContext.isParquetBinaryAsString) + } + val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) + val resultWithString = schemaRDD.collect + range.foreach { + i => + assert(resultWithString(i).getInt(0) === i) + assert(resultWithString(i)(1) === s"val_$i") + } + + schemaRDD.registerTempTable("tmp") + checkAnswer( + sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + // Set it back. + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) + } + test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) @@ -381,11 +429,14 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) assert(badfilter.isDefined === false) + + val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) + val badfilter2 = ParquetFilters.createFilter(predicate6) + assert(badfilter2.isDefined === false) } test("test filter by predicate pushdown") { for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - println(s"testing field $myval") val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") assert( query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 6f7942aba314a..cadf7aaf42157 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -60,7 +60,7 @@ private[hive] object HiveThriftServer2 extends Logging { Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { - SparkSQLEnv.sparkContext.stop() + SparkSQLEnv.stop() } } ) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4d0c506c5a397..c16a7d3661c66 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,13 +26,15 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -116,13 +118,17 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( + /** + * This should be executed before shutdown hook of + * FileSystem to avoid race condition of FileSystem operation + */ + ShutdownHookManager.get.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - ) + , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { @@ -278,7 +284,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index dee092159dd4c..9338e8121b0fe 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage var curCol = 0 while (curCol < sparkRow.length) { - dataTypes(curCol) match { - case StringType => - row.addString(sparkRow(curCol).asInstanceOf[String]) - case IntegerType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) - case BooleanType => - row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) - case DoubleType => - row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) - case FloatType => - row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) - case DecimalType => - val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal - row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) - case ByteType => - row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) - case ShortType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) - case TimestampType => - row.addColumnValue( - ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) - row.addColumnValue(ColumnValue.stringValue(hiveString)) + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) } curCol += 1 } @@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } } + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.intValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + def getResultSetSchema: TableSchema = { logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { @@ -132,7 +167,16 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) - iter = result.queryExecution.toRdd.toLocalIterator + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray setHasResultSet(true) } catch { diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt new file mode 100644 index 0000000000000..ae08c640e6c13 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt @@ -0,0 +1,10 @@ +238val_238 + +311val_311 +val_27 +val_165 +val_409 +255val_255 +278val_278 +98val_98 +val_484 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 78bffa2607349..aedef6ce1f5f2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val stmt = createStatement() stmt.execute("DROP TABLE IF EXISTS test") stmt.execute("DROP TABLE IF EXISTS test_cached") - stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute("CREATE TABLE test(key INT, val STRING)") stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") - stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") stmt.execute("CACHE TABLE test_cached") - var rs = stmt.executeQuery("select count(*) from test") + var rs = stmt.executeQuery("SELECT COUNT(*) FROM test") rs.next() assert(rs.getInt(1) === 5) - rs = stmt.executeQuery("select count(*) from test_cached") + rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") rs.next() assert(rs.getInt(1) === 4) stmt.close() } + test("SPARK-3004 regression: result set containing NULL") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test_null") + stmt.execute("CREATE TABLE test_null(key INT, val STRING)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") + + val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + var count = 0 + while (rs.next()) { + count += 1 + } + assert(count === 5) + + stmt.close() + } + def getConnection: Connection = { val connectURI = s"jdbc:hive2://localhost:$PORT/" DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4fef071161719..210753efe7678 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -635,6 +635,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "serde_regex", "serde_reported_schema", "set_variable_sub", + "show_create_table_partitioned", + "show_create_table_delimited", + "show_create_table_alter", + "show_create_table_view", + "show_create_table_serde", + "show_create_table_db_table", + "show_create_table_does_not_exist", + "show_create_table_index", "show_describe_func_quotes", "show_functions", "show_partitions", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 53f3dc11dbb9f..a8da676ffa0e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,7 +39,8 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException @@ -119,10 +120,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the Hive metastore. */ def analyze(tableName: String) { - val relation = catalog.lookupRelation(None, tableName) match { - case LowerCaseSchema(r) => r - case o => o - } + val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { case relation: MetastoreRelation => { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 354fcd53f303b..943bbaa8ce25e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -71,6 +71,9 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + + // Hive seems to return this for struct types? + case c: Class[_] if c == classOf[java.lang.Object] => NullType } /** Converts hive types to native catalyst types. */ @@ -147,7 +150,10 @@ private[hive] trait HiveInspectors { case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) }) + hashMap case null => null } @@ -214,6 +220,12 @@ private[hive] trait HiveInspectors { import TypeInfoFactory._ def toTypeInfo: TypeInfo = dt match { + case ArrayType(elemType, _) => + getListTypeInfo(elemType.toTypeInfo) + case StructType(fields) => + getStructTypeInfo(fields.map(_.name), fields.map(_.dataType.toTypeInfo)) + case MapType(keyType, valueType, _) => + getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo case ByteType => byteTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 82e9c1a248626..3b371211e14cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -137,7 +137,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 05b2f5f6cd3f7..1d9ba1b24a7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -53,6 +53,7 @@ private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", "TOK_DESCDATABASE", + "TOK_SHOW_CREATETABLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 85d2496a34cfb..5fcc1bd4b9adf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -45,7 +45,7 @@ private[hive] trait HiveStrategies { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index d890df866fbe5..a013f3f7a805f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -70,6 +70,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("hive.metastore.warehouse.dir", warehousePath) } + val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") + testTempDir.delete() + testTempDir.mkdir() + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ @@ -109,6 +116,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveFilesTemp.mkdir() hiveFilesTemp.deleteOnExit() + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 179aac5cbd5cd..c6497a15efa0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -55,7 +55,10 @@ private[hive] abstract class HiveFunctionRegistry HiveSimpleUdf( functionClassName, - children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + children.zip(expectedDataTypes).map { + case (e, NullType) => e + case (e, t) => Cast(e, t) + } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) @@ -115,22 +118,26 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) } - val constructor = matchingConstructor.getOrElse( - sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) - - (a: Any) => { - logDebug( - s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") - // We must make sure that primitives get boxed java style. - if (a == null) { - null - } else { - constructor.newInstance(a match { - case i: Int => i: java.lang.Integer - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case other: AnyRef => other - }).asInstanceOf[AnyRef] - } + matchingConstructor match { + case Some(constructor) => + (a: Any) => { + logDebug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + case None => + (a: Any) => a match { + case wrapper => wrap(wrapper) + } } } diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 b/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..3c1fc128bedce --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,18 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 b/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2ece813dd7d56 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,22 @@ +CREATE TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'temporary table' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'EXTERNAL'='FALSE', + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 b/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2af657bd29506 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb b/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..f793ffb7a0bfd --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e b/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..c65aff26a7fc1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='1') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 b/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 new file mode 100644 index 0000000000000..707b2ae3ed1df --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 @@ -0,0 +1,2 @@ +default +tmp_feng diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee b/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 new file mode 100644 index 0000000000000..b5a18368ed85e --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -0,0 +1,13 @@ +CREATE TABLE tmp_feng.tmp_showcrt( + key string, + value int) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_feng.db/tmp_showcrt' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132107') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 b/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e b/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..d36ad25dc8273 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,17 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + COLLECTION ITEMS TERMINATED BY '|' + MAP KEYS TERMINATED BY '%' + LINES TERMINATED BY '\n' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132730') diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c b/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..9e572c0d7df6a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,16 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + newvalue boolean COMMENT 'a new value') +COMMENT 'temporary table' +PARTITIONED BY ( + value bigint COMMENT 'some value') +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132112') diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c b/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..69a38e1a7b20a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,15 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +COMMENT 'temporary table' +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b b/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..b4e693dc622fb --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,14 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + value boolean) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='$', + 'field.delim'=',') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 b/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 new file mode 100644 index 0000000000000..be3fb3ce30960 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 @@ -0,0 +1 @@ +CREATE VIEW tmp_copy_src AS SELECT `src`.`key`, `src`.`value` FROM `default`.`src` diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f b/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0ebaf6ffd5458..502ce8fb297e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -161,6 +161,7 @@ abstract class HiveComparisonTest "transient_lastDdlTime", "grantTime", "lastUpdateTime", + "last_modified_by", "last_modified_time", "Owner:", // The following are hive specific schema parameters which we do not need to match exactly. diff --git a/streaming/pom.xml b/streaming/pom.xml index 1072f74aea0d9..ce35520a28609 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -81,11 +81,11 @@ org.apache.maven.plugins @@ -97,8 +97,8 @@ - test-jar-on-compile - compile + test-jar-on-test-compile + test-compile test-jar diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e0677b795cb94..101cec1c7a7c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -98,9 +98,15 @@ class StreamingContext private[streaming] ( * @param hadoopConf Optional, configuration object if necessary for reading from * HDFS compatible filesystems */ - def this(path: String, hadoopConf: Configuration = new Configuration) = + def this(path: String, hadoopConf: Configuration) = this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + /** + * Recreate a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + */ + def this(path: String) = this(path, new Configuration) + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 774adc3c23c21..75f0e8716dc7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -23,10 +23,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.streaming.ui.StreamingJobProgressListener private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { - val metricRegistry = new MetricRegistry - val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) + override val metricRegistry = new MetricRegistry + override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.uiTab.listener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f8fb96b312f23..833e249f9f612 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -30,15 +30,15 @@ private[spark] class YarnClientSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } var client: Client = null var appId: ApplicationId = null var checkerThread: Thread = null var stopping: Boolean = false + var totalExpectedExecutors = 0 private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -84,7 +84,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors.set(args.numExecutors) + totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.runApp() waitForApp() @@ -150,4 +150,7 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 0ad1794d19538..55665220a6f96 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -27,19 +27,24 @@ private[spark] class YarnClusterSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + var totalExpectedExecutors = 0 + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } override def start() { super.start() - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors) + totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) + .getOrElse(totalExpectedExecutors) } // System property can override environment variable. - numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors) - totalExpectedExecutors.set(numExecutors) + totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } }