From 44dd57fb66bb676d753ad8d9757f9f4c03364113 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 8 May 2014 10:23:05 -0700 Subject: [PATCH 01/33] SPARK-1565, update examples to be used with spark-submit script. Commit for initial feedback, basically I am curious if we should prompt user for providing args esp. when its mandatory. And can we skip if they are not ? Also few other things that did not work like `bin/spark-submit examples/target/scala-2.10/spark-examples-1.0.0-SNAPSHOT-hadoop1.0.4.jar --class org.apache.spark.examples.SparkALS --arg 100 500 10 5 2` Not all the args get passed properly, may be I have messed up something will try to sort it out hopefully. Author: Prashant Sharma Closes #552 from ScrapCodes/SPARK-1565/update-examples and squashes the following commits: 669dd23 [Prashant Sharma] Review comments 2727e70 [Prashant Sharma] SPARK-1565, update examples to be used with spark-submit script. --- .gitignore | 1 + .../scala/org/apache/spark/SparkContext.scala | 8 ++-- .../org/apache/spark/examples/JavaHdfsLR.java | 13 ++++--- .../apache/spark/examples/JavaLogQuery.java | 13 +++---- .../apache/spark/examples/JavaPageRank.java | 15 +++++--- .../apache/spark/examples/JavaSparkPi.java | 18 ++++----- .../org/apache/spark/examples/JavaTC.java | 24 ++++++------ .../apache/spark/examples/JavaWordCount.java | 12 +++--- .../apache/spark/examples/mllib/JavaALS.java | 22 +++++------ .../spark/examples/mllib/JavaKMeans.java | 22 +++++------ .../apache/spark/examples/mllib/JavaLR.java | 18 ++++----- .../spark/examples/sql/JavaSparkSQL.java | 5 ++- .../streaming/JavaFlumeEventCount.java | 19 ++++------ .../streaming/JavaKafkaWordCount.java | 27 +++++++------- .../streaming/JavaNetworkWordCount.java | 25 ++++++------- .../examples/streaming/JavaQueueStream.java | 22 +++++------ .../apache/spark/examples/BroadcastTest.scala | 22 +++++------ .../spark/examples/CassandraCQLTest.scala | 19 +++++----- .../apache/spark/examples/CassandraTest.scala | 10 ++--- .../examples/ExceptionHandlingTest.scala | 11 ++---- .../apache/spark/examples/GroupByTest.scala | 25 ++++++------- .../org/apache/spark/examples/HBaseTest.scala | 6 +-- .../org/apache/spark/examples/HdfsTest.scala | 4 +- .../org/apache/spark/examples/LogQuery.scala | 14 +++---- .../spark/examples/MultiBroadcastTest.scala | 17 ++++----- .../examples/SimpleSkewedGroupByTest.scala | 24 ++++++------ .../spark/examples/SkewedGroupByTest.scala | 25 ++++++------- .../org/apache/spark/examples/SparkALS.scala | 18 +++------ .../apache/spark/examples/SparkHdfsLR.scala | 13 ++++--- .../apache/spark/examples/SparkKMeans.scala | 18 ++++----- .../org/apache/spark/examples/SparkLR.scala | 11 ++---- .../apache/spark/examples/SparkPageRank.scala | 14 +++---- .../org/apache/spark/examples/SparkPi.scala | 10 ++--- .../org/apache/spark/examples/SparkTC.scala | 12 ++---- .../spark/examples/SparkTachyonHdfsLR.scala | 12 ++---- .../spark/examples/SparkTachyonPi.scala | 10 ++--- .../examples/bagel/WikipediaPageRank.scala | 10 ++--- .../bagel/WikipediaPageRankStandalone.scala | 10 ++--- .../examples/graphx/LiveJournalPageRank.scala | 6 +-- .../spark/examples/sql/RDDRelation.scala | 5 ++- .../examples/sql/hive/HiveFromSpark.scala | 5 ++- .../examples/streaming/ActorWordCount.scala | 21 +++++------ .../examples/streaming/FlumeEventCount.scala | 14 +++---- .../examples/streaming/HdfsWordCount.scala | 18 ++++----- .../examples/streaming/KafkaWordCount.scala | 21 +++++------ .../examples/streaming/MQTTWordCount.scala | 26 ++++++------- .../examples/streaming/NetworkWordCount.scala | 23 +++++------- .../examples/streaming/QueueStream.scala | 10 ++--- .../examples/streaming/RawNetworkGrep.scala | 16 ++++---- .../RecoverableNetworkWordCount.scala | 37 ++++++++++--------- .../streaming/StatefulNetworkWordCount.scala | 21 +++++------ .../streaming/TwitterAlgebirdCMS.scala | 15 +++----- .../streaming/TwitterAlgebirdHLL.scala | 14 +++---- .../streaming/TwitterPopularTags.scala | 13 ++----- .../examples/streaming/ZeroMQWordCount.scala | 23 ++++++------ .../apache/spark/graphx/lib/Analytics.scala | 18 +++++---- 56 files changed, 405 insertions(+), 480 deletions(-) diff --git a/.gitignore b/.gitignore index 32b603f1bc84f..ad72588b472d6 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ unit-tests.log /lib/ rat-results.txt scalastyle.txt +conf/*.conf # For Hive metastore_db/ diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index eb14d87467af7..9d7c2c8d3d630 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -74,10 +74,10 @@ class SparkContext(config: SparkConf) extends Logging { * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. */ - @DeveloperApi - def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { - this(config) - this.preferredNodeLocationData = preferredNodeLocationData + @DeveloperApi + def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { + this(config) + this.preferredNodeLocationData = preferredNodeLocationData } /** diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index bd96274021756..6c177de359b60 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -17,6 +17,7 @@ package org.apache.spark.examples; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -103,16 +104,16 @@ public static void printWeights(double[] a) { public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: JavaHdfsLR "); + if (args.length < 2) { + System.err.println("Usage: JavaHdfsLR "); System.exit(1); } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaHdfsLR.class)); - JavaRDD lines = sc.textFile(args[1]); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD points = lines.map(new ParsePoint()).cache(); - int ITERATIONS = Integer.parseInt(args[2]); + int ITERATIONS = Integer.parseInt(args[1]); // Initialize w to a random value double[] w = new double[D]; diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 3f7a879538016..812e9d5580cbf 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -20,6 +20,7 @@ import com.google.common.collect.Lists; import scala.Tuple2; import scala.Tuple3; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -34,6 +35,8 @@ /** * Executes a roll up-style query against Apache logs. + * + * Usage: JavaLogQuery [logFile] */ public final class JavaLogQuery { @@ -97,15 +100,11 @@ public static Stats extractStats(String line) { } public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [logFile]"); - System.exit(1); - } - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaLogQuery.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); - JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); + JavaRDD dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs); JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index e31f676f5fd4c..7ea6df9c17245 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -18,9 +18,12 @@ package org.apache.spark.examples; + import scala.Tuple2; import com.google.common.collect.Iterables; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -54,20 +57,20 @@ public Double call(Double a, Double b) { } public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.err.println("Usage: JavaPageRank "); + if (args.length < 2) { + System.err.println("Usage: JavaPageRank "); System.exit(1); } - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaPageRank.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... - JavaRDD lines = ctx.textFile(args[1], 1); + JavaRDD lines = ctx.textFile(args[0], 1); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD> links = lines.mapToPair(new PairFunction() { @@ -87,7 +90,7 @@ public Double call(Iterable rs) { }); // Calculates and updates URL ranks continuously using PageRank algorithm. - for (int current = 0; current < Integer.parseInt(args[2]); current++) { + for (int current = 0; current < Integer.parseInt(args[1]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD contribs = links.join(ranks).values() .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index ac8df02c4630b..11157d7573fae 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -17,6 +17,7 @@ package org.apache.spark.examples; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -25,19 +26,18 @@ import java.util.ArrayList; import java.util.List; -/** Computes an approximation to pi */ +/** + * Computes an approximation to pi + * Usage: JavaSparkPi [slices] + */ public final class JavaSparkPi { + public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaSparkPi [slices]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaSparkPi", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkPi.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); - int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; + int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; List l = new ArrayList(n); for (int i = 0; i < n; i++) { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index d66b9ba265fe8..2563fcdd234bb 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -17,19 +17,22 @@ package org.apache.spark.examples; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.PairFunction; - import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Set; +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.PairFunction; + /** * Transitive closure on a graph, implemented in Java. + * Usage: JavaTC [slices] */ public final class JavaTC { @@ -61,14 +64,9 @@ public Tuple2 call(Tuple2> t } public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaTC []"); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaTC.class)); - Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); // Linear transitive closure: each round grows paths by one edge, diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 87c1b80981961..9a6a944f7edef 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -18,6 +18,7 @@ package org.apache.spark.examples; import scala.Tuple2; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -33,14 +34,15 @@ public final class JavaWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaWordCount "); + + if (args.length < 1) { + System.err.println("Usage: JavaWordCount "); System.exit(1); } - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaWordCount.class)); - JavaRDD lines = ctx.textFile(args[1], 1); + SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); + JavaRDD lines = ctx.textFile(args[0], 1); JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java index 4533c4c5f241a..8d381d4e0a943 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java @@ -17,6 +17,7 @@ package org.apache.spark.examples.mllib; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -57,23 +58,22 @@ public String call(Tuple2 element) { public static void main(String[] args) { - if (args.length != 5 && args.length != 6) { + if (args.length < 4) { System.err.println( - "Usage: JavaALS []"); + "Usage: JavaALS []"); System.exit(1); } - - int rank = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - String outputDir = args[4]; + SparkConf sparkConf = new SparkConf().setAppName("JavaALS"); + int rank = Integer.parseInt(args[1]); + int iterations = Integer.parseInt(args[2]); + String outputDir = args[3]; int blocks = -1; - if (args.length == 6) { - blocks = Integer.parseInt(args[5]); + if (args.length == 5) { + blocks = Integer.parseInt(args[4]); } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaALS.class)); - JavaRDD lines = sc.textFile(args[1]); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD ratings = lines.map(new ParseRating()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java index 0cfb8e69ed28f..f796123a25727 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java @@ -19,6 +19,7 @@ import java.util.regex.Pattern; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -48,24 +49,21 @@ public Vector call(String line) { } public static void main(String[] args) { - - if (args.length < 4) { + if (args.length < 3) { System.err.println( - "Usage: JavaKMeans []"); + "Usage: JavaKMeans []"); System.exit(1); } - - String inputFile = args[1]; - int k = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); + String inputFile = args[0]; + int k = Integer.parseInt(args[1]); + int iterations = Integer.parseInt(args[2]); int runs = 1; - if (args.length >= 5) { - runs = Integer.parseInt(args[4]); + if (args.length >= 4) { + runs = Integer.parseInt(args[3]); } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(inputFile); JavaRDD points = lines.map(new ParsePoint()); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java index f6e48b498727b..eceb6927d5551 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java @@ -19,6 +19,7 @@ import java.util.regex.Pattern; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -51,17 +52,16 @@ public LabeledPoint call(String line) { } public static void main(String[] args) { - if (args.length != 4) { - System.err.println("Usage: JavaLR "); + if (args.length != 3) { + System.err.println("Usage: JavaLR "); System.exit(1); } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaLR.class)); - JavaRDD lines = sc.textFile(args[1]); + SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + JavaRDD lines = sc.textFile(args[0]); JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[2]); - int iterations = Integer.parseInt(args[3]); + double stepSize = Double.parseDouble(args[1]); + int iterations = Integer.parseInt(args[2]); // Another way to configure LogisticRegression // @@ -73,7 +73,7 @@ public static void main(String[] args) { // LogisticRegressionModel model = lr.train(points.rdd()); LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); + iterations, stepSize); System.out.print("Final w: " + model.weights()); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index d62a72f53443c..ad5ec84b71e69 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -20,6 +20,7 @@ import java.io.Serializable; import java.util.List; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -51,8 +52,8 @@ public void setAge(int age) { } public static void main(String[] args) throws Exception { - JavaSparkContext ctx = new JavaSparkContext("local", "JavaSparkSQL", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkSQL.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); + JavaSparkContext ctx = new JavaSparkContext(sparkConf); JavaSQLContext sqlCtx = new JavaSQLContext(ctx); // Load a text file and convert each line to a Java Bean. diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index a5ece68cef870..400b68c2215b3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.Function; import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.*; @@ -31,9 +32,8 @@ * an Avro server on at the request host:port address and listen for requests. * Your Flume AvroSink should be pointed to this address. * - * Usage: JavaFlumeEventCount + * Usage: JavaFlumeEventCount * - * is a Spark master URL * is the host the Flume receiver will be started on - a receiver * creates a server and listens for flume events. * is the port the Flume receiver will listen on. @@ -43,22 +43,19 @@ private JavaFlumeEventCount() { } public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaFlumeEventCount "); + if (args.length != 2) { + System.err.println("Usage: JavaFlumeEventCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - String master = args[0]; - String host = args[1]; - int port = Integer.parseInt(args[2]); + String host = args[0]; + int port = Integer.parseInt(args[1]); Duration batchInterval = new Duration(2000); - - JavaStreamingContext ssc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaFlumeEventCount.class)); + SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, "localhost", port); flumeStream.count(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index da51eb189a649..6a74cc50d19ed 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -21,7 +21,11 @@ import java.util.HashMap; import java.util.regex.Pattern; + +import scala.Tuple2; + import com.google.common.collect.Lists; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; @@ -33,19 +37,18 @@ import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import org.apache.spark.streaming.kafka.KafkaUtils; -import scala.Tuple2; /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: JavaKafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: JavaKafkaWordCount * is a list of one or more zookeeper servers that make quorum * is the name of kafka consumer group * is a list of one or more kafka topics to consume from * is the number of threads the kafka consumer should use * * Example: - * `./bin/run-example org.apache.spark.examples.streaming.JavaKafkaWordCount local[2] zoo01,zoo02, + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.JavaKafkaWordCount zoo01,zoo02, \ * zoo03 my-consumer-group topic1,topic2 1` */ @@ -56,27 +59,25 @@ private JavaKafkaWordCount() { } public static void main(String[] args) { - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount "); + if (args.length < 4) { + System.err.println("Usage: JavaKafkaWordCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - + SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); // Create the context with a 1 second batch size - JavaStreamingContext jssc = new JavaStreamingContext(args[0], "KafkaWordCount", - new Duration(2000), System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaKafkaWordCount.class)); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); - int numThreads = Integer.parseInt(args[4]); + int numThreads = Integer.parseInt(args[3]); Map topicMap = new HashMap(); - String[] topics = args[3].split(","); + String[] topics = args[2].split(","); for (String topic: topics) { topicMap.put(topic, numThreads); } JavaPairReceiverInputDStream messages = - KafkaUtils.createStream(jssc, args[1], args[2], topicMap); + KafkaUtils.createStream(jssc, args[0], args[1], topicMap); JavaDStream lines = messages.map(new Function, String>() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index ac84991d87b8b..e5cbd39f437c2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -17,9 +17,10 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import scala.Tuple2; +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; @@ -27,41 +28,39 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; import java.util.regex.Pattern; /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: JavaNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: JavaNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./run org.apache.spark.examples.streaming.JavaNetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.JavaNetworkWordCount localhost 9999` */ public final class JavaNetworkWordCount { private static final Pattern SPACE = Pattern.compile(" "); public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: JavaNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1"); + if (args.length < 2) { + System.err.println("Usage: JavaNetworkWordCount "); System.exit(1); } StreamingExamples.setStreamingLogLevels(); - + SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount"); // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount", - new Duration(1000), System.getenv("SPARK_HOME"), - JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); // Create a JavaReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') - JavaReceiverInputDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); + JavaReceiverInputDStream lines = ssc.socketTextStream(args[0], Integer.parseInt(args[1])); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java index 819311968fac5..4ce8437f82705 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java @@ -17,8 +17,16 @@ package org.apache.spark.examples.streaming; -import com.google.common.collect.Lists; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + import scala.Tuple2; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; @@ -28,25 +36,17 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; - public final class JavaQueueStream { private JavaQueueStream() { } public static void main(String[] args) throws Exception { - if (args.length < 1) { - System.err.println("Usage: JavaQueueStream "); - System.exit(1); - } StreamingExamples.setStreamingLogLevels(); + SparkConf sparkConf = new SparkConf().setAppName("JavaQueueStream"); // Create the context - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), - System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaQueueStream.class)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); // Create the queue through which RDDs can be pushed to // a QueueInputDStream diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index f6dfd2c4c6217..973049b95a7bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -17,28 +17,26 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} +/** + * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize] + */ object BroadcastTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [slices] [numElem] [broadcastAlgo]" + - " [blockSize]") - System.exit(1) - } - val bcName = if (args.length > 3) args(3) else "Http" - val blockSize = if (args.length > 4) args(4) else "4096" + val bcName = if (args.length > 2) args(2) else "Http" + val blockSize = if (args.length > 3) args(3) else "4096" System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory") System.setProperty("spark.broadcast.blockSize", blockSize) + val sparkConf = new SparkConf().setAppName("Broadcast Test") - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sc = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 + val slices = if (args.length > 0) args(0).toInt else 2 + val num = if (args.length > 1) args(1).toInt else 1000000 val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 3798329fc2f41..9a00701f985f0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -30,7 +30,7 @@ import org.apache.cassandra.hadoop.cql3.CqlOutputFormat import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /* @@ -65,19 +65,18 @@ import org.apache.spark.SparkContext._ /** * This example demonstrates how to read and write to cassandra column family created using CQL3 * using Spark. - * Parameters : - * Usage: ./bin/run-example org.apache.spark.examples.CassandraCQLTest local[2] localhost 9160 - * + * Parameters : + * Usage: ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.CassandraCQLTest localhost 9160 */ object CassandraCQLTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), - "CQLTestApp", - System.getenv("SPARK_HOME"), - SparkContext.jarOfClass(this.getClass).toSeq) - val cHost: String = args(1) - val cPort: String = args(2) + val sparkConf = new SparkConf().setAppName("CQLTestApp") + + val sc = new SparkContext(sparkConf) + val cHost: String = args(0) + val cPort: String = args(1) val KeySpace = "retail" val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index ed5d2f9e46f29..91ba364a346a5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -30,7 +30,7 @@ import org.apache.cassandra.thrift._ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /* @@ -38,10 +38,10 @@ import org.apache.spark.SparkContext._ * support for Hadoop. * * To run this example, run this file with the following command params - - * + * * * So if you want to run this on localhost this will be, - * local[3] localhost 9160 + * localhost 9160 * * The example makes some assumptions: * 1. You have already created a keyspace called casDemo and it has a column family named Words @@ -54,9 +54,9 @@ import org.apache.spark.SparkContext._ object CassandraTest { def main(args: Array[String]) { - + val sparkConf = new SparkConf().setAppName("casDemo") // Get a SparkContext - val sc = new SparkContext(args(0), "casDemo") + val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra val job = new Job() diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index f0dcef431b2e1..d42f63e87052e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -17,17 +17,12 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} object ExceptionHandlingTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: ExceptionHandlingTest ") - System.exit(1) - } - - val sc = new SparkContext(args(0), "ExceptionHandlingTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest") + val sc = new SparkContext(sparkConf) sc.parallelize(0 until sc.defaultParallelism).foreach { i => if (math.random > 0.75) { throw new Exception("Testing exception handling") diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index e67bb29a49405..efd91bb054981 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -19,24 +19,21 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object GroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println( - "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("GroupBy Test") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index adbd1c02fa2ea..a8c338480e6e2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -26,11 +26,9 @@ import org.apache.spark.rdd.NewHadoopRDD object HBaseTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HBaseTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - + val sparkConf = new SparkConf().setAppName("HBaseTest") + val sc = new SparkContext(sparkConf) val conf = HBaseConfiguration.create() - // Other options for configuring scan behavior are available. More information available at // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html conf.set(TableInputFormat.INPUT_TABLE, args(1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index c7a4884af10b7..331de3ad1ef53 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -21,8 +21,8 @@ import org.apache.spark._ object HdfsTest { def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HdfsTest", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("HdfsTest") + val sc = new SparkContext(sparkConf) val file = sc.textFile(args(1)) val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index f77a444ff7a9f..4c655b84fde2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. + * + * Usage: LogQuery [logFile] */ object LogQuery { val exampleApacheLogs = List( @@ -40,16 +42,12 @@ object LogQuery { ) def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: LogQuery [logFile]") - System.exit(1) - } - val sc = new SparkContext(args(0), "Log Query", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("Log Query") + val sc = new SparkContext(sparkConf) val dataSet = - if (args.length == 2) sc.textFile(args(1)) else sc.parallelize(exampleApacheLogs) + if (args.length == 1) sc.textFile(args(0)) else sc.parallelize(exampleApacheLogs) // scalastyle:off val apacheLogRegex = """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index c8985eae33de3..2a5c0c0defe13 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,20 +18,19 @@ package org.apache.spark.examples import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} +/** + * Usage: MultiBroadcastTest [slices] [numElem] + */ object MultiBroadcastTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: MultiBroadcastTest [] [numElem]") - System.exit(1) - } - val sc = new SparkContext(args(0), "Multi-Broadcast Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test") + val sc = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 + val slices = if (args.length > 0) args(0).toInt else 2 + val num = if (args.length > 1) args(1).toInt else 1000000 val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 54e8503711e30..5291ab81f459e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -19,25 +19,23 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] + */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SimpleSkewedGroupByTest " + - "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") - System.exit(1) - } - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - var ratio = if (args.length > 5) args(5).toInt else 5.0 + val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + var ratio = if (args.length > 4) args(4).toInt else 5.0 - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 1c5f22e1c00bb..017d4e1e5ce13 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -19,24 +19,21 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +/** + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object SkewedGroupByTest { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println( - "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("GroupBy Test") + var numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + var valSize = if (args.length > 2) args(2).toInt else 1000 + var numReducers = if (args.length > 3) args(3).toInt else numMappers + + val sc = new SparkContext(sparkConf) val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 0dc726aecdd28..5cbc966bf06ca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -88,32 +88,24 @@ object SparkALS { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - - var host = "" var slices = 0 - val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) + val options = (0 to 4).map(i => if (i < args.length) Some(args(i)) else None) options.toArray match { - case Array(host_, m, u, f, iters, slices_) => - host = host_.get + case Array(m, u, f, iters, slices_) => M = m.getOrElse("100").toInt U = u.getOrElse("500").toInt F = f.getOrElse("10").toInt ITERATIONS = iters.getOrElse("5").toInt slices = slices_.getOrElse("2").toInt case _ => - System.err.println("Usage: SparkALS [ ]") + System.err.println("Usage: SparkALS [M] [U] [F] [iters] [slices]") System.exit(1) } printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) - - val sc = new SparkContext(host, "SparkALS", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("SparkALS") + val sc = new SparkContext(sparkConf) val R = generateR() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 3a6f18c33ea4b..4906a696e90a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -49,20 +49,21 @@ object SparkHdfsLR { } def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkHdfsLR ") + if (args.length < 2) { + System.err.println("Usage: SparkHdfsLR ") System.exit(1) } - val inputPath = args(1) + + val sparkConf = new SparkConf().setAppName("SparkHdfsLR") + val inputPath = args(0) val conf = SparkHadoopUtil.get.newConfiguration() - val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq, Map(), + val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) )) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).cache() - val ITERATIONS = args(2).toInt + val ITERATIONS = args(1).toInt // Initialize w to a random value var w = DenseVector.fill(D){2 * rand.nextDouble - 1} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index dcae9591b0407..4d28e0aad6597 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -21,7 +21,7 @@ import java.util.Random import breeze.linalg.{Vector, DenseVector, squaredDistance} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** @@ -52,16 +52,16 @@ object SparkKMeans { } def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans ") - System.exit(1) + if (args.length < 3) { + System.err.println("Usage: SparkKMeans ") + System.exit(1) } - val sc = new SparkContext(args(0), "SparkLocalKMeans", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val lines = sc.textFile(args(1)) + val sparkConf = new SparkConf().setAppName("SparkKMeans") + val sc = new SparkContext(sparkConf) + val lines = sc.textFile(args(0)) val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble + val K = args(1).toInt + val convergeDist = args(2).toDouble val kPoints = data.takeSample(withReplacement = false, K, 42).toArray var tempDist = 1.0 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 4f74882ccbea5..99ceb3089e9fe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -27,6 +27,7 @@ import org.apache.spark._ /** * Logistic regression based classification. + * Usage: SparkLR [slices] */ object SparkLR { val N = 10000 // Number of data points @@ -47,13 +48,9 @@ object SparkLR { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkLR []") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val numSlices = if (args.length > 1) args(1).toInt else 2 + val sparkConf = new SparkConf().setAppName("SparkLR") + val sc = new SparkContext(sparkConf) + val numSlices = if (args.length > 0) args(0).toInt else 2 val points = sc.parallelize(generateData, numSlices).cache() // Initialize w to a random value diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index fa41c5c560943..40b36c779afd6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples import org.apache.spark.SparkContext._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} /** * Computes the PageRank of URLs from an input file. Input file should @@ -31,14 +31,10 @@ import org.apache.spark.SparkContext */ object SparkPageRank { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: PageRank ") - System.exit(1) - } - var iters = args(2).toInt - val ctx = new SparkContext(args(0), "PageRank", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val lines = ctx.textFile(args(1), 1) + val sparkConf = new SparkConf().setAppName("PageRank") + var iters = args(1).toInt + val ctx = new SparkContext(sparkConf) + val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => val parts = s.split("\\s+") (parts(0), parts(1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index d8f5720504223..9fbb0a800d735 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -24,13 +24,9 @@ import org.apache.spark._ /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkPi", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val slices = if (args.length > 1) args(1).toInt else 2 + val conf = new SparkConf().setAppName("Spark Pi") + val spark = new SparkContext(conf) + val slices = if (args.length > 0) args(0).toInt else 2 val n = 100000 * slices val count = spark.parallelize(1 to n, slices).map { i => val x = random * 2 - 1 diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 17d983cd875db..f7f83086df3db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import scala.util.Random import scala.collection.mutable -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** @@ -42,13 +42,9 @@ object SparkTC { } def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTC []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTC", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) - val slices = if (args.length > 1) args(1).toInt else 2 + val sparkConf = new SparkConf().setAppName("SparkTC") + val spark = new SparkContext(sparkConf) + val slices = if (args.length > 0) args(0).toInt else 2 var tc = spark.parallelize(generateGraph, slices).cache() // Linear transitive closure: each round grows paths by one edge, diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 7e43c384bdb9d..22127621867e1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -51,20 +51,16 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkTachyonHdfsLR ") - System.exit(1) - } - val inputPath = args(1) + val inputPath = args(0) val conf = SparkHadoopUtil.get.newConfiguration() - val sc = new SparkContext(args(0), "SparkTachyonHdfsLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq, Map(), + val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") + val sc = new SparkContext(sparkConf, InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) )) val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) - val ITERATIONS = args(2).toInt + val ITERATIONS = args(1).toInt // Initialize w to a random value var w = DenseVector.fill(D){2 * rand.nextDouble - 1} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index 93459110e4e0e..7743f7968b100 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -28,14 +28,10 @@ import org.apache.spark.storage.StorageLevel */ object SparkTachyonPi { def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTachyonPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTachyonPi", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass).toSeq) + val sparkConf = new SparkConf().setAppName("SparkTachyonPi") + val spark = new SparkContext(sparkConf) - val slices = if (args.length > 1) args(1).toInt else 2 + val slices = if (args.length > 0) args(0).toInt else 2 val n = 100000 * slices val rdd = spark.parallelize(1 to n, slices) diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index 25bd55ca88b94..235c3bf820244 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -32,22 +32,22 @@ import scala.xml.{XML,NodeSeq} */ object WikipediaPageRank { def main(args: Array[String]) { - if (args.length < 5) { + if (args.length < 4) { System.err.println( - "Usage: WikipediaPageRank ") + "Usage: WikipediaPageRank ") System.exit(-1) } val sparkConf = new SparkConf() + sparkConf.setAppName("WikipediaPageRank") sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) val inputFile = args(0) val threshold = args(1).toDouble val numPartitions = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean + val usePartitioner = args(3).toBoolean - sparkConf.setMaster(host).setAppName("WikipediaPageRank") + sparkConf.setAppName("WikipediaPageRank") val sc = new SparkContext(sparkConf) // Parse the Wikipedia page data into a graph diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index dee3cb6c0abae..a197dac87d6db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -30,22 +30,20 @@ import org.apache.spark.rdd.RDD object WikipediaPageRankStandalone { def main(args: Array[String]) { - if (args.length < 5) { + if (args.length < 4) { System.err.println("Usage: WikipediaPageRankStandalone " + - " ") + " ") System.exit(-1) } val sparkConf = new SparkConf() sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - val inputFile = args(0) val threshold = args(1).toDouble val numIterations = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean + val usePartitioner = args(3).toBoolean - sparkConf.setMaster(host).setAppName("WikipediaPageRankStandalone") + sparkConf.setAppName("WikipediaPageRankStandalone") val sc = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index d58fddff2b5ec..6ef3b62dcbedc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -28,9 +28,9 @@ import org.apache.spark.graphx.lib.Analytics */ object LiveJournalPageRank { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 1) { System.err.println( - "Usage: LiveJournalPageRank \n" + + "Usage: LiveJournalPageRank \n" + " [--tol=]\n" + " The tolerance allowed at convergence (smaller => more accurate). Default is " + "0.001.\n" + @@ -44,6 +44,6 @@ object LiveJournalPageRank { System.exit(-1) } - Analytics.main(args.patch(1, List("pagerank"), 0)) + Analytics.main(args.patch(0, List("pagerank"), 0)) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index ff9254b044c24..61c460c6b1de8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext // One method for defining the schema of an RDD is to make a case class with the desired column @@ -26,7 +26,8 @@ case class Record(key: Int, value: String) object RDDRelation { def main(args: Array[String]) { - val sc = new SparkContext("local", "RDDRelation") + val sparkConf = new SparkConf().setAppName("RDDRelation") + val sc = new SparkContext(sparkConf) val sqlContext = new SQLContext(sc) // Importing the SQL context gives access to all the SQL functions and implicit conversions. diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index 66ce93a26ef42..b262fabbe0e0d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.sql.hive -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.apache.spark.sql.hive.LocalHiveContext @@ -25,7 +25,8 @@ object HiveFromSpark { case class Record(key: Int, value: String) def main(args: Array[String]) { - val sc = new SparkContext("local", "HiveFromSpark") + val sparkConf = new SparkConf().setAppName("HiveFromSpark") + val sc = new SparkContext(sparkConf) // A local hive context creates an instance of the Hive Metastore in process, storing the // the warehouse data in the current directory. This location can be overridden by diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 84cf43df0f96c..e29e16a9c1b17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -126,31 +126,30 @@ object FeederActor { /** * A sample word count program demonstrating the use of plugging in * Actor as Receiver - * Usage: ActorWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: ActorWordCount * and describe the AkkaSystem that Spark Sample feeder is running on. * * To run this example locally, you may run Feeder Actor as - * `$ ./bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` * and then run the example - * `./bin/run-example org.apache.spark.examples.streaming.ActorWordCount local[2] 127.0.1.1 9999` + * `./bin/spark-submit examples.jar --class org.apache.spark.examples.streaming.ActorWordCount \ + * 127.0.1.1 9999` */ object ActorWordCount { def main(args: Array[String]) { - if (args.length < 3) { + if (args.length < 2) { System.err.println( - "Usage: ActorWordCount " + - "In local mode, should be 'local[n]' with n > 1") + "Usage: ActorWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Seq(master, host, port) = args.toSeq - + val Seq(host, port) = args.toSeq + val sparkConf = new SparkConf().setAppName("ActorWordCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) /* * Following is the use of actorStream to plug in custom actor as receiver diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala index 5b2a1035fc779..38362edac27f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ @@ -29,9 +30,8 @@ import org.apache.spark.util.IntParam * an Avro server on at the request host:port address and listen for requests. * Your Flume AvroSink should be pointed to this address. * - * Usage: FlumeEventCount + * Usage: FlumeEventCount * - * is a Spark master URL * is the host the Flume receiver will be started on - a receiver * creates a server and listens for flume events. * is the port the Flume receiver will listen on. @@ -40,21 +40,21 @@ object FlumeEventCount { def main(args: Array[String]) { if (args.length != 3) { System.err.println( - "Usage: FlumeEventCount ") + "Usage: FlumeEventCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, host, IntParam(port)) = args + val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) + val sparkConf = new SparkConf().setAppName("FlumeEventCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, batchInterval) // Create a flume stream - val stream = FlumeUtils.createStream(ssc, host,port,StorageLevel.MEMORY_ONLY_SER_2) + val stream = FlumeUtils.createStream(ssc, host, port, StorageLevel.MEMORY_ONLY_SER_2) // Print out the count of events received from this server in each batch stream.count().map(cnt => "Received " + cnt + " flume events." ).print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala index b440956ba3137..55ac48cfb6d10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/HdfsWordCount.scala @@ -17,35 +17,35 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ /** * Counts words in new text files created in the given directory - * Usage: HdfsWordCount - * is the Spark master URL. + * Usage: HdfsWordCount * is the directory that Spark Streaming will use to find and read new text files. * * To run this on your local machine on directory `localdir`, run this example - * `$ ./bin/run-example org.apache.spark.examples.streaming.HdfsWordCount local[2] localdir` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.HdfsWordCount localdir` * Then create a text file in `localdir` and the words in the file will get counted. */ object HdfsWordCount { def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: HdfsWordCount ") + if (args.length < 1) { + System.err.println("Usage: HdfsWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("HdfsWordCount") // Create the context - val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create the FileInputDStream on the directory and use the // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) + val lines = ssc.textFileStream(args(0)) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index c3aae5af05b1c..3af806981f37a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -24,34 +24,33 @@ import kafka.producer._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.kafka._ +import org.apache.spark.SparkConf -// scalastyle:off /** * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: KafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: KafkaWordCount * is a list of one or more zookeeper servers that make quorum * is the name of kafka consumer group * is a list of one or more kafka topics to consume from * is the number of threads the kafka consumer should use * * Example: - * `./bin/run-example org.apache.spark.examples.streaming.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` + * `./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.KafkaWordCount local[2] zoo01,zoo02,zoo03 \ + * my-consumer-group topic1,topic2 1` */ -// scalastyle:on object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount ") + if (args.length < 4) { + System.err.println("Usage: KafkaWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, zkQuorum, group, topics, numThreads) = args - - val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val Array(zkQuorum, group, topics, numThreads) = args + val sparkConf = new SparkConf().setAppName("KafkaWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 47bf1e5a06439..3a10daa9ab84a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -24,6 +24,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.mqtt._ +import org.apache.spark.SparkConf /** * A simple Mqtt publisher for demonstration purposes, repeatedly publishes @@ -64,7 +65,6 @@ object MQTTPublisher { } } -// scalastyle:off /** * A sample wordcount with MqttStream stream * @@ -74,30 +74,28 @@ object MQTTPublisher { * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/ * Example Java code for Mqtt Publisher and Subscriber can be found here * https://bitbucket.org/mkjinesh/mqttclient - * Usage: MQTTWordCount - * In local mode, should be 'local[n]' with n > 1 - * and describe where Mqtt publisher is running. + * Usage: MQTTWordCount +\ * and describe where Mqtt publisher is running. * * To run this example locally, you may run publisher as - * `$ ./bin/run-example org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.MQTTPublisher tcp://localhost:1883 foo` * and run the example as - * `$ ./bin/run-example org.apache.spark.examples.streaming.MQTTWordCount local[2] tcp://localhost:1883 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.MQTTWordCount tcp://localhost:1883 foo` */ -// scalastyle:on object MQTTWordCount { def main(args: Array[String]) { - if (args.length < 3) { + if (args.length < 2) { System.err.println( - "Usage: MQTTWordCount " + - " In local mode, should be 'local[n]' with n > 1") + "Usage: MQTTWordCount ") System.exit(1) } - val Seq(master, brokerUrl, topic) = args.toSeq - - val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"), - StreamingContext.jarOfClass(this.getClass).toSeq) + val Seq(brokerUrl, topic) = args.toSeq + val sparkConf = new SparkConf().setAppName("MQTTWordCount") + val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.toString.split(" ")) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index acfe9a4da3596..ad7a199b2c0ab 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -17,41 +17,38 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.storage.StorageLevel -// scalastyle:off /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. + * Usage: NetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.NetworkWordCount localhost 9999` */ -// scalastyle:on object NetworkWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: NetworkWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("NetworkWordCount"); // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt, StorageLevel.MEMORY_ONLY_SER) + val lines = ssc.socketTextStream(args(0), args(1).toInt, StorageLevel.MEMORY_ONLY_SER) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index f92f72f2de876..4caa90659111a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.streaming import scala.collection.mutable.SynchronizedQueue +import org.apache.spark.SparkConf import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ @@ -26,16 +27,11 @@ import org.apache.spark.streaming.StreamingContext._ object QueueStream { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: QueueStream ") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() - + val sparkConf = new SparkConf().setAppName("QueueStream") // Create the context - val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create the queue through which RDDs can be pushed to // a QueueInputDStream diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 1b0319a046433..a9aaa445bccb6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -17,6 +17,7 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.util.IntParam @@ -27,29 +28,26 @@ import org.apache.spark.util.IntParam * will only work with spark.streaming.util.RawTextSender running on all worker nodes * and with Spark using Kryo serialization (set Java property "spark.serializer" to * "org.apache.spark.serializer.KryoSerializer"). - * Usage: RawNetworkGrep - * is the Spark master URL + * Usage: RawNetworkGrep * is the number rawNetworkStreams, which should be same as number * of work nodes in the cluster * is "localhost". * is the port on which RawTextSender is running in the worker nodes. * is the Spark Streaming batch duration in milliseconds. */ - object RawNetworkGrep { def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: RawNetworkGrep ") + if (args.length != 4) { + System.err.println("Usage: RawNetworkGrep ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - + val Array(IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + val sparkConf = new SparkConf().setAppName("RawNetworkGrep") // Create the context - val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Duration(batchMillis)) val rawStreams = (1 to numStreams).map(_ => ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index b0bc31cc66ab5..ace785d9fe4c5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -17,19 +17,21 @@ package org.apache.spark.examples.streaming +import java.io.File +import java.nio.charset.Charset + +import com.google.common.io.Files + +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.util.IntParam -import java.io.File -import org.apache.spark.rdd.RDD -import com.google.common.io.Files -import java.nio.charset.Charset /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: NetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. directory to HDFS-compatible file system which checkpoint data * file to which the word counts will be appended @@ -44,8 +46,9 @@ import java.nio.charset.Charset * * and run the example as * - * `$ ./run-example org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ - * local[2] localhost 9999 ~/checkpoint/ ~/out` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ + * localhost 9999 ~/checkpoint/ ~/out` * * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if @@ -67,17 +70,16 @@ import java.nio.charset.Charset object RecoverableNetworkWordCount { - def createContext(master: String, ip: String, port: Int, outputPath: String) = { + def createContext(ip: String, port: Int, outputPath: String) = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint println("Creating new context") val outputFile = new File(outputPath) if (outputFile.exists()) outputFile.delete() - + val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount") // Create the context with a 1 second batch size - val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') @@ -94,13 +96,12 @@ object RecoverableNetworkWordCount { } def main(args: Array[String]) { - if (args.length != 5) { + if (args.length != 4) { System.err.println("You arguments were " + args.mkString("[", ", ", "]")) System.err.println( """ - |Usage: RecoverableNetworkWordCount - | is the Spark master URL. In local mode, should be - | 'local[n]' with n > 1. and describe the TCP server that Spark + |Usage: RecoverableNetworkWordCount + | . and describe the TCP server that Spark | Streaming would connect to receive data. directory to | HDFS-compatible file system which checkpoint data file to which the | word counts will be appended @@ -111,10 +112,10 @@ object RecoverableNetworkWordCount { ) System.exit(1) } - val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args + val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args val ssc = StreamingContext.getOrCreate(checkpointDirectory, () => { - createContext(master, ip, port, outputPath) + createContext(ip, port, outputPath) }) ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 8001d56c98d86..5e1415f3cc536 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -17,28 +17,27 @@ package org.apache.spark.examples.streaming +import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// scalastyle:off + /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every * second. - * Usage: StatefulNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * Usage: StatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./bin/run-example org.apache.spark.examples.streaming.StatefulNetworkWordCount local[2] localhost 9999` + * `$ ./bin/spark-submit examples.jar + * --class org.apache.spark.examples.streaming.StatefulNetworkWordCount localhost 9999` */ -// scalastyle:on object StatefulNetworkWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: StatefulNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: StatefulNetworkWordCount ") System.exit(1) } @@ -52,14 +51,14 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } + val sparkConf = new SparkConf().setAppName("NetworkWordCumulativeCountUpdateStateByKey") // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", - Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) + val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index b12617d881787..683752ac96241 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -19,11 +19,13 @@ package org.apache.spark.examples.streaming import com.twitter.algebird._ +import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.twitter._ + // scalastyle:off /** * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute @@ -49,12 +51,6 @@ import org.apache.spark.streaming.twitter._ // scalastyle:on object TwitterAlgebirdCMS { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdCMS " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - StreamingExamples.setStreamingLogLevels() // CMS parameters @@ -65,10 +61,9 @@ object TwitterAlgebirdCMS { // K highest frequency elements to take val TOPK = 10 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterAlgebirdCMS") + val ssc = new StreamingContext(sparkConf, Seconds(10)) val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER_2) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 22f232c72545c..62db5e663b8af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -23,6 +23,8 @@ import com.twitter.algebird.HyperLogLog._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.twitter._ +import org.apache.spark.SparkConf + // scalastyle:off /** * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute @@ -42,20 +44,14 @@ import org.apache.spark.streaming.twitter._ // scalastyle:on object TwitterAlgebirdHLL { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdHLL " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ val BIT_SIZE = 12 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterAlgebirdHLL") + val ssc = new StreamingContext(sparkConf, Seconds(5)) val stream = TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala index 5b58e94600a16..1ddff22cb8a42 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala @@ -21,6 +21,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import StreamingContext._ import org.apache.spark.SparkContext._ import org.apache.spark.streaming.twitter._ +import org.apache.spark.SparkConf /** * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter @@ -30,18 +31,12 @@ import org.apache.spark.streaming.twitter._ */ object TwitterPopularTags { def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterPopularTags " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } StreamingExamples.setStreamingLogLevels() - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val filters = args + val sparkConf = new SparkConf().setAppName("TwitterPopularTags") + val ssc = new StreamingContext(sparkConf, Seconds(2)) val stream = TwitterUtils.createStream(ssc, None, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index de46e5f5b10b6..7ade3f1018ee8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -28,6 +28,7 @@ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.zeromq._ import scala.language.implicitConversions +import org.apache.spark.SparkConf /** * A simple publisher for demonstration purposes, repeatedly publishes random Messages @@ -63,30 +64,28 @@ object SimpleZeroMQPublisher { * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide] * (http://www.zeromq.org/intro:get-the-software) * - * Usage: ZeroMQWordCount - * In local mode, should be 'local[n]' with n > 1 + * Usage: ZeroMQWordCount * and describe where zeroMq publisher is running. * * To run this example locally, you may run publisher as - * `$ ./bin/run-example org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` * and run the example as - * `$ ./bin/run-example org.apache.spark.examples.streaming.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` + * `$ ./bin/spark-submit examples.jar \ + * --class org.apache.spark.examples.streaming.ZeroMQWordCount tcp://127.0.1.1:1234 foo` */ // scalastyle:on object ZeroMQWordCount { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ZeroMQWordCount " + - "In local mode, should be 'local[n]' with n > 1") + if (args.length < 2) { + System.err.println("Usage: ZeroMQWordCount ") System.exit(1) } StreamingExamples.setStreamingLogLevels() - val Seq(master, url, topic) = args.toSeq - + val Seq(url, topic) = args.toSeq + val sparkConf = new SparkConf().setAppName("ZeroMQWordCount") // Create the context and set the batch size - val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), - System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) + val ssc = new StreamingContext(sparkConf, Seconds(2)) def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala index fa533a512d53b..d901d4fe225fe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala @@ -27,10 +27,14 @@ import org.apache.spark.graphx.PartitionStrategy._ object Analytics extends Logging { def main(args: Array[String]): Unit = { - val host = args(0) - val taskType = args(1) - val fname = args(2) - val options = args.drop(3).map { arg => + if (args.length < 2) { + System.err.println("Usage: Analytics [other options]") + System.exit(1) + } + + val taskType = args(0) + val fname = args(1) + val options = args.drop(2).map { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) case _ => throw new IllegalArgumentException("Invalid argument: " + arg) @@ -71,7 +75,7 @@ object Analytics extends Logging { println("| PageRank |") println("======================================") - val sc = new SparkContext(host, "PageRank(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, minEdgePartitions = numEPart).cache() @@ -115,7 +119,7 @@ object Analytics extends Logging { println("| Connected Components |") println("======================================") - val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, minEdgePartitions = numEPart).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -137,7 +141,7 @@ object Analytics extends Logging { println("======================================") println("| Triangle Count |") println("======================================") - val sc = new SparkContext(host, "TriangleCount(" + fname + ")", conf) + val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, minEdgePartitions = numEPart).partitionBy(partitionStrategy).cache() val triangles = TriangleCount.run(graph) From c3f8b78c211df6c5adae74f37e39fb55baeff723 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 8 May 2014 12:13:07 -0700 Subject: [PATCH 02/33] [SPARK-1745] Move interrupted flag from TaskContext constructor (minor) It makes little sense to start a TaskContext that is interrupted. Indeed, I searched for all use cases of it and didn't find a single instance in which `interrupted` is true on construction. This was inspired by reviewing #640, which adds an additional `@volatile var completed` that is similar. These are not the most urgent changes, but I wanted to push them out before I forget. Author: Andrew Or Closes #675 from andrewor14/task-context and squashes the following commits: 9575e02 [Andrew Or] Add space 69455d1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into task-context c471490 [Andrew Or] Oops, removed one flag too many. Adding it back. 85311f8 [Andrew Or] Move interrupted flag from TaskContext constructor --- .../scala/org/apache/spark/TaskContext.scala | 20 ++++++++++--------- .../spark/scheduler/ShuffleMapTask.scala | 3 +-- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 10 +++------- .../org/apache/spark/PipedRDDSuite.scala | 4 +--- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index fc4812753d005..51f40c339d13c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -28,13 +28,12 @@ import org.apache.spark.executor.TaskMetrics */ @DeveloperApi class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - @volatile var interrupted: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty -) extends Serializable { + val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends Serializable { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -42,7 +41,10 @@ class TaskContext( // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] - // Set to true when the task is completed, before the onCompleteCallbacks are executed. + // Whether the corresponding task has been killed. + @volatile var interrupted: Boolean = false + + // Whether the task has completed, before the onCompleteCallbacks are executed. @volatile var completed: Boolean = false /** @@ -58,6 +60,6 @@ class TaskContext( def executeOnCompleteCallbacks() { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach{_()} + onCompleteCallbacks.reverse.foreach { _() } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2259df0b56bad..4b0324f2b5447 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -23,7 +23,6 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.util.Try import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics @@ -70,7 +69,7 @@ private[spark] object ShuffleMapTask { } // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c3e03cea917b3..1912015827927 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -597,7 +597,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index fd5b0906e6765..4f178db40f638 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.mock.EasyMockSugar import org.apache.spark.rdd.RDD -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage._ // TODO: Test the CacheManager's thread-safety aspects @@ -59,8 +58,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -72,8 +70,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,8 +83,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false, - taskMetrics = TaskMetrics.empty) + val context = new TaskContext(0, 0, 0, runningLocally = true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala index 0bb6a6b09c5b5..db56a4acdd6f5 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -178,14 +178,12 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, - taskMetrics = TaskMetrics.empty) + val tContext = new TaskContext(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") } else { // printenv isn't available so just pass the test - assert(true) } } From 5c5e7d5809d337ce41a7a90eb9201e12803aba48 Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Thu, 8 May 2014 13:07:30 -0700 Subject: [PATCH 03/33] Fixing typo in als.py XtY should be Xty. Author: Evan Sparks Closes #696 from etrain/patch-2 and squashes the following commits: 634cb8d [Evan Sparks] Fixing typo in als.py --- examples/src/main/python/als.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 33700ab4f8c53..01552dc1d449e 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -38,7 +38,7 @@ def update(i, vec, mat, ratings): ff = mat.shape[1] XtX = mat.T * mat - XtY = mat.T * ratings[i, :].T + Xty = mat.T * ratings[i, :].T for j in range(ff): XtX[j,j] += LAMBDA * uu From 322b1808d21143dc323493203929488d69e8878a Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 8 May 2014 15:31:47 -0700 Subject: [PATCH 04/33] [SPARK-1754] [SQL] Add missing arithmetic DSL operations. Add missing arithmetic DSL operations: `unary_-`, `%`. Author: Takuya UESHIN Closes #689 from ueshin/issues/SPARK-1754 and squashes the following commits: a09ef69 [Takuya UESHIN] Add also missing ! (not) operation. f73ae2c [Takuya UESHIN] Remove redundant tests. 5b3f087 [Takuya UESHIN] Add tests relating DSL operations. e09c5b8 [Takuya UESHIN] Add missing arithmetic DSL operations. --- .../apache/spark/sql/catalyst/dsl/package.scala | 4 ++++ .../expressions/ExpressionEvaluationSuite.scala | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index dc83485df195c..78d3a1d8096af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -57,10 +57,14 @@ package object dsl { trait ImplicitOperators { def expr: Expression + def unary_- = UnaryMinus(expr) + def unary_! = Not(expr) + def + (other: Expression) = Add(expr, other) def - (other: Expression) = Subtract(expr, other) def * (other: Expression) = Multiply(expr, other) def / (other: Expression) = Divide(expr, other) + def % (other: Expression) = Remainder(expr, other) def && (other: Expression) = And(expr, other) def || (other: Expression) = Or(expr, other) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 91605d0a260e5..344d8a304fc11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -61,7 +61,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = Not(Literal(v, BooleanType)) + val expr = ! Literal(v, BooleanType) val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } @@ -381,6 +381,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(-c1, -1, row) + checkEvaluation(c1 + c2, 3, row) + checkEvaluation(c1 - c2, -1, row) + checkEvaluation(c1 * c2, 2, row) + checkEvaluation(c1 / c2, 0, row) + checkEvaluation(c1 % c2, 1, row) } test("BinaryComparison") { @@ -395,6 +402,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) + + checkEvaluation(c1 < c2, true, row) + checkEvaluation(c1 <= c2, true, row) + checkEvaluation(c1 > c2, false, row) + checkEvaluation(c1 >= c2, false, row) + checkEvaluation(c1 === c2, false, row) + checkEvaluation(c1 !== c2, true, row) } } From d38febee46ed156b0c8ec64757db6c290e488421 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 8 May 2014 17:52:32 -0700 Subject: [PATCH 05/33] MLlib documentation fix Fixed the documentation for that `loadLibSVMData` is changed to `loadLibSVMFile`. Author: DB Tsai Closes #703 from dbtsai/dbtsai-docfix and squashes the following commits: 71dd508 [DB Tsai] loadLibSVMData is changed to loadLibSVMFile --- docs/mllib-basics.md | 8 ++++---- docs/mllib-linear-methods.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index 704308802d65b..aa9321a547097 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -184,7 +184,7 @@ After loading, the feature indices are converted to zero-based.
-[`MLUtils.loadLibSVMData`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training +[`MLUtils.loadLibSVMFile`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. {% highlight scala %} @@ -192,12 +192,12 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -val training: RDD[LabeledPoint] = MLUtils.loadLibSVMData(sc, "mllib/data/sample_libsvm_data.txt") +val training: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") {% endhighlight %}
-[`MLUtils.loadLibSVMData`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training +[`MLUtils.loadLibSVMFile`](api/mllib/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. {% highlight java %} @@ -205,7 +205,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDDimport; -RDD training = MLUtils.loadLibSVMData(jsc, "mllib/data/sample_libsvm_data.txt"); +RDD training = MLUtils.loadLibSVMFile(jsc, "mllib/data/sample_libsvm_data.txt"); {% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 40b7a7f80708c..eff617d8641e2 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -186,7 +186,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMData(sc, "mllib/data/sample_libsvm_data.txt") +val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") // Split data into training (60%) and test (40%). val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) From 910a13b3c52a6309068b4997da6df6b7d6058a1b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 8 May 2014 17:53:22 -0700 Subject: [PATCH 06/33] [SPARK-1157][MLlib] Bug fix: lossHistory should exclude rejection steps, and remove miniBatch Getting the lossHistory from Breeze's API which already excludes the rejection steps in line search. Also, remove the miniBatch in LBFGS since those quasi-Newton methods approximate the inverse of Hessian. It doesn't make sense if the gradients are computed from a varying objective. Author: DB Tsai Closes #582 from dbtsai/dbtsai-lbfgs-bug and squashes the following commits: 9cc6cf9 [DB Tsai] Removed the miniBatch in LBFGS. 1ba6a33 [DB Tsai] Formatting the code. d72c679 [DB Tsai] Using Breeze's states to get the loss. --- .../spark/mllib/optimization/LBFGS.scala | 63 ++++++++----------- .../spark/mllib/optimization/LBFGSSuite.scala | 15 ++--- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 969a0c5f7c953..8f187c9df5102 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -42,7 +42,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) private var convergenceTol = 1E-4 private var maxNumIterations = 100 private var regParam = 0.0 - private var miniBatchFraction = 1.0 /** * Set the number of corrections used in the LBFGS update. Default 10. @@ -57,14 +56,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } - /** - * Set fraction of data to be used for each L-BFGS iteration. Default 1.0. - */ - def setMiniBatchFraction(fraction: Double): this.type = { - this.miniBatchFraction = fraction - this - } - /** * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. @@ -110,7 +101,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) } override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { - val (weights, _) = LBFGS.runMiniBatchLBFGS( + val (weights, _) = LBFGS.runLBFGS( data, gradient, updater, @@ -118,7 +109,6 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) convergenceTol, maxNumIterations, regParam, - miniBatchFraction, initialWeights) weights } @@ -132,10 +122,8 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) @DeveloperApi object LBFGS extends Logging { /** - * Run Limited-memory BFGS (L-BFGS) in parallel using mini batches. - * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data - * in order to compute a gradient estimate. - * Sampling, and averaging the subgradients over this subset is performed using one standard + * Run Limited-memory BFGS (L-BFGS) in parallel. + * Averaging the subgradients over different partitions is performed using one standard * spark map-reduce in each iteration. * * @param data - Input data for L-BFGS. RDD of the set of data examples, each of @@ -147,14 +135,12 @@ object LBFGS extends Logging { * @param convergenceTol - The convergence tolerance of iterations for L-BFGS * @param maxNumIterations - Maximal number of iterations that L-BFGS can be run. * @param regParam - Regularization parameter - * @param miniBatchFraction - Fraction of the input data set that should be used for - * one iteration of L-BFGS. Default value 1.0. * * @return A tuple containing two elements. The first element is a column matrix containing * weights for every feature, and the second element is an array containing the loss * computed for every iteration. */ - def runMiniBatchLBFGS( + def runLBFGS( data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, @@ -162,23 +148,33 @@ object LBFGS extends Logging { convergenceTol: Double, maxNumIterations: Int, regParam: Double, - miniBatchFraction: Double, initialWeights: Vector): (Vector, Array[Double]) = { val lossHistory = new ArrayBuffer[Double](maxNumIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction val costFun = - new CostFun(data, gradient, updater, regParam, miniBatchFraction, lossHistory, miniBatchSize) + new CostFun(data, gradient, updater, regParam, numExamples) val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) - val weights = Vectors.fromBreeze( - lbfgs.minimize(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)) + val states = + lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + /** + * NOTE: lossSum and loss is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + var state = states.next() + while(states.hasNext) { + lossHistory.append(state.value) + state = states.next() + } + lossHistory.append(state.value) + val weights = Vectors.fromBreeze(state.x) - logInfo("LBFGS.runMiniBatchSGD finished. Last 10 losses %s".format( + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( lossHistory.takeRight(10).mkString(", "))) (weights, lossHistory.toArray) @@ -193,9 +189,7 @@ object LBFGS extends Logging { gradient: Gradient, updater: Updater, regParam: Double, - miniBatchFraction: Double, - lossHistory: ArrayBuffer[Double], - miniBatchSize: Double) extends DiffFunction[BDV[Double]] { + numExamples: Long) extends DiffFunction[BDV[Double]] { private var i = 0 @@ -204,8 +198,7 @@ object LBFGS extends Logging { val localData = data val localGradient = gradient - val (gradientSum, lossSum) = localData.sample(false, miniBatchFraction, 42 + i) - .aggregate((BDV.zeros[Double](weights.size), 0.0))( + val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad)) @@ -223,7 +216,7 @@ object LBFGS extends Logging { Vectors.fromBreeze(weights), Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 - val loss = lossSum / miniBatchSize + regVal + val loss = lossSum / numExamples + regVal /** * It will return the gradient part of regularization using updater. * @@ -245,14 +238,8 @@ object LBFGS extends Logging { Vectors.fromBreeze(weights), Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze - // gradientTotal = gradientSum / miniBatchSize + gradientTotal - axpy(1.0 / miniBatchSize, gradientSum, gradientTotal) - - /** - * NOTE: lossSum and loss is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - lossHistory.append(loss) + // gradientTotal = gradientSum / numExamples + gradientTotal + axpy(1.0 / numExamples, gradientSum, gradientTotal) i += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index f33770aed30bd..6af1b502eb4dd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -59,7 +59,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val convergenceTol = 1e-12 val maxNumIterations = 10 - val (_, loss) = LBFGS.runMiniBatchLBFGS( + val (_, loss) = LBFGS.runLBFGS( dataRDD, gradient, simpleUpdater, @@ -67,7 +67,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Since the cost function is convex, the loss is guaranteed to be monotonically decreasing @@ -104,7 +103,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val convergenceTol = 1e-12 val maxNumIterations = 10 - val (weightLBFGS, lossLBFGS) = LBFGS.runMiniBatchLBFGS( + val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -112,7 +111,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) val numGDIterations = 50 @@ -150,7 +148,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { val maxNumIterations = 8 var convergenceTol = 0.0 - val (_, lossLBFGS1) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS1) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -158,7 +156,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Note that the first loss is computed with initial weights, @@ -166,7 +163,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { assert(lossLBFGS1.length == 9) convergenceTol = 0.1 - val (_, lossLBFGS2) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS2) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -174,7 +171,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed. @@ -182,7 +178,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { assert((lossLBFGS2(2) - lossLBFGS2(3)) / lossLBFGS2(2) < convergenceTol) convergenceTol = 0.01 - val (_, lossLBFGS3) = LBFGS.runMiniBatchLBFGS( + val (_, lossLBFGS3) = LBFGS.runLBFGS( dataRDD, gradient, squaredL2Updater, @@ -190,7 +186,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers { convergenceTol, maxNumIterations, regParam, - miniBatchFrac, initialWeightsWithIntercept) // With smaller convergenceTol, it takes more steps. From 191279ce4edb940821d11a6b25cd33c8ad0af054 Mon Sep 17 00:00:00 2001 From: Funes Date: Thu, 8 May 2014 17:54:10 -0700 Subject: [PATCH 07/33] Bug fix of sparse vector conversion Fixed a small bug caused by the inconsistency of index/data array size and vector length. Author: Funes Author: funes Closes #661 from funes/bugfix and squashes the following commits: edb2b9d [funes] remove unused import 75dced3 [Funes] update test case d129a66 [Funes] Add test for sparse breeze by vector builder 64e7198 [Funes] Copy data only when necessary b85806c [Funes] Bug fix of sparse vector conversion --- .../scala/org/apache/spark/mllib/linalg/Vectors.scala | 6 +++++- .../spark/mllib/linalg/BreezeVectorConversionSuite.scala | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 7cdf6bd56acd9..84d223908c1f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -136,7 +136,11 @@ object Vectors { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one } case v: BSV[Double] => - new SparseVector(v.length, v.index, v.data) + if (v.index.length == v.used) { + new SparseVector(v.length, v.index, v.data) + } else { + new SparseVector(v.length, v.index.slice(0, v.used), v.data.slice(0, v.used)) + } case v: BV[_] => sys.error("Unsupported Breeze vector type: " + v.getClass.getName) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index aacaa300849aa..8abdac72902c6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -55,4 +55,13 @@ class BreezeVectorConversionSuite extends FunSuite { assert(vec.indices.eq(indices), "should not copy data") assert(vec.values.eq(values), "should not copy data") } + + test("sparse breeze with partially-used arrays to vector") { + val activeSize = 3 + val breeze = new BSV[Double](indices, values, activeSize, n) + val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector] + assert(vec.size === n) + assert(vec.indices === indices.slice(0, activeSize)) + assert(vec.values === values.slice(0, activeSize)) + } } From 2fd2752e572921a9010614eb1c1238c493d34a7c Mon Sep 17 00:00:00 2001 From: Bouke van der Bijl Date: Thu, 8 May 2014 20:43:37 -0700 Subject: [PATCH 08/33] Include the sbin/spark-config.sh in spark-executor This is needed because broadcast values are broken on pyspark on Mesos, it tries to import pyspark but can't, as the PYTHONPATH is not set due to changes in ff5be9a4 https://issues.apache.org/jira/browse/SPARK-1725 Author: Bouke van der Bijl Closes #651 from bouk/include-spark-config-in-mesos-executor and squashes the following commits: b2f1295 [Bouke van der Bijl] Inline PYTHONPATH in spark-executor eedbbcc [Bouke van der Bijl] Include the sbin/spark-config.sh in spark-executor --- sbin/spark-executor | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sbin/spark-executor b/sbin/spark-executor index de5bfab563125..336549f29c9ce 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -19,5 +19,8 @@ FWDIR="$(cd `dirname $0`/..; pwd)" +export PYTHONPATH=$FWDIR/python:$PYTHONPATH +export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.1-src.zip:$PYTHONPATH + echo "Running spark-executor with framework dir = $FWDIR" exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend From 8b7841299439b7dc590b2f7e2339f24e8f3e19f6 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 8 May 2014 20:45:29 -0700 Subject: [PATCH 09/33] [SPARK-1755] Respect SparkSubmit --name on YARN Right now, SparkSubmit ignores the `--name` flag for both yarn-client and yarn-cluster. This is a bug. In client mode, SparkSubmit treats `--name` as a [cluster config](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala#L170) and does not propagate this to SparkContext. In cluster mode, SparkSubmit passes this flag to `org.apache.spark.deploy.yarn.Client`, which only uses it for the [YARN ResourceManager](https://github.com/apache/spark/blob/master/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala#L80), but does not propagate this to SparkContext. This PR ensures that `spark.app.name` is always set if SparkSubmit receives the `--name` flag, which is what the usage promises. This makes it possible for applications to start a SparkContext with an empty conf `val sc = new SparkContext(new SparkConf)`, and inherit the app name from SparkSubmit. Tested both modes on a YARN cluster. Author: Andrew Or Closes #699 from andrewor14/yarn-app-name and squashes the following commits: 98f6a79 [Andrew Or] Fix tests dea932f [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-app-name c86d9ca [Andrew Or] Respect SparkSubmit --name on YARN --- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 9 +++++---- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 10 ++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index e39723f38347c..16de6f7cdb100 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -160,6 +160,7 @@ object SparkSubmit { // each deploy mode; we iterate through these below val options = List[OptionAssigner]( OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"), OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true, @@ -167,7 +168,7 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraLibraryPath"), OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"), - OptionAssigner(args.name, YARN, true, clOption = "--name"), + OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), OptionAssigner(args.queue, YARN, true, clOption = "--queue"), OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"), @@ -188,8 +189,7 @@ object SparkSubmit { OptionAssigner(args.jars, YARN, true, clOption = "--addJars"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars"), - OptionAssigner(args.name, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.app.name") + OptionAssigner(args.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars") ) // For client mode make any added jars immediately visible on the classpath @@ -205,7 +205,8 @@ object SparkSubmit { (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) - } else if (opt.sysProp != null) { + } + if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d7e3b22ed476e..c9edb03cdeb0f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -104,7 +104,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) @@ -122,7 +122,8 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { childArgsStr should include ("--num-executors 6") mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) - sysProps should have size (1) + sysProps("spark.app.name") should be ("beauty") + sysProps("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -130,8 +131,8 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "thejar.jar", - "arg1", "arg2") + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", + "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") @@ -140,6 +141,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { classpath should contain ("one.jar") classpath should contain ("two.jar") classpath should contain ("three.jar") + sysProps("spark.app.name") should be ("trill") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.executor.cores") should be ("5") sysProps("spark.yarn.queue") should be ("thequeue") From 3f779d872d8459b262b3db9e4d12b011910b6ce9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 8 May 2014 20:46:11 -0700 Subject: [PATCH 10/33] [SPARK-1631] Correctly set the Yarn app name when launching the AM. Author: Marcelo Vanzin Closes #539 from vanzin/yarn-app-name and squashes the following commits: 7d1ca4f [Marcelo Vanzin] [SPARK-1631] Correctly set the Yarn app name when launching the AM. --- .../scheduler/cluster/YarnClientSchedulerBackend.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index ce2dde0631ed9..2924189077b7d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -35,10 +35,10 @@ private[spark] class YarnClientSchedulerBackend( private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { - if (System.getProperty(sysProp) != null) { - arrayBuf += (optionName, System.getProperty(sysProp)) - } else if (System.getenv(envVar) != null) { + if (System.getenv(envVar) != null) { arrayBuf += (optionName, System.getenv(envVar)) + } else if (sc.getConf.contains(sysProp)) { + arrayBuf += (optionName, sc.getConf.get(sysProp)) } } From 06b15baab25951d124bbe6b64906f4139e037deb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 8 May 2014 22:26:17 -0700 Subject: [PATCH 11/33] SPARK-1565 (Addendum): Replace `run-example` with `spark-submit`. Gives a nicely formatted message to the user when `run-example` is run to tell them to use `spark-submit`. Author: Patrick Wendell Closes #704 from pwendell/examples and squashes the following commits: 1996ee8 [Patrick Wendell] Feedback form Andrew 3eb7803 [Patrick Wendell] Suggestions from TD 2474668 [Patrick Wendell] SPARK-1565 (Addendum): Replace `run-example` with `spark-submit`. --- README.md | 19 +++-- bin/pyspark | 2 +- bin/run-example | 71 +++++-------------- bin/spark-class | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- docs/running-on-yarn.md | 2 +- make-distribution.sh | 2 + 7 files changed, 37 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index e2d1dcb5672ff..9c2e32b90f162 100644 --- a/README.md +++ b/README.md @@ -39,17 +39,22 @@ And run the following command, which should also return 1000: ## Example Programs Spark also comes with several sample programs in the `examples` directory. -To run one of them, use `./bin/run-example `. For example: +To run one of them, use `./bin/run-example [params]`. For example: - ./bin/run-example org.apache.spark.examples.SparkLR local[2] + ./bin/run-example org.apache.spark.examples.SparkLR -will run the Logistic Regression example locally on 2 CPUs. +will run the Logistic Regression example locally. -Each of the example programs prints usage help if no params are given. +You can set the MASTER environment variable when running examples to submit +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You +can also use an abbreviated class name if the class is in the `examples` +package. For instance: -All of the Spark samples take a `` parameter that is the cluster URL -to connect to. This can be a mesos:// or spark:// URL, or "local" to run -locally with one thread, or "local[N]" to run locally with N threads. + MASTER=spark://host:7077 ./bin/run-example SparkPi + +Many of the example programs print usage help if no params are given. ## Running Tests diff --git a/bin/pyspark b/bin/pyspark index f5558853e8a4e..10e35e0f1734e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -31,7 +31,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2 - echo "You need to build Spark with sbt/sbt assembly before running this program" >&2 + echo "You need to build Spark before running this program" >&2 exit 1 fi fi diff --git a/bin/run-example b/bin/run-example index d8a94f2e31e07..146951ac0ee56 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,28 +17,10 @@ # limitations under the License. # -cygwin=false -case "`uname`" in - CYGWIN*) cygwin=true;; -esac - SCALA_VERSION=2.10 -# Figure out where the Scala framework is installed FWDIR="$(cd `dirname $0`/..; pwd)" - -# Export this as SPARK_HOME export SPARK_HOME="$FWDIR" - -. $FWDIR/bin/load-spark-env.sh - -if [ -z "$1" ]; then - echo "Usage: run-example []" >&2 - exit 1 -fi - -# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack -# to avoid the -sources and -doc packages that are built by publish-local. EXAMPLES_DIR="$FWDIR"/examples if [ -f "$FWDIR/RELEASE" ]; then @@ -49,46 +31,29 @@ fi if [[ -z $SPARK_EXAMPLES_JAR ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" >&2 - echo "You need to build Spark with sbt/sbt assembly before running this program" >&2 + echo "You need to build Spark before running this program" >&2 exit 1 fi +EXAMPLE_MASTER=${MASTER:-"local[*]"} -# Since the examples JAR ideally shouldn't include spark-core (that dependency should be -# "provided"), also add our standard Spark classpath, built using compute-classpath.sh. -CLASSPATH=`$FWDIR/bin/compute-classpath.sh` -CLASSPATH="$SPARK_EXAMPLES_JAR:$CLASSPATH" - -if $cygwin; then - CLASSPATH=`cygpath -wp $CLASSPATH` - export SPARK_EXAMPLES_JAR=`cygpath -w $SPARK_EXAMPLES_JAR` -fi - -# Find java binary -if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" -else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi -fi - -# Set JAVA_OPTS to be able to load native libraries and to set heap size -JAVA_OPTS="$SPARK_JAVA_OPTS" -# Load extra JAVA_OPTS from conf/java-opts, if it exists -if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" +if [ -n "$1" ]; then + EXAMPLE_CLASS="$1" + shift +else + echo "usage: ./bin/run-example [example-args]" + echo " - set MASTER=XX to use a specific master" + echo " - can use abbreviated example class name (e.g. SparkPi, mllib.MovieLensALS)" + echo + exit -1 fi -export JAVA_OPTS -if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then - echo -n "Spark Command: " - echo "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" - echo "========================================" - echo +if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then + EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" +./bin/spark-submit \ + --master $EXAMPLE_MASTER \ + --class $EXAMPLE_CLASS \ + $SPARK_EXAMPLES_JAR \ + "$@" diff --git a/bin/spark-class b/bin/spark-class index 72f8b9bf9a495..6480ccb58d6aa 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -114,7 +114,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar") if [ "$num_jars" -eq "0" ]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2 - echo "You need to build Spark with 'sbt/sbt assembly' before running this program." >&2 + echo "You need to build Spark before running this program." >&2 exit 1 fi if [ "$num_jars" -gt "1" ]; then diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a1ca612cc9a09..9d8d8044f07eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -330,9 +330,9 @@ abstract class RDD[T: ClassTag]( if (shuffle) { // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), + new ShuffledRDD[Int, T, (Int, T)](map(x => (Utils.random.nextInt(), x)), new HashPartitioner(numPartitions)), - numPartitions).keys + numPartitions).values } else { new CoalescedRDD(this, numPartitions) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 68183ee8b4613..c563594296802 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -53,7 +53,7 @@ For example: --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 - examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + lib/spark-examples*.jar \ yarn-cluster 5 The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. diff --git a/make-distribution.sh b/make-distribution.sh index 759e555b4b69a..1cc2844703fbb 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -40,6 +40,8 @@ # set -o pipefail +set -e + # Figure out where the Spark framework is installed FWDIR="$(cd `dirname $0`; pwd)" DISTDIR="$FWDIR/dist" From 7db47c463fefc244e9c100d4aab90451c3828261 Mon Sep 17 00:00:00 2001 From: Sandeep Date: Thu, 8 May 2014 22:30:17 -0700 Subject: [PATCH 12/33] SPARK-1775: Unneeded lock in ShuffleMapTask.deserializeInfo This was used in the past to have a cache of deserialized ShuffleMapTasks, but that's been removed, so there's no need for a lock. It slows down Spark when task descriptions are large, e.g. due to large lineage graphs or local variables. Author: Sandeep Closes #707 from techaddict/SPARK-1775 and squashes the following commits: 18d8ebf [Sandeep] SPARK-1775: Unneeded lock in ShuffleMapTask.deserializeInfo This was used in the past to have a cache of deserialized ShuffleMapTasks, but that's been removed, so there's no need for a lock. It slows down Spark when task descriptions are large, e.g. due to large lineage graphs or local variables. --- .../apache/spark/scheduler/ShuffleMapTask.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 4b0324f2b5447..9ba586f7581cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -57,15 +57,13 @@ private[spark] object ShuffleMapTask { } def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - synchronized { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - (rdd, dep) - } + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + (rdd, dep) } // Since both the JarSet and FileSet have the same format this is used for both. From 4c60fd1e8c526278b7e5544d6164050d1aee0338 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 8 May 2014 22:33:06 -0700 Subject: [PATCH 13/33] MINOR: Removing dead code. Meant to do this when patching up the last merge. --- .../main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 9ba586f7581cf..ed0f56f1abdf5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -57,7 +57,6 @@ private[spark] object ShuffleMapTask { } def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) From 32868f31f88aebd580ab9329dc51a30c26af7a74 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 8 May 2014 22:34:08 -0700 Subject: [PATCH 14/33] Converted bang to ask to avoid scary warning when a block is removed Removing a block through the blockmanager gave a scary warning messages in the driver. ``` 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true 2014-05-08 20:16:19,172 WARN BlockManagerMasterActor: Got unknown message: true ``` This is because the [BlockManagerSlaveActor](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala#L44) would send back an acknowledgement ("true"). But the BlockManagerMasterActor would have sent the RemoveBlock message as a send, not as ask(), so would reject the receiver "true" as a unknown message. @pwendell Author: Tathagata Das Closes #708 from tdas/bm-fix and squashes the following commits: ed4ef15 [Tathagata Das] Converted bang to ask to avoid scary warning when a block is removed. --- .../org/apache/spark/storage/BlockManagerMasterActor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 98fa0df6ec289..6aed322eeb185 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -250,7 +250,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor ! RemoveBlock(blockId) + blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) } } } From bd67551ee724fd7cce4f2e2977a862216c992ef5 Mon Sep 17 00:00:00 2001 From: witgo Date: Fri, 9 May 2014 01:51:26 -0700 Subject: [PATCH 15/33] [SPARK-1760]: fix building spark with maven documentation Author: witgo Closes #712 from witgo/building-with-maven and squashes the following commits: 215523b [witgo] fix building spark with maven documentation --- docs/building-with-maven.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index cac01ded60d94..b6dd553bbe06b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,7 +96,7 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o The ScalaTest plugin also supports running only a specific test suite as follows: - $ mvn -Dhadoop.version=... -Dsuites=org.apache.spark.repl.ReplSuite test + $ mvn -Dhadoop.version=... -DwildcardSuites=org.apache.spark.repl.ReplSuite test ## Continuous Compilation ## From 59577df14c06417676a9ffdd599f5713c448e299 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Fri, 9 May 2014 14:51:34 -0700 Subject: [PATCH 16/33] SPARK-1770: Revert accidental(?) fix Looks like this change was accidentally committed here: https://github.com/apache/spark/commit/06b15baab25951d124bbe6b64906f4139e037deb but the change does not show up in the PR itself (#704). Other than not intending to go in with that PR, this also broke the test JavaAPISuite.repartition. Author: Aaron Davidson Closes #716 from aarondav/shufflerand and squashes the following commits: b1cf70b [Aaron Davidson] SPARK-1770: Revert accidental(?) fix --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9d8d8044f07eb..a1ca612cc9a09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -330,9 +330,9 @@ abstract class RDD[T: ClassTag]( if (shuffle) { // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, (Int, T)](map(x => (Utils.random.nextInt(), x)), + new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), new HashPartitioner(numPartitions)), - numPartitions).values + numPartitions).keys } else { new CoalescedRDD(this, numPartitions) } From 2f452cbaf35dbc609ab48ec0ee5e3dd7b6b9b790 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 9 May 2014 21:50:23 -0700 Subject: [PATCH 17/33] SPARK-1686: keep schedule() calling in the main thread https://issues.apache.org/jira/browse/SPARK-1686 moved from original JIRA (by @markhamstra): In deploy.master.Master, the completeRecovery method is the last thing to be called when a standalone Master is recovering from failure. It is responsible for resetting some state, relaunching drivers, and eventually resuming its scheduling duties. There are currently four places in Master.scala where completeRecovery is called. Three of them are from within the actor's receive method, and aren't problems. The last starts from within receive when the ElectedLeader message is received, but the actual completeRecovery() call is made from the Akka scheduler. That means that it will execute on a different scheduler thread, and Master itself will end up running (i.e., schedule() ) from that Akka scheduler thread. In this PR, I added a new master message TriggerSchedule to trigger the "local" call of schedule() in the scheduler thread Author: CodingCat Closes #639 from CodingCat/SPARK-1686 and squashes the following commits: 81bb4ca [CodingCat] rename variable 69e0a2a [CodingCat] style fix 36a2ac0 [CodingCat] address Aaron's comments ec9b7bb [CodingCat] address the comments 02b37ca [CodingCat] keep schedule() calling in the main thread --- .../org/apache/spark/deploy/master/Master.scala | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fdb633bd33608..f254f5585ba25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -104,6 +104,8 @@ private[spark] class Master( var leaderElectionAgent: ActorRef = _ + private var recoveryCompletionTask: Cancellable = _ + // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. @@ -152,6 +154,10 @@ private[spark] class Master( } override def postStop() { + // prevent the CompleteRecovery message sending to restarted master + if (recoveryCompletionTask != null) { + recoveryCompletionTask.cancel() + } webUi.stop() fileSystemsUsed.foreach(_.close()) masterMetricsSystem.stop() @@ -171,10 +177,13 @@ private[spark] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } + recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, + CompleteRecovery) } } + case CompleteRecovery => completeRecovery() + case RevokedLeadership => { logError("Leadership has been revoked -- master shutting down.") System.exit(0) @@ -465,7 +474,7 @@ private[spark] class Master( * Schedule the currently available resources among waiting apps. This method will be called * every time a new app joins or resource availability changes. */ - def schedule() { + private def schedule() { if (state != RecoveryState.ALIVE) { return } // First schedule drivers, they take strict precedence over applications @@ -485,7 +494,7 @@ private[spark] class Master( // Try to spread out each app among all the nodes, until it has all its cores for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse + .filter(canUse(app, _)).sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) From 561510867a1b79beef57acf9df65c9f88481435d Mon Sep 17 00:00:00 2001 From: witgo Date: Sat, 10 May 2014 10:15:04 -0700 Subject: [PATCH 18/33] [SPARK-1644] The org.datanucleus:* should not be packaged into spark-assembly-*.jar Author: witgo Closes #688 from witgo/SPARK-1644 and squashes the following commits: 56ad6ac [witgo] review commit 87c03e4 [witgo] Merge branch 'master' of https://github.com/apache/spark into SPARK-1644 6ffa7e4 [witgo] review commit a597414 [witgo] The org.datanucleus:* should not be packaged into spark-assembly-*.jar --- assembly/pom.xml | 1 + project/SparkBuild.scala | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 7d123fb1d7f02..6c4d46aeb67bd 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -96,6 +96,7 @@ *:* + org.datanucleus:* META-INF/*.SF META-INF/*.DSA META-INF/*.RSA diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7f9746ec4acc0..27e9505ec9831 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -579,12 +579,13 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq( test in assembly := {}, mergeStrategy in assembly := { - case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard - case "log4j.properties" => MergeStrategy.discard + case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard + case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard + case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case "log4j.properties" => MergeStrategy.discard case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines - case "reference.conf" => MergeStrategy.concat - case _ => MergeStrategy.first + case "reference.conf" => MergeStrategy.concat + case _ => MergeStrategy.first } ) From 4d6055329846f5e09472e5f844127a5ab5880e15 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 10 May 2014 11:48:01 -0700 Subject: [PATCH 19/33] [SQL] Upgrade parquet library. I think we are hitting this issue in some perf tests: https://github.com/Parquet/parquet-mr/commit/6aed5288fd4a1398063a5a219b2ae4a9f71b02cf Credit to @aarondav ! Author: Michael Armbrust Closes #684 from marmbrus/upgradeParquet and squashes the following commits: e10a619 [Michael Armbrust] Upgrade parquet library. --- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index fb904e868cfaf..e0bff60a54cde 100644 --- a/pom.xml +++ b/pom.xml @@ -123,7 +123,7 @@ ${hadoop.version} 0.94.6 0.12.0 - 1.3.2 + 1.4.3 1.2.3 8.1.14.v20131031 0.3.6 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 27e9505ec9831..af882b3ea7beb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -300,7 +300,7 @@ object SparkBuild extends Build { val jets3tVersion = if ("^2\\.[3-9]+".r.findFirstIn(hadoopVersion).isDefined) "0.9.0" else "0.7.1" val jettyVersion = "8.1.14.v20131031" val hiveVersion = "0.12.0" - val parquetVersion = "1.3.2" + val parquetVersion = "1.4.3" val slf4jVersion = "1.7.5" val excludeNetty = ExclusionRule(organization = "org.jboss.netty") From 8e94d2721a9d3d36697e13f8cc6567ae8aeee78b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Sat, 10 May 2014 12:03:27 -0700 Subject: [PATCH 20/33] [SPARK-1778] [SQL] Add 'limit' transformation to SchemaRDD. Add `limit` transformation to `SchemaRDD`. Author: Takuya UESHIN Closes #711 from ueshin/issues/SPARK-1778 and squashes the following commits: 33169df [Takuya UESHIN] Add 'limit' transformation to SchemaRDD. --- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 9 +++++++++ .../test/scala/org/apache/spark/sql/DslQuerySuite.scala | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 34200be3ac955..2569815ebb209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -178,6 +178,15 @@ class SchemaRDD( def orderBy(sortExprs: SortOrder*): SchemaRDD = new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + /** + * Limits the results by the given expressions. + * {{{ + * schemaRDD.limit(10) + * }}} + */ + def limit(limitExpr: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) + /** * Performs a grouping followed by an aggregation. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index be0f4a4c73b36..92a707ea57504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -71,6 +71,12 @@ class DslQuerySuite extends QueryTest { Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) } + test("limit") { + checkAnswer( + testData.limit(10), + testData.take(10).toSeq) + } + test("average") { checkAnswer( testData2.groupBy()(Average('a)), From 7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 10 May 2014 12:10:24 -0700 Subject: [PATCH 21/33] SPARK-1708. Add a ClassTag on Serializer and things that depend on it This pull request contains a rebased patch from @heathermiller (https://github.com/heathermiller/spark/pull/1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility. One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly. CC @rxin, @pwendell, @heathermiller Author: Matei Zaharia Closes #700 from mateiz/spark-1708 and squashes the following commits: 1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java 3b449ed [Matei Zaharia] test fix 2209a27 [Matei Zaharia] Code style fixes 9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it --- .../scala/org/apache/spark/Accumulators.scala | 7 +-- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../spark/api/java/JavaSparkContext.scala | 2 +- .../apache/spark/broadcast/Broadcast.scala | 4 +- .../spark/broadcast/BroadcastFactory.scala | 4 +- .../spark/broadcast/BroadcastManager.scala | 4 +- .../spark/broadcast/HttpBroadcast.scala | 7 ++- .../broadcast/HttpBroadcastFactory.scala | 4 +- .../spark/broadcast/TorrentBroadcast.scala | 4 +- .../broadcast/TorrentBroadcastFactory.scala | 4 +- .../org/apache/spark/rdd/CheckpointRDD.scala | 4 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../apache/spark/rdd/RDDCheckpointData.scala | 2 +- .../spark/serializer/JavaSerializer.scala | 13 +++--- .../spark/serializer/KryoSerializer.scala | 12 ++--- .../apache/spark/serializer/Serializer.scala | 17 +++---- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../serializer/KryoSerializerSuite.scala | 11 ++--- .../bagel/WikipediaPageRankStandalone.scala | 12 ++--- .../spark/graphx/impl/Serializers.scala | 45 ++++++++++--------- .../apache/spark/graphx/SerializerSuite.scala | 5 ++- .../sql/execution/SparkSqlSerializer.scala | 6 ++- 22 files changed, 103 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 6d652faae149a..cdfd338081fa2 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable} import scala.collection.generic.Growable import scala.collection.mutable.Map +import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer @@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable { def zero(initialValue: R): R } -private[spark] -class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] - extends AccumulableParam[R,T] { +private[spark] class +GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] + extends AccumulableParam[R, T] { def addAccumulator(growable: R, elem: T): R = { growable += elem diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9d7c2c8d3d630..c639b3e15ded5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging { * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by * standard mutable collections. So you can use this with mutable Map, Set, etc. */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R,T] new Accumulable(initialValue, param) @@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging { * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = { + def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8b95cda511643..a7cfee6d01711 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) + def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag) /** Shut down the SparkContext. */ def stop() { diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 738a3b1bed7f3..76956f6a345d1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -21,6 +21,8 @@ import java.io.Serializable import org.apache.spark.SparkException +import scala.reflect.ClassTag + /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for @@ -50,7 +52,7 @@ import org.apache.spark.SparkException * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** * Flag signifying whether the broadcast variable is valid diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 8c8ce9b1691ac..a8c827030a1ef 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.SecurityManager import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi @@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index cf62aca4d45e8..c88be6aba6901 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.broadcast import java.util.concurrent.atomic.AtomicLong +import scala.reflect.ClassTag + import org.apache.spark._ private[spark] class BroadcastManager( @@ -56,7 +58,7 @@ private[spark] class BroadcastManager( private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = { + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 29372f16f2cac..78fc286e5192c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit +import scala.reflect.ClassTag + import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} @@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * (through a HTTP server running at the driver) and stored in the BlockManager of the * executor to speed up future accesses. */ -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T: ClassTag]( + @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def getValue = value_ @@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging { files += file.getAbsolutePath } - def read[T](id: Long): T = { + def read[T: ClassTag](id: Long): T = { logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index e3f6cdc6154dd..d5a031e2bbb59 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.{SecurityManager, SparkConf} /** @@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) def stop() { HttpBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2659274c5e98e..734de37ba115d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import scala.reflect.ClassTag import scala.math import scala.util.Random @@ -44,7 +45,8 @@ import org.apache.spark.util.Utils * copies of the broadcast data (one per executor) as done by the * [[org.apache.spark.broadcast.HttpBroadcast]]. */ -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TorrentBroadcast[T: ClassTag]( + @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def getValue = value_ diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index d216b58718148..1de8396a0e17f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import scala.reflect.ClassTag + import org.apache.spark.{SecurityManager, SparkConf} /** @@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) def stop() { TorrentBroadcast.stop() } diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 888af541cf970..34c51b833025e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -84,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging { "part-%05d".format(splitId) } - def writeToFile[T]( + def writeToFile[T: ClassTag]( path: String, broadcastedConf: Broadcast[SerializableWritable[Configuration]], blockSize: Int = -1 @@ -160,7 +160,7 @@ private[spark] object CheckpointRDD extends Logging { val conf = SparkHadoopUtil.get.newConfiguration() val fs = path.getFileSystem(conf) val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _) + sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 5f03d7d650a30..2425929fc73c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -77,7 +77,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( slice = in.readInt() val ser = sfactory.newInstance() - Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject[Seq[T]]()) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 953f0555e57c5..c3b2a33fb54d0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -92,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (newRDD.partitions.size != rdd.partitions.size) { throw new SparkException( diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index e9163deaf2036..0a7e1ec539679 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -20,6 +20,8 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream @@ -36,7 +38,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * But only call it every 10,000th time to avoid bloated serialization streams (when * the stream 'resets' object class descriptions have to be re-written) */ - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { objOut.writeObject(t) if (counterReset > 0 && counter >= counterReset) { objOut.reset() @@ -46,6 +48,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In } this } + def flush() { objOut.flush() } def close() { objOut.close() } } @@ -57,12 +60,12 @@ extends DeserializationStream { Class.forName(desc.getName, false, loader) } - def readObject[T](): T = objIn.readObject().asInstanceOf[T] + def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -70,13 +73,13 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) in.readObject().asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) in.readObject().asInstanceOf[T] diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index c4daec7875d26..5286f7b4c211a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,6 +31,8 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} +import scala.reflect.ClassTag + /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * @@ -95,7 +97,7 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } @@ -108,7 +110,7 @@ private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { val input = new KryoInput(inStream) - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -131,18 +133,18 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ lazy val output = ks.newKryoOutput() lazy val input = new KryoInput() - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2c8f9b6218d6..ee26970a3d874 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -20,6 +20,8 @@ package org.apache.spark.serializer import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.SparkEnv import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.{ByteBufferInputStream, NextIterator} @@ -59,17 +61,17 @@ object Serializer { */ @DeveloperApi trait SerializerInstance { - def serialize[T](t: T): ByteBuffer + def serialize[T: ClassTag](t: T): ByteBuffer - def deserialize[T](bytes: ByteBuffer): T + def deserialize[T: ClassTag](bytes: ByteBuffer): T - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { + def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { // Default implementation uses serializeStream val stream = new ByteArrayOutputStream() serializeStream(stream).writeAll(iterator) @@ -85,18 +87,17 @@ trait SerializerInstance { } } - /** * :: DeveloperApi :: * A stream for writing serialized objects. */ @DeveloperApi trait SerializationStream { - def writeObject[T](t: T): SerializationStream + def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit - def writeAll[T](iter: Iterator[T]): SerializationStream = { + def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { writeObject(iter.next()) } @@ -111,7 +112,7 @@ trait SerializationStream { */ @DeveloperApi trait DeserializationStream { - def readObject[T](): T + def readObject[T: ClassTag](): T def close(): Unit /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3f0ed61c5bbfb..95777fbf57d8b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -850,7 +850,7 @@ private[spark] object Utils extends Logging { /** * Clone an object using a Spark serializer. */ - def clone[T](value: T, serializer: SerializerInstance): T = { + def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = { serializer.deserialize[T](serializer.serialize(value)) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 5d4673aebe9e8..cdd6b3d8feed7 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import scala.collection.mutable +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite @@ -31,7 +32,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("basic types") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(1) @@ -61,7 +62,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("pairs") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check((1, 1)) @@ -85,7 +86,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("Scala data structures") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) @@ -108,7 +109,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("ranges") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time assert(ser.serialize(t).limit < 100) @@ -129,7 +130,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("custom registrator") { val ser = new KryoSerializer(conf).newInstance() - def check[T](t: T) { + def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index a197dac87d6db..576a3e371b993 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -28,6 +28,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag + object WikipediaPageRankStandalone { def main(args: Array[String]) { if (args.length < 4) { @@ -143,15 +145,15 @@ class WPRSerializer extends org.apache.spark.serializer.Serializer { } class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T: ClassTag](t: T): ByteBuffer = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer): T = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { throw new UnsupportedOperationException() } @@ -167,7 +169,7 @@ class WPRSerializerInstance extends SerializerInstance { class WPRSerializationStream(os: OutputStream) extends SerializationStream { val dos = new DataOutputStream(os) - def writeObject[T](t: T): SerializationStream = t match { + def writeObject[T: ClassTag](t: T): SerializationStream = t match { case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { case links: Array[String] => { dos.writeInt(0) // links @@ -200,7 +202,7 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream { class WPRDeserializationStream(is: InputStream) extends DeserializationStream { val dis = new DataInputStream(is) - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val typeId = dis.readInt() typeId match { case 0 => { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 2f0531ee5f379..1de42eeca1f00 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -17,20 +17,22 @@ package org.apache.spark.graphx.impl +import scala.language.existentials + import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer +import scala.reflect.ClassTag + import org.apache.spark.graphx._ import org.apache.spark.serializer._ -import scala.language.existentials - private[graphx] class VertexIdMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, _)] writeVarLong(msg._1, optimizePositive = false) this @@ -38,7 +40,7 @@ class VertexIdMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { (readVarLong(optimizePositive = false), null).asInstanceOf[T] } } @@ -51,7 +53,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Int]] writeVarLong(msg.vid, optimizePositive = false) writeInt(msg.data) @@ -60,7 +62,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readInt() new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T] @@ -75,7 +77,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Long]] writeVarLong(msg.vid, optimizePositive = false) writeLong(msg.data) @@ -84,7 +86,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readLong() new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T] @@ -99,7 +101,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[VertexBroadcastMsg[Double]] writeVarLong(msg.vid, optimizePositive = false) writeDouble(msg.data) @@ -108,7 +110,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readDouble() new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T] @@ -123,7 +125,7 @@ class IntAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Int)] writeVarLong(msg._1, optimizePositive = false) writeUnsignedVarInt(msg._2) @@ -132,7 +134,7 @@ class IntAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readUnsignedVarInt() (a, b).asInstanceOf[T] @@ -147,7 +149,7 @@ class LongAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Long)] writeVarLong(msg._1, optimizePositive = false) writeVarLong(msg._2, optimizePositive = true) @@ -156,7 +158,7 @@ class LongAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T](): T = { + override def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readVarLong(optimizePositive = true) (a, b).asInstanceOf[T] @@ -171,7 +173,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T](t: T) = { + def writeObject[T: ClassTag](t: T) = { val msg = t.asInstanceOf[(VertexId, Double)] writeVarLong(msg._1, optimizePositive = false) writeDouble(msg._2) @@ -180,7 +182,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { } override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { val a = readVarLong(optimizePositive = false) val b = readDouble() (a, b).asInstanceOf[T] @@ -196,7 +198,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable { private[graphx] abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { // The implementation should override this one. - def writeObject[T](t: T): SerializationStream + def writeObject[T: ClassTag](t: T): SerializationStream def writeInt(v: Int) { s.write(v >> 24) @@ -309,7 +311,7 @@ abstract class ShuffleSerializationStream(s: OutputStream) extends Serialization private[graphx] abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { // The implementation should override this one. - def readObject[T](): T + def readObject[T: ClassTag](): T def readInt(): Int = { val first = s.read() @@ -398,11 +400,12 @@ abstract class ShuffleDeserializationStream(s: InputStream) extends Deserializat private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException + override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException - override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = throw new UnsupportedOperationException // The implementation should override the following two. diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index 73438d9535962..91caa6b605a1e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.graphx import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} import scala.util.Random +import scala.reflect.ClassTag import org.scalatest.FunSuite @@ -164,7 +165,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext { def testVarLongEncoding(v: Long, optimizePositive: Boolean) { val bout = new ByteArrayOutputStream val stream = new ShuffleSerializationStream(bout) { - def writeObject[T](t: T): SerializationStream = { + def writeObject[T: ClassTag](t: T): SerializationStream = { writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) this } @@ -173,7 +174,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext { val bin = new ByteArrayInputStream(bout.toByteArray) val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T](): T = { + def readObject[T: ClassTag](): T = { readVarLong(optimizePositive).asInstanceOf[T] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 5067c14ddffeb..1c6e29b3cdee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Serializer, Kryo} @@ -59,11 +61,11 @@ private[sql] object SparkSqlSerializer { new KryoSerializer(sparkConf) } - def serialize[T](o: T): Array[Byte] = { + def serialize[T: ClassTag](o: T): Array[Byte] = { ser.newInstance().serialize(o).array() } - def deserialize[T](bytes: Array[Byte]): T = { + def deserialize[T: ClassTag](bytes: Array[Byte]): T = { ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) } } From c05d11bb307eaba40c5669da2d374c28debaa55a Mon Sep 17 00:00:00 2001 From: Andy Konwinski Date: Sat, 10 May 2014 12:46:51 -0700 Subject: [PATCH 22/33] fix broken in link in python docs Author: Andy Konwinski Closes #650 from andyk/python-docs-link-fix and squashes the following commits: a1f9d51 [Andy Konwinski] fix broken in link in python docs --- docs/python-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 6813963bb080c..39fb5f0c99ca3 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -45,7 +45,7 @@ errors = logData.filter(is_error) PySpark will automatically ship these functions to executors, along with any objects that they reference. Instances of classes will be serialized and shipped to executors by PySpark, but classes themselves cannot be automatically distributed to executors. -The [Standalone Use](#standalone-use) section describes how to ship code dependencies to executors. +The [Standalone Use](#standalone-programs) section describes how to ship code dependencies to executors. In addition, PySpark fully supports interactive use---simply run `./bin/pyspark` to launch an interactive shell. From 3776f2f283842543ff766398292532c6e94221cc Mon Sep 17 00:00:00 2001 From: Bouke van der Bijl Date: Sat, 10 May 2014 13:02:13 -0700 Subject: [PATCH 23/33] Add Python includes to path before depickling broadcast values This fixes https://issues.apache.org/jira/browse/SPARK-1731 by adding the Python includes to the PYTHONPATH before depickling the broadcast values @airhorns Author: Bouke van der Bijl Closes #656 from bouk/python-includes-before-broadcast and squashes the following commits: 7b0dfe4 [Bouke van der Bijl] Add Python includes to path before depickling broadcast values --- .../org/apache/spark/api/python/PythonRDD.scala | 10 +++++----- python/pyspark/worker.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fecd9762f3f60..388b838d78bba 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -179,6 +179,11 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(split.index) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -186,11 +191,6 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(broadcast.value.length) dataOut.write(broadcast.value) } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } dataOut.flush() // Serialized command: dataOut.writeInt(command.length) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4c214ef359685..f43210c6c0301 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -56,13 +56,6 @@ def main(infile, outfile): SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True - # fetch names and values of broadcast variables - num_broadcast_variables = read_int(infile) - for _ in range(num_broadcast_variables): - bid = read_long(infile) - value = pickleSer._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) @@ -70,6 +63,13 @@ def main(infile, outfile): filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) + # fetch names and values of broadcast variables + num_broadcast_variables = read_int(infile) + for _ in range(num_broadcast_variables): + bid = read_long(infile) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() From 6c2691d0a0ed46a8b8093e05a4708706cf187168 Mon Sep 17 00:00:00 2001 From: Kan Zhang Date: Sat, 10 May 2014 14:01:08 -0700 Subject: [PATCH 24/33] [SPARK-1690] Tolerating empty elements when saving Python RDD to text files Tolerate empty strings in PythonRDD Author: Kan Zhang Closes #644 from kanzhang/SPARK-1690 and squashes the following commits: c62ad33 [Kan Zhang] Adding Python doctest 473ec4b [Kan Zhang] [SPARK-1690] Tolerating empty elements when saving Python RDD to text files --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 5 +++-- python/pyspark/rdd.py | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 388b838d78bba..2971c277aa863 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -94,6 +94,7 @@ private[spark] class PythonRDD[T: ClassTag]( val obj = new Array[Byte](length) stream.readFully(obj) obj + case 0 => Array.empty[Byte] case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() @@ -123,7 +124,7 @@ private[spark] class PythonRDD[T: ClassTag]( stream.readFully(update) accumulator += Collections.singletonList(update) } - Array.empty[Byte] + null } } catch { @@ -143,7 +144,7 @@ private[spark] class PythonRDD[T: ClassTag]( var _nextObj = read() - def hasNext = _nextObj.length != 0 + def hasNext = _nextObj != null } new InterruptibleIterator(context, stdoutIterator) } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3a1c56af5b221..4f74824ba4cf2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -891,6 +891,14 @@ def saveAsTextFile(self, path): >>> from glob import glob >>> ''.join(sorted(input(glob(tempFile.name + "/part-0000*")))) '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' + + Empty lines are tolerated when saving to text files. + + >>> tempFile2 = NamedTemporaryFile(delete=True) + >>> tempFile2.close() + >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name) + >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*")))) + '\\n\\n\\nbar\\nfoo\\n' """ def func(split, iterator): for x in iterator: From 905173df57b90f90ebafb22e43f55164445330e6 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 10 May 2014 14:48:07 -0700 Subject: [PATCH 25/33] Unify GraphImpl RDDs + other graph load optimizations This PR makes the following changes, primarily in e4fbd329aef85fe2c38b0167255d2a712893d683: 1. *Unify RDDs to avoid zipPartitions.* A graph used to be four RDDs: vertices, edges, routing table, and triplet view. This commit merges them down to two: vertices (with routing table), and edges (with replicated vertices). 2. *Avoid duplicate shuffle in graph building.* We used to do two shuffles when building a graph: one to extract routing information from the edges and move it to the vertices, and another to find nonexistent vertices referred to by edges. With this commit, the latter is done as a side effect of the former. 3. *Avoid no-op shuffle when joins are fully eliminated.* This is a side effect of unifying the edges and the triplet view. 4. *Join elimination for mapTriplets.* 5. *Ship only the needed vertex attributes when upgrading the triplet view.* If the triplet view already contains source attributes, and we now need both attributes, only ship destination attributes rather than re-shipping both. This is done in `ReplicatedVertexView#upgrade`. Author: Ankur Dave Closes #497 from ankurdave/unify-rdds and squashes the following commits: 332ab43 [Ankur Dave] Merge remote-tracking branch 'apache-spark/master' into unify-rdds 4933e2e [Ankur Dave] Exclude RoutingTable from binary compatibility check 5ba8789 [Ankur Dave] Add GraphX upgrade guide from Spark 0.9.1 13ac845 [Ankur Dave] Merge remote-tracking branch 'apache-spark/master' into unify-rdds a04765c [Ankur Dave] Remove unnecessary toOps call 57202e8 [Ankur Dave] Replace case with pair parameter 75af062 [Ankur Dave] Add explicit return types 04d3ae5 [Ankur Dave] Convert implicit parameter to context bound c88b269 [Ankur Dave] Revert upgradeIterator to if-in-a-loop 0d3584c [Ankur Dave] EdgePartition.size should be val 2a928b2 [Ankur Dave] Set locality wait 10b3596 [Ankur Dave] Clean up public API ae36110 [Ankur Dave] Fix style errors e4fbd32 [Ankur Dave] Unify GraphImpl RDDs + other graph load optimizations d6d60e2 [Ankur Dave] In GraphLoader, coalesce to minEdgePartitions 62c7b78 [Ankur Dave] In Analytics, take PageRank numIter d64e8d4 [Ankur Dave] Log current Pregel iteration --- docs/graphx-programming-guide.md | 22 +- .../org/apache/spark/graphx/EdgeRDD.scala | 56 +-- .../org/apache/spark/graphx/EdgeTriplet.scala | 2 + .../scala/org/apache/spark/graphx/Graph.scala | 2 +- .../spark/graphx/GraphKryoRegistrator.scala | 8 +- .../org/apache/spark/graphx/GraphLoader.scala | 10 +- .../org/apache/spark/graphx/GraphOps.scala | 17 +- .../org/apache/spark/graphx/Pregel.scala | 6 +- .../org/apache/spark/graphx/VertexRDD.scala | 166 ++++++--- .../spark/graphx/impl/EdgePartition.scala | 132 +++++-- .../graphx/impl/EdgePartitionBuilder.scala | 18 +- .../graphx/impl/EdgeTripletIterator.scala | 50 ++- .../apache/spark/graphx/impl/GraphImpl.scala | 344 +++++++----------- .../graphx/impl/MessageToPartition.scala | 21 +- .../graphx/impl/ReplicatedVertexView.scala | 238 ++++-------- .../spark/graphx/impl/RoutingTable.scala | 82 ----- .../graphx/impl/RoutingTablePartition.scala | 158 ++++++++ .../spark/graphx/impl/Serializers.scala | 29 ++ .../impl/ShippableVertexPartition.scala | 149 ++++++++ .../spark/graphx/impl/VertexPartition.scala | 269 ++------------ .../graphx/impl/VertexPartitionBase.scala | 91 +++++ .../graphx/impl/VertexPartitionBaseOps.scala | 245 +++++++++++++ .../apache/spark/graphx/lib/Analytics.scala | 8 +- .../org/apache/spark/graphx/GraphSuite.scala | 10 +- .../graphx/impl/EdgePartitionSuite.scala | 48 ++- .../impl/EdgeTripletIteratorSuite.scala | 10 +- .../graphx/impl/VertexPartitionSuite.scala | 11 - project/MimaBuild.scala | 2 + 28 files changed, 1353 insertions(+), 851 deletions(-) delete mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 07be8ba58efa3..42ab27bf55ccf 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -86,6 +86,12 @@ support the [Bagel API](api/scala/index.html#org.apache.spark.bagel.package) and [Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to explore the new GraphX API and comment on issues that may complicate the transition from Bagel. +## Upgrade Guide from Spark 0.9.1 + +GraphX in Spark {{site.SPARK_VERSION}} contains one user-facing interface change from Spark 0.9.1. [`EdgeRDD`][EdgeRDD] may now store adjacent vertex attributes to construct the triplets, so it has gained a type parameter. The edges of a graph of type `Graph[VD, ED]` are of type `EdgeRDD[ED, VD]` rather than `EdgeRDD[ED]`. + +[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD + # Getting Started To get started you first need to import Spark and GraphX into your project, as follows: @@ -145,12 +151,12 @@ the vertices and edges of the graph: {% highlight scala %} class Graph[VD, ED] { val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED, VD] } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, -VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional +The classes `VertexRDD[VD]` and `EdgeRDD[ED, VD]` extend and are optimized versions of `RDD[(VertexID, +VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED, VD]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the `VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: @@ -302,7 +308,7 @@ class Graph[VD, ED] { val degrees: VertexRDD[Int] // Views of the graph as collections ============================================================= val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED] + val edges: EdgeRDD[ED, VD] val triplets: RDD[EdgeTriplet[VD, ED]] // Functions for caching graphs ================================================================== def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED] @@ -908,7 +914,7 @@ val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b) ## EdgeRDDs -The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one +The `EdgeRDD[ED, VD]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one of the various partitioning strategies defined in [`PartitionStrategy`][PartitionStrategy]. Within each partition, edge attributes and adjacency structure, are stored separately enabling maximum reuse when changing attribute values. @@ -918,11 +924,11 @@ reuse when changing attribute values. The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure -def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] +def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] // Revere the edges reusing both attributes and structure -def reverse: EdgeRDD[ED] +def reverse: EdgeRDD[ED, VD] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. -def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] +def innerJoin[ED2, ED3](other: EdgeRDD[ED2, VD])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] {% endhighlight %} In most applications we have found that operations on the `EdgeRDD` are accomplished through the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index fa78ca99b8891..a8fc095072512 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -20,16 +20,19 @@ package org.apache.spark.graphx import scala.reflect.{classTag, ClassTag} import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.graphx.impl.EdgePartition + /** - * `EdgeRDD[ED]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each partition - * for performance. + * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each + * partition for performance. It may additionally store the vertex attributes associated with each + * edge to provide the triplet view. Shipping of the vertex attributes is managed by + * `impl.ReplicatedVertexView`. */ -class EdgeRDD[@specialized ED: ClassTag]( - val partitionsRDD: RDD[(PartitionID, EdgePartition[ED])]) +class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( + val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])]) extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { partitionsRDD.setName("EdgeRDD") @@ -45,8 +48,12 @@ class EdgeRDD[@specialized ED: ClassTag]( partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - val p = firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context) - p.next._2.iterator.map(_.copy()) + val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context) + if (p.hasNext) { + p.next._2.iterator.map(_.copy()) + } else { + Iterator.empty + } } override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() @@ -61,11 +68,15 @@ class EdgeRDD[@specialized ED: ClassTag]( this } - private[graphx] def mapEdgePartitions[ED2: ClassTag]( - f: (PartitionID, EdgePartition[ED]) => EdgePartition[ED2]): EdgeRDD[ED2] = { - new EdgeRDD[ED2](partitionsRDD.mapPartitions({ iter => - val (pid, ep) = iter.next() - Iterator(Tuple2(pid, f(pid, ep))) + private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( + f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { + new EdgeRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => + if (iter.hasNext) { + val (pid, ep) = iter.next() + Iterator(Tuple2(pid, f(pid, ep))) + } else { + Iterator.empty + } }, preservesPartitioning = true)) } @@ -76,7 +87,7 @@ class EdgeRDD[@specialized ED: ClassTag]( * @param f the function from an edge to a new edge value * @return a new EdgeRDD containing the new edge values */ - def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] = + def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] = mapEdgePartitions((pid, part) => part.map(f)) /** @@ -84,7 +95,14 @@ class EdgeRDD[@specialized ED: ClassTag]( * * @return a new EdgeRDD containing all the edges reversed */ - def reverse: EdgeRDD[ED] = mapEdgePartitions((pid, part) => part.reverse) + def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) + + /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ + def filter( + epred: EdgeTriplet[VD, ED] => Boolean, + vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = { + mapEdgePartitions((pid, part) => part.filter(epred, vpred)) + } /** * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same @@ -96,19 +114,15 @@ class EdgeRDD[@specialized ED: ClassTag]( * with values supplied by `f` */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] = { + (other: EdgeRDD[ED2, _]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = { val ed2Tag = classTag[ED2] val ed3Tag = classTag[ED3] - new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, true) { + new EdgeRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { (thisIter, otherIter) => val (pid, thisEPart) = thisIter.next() val (_, otherEPart) = otherIter.next() Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) }) } - - private[graphx] def collectVertexIds(): RDD[VertexId] = { - partitionsRDD.flatMap { case (_, p) => Array.concat(p.srcIds, p.dstIds) } - } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index dfc6a801587d2..9d473d5ebda44 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -63,4 +63,6 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr } override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString() + + def toTuple: ((VertexId, VD), (VertexId, VD), ED) = ((srcId, srcAttr), (dstId, dstAttr), attr) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 50395868902dc..dc5dac4fdad57 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED] + @transient val edges: EdgeRDD[ED, VD] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index dd380d8c182c9..d295d0127ac72 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -19,10 +19,11 @@ package org.apache.spark.graphx import com.esotericsoftware.kryo.Kryo -import org.apache.spark.graphx.impl._ import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.util.collection.BitSet import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.BitSet + +import org.apache.spark.graphx.impl._ /** * Registers GraphX classes with Kryo for improved performance. @@ -33,8 +34,9 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[Edge[Object]]) kryo.register(classOf[MessageToPartition[Object]]) kryo.register(classOf[VertexBroadcastMsg[Object]]) + kryo.register(classOf[RoutingTableMessage]) kryo.register(classOf[(VertexId, Object)]) - kryo.register(classOf[EdgePartition[Object]]) + kryo.register(classOf[EdgePartition[Object, Object]]) kryo.register(classOf[BitSet]) kryo.register(classOf[VertexIdToIndexMap]) kryo.register(classOf[VertexAttributeBlock[Object]]) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 18858466db27b..389490c139848 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -47,8 +47,7 @@ object GraphLoader extends Logging { * @param path the path to the file (e.g., /home/data/file or hdfs://file) * @param canonicalOrientation whether to orient edges in the positive * direction - * @param minEdgePartitions the number of partitions for the - * the edge RDD + * @param minEdgePartitions the number of partitions for the edge RDD */ def edgeListFile( sc: SparkContext, @@ -60,8 +59,9 @@ object GraphLoader extends Logging { val startTime = System.currentTimeMillis // Parse the edge data table directly into edge partitions - val edges = sc.textFile(path, minEdgePartitions).mapPartitionsWithIndex { (pid, iter) => - val builder = new EdgePartitionBuilder[Int] + val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions) + val edges = lines.mapPartitionsWithIndex { (pid, iter) => + val builder = new EdgePartitionBuilder[Int, Int] iter.foreach { line => if (!line.isEmpty && line(0) != '#') { val lineArray = line.split("\\s+") @@ -78,7 +78,7 @@ object GraphLoader extends Logging { } } Iterator((pid, builder.toEdgePartition)) - }.cache() + }.cache().setName("GraphLoader.edgeListFile - edges (%s)".format(path)) edges.count() logInfo("It took %d ms to load the edges".format(System.currentTimeMillis - startTime)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 4997fbc3cbcd8..edd5b79da1522 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -18,11 +18,13 @@ package org.apache.spark.graphx import scala.reflect.ClassTag -import org.apache.spark.SparkContext._ +import scala.util.Random + import org.apache.spark.SparkException -import org.apache.spark.graphx.lib._ +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import scala.util.Random + +import org.apache.spark.graphx.lib._ /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the @@ -43,19 +45,22 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * The in-degree of each vertex in the graph. * @note Vertices with no in-edges are not returned in the resulting RDD. */ - @transient lazy val inDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.In) + @transient lazy val inDegrees: VertexRDD[Int] = + degreesRDD(EdgeDirection.In).setName("GraphOps.inDegrees") /** * The out-degree of each vertex in the graph. * @note Vertices with no out-edges are not returned in the resulting RDD. */ - @transient lazy val outDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Out) + @transient lazy val outDegrees: VertexRDD[Int] = + degreesRDD(EdgeDirection.Out).setName("GraphOps.outDegrees") /** * The degree of each vertex in the graph. * @note Vertices with no edges are not returned in the resulting RDD. */ - @transient lazy val degrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Either) + @transient lazy val degrees: VertexRDD[Int] = + degreesRDD(EdgeDirection.Either).setName("GraphOps.degrees") /** * Computes the neighboring vertex degrees. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index ac07a594a12e4..4572eab2875bb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.Logging /** @@ -52,7 +53,7 @@ import scala.reflect.ClassTag * }}} * */ -object Pregel { +object Pregel extends Logging { /** * Execute a Pregel-like iterative vertex-parallel abstraction. The @@ -142,6 +143,9 @@ object Pregel { // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). activeMessages = messages.count() + + logInfo("Pregel finished iteration " + i) + // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking=false) newVerts.unpersist(blocking=false) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index f0fc605c88575..8c62897037b6d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -24,8 +24,11 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx.impl.MsgRDDFunctions -import org.apache.spark.graphx.impl.VertexPartition +import org.apache.spark.graphx.impl.RoutingTablePartition +import org.apache.spark.graphx.impl.ShippableVertexPartition +import org.apache.spark.graphx.impl.VertexAttributeBlock +import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._ +import org.apache.spark.graphx.impl.VertexRDDFunctions._ /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -33,6 +36,9 @@ import org.apache.spark.graphx.impl.VertexPartition * joined efficiently. All operations except [[reindex]] preserve the index. To construct a * `VertexRDD`, use the [[org.apache.spark.graphx.VertexRDD$ VertexRDD object]]. * + * Additionally, stores routing information to enable joining the vertex attributes with an + * [[EdgeRDD]]. + * * @example Construct a `VertexRDD` from a plain RDD: * {{{ * // Construct an initial vertex set @@ -50,13 +56,11 @@ import org.apache.spark.graphx.impl.VertexPartition * @tparam VD the vertex attribute associated with each vertex in the set. */ class VertexRDD[@specialized VD: ClassTag]( - val partitionsRDD: RDD[VertexPartition[VD]]) + val partitionsRDD: RDD[ShippableVertexPartition[VD]]) extends RDD[(VertexId, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { require(partitionsRDD.partitioner.isDefined) - partitionsRDD.setName("VertexRDD") - /** * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting * VertexRDD will be based on a different index and can no longer be quickly joined with this @@ -71,6 +75,16 @@ class VertexRDD[@specialized VD: ClassTag]( override protected def getPreferredLocations(s: Partition): Seq[String] = partitionsRDD.preferredLocations(s) + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("VertexRDD") + override def persist(newLevel: StorageLevel): this.type = { partitionsRDD.persist(newLevel) this @@ -90,14 +104,14 @@ class VertexRDD[@specialized VD: ClassTag]( * Provides the `RDD[(VertexId, VD)]` equivalent output. */ override def compute(part: Partition, context: TaskContext): Iterator[(VertexId, VD)] = { - firstParent[VertexPartition[VD]].iterator(part, context).next.iterator + firstParent[ShippableVertexPartition[VD]].iterator(part, context).next.iterator } /** * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD. */ private[graphx] def mapVertexPartitions[VD2: ClassTag]( - f: VertexPartition[VD] => VertexPartition[VD2]) + f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) : VertexRDD[VD2] = { val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) new VertexRDD(newPartitionsRDD) @@ -208,10 +222,8 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => new VertexRDD[VD3]( partitionsRDD.zipPartitions( - other.partitionBy(this.partitioner.get), preservesPartitioning = true) - { (part, msgs) => - val vertexPartition: VertexPartition[VD] = part.next() - Iterator(vertexPartition.leftJoin(msgs)(f)) + other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) } ) } @@ -254,10 +266,8 @@ class VertexRDD[@specialized VD: ClassTag]( case _ => new VertexRDD( partitionsRDD.zipPartitions( - other.partitionBy(this.partitioner.get), preservesPartitioning = true) - { (part, msgs) => - val vertexPartition: VertexPartition[VD] = part.next() - Iterator(vertexPartition.innerJoin(msgs)(f)) + other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) } ) } @@ -276,14 +286,31 @@ class VertexRDD[@specialized VD: ClassTag]( */ def aggregateUsingIndex[VD2: ClassTag]( messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = MsgRDDFunctions.partitionForAggregation(messages, this.partitioner.get) + val shuffled = messages.copartitionWithVertices(this.partitioner.get) val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => - val vertexPartition: VertexPartition[VD] = thisIter.next() - Iterator(vertexPartition.aggregateUsingIndex(msgIter, reduceFunc)) + thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) } new VertexRDD[VD2](parts) } + /** + * Returns a new `VertexRDD` reflecting a reversal of all edge directions in the corresponding + * [[EdgeRDD]]. + */ + def reverseRoutingTables(): VertexRDD[VD] = + this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + + /** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */ + private[graphx] def shipVertexAttributes( + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) + } + + /** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */ + private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) + } + } // end of VertexRDD @@ -293,52 +320,101 @@ class VertexRDD[@specialized VD: ClassTag]( object VertexRDD { /** - * Construct a `VertexRDD` from an RDD of vertex-attribute pairs. - * Duplicate entries are removed arbitrarily. + * Constructs a standalone `VertexRDD` (one that is not set up for efficient joins with an + * [[EdgeRDD]]) from an RDD of vertex-attribute pairs. Duplicate entries are removed arbitrarily. * * @tparam VD the vertex attribute type * - * @param rdd the collection of vertex-attribute pairs + * @param vertices the collection of vertex-attribute pairs */ - def apply[VD: ClassTag](rdd: RDD[(VertexId, VD)]): VertexRDD[VD] = { - val partitioned: RDD[(VertexId, VD)] = rdd.partitioner match { - case Some(p) => rdd - case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size)) + def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { + val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { + case Some(p) => vertices + case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) } - val vertexPartitions = partitioned.mapPartitions( - iter => Iterator(VertexPartition(iter)), + val vertexPartitions = vPartitioned.mapPartitions( + iter => Iterator(ShippableVertexPartition(iter)), preservesPartitioning = true) new VertexRDD(vertexPartitions) } /** - * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs, merging duplicates using - * `mergeFunc`. + * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs. Duplicate vertex entries are + * removed arbitrarily. The resulting `VertexRDD` will be joinable with `edges`, and any missing + * vertices referred to by `edges` will be created with the attribute `defaultVal`. * * @tparam VD the vertex attribute type * - * @param rdd the collection of vertex-attribute pairs - * @param mergeFunc the associative, commutative merge function. + * @param vertices the collection of vertex-attribute pairs + * @param edges the [[EdgeRDD]] that these vertices may be joined with + * @param defaultVal the vertex attribute to use when creating missing vertices */ - def apply[VD: ClassTag](rdd: RDD[(VertexId, VD)], mergeFunc: (VD, VD) => VD): VertexRDD[VD] = { - val partitioned: RDD[(VertexId, VD)] = rdd.partitioner match { - case Some(p) => rdd - case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size)) + def apply[VD: ClassTag]( + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { + VertexRDD(vertices, edges, defaultVal, (a, b) => b) + } + + /** + * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs. Duplicate vertex entries are + * merged using `mergeFunc`. The resulting `VertexRDD` will be joinable with `edges`, and any + * missing vertices referred to by `edges` will be created with the attribute `defaultVal`. + * + * @tparam VD the vertex attribute type + * + * @param vertices the collection of vertex-attribute pairs + * @param edges the [[EdgeRDD]] that these vertices may be joined with + * @param defaultVal the vertex attribute to use when creating missing vertices + * @param mergeFunc the commutative, associative duplicate vertex attribute merge function + */ + def apply[VD: ClassTag]( + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD + ): VertexRDD[VD] = { + val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { + case Some(p) => vertices + case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + } + val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) + val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { + (vertexIter, routingTableIter) => + val routingTable = + if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty + Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal)) } - val vertexPartitions = partitioned.mapPartitions( - iter => Iterator(VertexPartition(iter)), - preservesPartitioning = true) new VertexRDD(vertexPartitions) } /** - * Constructs a VertexRDD from the vertex IDs in `vids`, taking attributes from `rdd` and using - * `defaultVal` otherwise. + * Constructs a `VertexRDD` containing all vertices referred to in `edges`. The vertices will be + * created with the attribute `defaultVal`. The resulting `VertexRDD` will be joinable with + * `edges`. + * + * @tparam VD the vertex attribute type + * + * @param edges the [[EdgeRDD]] referring to the vertices to create + * @param numPartitions the desired number of partitions for the resulting `VertexRDD` + * @param defaultVal the vertex attribute to use when creating missing vertices */ - def apply[VD: ClassTag](vids: RDD[VertexId], rdd: RDD[(VertexId, VD)], defaultVal: VD) - : VertexRDD[VD] = { - VertexRDD(vids.map(vid => (vid, defaultVal))).leftJoin(rdd) { (vid, default, value) => - value.getOrElse(default) - } + def fromEdges[VD: ClassTag]( + edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { + val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions)) + val vertexPartitions = routingTables.mapPartitions({ routingTableIter => + val routingTable = + if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty + Iterator(ShippableVertexPartition(Iterator.empty, routingTable, defaultVal)) + }, preservesPartitioning = true) + new VertexRDD(vertexPartitions) + } + + private def createRoutingTables( + edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { + // Determine which vertices each edge partition needs by creating a mapping from vid to pid. + val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap( + Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) + .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") + + val numEdgePartitions = edges.partitions.size + vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions( + iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), + preservesPartitioning = true) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index b7c472e905a9b..871e81f8d245c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -17,39 +17,86 @@ package org.apache.spark.graphx.impl -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap /** - * A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are - * clustered by src. + * A collection of edges stored in columnar format, along with any vertex attributes referenced. The + * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by + * src. There is an optional active vertex set for filtering computation on the edges. + * + * @tparam ED the edge attribute type + * @tparam VD the vertex attribute type * * @param srcIds the source vertex id of each edge * @param dstIds the destination vertex id of each edge * @param data the attribute associated with each edge * @param index a clustered index on source vertex id - * @tparam ED the edge attribute type. + * @param vertices a map from referenced vertex ids to their corresponding attributes. Must + * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for + * those vertex ids. The mask is not used. + * @param activeSet an optional active vertex set for filtering computation on the edges */ private[graphx] -class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag]( +class EdgePartition[ + @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( @transient val srcIds: Array[VertexId], @transient val dstIds: Array[VertexId], @transient val data: Array[ED], - @transient val index: PrimitiveKeyOpenHashMap[VertexId, Int]) extends Serializable { + @transient val index: PrimitiveKeyOpenHashMap[VertexId, Int], + @transient val vertices: VertexPartition[VD], + @transient val activeSet: Option[VertexSet] = None + ) extends Serializable { + + /** Return a new `EdgePartition` with the specified edge data. */ + def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = { + new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet) + } + + /** Return a new `EdgePartition` with the specified vertex partition. */ + def withVertices[VD2: ClassTag]( + vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = { + new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet) + } + + /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ + def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { + val newActiveSet = new VertexSet + iter.foreach(newActiveSet.add(_)) + new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet)) + } + + /** Return a new `EdgePartition` with the specified active set. */ + def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = { + new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_) + } + + /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ + def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { + this.withVertices(vertices.innerJoinKeepLeft(iter)) + } + + /** Look up vid in activeSet, throwing an exception if it is None. */ + def isActive(vid: VertexId): Boolean = { + activeSet.get.contains(vid) + } + + /** The number of active vertices, if any exist. */ + def numActives: Option[Int] = activeSet.map(_.size) /** * Reverse all the edges in this partition. * * @return a new edge partition with all edges reversed. */ - def reverse: EdgePartition[ED] = { - val builder = new EdgePartitionBuilder(size) + def reverse: EdgePartition[ED, VD] = { + val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD]) for (e <- iterator) { builder.add(e.dstId, e.srcId, e.attr) } - builder.toEdgePartition + builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) } /** @@ -64,7 +111,7 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * @return a new edge partition with the result of the function `f` * applied to each edge */ - def map[ED2: ClassTag](f: Edge[ED] => ED2): EdgePartition[ED2] = { + def map[ED2: ClassTag](f: Edge[ED] => ED2): EdgePartition[ED2, VD] = { val newData = new Array[ED2](data.size) val edge = new Edge[ED]() val size = data.size @@ -76,7 +123,7 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) newData(i) = f(edge) i += 1 } - new EdgePartition(srcIds, dstIds, newData, index) + this.withData(newData) } /** @@ -91,7 +138,7 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * @tparam ED2 the type of the new attribute * @return a new edge partition with the attribute values replaced */ - def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = { + def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2, VD] = { // Faster than iter.toArray, because the expected size is known. val newData = new Array[ED2](data.size) var i = 0 @@ -100,7 +147,23 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) i += 1 } assert(newData.size == i) - new EdgePartition(srcIds, dstIds, newData, index) + this.withData(newData) + } + + /** + * Construct a new edge partition containing only the edges matching `epred` and where both + * vertices match `vpred`. + */ + def filter( + epred: EdgeTriplet[VD, ED] => Boolean, + vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { + val filtered = tripletIterator().filter(et => + vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) + val builder = new EdgePartitionBuilder[ED, VD] + for (e <- filtered) { + builder.add(e.srcId, e.dstId, e.attr) + } + builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) } /** @@ -119,8 +182,8 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * @param merge a commutative associative merge operation * @return a new edge partition without duplicate edges */ - def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED] = { - val builder = new EdgePartitionBuilder[ED] + def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { + val builder = new EdgePartitionBuilder[ED, VD] var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] var currAttr: ED = null.asInstanceOf[ED] @@ -141,11 +204,11 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) if (size > 0) { builder.add(currSrcId, currDstId, currAttr) } - builder.toEdgePartition + builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) } /** - * Apply `f` to all edges present in both `this` and `other` and return a new EdgePartition + * Apply `f` to all edges present in both `this` and `other` and return a new `EdgePartition` * containing the resulting edges. * * If there are multiple edges with the same src and dst in `this`, `f` will be invoked once for @@ -155,9 +218,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * once. */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgePartition[ED2]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3] = { - val builder = new EdgePartitionBuilder[ED3] + (other: EdgePartition[ED2, _]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { + val builder = new EdgePartitionBuilder[ED3, VD] var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -175,7 +238,7 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) } i += 1 } - builder.toEdgePartition + builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) } /** @@ -183,7 +246,7 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * * @return size of the partition */ - def size: Int = srcIds.size + val size: Int = srcIds.size /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size @@ -211,10 +274,35 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) } } + /** + * Get an iterator over the edge triplets in this partition. + * + * It is safe to keep references to the objects from this iterator. + */ + def tripletIterator( + includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { + new EdgeTripletIterator(this, includeSrc, includeDst) + } + + /** + * Upgrade the given edge iterator into a triplet iterator. + * + * Be careful not to keep references to the objects from this iterator. To improve GC performance + * the same object is re-used in `next()`. + */ + def upgradeIterator( + edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = { + new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst) + } + /** * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The * iterator is generated using an index scan, so it is efficient at skipping edges that don't * match srcIdPred. + * + * Be careful not to keep references to the objects from this iterator. To improve GC performance + * the same object is re-used in `next()`. */ def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] = index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 63ccccb056b48..ecb49bef42e45 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -20,12 +20,14 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Sorting +import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} + import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.PrimitiveVector private[graphx] -class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) { +class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( + size: Int = 64) { var edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ @@ -33,7 +35,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: I edges += Edge(src, dst, d) } - def toEdgePartition: EdgePartition[ED] = { + def toEdgePartition: EdgePartition[ED, VD] = { val edgeArray = edges.trim().array Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering) val srcIds = new Array[VertexId](edgeArray.size) @@ -57,6 +59,14 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: I i += 1 } } - new EdgePartition(srcIds, dstIds, data, index) + + // Create and populate a VertexPartition with vids from the edges, but no attributes + val vidsIter = srcIds.iterator ++ dstIds.iterator + val vertexIds = new OpenHashSet[VertexId] + vidsIter.foreach(vid => vertexIds.add(vid)) + val vertices = new VertexPartition( + vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet) + + new EdgePartition(srcIds, dstIds, data, index, vertices) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala index 220a89d73d711..ebb0b9418d65d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala @@ -23,32 +23,62 @@ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap /** - * The Iterator type returned when constructing edge triplets. This class technically could be - * an anonymous class in GraphImpl.triplets, but we name it here explicitly so it is easier to - * debug / profile. + * The Iterator type returned when constructing edge triplets. This could be an anonymous class in + * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. */ private[impl] class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val vidToIndex: VertexIdToIndexMap, - val vertexArray: Array[VD], - val edgePartition: EdgePartition[ED]) + val edgePartition: EdgePartition[ED, VD], + val includeSrc: Boolean, + val includeDst: Boolean) extends Iterator[EdgeTriplet[VD, ED]] { // Current position in the array. private var pos = 0 - private val vmap = new PrimitiveKeyOpenHashMap[VertexId, VD](vidToIndex, vertexArray) - override def hasNext: Boolean = pos < edgePartition.size override def next() = { val triplet = new EdgeTriplet[VD, ED] triplet.srcId = edgePartition.srcIds(pos) - triplet.srcAttr = vmap(triplet.srcId) + if (includeSrc) { + triplet.srcAttr = edgePartition.vertices(triplet.srcId) + } triplet.dstId = edgePartition.dstIds(pos) - triplet.dstAttr = vmap(triplet.dstId) + if (includeDst) { + triplet.dstAttr = edgePartition.vertices(triplet.dstId) + } triplet.attr = edgePartition.data(pos) pos += 1 triplet } } + +/** + * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous + * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug / + * profile. + */ +private[impl] +class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag]( + val edgeIter: Iterator[Edge[ED]], + val edgePartition: EdgePartition[ED, VD], + val includeSrc: Boolean, + val includeDst: Boolean) + extends Iterator[EdgeTriplet[VD, ED]] { + + private val triplet = new EdgeTriplet[VD, ED] + + override def hasNext = edgeIter.hasNext + + override def next() = { + triplet.set(edgeIter.next()) + if (includeSrc) { + triplet.srcAttr = edgePartition.vertices(triplet.srcId) + } + if (includeDst) { + triplet.dstAttr = edgePartition.vertices(triplet.dstId) + } + triplet + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 9eabccdee48db..2f2d0e03fd7b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -19,54 +19,45 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.util.collection.PrimitiveVector -import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.storage.StorageLevel + import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.impl.MsgRDDFunctions._ import org.apache.spark.graphx.util.BytecodeUtils -import org.apache.spark.rdd.{ShuffledRDD, RDD} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.ClosureCleaner /** - * A graph that supports computation on graphs. + * An implementation of [[org.apache.spark.graphx.Graph]] to support computation on graphs. * - * Graphs are represented using two classes of data: vertex-partitioned and - * edge-partitioned. `vertices` contains vertex attributes, which are vertex-partitioned. `edges` - * contains edge attributes, which are edge-partitioned. For operations on vertex neighborhoods, - * vertex attributes are replicated to the edge partitions where they appear as sources or - * destinations. `routingTable` stores the routing information for shipping vertex attributes to - * edge partitions. `replicatedVertexView` stores a view of the replicated vertex attributes created - * using the routing table. + * Graphs are represented using two RDDs: `vertices`, which contains vertex attributes and the + * routing information for shipping vertex attributes to edge partitions, and + * `replicatedVertexView`, which contains edges and the vertex attributes mentioned by each edge. */ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient val vertices: VertexRDD[VD], - @transient val edges: EdgeRDD[ED], - @transient val routingTable: RoutingTable, - @transient val replicatedVertexView: ReplicatedVertexView[VD]) + @transient val replicatedVertexView: ReplicatedVertexView[VD, ED]) extends Graph[VD, ED] with Serializable { /** Default constructor is provided to support serialization */ - protected def this() = this(null, null, null, null) + protected def this() = this(null, null) + + @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges /** Return a RDD that brings edges together with their source and destination vertices. */ - @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = { - val vdTag = classTag[VD] - val edTag = classTag[ED] - edges.partitionsRDD.zipPartitions( - replicatedVertexView.get(true, true), true) { (ePartIter, vPartIter) => - val (pid, ePart) = ePartIter.next() - val (_, vPart) = vPartIter.next() - new EdgeTripletIterator(vPart.index, vPart.values, ePart)(vdTag, edTag) - } + @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { + replicatedVertexView.upgrade(vertices, true, true) + replicatedVertexView.edges.partitionsRDD.mapPartitions(_.flatMap { + case (pid, part) => part.tripletIterator() + }) } override def persist(newLevel: StorageLevel): Graph[VD, ED] = { vertices.persist(newLevel) - edges.persist(newLevel) + replicatedVertexView.edges.persist(newLevel) this } @@ -74,14 +65,15 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def unpersistVertices(blocking: Boolean = true): Graph[VD, ED] = { vertices.unpersist(blocking) - replicatedVertexView.unpersist(blocking) + // TODO: unpersist the replicated vertices in `replicatedVertexView` but leave the edges alone this } override def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] = { - val numPartitions = edges.partitions.size + val numPartitions = replicatedVertexView.edges.partitions.size val edTag = classTag[ED] - val newEdges = new EdgeRDD(edges.map { e => + val vdTag = classTag[VD] + val newEdges = new EdgeRDD(replicatedVertexView.edges.map { e => val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions) // Should we be using 3-tuple or an optimized class @@ -89,105 +81,79 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } .partitionBy(new HashPartitioner(numPartitions)) .mapPartitionsWithIndex( { (pid, iter) => - val builder = new EdgePartitionBuilder[ED]()(edTag) + val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag) iter.foreach { message => val data = message.data builder.add(data._1, data._2, data._3) } val edgePartition = builder.toEdgePartition Iterator((pid, edgePartition)) - }, preservesPartitioning = true).cache()) - GraphImpl(vertices, newEdges) + }, preservesPartitioning = true)) + GraphImpl.fromExistingRDDs(vertices, newEdges) } override def reverse: Graph[VD, ED] = { - val newETable = edges.mapEdgePartitions((pid, part) => part.reverse) - GraphImpl(vertices, newETable) + new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse()) } override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { if (classTag[VD] equals classTag[VD2]) { + vertices.cache() // The map preserves type, so we can use incremental replication val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) - val newReplicatedVertexView = new ReplicatedVertexView[VD2]( - changedVerts, edges, routingTable, - Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]])) - new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView) + val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] + .updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) } else { // The map does not preserve type, so we must re-replicate all vertices - GraphImpl(vertices.mapVertexPartitions(_.map(f)), edges, routingTable) + GraphImpl(vertices.mapVertexPartitions(_.map(f)), replicatedVertexView.edges) } } override def mapEdges[ED2: ClassTag]( f: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] = { - val newETable = edges.mapEdgePartitions((pid, part) => part.map(f(pid, part.iterator))) - new GraphImpl(vertices, newETable , routingTable, replicatedVertexView) + val newEdges = replicatedVertexView.edges + .mapEdgePartitions((pid, part) => part.map(f(pid, part.iterator))) + new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } override def mapTriplets[ED2: ClassTag]( f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { - val newEdgePartitions = - edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) { - (ePartIter, vTableReplicatedIter) => - val (ePid, edgePartition) = ePartIter.next() - val (vPid, vPart) = vTableReplicatedIter.next() - assert(!vTableReplicatedIter.hasNext) - assert(ePid == vPid) - val et = new EdgeTriplet[VD, ED] - val inputIterator = edgePartition.iterator.map { e => - et.set(e) - et.srcAttr = vPart(e.srcId) - et.dstAttr = vPart(e.dstId) - et - } - // Apply the user function to the vertex partition - val outputIter = f(ePid, inputIterator) - // Consume the iterator to update the edge attributes - val newEdgePartition = edgePartition.map(outputIter) - Iterator((ePid, newEdgePartition)) - } - new GraphImpl(vertices, new EdgeRDD(newEdgePartitions), routingTable, replicatedVertexView) + vertices.cache() + val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") + val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") + replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => + part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + } + new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } override def subgraph( epred: EdgeTriplet[VD, ED] => Boolean = x => true, vpred: (VertexId, VD) => Boolean = (a, b) => true): Graph[VD, ED] = { + vertices.cache() // Filter the vertices, reusing the partitioner and the index from this graph val newVerts = vertices.mapVertexPartitions(_.filter(vpred)) - - // Filter the edges - val edTag = classTag[ED] - val newEdges = new EdgeRDD[ED](triplets.filter { et => - vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et) - }.mapPartitionsWithIndex( { (pid, iter) => - val builder = new EdgePartitionBuilder[ED]()(edTag) - iter.foreach { et => builder.add(et.srcId, et.dstId, et.attr) } - val edgePartition = builder.toEdgePartition - Iterator((pid, edgePartition)) - }, preservesPartitioning = true)).cache() - - // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been - // removed will be ignored, since we only refer to replicated vertices when they are adjacent to - // an edge. - new GraphImpl(newVerts, newEdges, new RoutingTable(newEdges, newVerts), replicatedVertexView) - } // end of subgraph + // Filter the triplets. We must always upgrade the triplet view fully because vpred always runs + // on both src and dst vertices + replicatedVertexView.upgrade(vertices, true, true) + val newEdges = replicatedVertexView.edges.filter(epred, vpred) + new GraphImpl(newVerts, replicatedVertexView.withEdges(newEdges)) + } override def mask[VD2: ClassTag, ED2: ClassTag] ( other: Graph[VD2, ED2]): Graph[VD, ED] = { val newVerts = vertices.innerJoin(other.vertices) { (vid, v, w) => v } - val newEdges = edges.innerJoin(other.edges) { (src, dst, v, w) => v } - // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been - // removed will be ignored, since we only refer to replicated vertices when they are adjacent to - // an edge. - new GraphImpl(newVerts, newEdges, routingTable, replicatedVertexView) + val newEdges = replicatedVertexView.edges.innerJoin(other.edges) { (src, dst, v, w) => v } + new GraphImpl(newVerts, replicatedVertexView.withEdges(newEdges)) } override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = { - ClosureCleaner.clean(merge) - val newETable = edges.mapEdgePartitions((pid, part) => part.groupEdges(merge)) - new GraphImpl(vertices, newETable, routingTable, replicatedVertexView) + val newEdges = replicatedVertexView.edges.mapEdgePartitions( + (pid, part) => part.groupEdges(merge)) + new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } // /////////////////////////////////////////////////////////////////////////////////////////////// @@ -199,68 +165,58 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( reduceFunc: (A, A) => A, activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { - ClosureCleaner.clean(mapFunc) - ClosureCleaner.clean(reduceFunc) + vertices.cache() // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - val vs = activeSetOpt match { + replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + val view = activeSetOpt match { case Some((activeSet, _)) => - replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr, activeSet) + replicatedVertexView.withActiveSet(activeSet) case None => - replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView } val activeDirectionOpt = activeSetOpt.map(_._2) // Map and combine. - val preAgg = edges.partitionsRDD.zipPartitions(vs, true) { (ePartIter, vPartIter) => - val (ePid, edgePartition) = ePartIter.next() - val (vPid, vPart) = vPartIter.next() - assert(!vPartIter.hasNext) - assert(ePid == vPid) - // Choose scan method - val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat - val edgeIter = activeDirectionOpt match { - case Some(EdgeDirection.Both) => - if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => vPart.isActive(srcVertexId)) - .filter(e => vPart.isActive(e.dstId)) - } else { - edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId)) - } - case Some(EdgeDirection.Either) => - // TODO: Because we only have a clustered index on the source vertex ID, we can't filter - // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.iterator.filter(e => vPart.isActive(e.srcId) || vPart.isActive(e.dstId)) - case Some(EdgeDirection.Out) => - if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => vPart.isActive(srcVertexId)) - } else { - edgePartition.iterator.filter(e => vPart.isActive(e.srcId)) - } - case Some(EdgeDirection.In) => - edgePartition.iterator.filter(e => vPart.isActive(e.dstId)) - case _ => // None - edgePartition.iterator - } - - // Scan edges and run the map function - val et = new EdgeTriplet[VD, ED] - val mapOutputs = edgeIter.flatMap { e => - et.set(e) - if (mapUsesSrcAttr) { - et.srcAttr = vPart(e.srcId) - } - if (mapUsesDstAttr) { - et.dstAttr = vPart(e.dstId) + val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap { + case (pid, edgePartition) => + // Choose scan method + val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat + val edgeIter = activeDirectionOpt match { + case Some(EdgeDirection.Both) => + if (activeFraction < 0.8) { + edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + .filter(e => edgePartition.isActive(e.dstId)) + } else { + edgePartition.iterator.filter(e => + edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId)) + } + case Some(EdgeDirection.Either) => + // TODO: Because we only have a clustered index on the source vertex ID, we can't filter + // the index here. Instead we have to scan all edges and then do the filter. + edgePartition.iterator.filter(e => + edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId)) + case Some(EdgeDirection.Out) => + if (activeFraction < 0.8) { + edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + } else { + edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId)) + } + case Some(EdgeDirection.In) => + edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId)) + case _ => // None + edgePartition.iterator } - mapFunc(et) - } - // Note: This doesn't allow users to send messages to arbitrary vertices. - vPart.aggregateUsingIndex(mapOutputs, reduceFunc).iterator - } + + // Scan edges and run the map function + val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) + .flatMap(mapFunc(_)) + // Note: This doesn't allow users to send messages to arbitrary vertices. + edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator + }).setName("GraphImpl.mapReduceTriplets - preAgg") // do the final reduction reusing the index map vertices.aggregateUsingIndex(preAgg, reduceFunc) @@ -268,20 +224,19 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) - (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = - { + (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { if (classTag[VD] equals classTag[VD2]) { + vertices.cache() // updateF preserves type, so we can use incremental replication - val newVerts = vertices.leftJoin(other)(updateF) + val newVerts = vertices.leftJoin(other)(updateF).cache() val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) - val newReplicatedVertexView = new ReplicatedVertexView[VD2]( - changedVerts, edges, routingTable, - Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]])) - new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView) + val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] + .updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) } else { // updateF does not preserve type, so we must re-replicate all vertices val newVerts = vertices.leftJoin(other)(updateF) - GraphImpl(newVerts, edges, routingTable) + GraphImpl(newVerts, replicatedVertexView.edges) } } @@ -298,73 +253,68 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( object GraphImpl { + /** Create a graph from edges, setting referenced vertices to `defaultVertexAttr`. */ def apply[VD: ClassTag, ED: ClassTag]( edges: RDD[Edge[ED]], - defaultVertexAttr: VD): GraphImpl[VD, ED] = - { + defaultVertexAttr: VD): GraphImpl[VD, ED] = { fromEdgeRDD(createEdgeRDD(edges), defaultVertexAttr) } + /** Create a graph from EdgePartitions, setting referenced vertices to `defaultVertexAttr`. */ def fromEdgePartitions[VD: ClassTag, ED: ClassTag]( - edgePartitions: RDD[(PartitionID, EdgePartition[ED])], + edgePartitions: RDD[(PartitionID, EdgePartition[ED, VD])], defaultVertexAttr: VD): GraphImpl[VD, ED] = { fromEdgeRDD(new EdgeRDD(edgePartitions), defaultVertexAttr) } + /** Create a graph from vertices and edges, setting missing vertices to `defaultVertexAttr`. */ def apply[VD: ClassTag, ED: ClassTag]( vertices: RDD[(VertexId, VD)], edges: RDD[Edge[ED]], - defaultVertexAttr: VD): GraphImpl[VD, ED] = - { - val edgeRDD = createEdgeRDD(edges).cache() - - // Get the set of all vids - val partitioner = Partitioner.defaultPartitioner(vertices) - val vPartitioned = vertices.partitionBy(partitioner) - val vidsFromEdges = collectVertexIdsFromEdges(edgeRDD, partitioner) - val vids = vPartitioned.zipPartitions(vidsFromEdges) { (vertexIter, vidsFromEdgesIter) => - vertexIter.map(_._1) ++ vidsFromEdgesIter.map(_._1) - } - - val vertexRDD = VertexRDD(vids, vPartitioned, defaultVertexAttr) - + defaultVertexAttr: VD): GraphImpl[VD, ED] = { + val edgeRDD = createEdgeRDD(edges)(classTag[ED], classTag[VD]).cache() + val vertexRDD = VertexRDD(vertices, edgeRDD, defaultVertexAttr) GraphImpl(vertexRDD, edgeRDD) } + /** Create a graph from a VertexRDD and an EdgeRDD with arbitrary replicated vertices. */ def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { - // Cache RDDs that are referenced multiple times - edges.cache() - - GraphImpl(vertices, edges, new RoutingTable(edges, vertices)) + edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { + // Convert the vertex partitions in edges to the correct type + val newEdges = edges.mapEdgePartitions( + (pid, part) => part.withVertices(part.vertices.map( + (vid, attr) => null.asInstanceOf[VD]))) + GraphImpl.fromExistingRDDs(vertices, newEdges) } - def apply[VD: ClassTag, ED: ClassTag]( + /** + * Create a graph from a VertexRDD and an EdgeRDD with the same replicated vertex type as the + * vertices. + */ + def fromExistingRDDs[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED], - routingTable: RoutingTable): GraphImpl[VD, ED] = { - // Cache RDDs that are referenced multiple times. `routingTable` is cached by default, so we - // don't cache it explicitly. - vertices.cache() - edges.cache() - - new GraphImpl( - vertices, edges, routingTable, new ReplicatedVertexView(vertices, edges, routingTable)) + edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = { + new GraphImpl(vertices, new ReplicatedVertexView(edges)) } /** - * Create the edge RDD, which is much more efficient for Java heap storage than the normal edges - * data structure (RDD[(VertexId, VertexId, ED)]). - * - * The edge RDD contains multiple partitions, and each partition contains only one RDD key-value - * pair: the key is the partition id, and the value is an EdgePartition object containing all the - * edges in a partition. + * Create a graph from an EdgeRDD with the correct vertex type, setting missing vertices to + * `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD. */ - private def createEdgeRDD[ED: ClassTag]( - edges: RDD[Edge[ED]]): EdgeRDD[ED] = { + private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( + edges: EdgeRDD[ED, VD], + defaultVertexAttr: VD): GraphImpl[VD, ED] = { + edges.cache() + val vertices = VertexRDD.fromEdges(edges, edges.partitions.size, defaultVertexAttr) + fromExistingRDDs(vertices, edges) + } + + /** Create an EdgeRDD from a set of edges. */ + private def createEdgeRDD[ED: ClassTag, VD: ClassTag]( + edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = { val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => - val builder = new EdgePartitionBuilder[ED] + val builder = new EdgePartitionBuilder[ED, VD] iter.foreach { e => builder.add(e.srcId, e.dstId, e.attr) } @@ -373,24 +323,4 @@ object GraphImpl { new EdgeRDD(edgePartitions) } - private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( - edges: EdgeRDD[ED], - defaultVertexAttr: VD): GraphImpl[VD, ED] = { - edges.cache() - // Get the set of all vids - val vids = collectVertexIdsFromEdges(edges, new HashPartitioner(edges.partitions.size)) - // Create the VertexRDD. - val vertices = VertexRDD(vids.mapValues(x => defaultVertexAttr)) - GraphImpl(vertices, edges) - } - - /** Collects all vids mentioned in edges and partitions them by partitioner. */ - private def collectVertexIdsFromEdges( - edges: EdgeRDD[_], - partitioner: Partitioner): RDD[(VertexId, Int)] = { - // TODO: Consider doing map side distinct before shuffle. - new ShuffledRDD[VertexId, Int, (VertexId, Int)]( - edges.collectVertexIds.map(vid => (vid, 0)), partitioner) - .setSerializer(new VertexIdMsgSerializer) - } } // end of object GraphImpl diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index c45ba3d2f8c24..1c6d7e59e9a27 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -89,7 +89,6 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) { } - private[graphx] object MsgRDDFunctions { implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = { @@ -99,18 +98,28 @@ object MsgRDDFunctions { implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = { new VertexBroadcastMsgRDDFunctions(rdd) } +} - def partitionForAggregation[T: ClassTag](msgs: RDD[(VertexId, T)], partitioner: Partitioner) = { - val rdd = new ShuffledRDD[VertexId, T, (VertexId, T)](msgs, partitioner) +private[graphx] +class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { + def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { + val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner) // Set a custom serializer if the data is of int or double type. - if (classTag[T] == ClassTag.Int) { + if (classTag[VD] == ClassTag.Int) { rdd.setSerializer(new IntAggMsgSerializer) - } else if (classTag[T] == ClassTag.Long) { + } else if (classTag[VD] == ClassTag.Long) { rdd.setSerializer(new LongAggMsgSerializer) - } else if (classTag[T] == ClassTag.Double) { + } else if (classTag[VD] == ClassTag.Double) { rdd.setSerializer(new DoubleAggMsgSerializer) } rdd } } + +private[graphx] +object VertexRDDFunctions { + implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = { + new VertexRDDFunctions(rdd) + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index a8154b63ce5fb..3a0bba1b93b41 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -21,192 +21,102 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet} import org.apache.spark.graphx._ /** - * A view of the vertices after they are shipped to the join sites specified in - * `vertexPlacement`. The resulting view is co-partitioned with `edges`. If `prevViewOpt` is - * specified, `updatedVerts` are treated as incremental updates to the previous view. Otherwise, a - * fresh view is created. - * - * The view is always cached (i.e., once it is evaluated, it remains materialized). This avoids - * constructing it twice if the user calls graph.triplets followed by graph.mapReduceTriplets, for - * example. However, it means iterative algorithms must manually call `Graph.unpersist` on previous - * iterations' graphs for best GC performance. See the implementation of - * [[org.apache.spark.graphx.Pregel]] for an example. + * Manages shipping vertex attributes to the edge partitions of an + * [[org.apache.spark.graphx.EdgeRDD]]. Vertex attributes may be partially shipped to construct a + * triplet view with vertex attributes on only one side, and they may be updated. An active vertex + * set may additionally be shipped to the edge partitions. Be careful not to store a reference to + * `edges`, since it may be modified when the attribute shipping level is upgraded. */ private[impl] -class ReplicatedVertexView[VD: ClassTag]( - updatedVerts: VertexRDD[VD], - edges: EdgeRDD[_], - routingTable: RoutingTable, - prevViewOpt: Option[ReplicatedVertexView[VD]] = None) { +class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( + var edges: EdgeRDD[ED, VD], + var hasSrcId: Boolean = false, + var hasDstId: Boolean = false) { /** - * Within each edge partition, create a local map from vid to an index into the attribute - * array. Each map contains a superset of the vertices that it will receive, because it stores - * vids from both the source and destination of edges. It must always include both source and - * destination vids because some operations, such as GraphImpl.mapReduceTriplets, rely on this. + * Return a new `ReplicatedVertexView` with the specified `EdgeRDD`, which must have the same + * shipping level. */ - private val localVertexIdMap: RDD[(Int, VertexIdToIndexMap)] = prevViewOpt match { - case Some(prevView) => - prevView.localVertexIdMap - case None => - edges.partitionsRDD.mapPartitions(_.map { - case (pid, epart) => - val vidToIndex = new VertexIdToIndexMap - epart.foreach { e => - vidToIndex.add(e.srcId) - vidToIndex.add(e.dstId) - } - (pid, vidToIndex) - }, preservesPartitioning = true).cache().setName("ReplicatedVertexView localVertexIdMap") - } - - private lazy val bothAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(true, true) - private lazy val srcAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(true, false) - private lazy val dstAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(false, true) - private lazy val noAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(false, false) - - def unpersist(blocking: Boolean = true): ReplicatedVertexView[VD] = { - bothAttrs.unpersist(blocking) - srcAttrOnly.unpersist(blocking) - dstAttrOnly.unpersist(blocking) - noAttrs.unpersist(blocking) - // Don't unpersist localVertexIdMap because a future ReplicatedVertexView may be using it - // without modification - this + def withEdges[VD2: ClassTag, ED2: ClassTag]( + edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + new ReplicatedVertexView(edges_, hasSrcId, hasDstId) } - def get(includeSrc: Boolean, includeDst: Boolean): RDD[(PartitionID, VertexPartition[VD])] = { - (includeSrc, includeDst) match { - case (true, true) => bothAttrs - case (true, false) => srcAttrOnly - case (false, true) => dstAttrOnly - case (false, false) => noAttrs - } + /** + * Return a new `ReplicatedVertexView` where edges are reversed and shipping levels are swapped to + * match. + */ + def reverse() = { + val newEdges = edges.mapEdgePartitions((pid, part) => part.reverse) + new ReplicatedVertexView(newEdges, hasDstId, hasSrcId) } - def get( - includeSrc: Boolean, - includeDst: Boolean, - actives: VertexRDD[_]): RDD[(PartitionID, VertexPartition[VD])] = { - // Ship active sets to edge partitions using vertexPlacement, but ignoring includeSrc and - // includeDst. These flags govern attribute shipping, but the activeness of a vertex must be - // shipped to all edges mentioning that vertex, regardless of whether the vertex attribute is - // also shipped there. - val shippedActives = routingTable.get(true, true) - .zipPartitions(actives.partitionsRDD)(ReplicatedVertexView.buildActiveBuffer(_, _)) - .partitionBy(edges.partitioner.get) - // Update the view with shippedActives, setting activeness flags in the resulting - // VertexPartitions - get(includeSrc, includeDst).zipPartitions(shippedActives) { (viewIter, shippedActivesIter) => - val (pid, vPart) = viewIter.next() - val newPart = vPart.replaceActives(shippedActivesIter.flatMap(_._2.iterator)) - Iterator((pid, newPart)) + /** + * Upgrade the shipping level in-place to the specified levels by shipping vertex attributes from + * `vertices`. This operation modifies the `ReplicatedVertexView`, and callers can access `edges` + * afterwards to obtain the upgraded view. + */ + def upgrade(vertices: VertexRDD[VD], includeSrc: Boolean, includeDst: Boolean) { + val shipSrc = includeSrc && !hasSrcId + val shipDst = includeDst && !hasDstId + if (shipSrc || shipDst) { + val shippedVerts: RDD[(Int, VertexAttributeBlock[VD])] = + vertices.shipVertexAttributes(shipSrc, shipDst) + .setName("ReplicatedVertexView.upgrade(%s, %s) - shippedVerts %s %s (broadcast)".format( + includeSrc, includeDst, shipSrc, shipDst)) + .partitionBy(edges.partitioner.get) + val newEdges = new EdgeRDD(edges.partitionsRDD.zipPartitions(shippedVerts) { + (ePartIter, shippedVertsIter) => ePartIter.map { + case (pid, edgePartition) => + (pid, edgePartition.updateVertices(shippedVertsIter.flatMap(_._2.iterator))) + } + }) + edges = newEdges + hasSrcId = includeSrc + hasDstId = includeDst } } - private def create(includeSrc: Boolean, includeDst: Boolean) - : RDD[(PartitionID, VertexPartition[VD])] = { - val vdTag = classTag[VD] - - // Ship vertex attributes to edge partitions according to vertexPlacement - val verts = updatedVerts.partitionsRDD - val shippedVerts = routingTable.get(includeSrc, includeDst) - .zipPartitions(verts)(ReplicatedVertexView.buildBuffer(_, _)(vdTag)) + /** + * Return a new `ReplicatedVertexView` where the `activeSet` in each edge partition contains only + * vertex ids present in `actives`. This ships a vertex id to all edge partitions where it is + * referenced, ignoring the attribute shipping level. + */ + def withActiveSet(actives: VertexRDD[_]): ReplicatedVertexView[VD, ED] = { + val shippedActives = actives.shipVertexIds() + .setName("ReplicatedVertexView.withActiveSet - shippedActives (broadcast)") .partitionBy(edges.partitioner.get) - // TODO: Consider using a specialized shuffler. - - prevViewOpt match { - case Some(prevView) => - // Update prevView with shippedVerts, setting staleness flags in the resulting - // VertexPartitions - prevView.get(includeSrc, includeDst).zipPartitions(shippedVerts) { - (prevViewIter, shippedVertsIter) => - val (pid, prevVPart) = prevViewIter.next() - val newVPart = prevVPart.innerJoinKeepLeft(shippedVertsIter.flatMap(_._2.iterator)) - Iterator((pid, newVPart)) - }.cache().setName("ReplicatedVertexView delta %s %s".format(includeSrc, includeDst)) - case None => - // Within each edge partition, place the shipped vertex attributes into the correct - // locations specified in localVertexIdMap - localVertexIdMap.zipPartitions(shippedVerts) { (mapIter, shippedVertsIter) => - val (pid, vidToIndex) = mapIter.next() - assert(!mapIter.hasNext) - // Populate the vertex array using the vidToIndex map - val vertexArray = vdTag.newArray(vidToIndex.capacity) - for ((_, block) <- shippedVertsIter) { - for (i <- 0 until block.vids.size) { - val vid = block.vids(i) - val attr = block.attrs(i) - val ind = vidToIndex.getPos(vid) - vertexArray(ind) = attr - } - } - val newVPart = new VertexPartition( - vidToIndex, vertexArray, vidToIndex.getBitSet)(vdTag) - Iterator((pid, newVPart)) - }.cache().setName("ReplicatedVertexView %s %s".format(includeSrc, includeDst)) - } - } -} - -private object ReplicatedVertexView { - protected def buildBuffer[VD: ClassTag]( - pid2vidIter: Iterator[Array[Array[VertexId]]], - vertexPartIter: Iterator[VertexPartition[VD]]) = { - val pid2vid: Array[Array[VertexId]] = pid2vidIter.next() - val vertexPart: VertexPartition[VD] = vertexPartIter.next() - - Iterator.tabulate(pid2vid.size) { pid => - val vidsCandidate = pid2vid(pid) - val size = vidsCandidate.length - val vids = new PrimitiveVector[VertexId](pid2vid(pid).size) - val attrs = new PrimitiveVector[VD](pid2vid(pid).size) - var i = 0 - while (i < size) { - val vid = vidsCandidate(i) - if (vertexPart.isDefined(vid)) { - vids += vid - attrs += vertexPart(vid) - } - i += 1 + val newEdges = new EdgeRDD(edges.partitionsRDD.zipPartitions(shippedActives) { + (ePartIter, shippedActivesIter) => ePartIter.map { + case (pid, edgePartition) => + (pid, edgePartition.withActiveSet(shippedActivesIter.flatMap(_._2.iterator))) } - (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) - } + }) + new ReplicatedVertexView(newEdges, hasSrcId, hasDstId) } - protected def buildActiveBuffer( - pid2vidIter: Iterator[Array[Array[VertexId]]], - activePartIter: Iterator[VertexPartition[_]]) - : Iterator[(Int, Array[VertexId])] = { - val pid2vid: Array[Array[VertexId]] = pid2vidIter.next() - val activePart: VertexPartition[_] = activePartIter.next() + /** + * Return a new `ReplicatedVertexView` where vertex attributes in edge partition are updated using + * `updates`. This ships a vertex attribute only to the edge partitions where it is in the + * position(s) specified by the attribute shipping level. + */ + def updateVertices(updates: VertexRDD[VD]): ReplicatedVertexView[VD, ED] = { + val shippedVerts = updates.shipVertexAttributes(hasSrcId, hasDstId) + .setName("ReplicatedVertexView.updateVertices - shippedVerts %s %s (broadcast)".format( + hasSrcId, hasDstId)) + .partitionBy(edges.partitioner.get) - Iterator.tabulate(pid2vid.size) { pid => - val vidsCandidate = pid2vid(pid) - val size = vidsCandidate.length - val actives = new PrimitiveVector[VertexId](vidsCandidate.size) - var i = 0 - while (i < size) { - val vid = vidsCandidate(i) - if (activePart.isDefined(vid)) { - actives += vid - } - i += 1 + val newEdges = new EdgeRDD(edges.partitionsRDD.zipPartitions(shippedVerts) { + (ePartIter, shippedVertsIter) => ePartIter.map { + case (pid, edgePartition) => + (pid, edgePartition.updateVertices(shippedVertsIter.flatMap(_._2.iterator))) } - (pid, actives.trim().array) - } + }) + new ReplicatedVertexView(newEdges, hasSrcId, hasDstId) } } - -private[graphx] -class VertexAttributeBlock[VD: ClassTag](val vids: Array[VertexId], val attrs: Array[VD]) - extends Serializable { - def iterator: Iterator[(VertexId, VD)] = - (0 until vids.size).iterator.map { i => (vids(i), attrs(i)) } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala deleted file mode 100644 index 022d5668e2942..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import org.apache.spark.SparkContext._ -import org.apache.spark.graphx._ -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.collection.PrimitiveVector - -/** - * Stores the locations of edge-partition join sites for each vertex attribute; that is, the routing - * information for shipping vertex attributes to edge partitions. This is always cached because it - * may be used multiple times in ReplicatedVertexView -- once to ship the vertex attributes and - * (possibly) once to ship the active-set information. - */ -private[impl] -class RoutingTable(edges: EdgeRDD[_], vertices: VertexRDD[_]) { - - val bothAttrs: RDD[Array[Array[VertexId]]] = createPid2Vid(true, true) - val srcAttrOnly: RDD[Array[Array[VertexId]]] = createPid2Vid(true, false) - val dstAttrOnly: RDD[Array[Array[VertexId]]] = createPid2Vid(false, true) - val noAttrs: RDD[Array[Array[VertexId]]] = createPid2Vid(false, false) - - def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexId]]] = - (includeSrcAttr, includeDstAttr) match { - case (true, true) => bothAttrs - case (true, false) => srcAttrOnly - case (false, true) => dstAttrOnly - case (false, false) => noAttrs - } - - private def createPid2Vid( - includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexId]]] = { - // Determine which vertices each edge partition needs by creating a mapping from vid to pid. - val vid2pid: RDD[(VertexId, PartitionID)] = edges.partitionsRDD.mapPartitions { iter => - val (pid: PartitionID, edgePartition: EdgePartition[_]) = iter.next() - val numEdges = edgePartition.size - val vSet = new VertexSet - if (includeSrcAttr) { // Add src vertices to the set. - var i = 0 - while (i < numEdges) { - vSet.add(edgePartition.srcIds(i)) - i += 1 - } - } - if (includeDstAttr) { // Add dst vertices to the set. - var i = 0 - while (i < numEdges) { - vSet.add(edgePartition.dstIds(i)) - i += 1 - } - } - vSet.iterator.map { vid => (vid, pid) } - } - - val numEdgePartitions = edges.partitions.size - vid2pid.partitionBy(vertices.partitioner.get).mapPartitions { iter => - val pid2vid = Array.fill(numEdgePartitions)(new PrimitiveVector[VertexId]) - for ((vid, pid) <- iter) { - pid2vid(pid) += vid - } - - Iterator(pid2vid.map(_.trim().array)) - }.cache().setName("RoutingTable %s %s".format(includeSrcAttr, includeDstAttr)) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala new file mode 100644 index 0000000000000..927e32ad0f448 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag + +import org.apache.spark.Partitioner +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} + +import org.apache.spark.graphx._ +import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap + +/** + * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that + * the edge partition references `vid` in the specified `position` (src, dst, or both). +*/ +private[graphx] +class RoutingTableMessage( + var vid: VertexId, + var pid: PartitionID, + var position: Byte) + extends Product2[VertexId, (PartitionID, Byte)] with Serializable { + override def _1 = vid + override def _2 = (pid, position) + override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage] +} + +private[graphx] +class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { + /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ + def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { + new ShuffledRDD[VertexId, (PartitionID, Byte), RoutingTableMessage](self, partitioner) + .setSerializer(new RoutingTableMessageSerializer) + } +} + +private[graphx] +object RoutingTableMessageRDDFunctions { + import scala.language.implicitConversions + + implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = { + new RoutingTableMessageRDDFunctions(rdd) + } +} + +private[graphx] +object RoutingTablePartition { + val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty) + + /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */ + def edgePartitionToMsgs(pid: PartitionID, edgePartition: EdgePartition[_, _]) + : Iterator[RoutingTableMessage] = { + // Determine which positions each vertex id appears in using a map where the low 2 bits + // represent src and dst + val map = new PrimitiveKeyOpenHashMap[VertexId, Byte] + edgePartition.srcIds.iterator.foreach { srcId => + map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) + } + edgePartition.dstIds.iterator.foreach { dstId => + map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) + } + map.iterator.map { vidAndPosition => + new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2) + } + } + + /** Build a `RoutingTablePartition` from `RoutingTableMessage`s. */ + def fromMsgs(numEdgePartitions: Int, iter: Iterator[RoutingTableMessage]) + : RoutingTablePartition = { + val pid2vid = Array.fill(numEdgePartitions)(new PrimitiveVector[VertexId]) + val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) + val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean]) + for (msg <- iter) { + pid2vid(msg.pid) += msg.vid + srcFlags(msg.pid) += (msg.position & 0x1) != 0 + dstFlags(msg.pid) += (msg.position & 0x2) != 0 + } + + new RoutingTablePartition(pid2vid.zipWithIndex.map { + case (vids, pid) => (vids.trim().array, toBitSet(srcFlags(pid)), toBitSet(dstFlags(pid))) + }) + } + + /** Compact the given vector of Booleans into a BitSet. */ + private def toBitSet(flags: PrimitiveVector[Boolean]): BitSet = { + val bitset = new BitSet(flags.size) + var i = 0 + while (i < flags.size) { + if (flags(i)) { + bitset.set(i) + } + i += 1 + } + bitset + } +} + +/** + * Stores the locations of edge-partition join sites for each vertex attribute in a particular + * vertex partition. This provides routing information for shipping vertex attributes to edge + * partitions. + */ +private[graphx] +class RoutingTablePartition( + private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) { + /** The maximum number of edge partitions this `RoutingTablePartition` is built to join with. */ + val numEdgePartitions: Int = routingTable.size + + /** Returns the number of vertices that will be sent to the specified edge partition. */ + def partitionSize(pid: PartitionID): Int = routingTable(pid)._1.size + + /** Returns an iterator over all vertex ids stored in this `RoutingTablePartition`. */ + def iterator: Iterator[VertexId] = routingTable.iterator.flatMap(_._1.iterator) + + /** Returns a new RoutingTablePartition reflecting a reversal of all edge directions. */ + def reverse: RoutingTablePartition = { + new RoutingTablePartition(routingTable.map { + case (vids, srcVids, dstVids) => (vids, dstVids, srcVids) + }) + } + + /** + * Runs `f` on each vertex id to be sent to the specified edge partition. Vertex ids can be + * filtered by the position they have in the edge partition. + */ + def foreachWithinEdgePartition + (pid: PartitionID, includeSrc: Boolean, includeDst: Boolean) + (f: VertexId => Unit) { + val (vidsCandidate, srcVids, dstVids) = routingTable(pid) + val size = vidsCandidate.length + if (includeSrc && includeDst) { + // Avoid checks for performance + vidsCandidate.iterator.foreach(f) + } else if (!includeSrc && !includeDst) { + // Do nothing + } else { + val relevantVids = if (includeSrc) srcVids else dstVids + relevantVids.iterator.foreach { i => f(vidsCandidate(i)) } + } + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 1de42eeca1f00..033237f597216 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -27,6 +27,35 @@ import scala.reflect.ClassTag import org.apache.spark.graphx._ import org.apache.spark.serializer._ +private[graphx] +class RoutingTableMessageSerializer extends Serializer with Serializable { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream): SerializationStream = + new ShuffleSerializationStream(s) { + def writeObject[T: ClassTag](t: T): SerializationStream = { + val msg = t.asInstanceOf[RoutingTableMessage] + writeVarLong(msg.vid, optimizePositive = false) + writeUnsignedVarInt(msg.pid) + // TODO: Write only the bottom two bits of msg.position + s.write(msg.position) + this + } + } + + override def deserializeStream(s: InputStream): DeserializationStream = + new ShuffleDeserializationStream(s) { + override def readObject[T: ClassTag](): T = { + val a = readVarLong(optimizePositive = false) + val b = readUnsignedVarInt() + val c = s.read() + if (c == -1) throw new EOFException + new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T] + } + } + } +} + private[graphx] class VertexIdMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala new file mode 100644 index 0000000000000..f4e221d4e05ae --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag + +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} + +import org.apache.spark.graphx._ +import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap + +/** Stores vertex attributes to ship to an edge partition. */ +private[graphx] +class VertexAttributeBlock[VD: ClassTag](val vids: Array[VertexId], val attrs: Array[VD]) + extends Serializable { + def iterator: Iterator[(VertexId, VD)] = + (0 until vids.size).iterator.map { i => (vids(i), attrs(i)) } +} + +private[graphx] +object ShippableVertexPartition { + /** Construct a `ShippableVertexPartition` from the given vertices without any routing table. */ + def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]): ShippableVertexPartition[VD] = + apply(iter, RoutingTablePartition.empty, null.asInstanceOf[VD]) + + /** + * Construct a `ShippableVertexPartition` from the given vertices with the specified routing + * table, filling in missing vertices mentioned in the routing table using `defaultVal`. + */ + def apply[VD: ClassTag]( + iter: Iterator[(VertexId, VD)], routingTable: RoutingTablePartition, defaultVal: VD) + : ShippableVertexPartition[VD] = { + val fullIter = iter ++ routingTable.iterator.map(vid => (vid, defaultVal)) + val (index, values, mask) = VertexPartitionBase.initFrom(fullIter, (a: VD, b: VD) => a) + new ShippableVertexPartition(index, values, mask, routingTable) + } + + import scala.language.implicitConversions + + /** + * Implicit conversion to allow invoking `VertexPartitionBase` operations directly on a + * `ShippableVertexPartition`. + */ + implicit def shippablePartitionToOps[VD: ClassTag](partition: ShippableVertexPartition[VD]) = + new ShippableVertexPartitionOps(partition) + + /** + * Implicit evidence that `ShippableVertexPartition` is a member of the + * `VertexPartitionBaseOpsConstructor` typeclass. This enables invoking `VertexPartitionBase` + * operations on a `ShippableVertexPartition` via an evidence parameter, as in + * [[VertexPartitionBaseOps]]. + */ + implicit object ShippableVertexPartitionOpsConstructor + extends VertexPartitionBaseOpsConstructor[ShippableVertexPartition] { + def toOps[VD: ClassTag](partition: ShippableVertexPartition[VD]) + : VertexPartitionBaseOps[VD, ShippableVertexPartition] = shippablePartitionToOps(partition) + } +} + +/** + * A map from vertex id to vertex attribute that additionally stores edge partition join sites for + * each vertex attribute, enabling joining with an [[org.apache.spark.graphx.EdgeRDD]]. + */ +private[graphx] +class ShippableVertexPartition[VD: ClassTag]( + val index: VertexIdToIndexMap, + val values: Array[VD], + val mask: BitSet, + val routingTable: RoutingTablePartition) + extends VertexPartitionBase[VD] { + + /** Return a new ShippableVertexPartition with the specified routing table. */ + def withRoutingTable(routingTable_ : RoutingTablePartition): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(index, values, mask, routingTable_) + } + + /** + * Generate a `VertexAttributeBlock` for each edge partition keyed on the edge partition ID. The + * `VertexAttributeBlock` contains the vertex attributes from the current partition that are + * referenced in the specified positions in the edge partition. + */ + def shipVertexAttributes( + shipSrc: Boolean, shipDst: Boolean): Iterator[(PartitionID, VertexAttributeBlock[VD])] = { + Iterator.tabulate(routingTable.numEdgePartitions) { pid => + val initialSize = if (shipSrc && shipDst) routingTable.partitionSize(pid) else 64 + val vids = new PrimitiveVector[VertexId](initialSize) + val attrs = new PrimitiveVector[VD](initialSize) + var i = 0 + routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid => + if (isDefined(vid)) { + vids += vid + attrs += this(vid) + } + i += 1 + } + (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array)) + } + } + + /** + * Generate a `VertexId` array for each edge partition keyed on the edge partition ID. The array + * contains the visible vertex ids from the current partition that are referenced in the edge + * partition. + */ + def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = { + Iterator.tabulate(routingTable.numEdgePartitions) { pid => + val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid)) + var i = 0 + routingTable.foreachWithinEdgePartition(pid, true, true) { vid => + if (isDefined(vid)) { + vids += vid + } + i += 1 + } + (pid, vids.trim().array) + } + } +} + +private[graphx] class ShippableVertexPartitionOps[VD: ClassTag](self: ShippableVertexPartition[VD]) + extends VertexPartitionBaseOps[VD, ShippableVertexPartition](self) { + + def withIndex(index: VertexIdToIndexMap): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(index, self.values, self.mask, self.routingTable) + } + + def withValues[VD2: ClassTag](values: Array[VD2]): ShippableVertexPartition[VD2] = { + new ShippableVertexPartition(self.index, values, self.mask, self.routingTable) + } + + def withMask(mask: BitSet): ShippableVertexPartition[VD] = { + new ShippableVertexPartition(self.index, self.values, mask, self.routingTable) + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index 7a54b413dc8ca..f1d174720a1ba 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -19,260 +19,59 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.util.collection.BitSet + import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartition { - - def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]): VertexPartition[VD] = { - val map = new PrimitiveKeyOpenHashMap[VertexId, VD] - iter.foreach { case (k, v) => - map(k) = v - } - new VertexPartition(map.keySet, map._values, map.keySet.getBitSet) - } - - def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)], mergeFunc: (VD, VD) => VD) - : VertexPartition[VD] = - { - val map = new PrimitiveKeyOpenHashMap[VertexId, VD] - iter.foreach { case (k, v) => - map.setMerge(k, v, mergeFunc) - } - new VertexPartition(map.keySet, map._values, map.keySet.getBitSet) - } -} - - -private[graphx] -class VertexPartition[@specialized(Long, Int, Double) VD: ClassTag]( - val index: VertexIdToIndexMap, - val values: Array[VD], - val mask: BitSet, - /** A set of vids of active vertices. May contain vids not in index due to join rewrite. */ - private val activeSet: Option[VertexSet] = None) - extends Logging { - - val capacity: Int = index.capacity - - def size: Int = mask.cardinality() - - /** Return the vertex attribute for the given vertex ID. */ - def apply(vid: VertexId): VD = values(index.getPos(vid)) - - def isDefined(vid: VertexId): Boolean = { - val pos = index.getPos(vid) - pos >= 0 && mask.get(pos) - } - - /** Look up vid in activeSet, throwing an exception if it is None. */ - def isActive(vid: VertexId): Boolean = { - activeSet.get.contains(vid) + /** Construct a `VertexPartition` from the given vertices. */ + def apply[VD: ClassTag](iter: Iterator[(VertexId, VD)]) + : VertexPartition[VD] = { + val (index, values, mask) = VertexPartitionBase.initFrom(iter) + new VertexPartition(index, values, mask) } - /** The number of active vertices, if any exist. */ - def numActives: Option[Int] = activeSet.map(_.size) + import scala.language.implicitConversions /** - * Pass each vertex attribute along with the vertex id through a map - * function and retain the original RDD's partitioning and index. - * - * @tparam VD2 the type returned by the map function - * - * @param f the function applied to each vertex id and vertex - * attribute in the RDD - * - * @return a new VertexPartition with values obtained by applying `f` to - * each of the entries in the original VertexRDD. The resulting - * VertexPartition retains the same index. + * Implicit conversion to allow invoking `VertexPartitionBase` operations directly on a + * `VertexPartition`. */ - def map[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexPartition[VD2] = { - // Construct a view of the map transformation - val newValues = new Array[VD2](capacity) - var i = mask.nextSetBit(0) - while (i >= 0) { - newValues(i) = f(index.getValue(i), values(i)) - i = mask.nextSetBit(i + 1) - } - new VertexPartition[VD2](index, newValues, mask) - } - - /** - * Restrict the vertex set to the set of vertices satisfying the given predicate. - * - * @param pred the user defined predicate - * - * @note The vertex set preserves the original index structure which means that the returned - * RDD can be easily joined with the original vertex-set. Furthermore, the filter only - * modifies the bitmap index and so no new values are allocated. - */ - def filter(pred: (VertexId, VD) => Boolean): VertexPartition[VD] = { - // Allocate the array to store the results into - val newMask = new BitSet(capacity) - // Iterate over the active bits in the old mask and evaluate the predicate - var i = mask.nextSetBit(0) - while (i >= 0) { - if (pred(index.getValue(i), values(i))) { - newMask.set(i) - } - i = mask.nextSetBit(i + 1) - } - new VertexPartition(index, values, newMask) - } + implicit def partitionToOps[VD: ClassTag](partition: VertexPartition[VD]) = + new VertexPartitionOps(partition) /** - * Hides vertices that are the same between this and other. For vertices that are different, keeps - * the values from `other`. The indices of `this` and `other` must be the same. + * Implicit evidence that `VertexPartition` is a member of the `VertexPartitionBaseOpsConstructor` + * typeclass. This enables invoking `VertexPartitionBase` operations on a `VertexPartition` via an + * evidence parameter, as in [[VertexPartitionBaseOps]]. */ - def diff(other: VertexPartition[VD]): VertexPartition[VD] = { - if (index != other.index) { - logWarning("Diffing two VertexPartitions with different indexes is slow.") - diff(createUsingIndex(other.iterator)) - } else { - val newMask = mask & other.mask - var i = newMask.nextSetBit(0) - while (i >= 0) { - if (values(i) == other.values(i)) { - newMask.unset(i) - } - i = newMask.nextSetBit(i + 1) - } - new VertexPartition(index, other.values, newMask) - } - } - - /** Left outer join another VertexPartition. */ - def leftJoin[VD2: ClassTag, VD3: ClassTag] - (other: VertexPartition[VD2]) - (f: (VertexId, VD, Option[VD2]) => VD3): VertexPartition[VD3] = { - if (index != other.index) { - logWarning("Joining two VertexPartitions with different indexes is slow.") - leftJoin(createUsingIndex(other.iterator))(f) - } else { - val newValues = new Array[VD3](capacity) - - var i = mask.nextSetBit(0) - while (i >= 0) { - val otherV: Option[VD2] = if (other.mask.get(i)) Some(other.values(i)) else None - newValues(i) = f(index.getValue(i), values(i), otherV) - i = mask.nextSetBit(i + 1) - } - new VertexPartition(index, newValues, mask) - } - } - - /** Left outer join another iterator of messages. */ - def leftJoin[VD2: ClassTag, VD3: ClassTag] - (other: Iterator[(VertexId, VD2)]) - (f: (VertexId, VD, Option[VD2]) => VD3): VertexPartition[VD3] = { - leftJoin(createUsingIndex(other))(f) - } - - /** Inner join another VertexPartition. */ - def innerJoin[U: ClassTag, VD2: ClassTag](other: VertexPartition[U]) - (f: (VertexId, VD, U) => VD2): VertexPartition[VD2] = { - if (index != other.index) { - logWarning("Joining two VertexPartitions with different indexes is slow.") - innerJoin(createUsingIndex(other.iterator))(f) - } else { - val newMask = mask & other.mask - val newValues = new Array[VD2](capacity) - var i = newMask.nextSetBit(0) - while (i >= 0) { - newValues(i) = f(index.getValue(i), values(i), other.values(i)) - i = newMask.nextSetBit(i + 1) - } - new VertexPartition(index, newValues, newMask) - } - } - - /** - * Inner join an iterator of messages. - */ - def innerJoin[U: ClassTag, VD2: ClassTag] - (iter: Iterator[Product2[VertexId, U]]) - (f: (VertexId, VD, U) => VD2): VertexPartition[VD2] = { - innerJoin(createUsingIndex(iter))(f) + implicit object VertexPartitionOpsConstructor + extends VertexPartitionBaseOpsConstructor[VertexPartition] { + def toOps[VD: ClassTag](partition: VertexPartition[VD]) + : VertexPartitionBaseOps[VD, VertexPartition] = partitionToOps(partition) } +} - /** - * Similar effect as aggregateUsingIndex((a, b) => a) - */ - def createUsingIndex[VD2: ClassTag](iter: Iterator[Product2[VertexId, VD2]]) - : VertexPartition[VD2] = { - val newMask = new BitSet(capacity) - val newValues = new Array[VD2](capacity) - iter.foreach { case (vid, vdata) => - val pos = index.getPos(vid) - if (pos >= 0) { - newMask.set(pos) - newValues(pos) = vdata - } - } - new VertexPartition[VD2](index, newValues, newMask) - } +/** A map from vertex id to vertex attribute. */ +private[graphx] class VertexPartition[VD: ClassTag]( + val index: VertexIdToIndexMap, + val values: Array[VD], + val mask: BitSet) + extends VertexPartitionBase[VD] - /** - * Similar to innerJoin, but vertices from the left side that don't appear in iter will remain in - * the partition, hidden by the bitmask. - */ - def innerJoinKeepLeft(iter: Iterator[Product2[VertexId, VD]]): VertexPartition[VD] = { - val newMask = new BitSet(capacity) - val newValues = new Array[VD](capacity) - System.arraycopy(values, 0, newValues, 0, newValues.length) - iter.foreach { case (vid, vdata) => - val pos = index.getPos(vid) - if (pos >= 0) { - newMask.set(pos) - newValues(pos) = vdata - } - } - new VertexPartition(index, newValues, newMask) - } +private[graphx] class VertexPartitionOps[VD: ClassTag](self: VertexPartition[VD]) + extends VertexPartitionBaseOps[VD, VertexPartition](self) { - def aggregateUsingIndex[VD2: ClassTag]( - iter: Iterator[Product2[VertexId, VD2]], - reduceFunc: (VD2, VD2) => VD2): VertexPartition[VD2] = { - val newMask = new BitSet(capacity) - val newValues = new Array[VD2](capacity) - iter.foreach { product => - val vid = product._1 - val vdata = product._2 - val pos = index.getPos(vid) - if (pos >= 0) { - if (newMask.get(pos)) { - newValues(pos) = reduceFunc(newValues(pos), vdata) - } else { // otherwise just store the new value - newMask.set(pos) - newValues(pos) = vdata - } - } - } - new VertexPartition[VD2](index, newValues, newMask) + def withIndex(index: VertexIdToIndexMap): VertexPartition[VD] = { + new VertexPartition(index, self.values, self.mask) } - def replaceActives(iter: Iterator[VertexId]): VertexPartition[VD] = { - val newActiveSet = new VertexSet - iter.foreach(newActiveSet.add(_)) - new VertexPartition(index, values, mask, Some(newActiveSet)) + def withValues[VD2: ClassTag](values: Array[VD2]): VertexPartition[VD2] = { + new VertexPartition(self.index, values, self.mask) } - /** - * Construct a new VertexPartition whose index contains only the vertices in the mask. - */ - def reindex(): VertexPartition[VD] = { - val hashMap = new PrimitiveKeyOpenHashMap[VertexId, VD] - val arbitraryMerge = (a: VD, b: VD) => a - for ((k, v) <- this.iterator) { - hashMap.setMerge(k, v, arbitraryMerge) - } - new VertexPartition(hashMap.keySet, hashMap._values, hashMap.keySet.getBitSet) + def withMask(mask: BitSet): VertexPartition[VD] = { + new VertexPartition(self.index, self.values, mask) } - - def iterator: Iterator[(VertexId, VD)] = - mask.iterator.map(ind => (index.getValue(ind), values(ind))) - - def vidIterator: Iterator[VertexId] = mask.iterator.map(ind => index.getValue(ind)) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala new file mode 100644 index 0000000000000..8d9e0204d27f2 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.language.higherKinds +import scala.reflect.ClassTag + +import org.apache.spark.util.collection.BitSet + +import org.apache.spark.graphx._ +import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap + +private[graphx] object VertexPartitionBase { + /** + * Construct the constituents of a VertexPartitionBase from the given vertices, merging duplicate + * entries arbitrarily. + */ + def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)]) + : (VertexIdToIndexMap, Array[VD], BitSet) = { + val map = new PrimitiveKeyOpenHashMap[VertexId, VD] + iter.foreach { pair => + map(pair._1) = pair._2 + } + (map.keySet, map._values, map.keySet.getBitSet) + } + + /** + * Construct the constituents of a VertexPartitionBase from the given vertices, merging duplicate + * entries using `mergeFunc`. + */ + def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)], mergeFunc: (VD, VD) => VD) + : (VertexIdToIndexMap, Array[VD], BitSet) = { + val map = new PrimitiveKeyOpenHashMap[VertexId, VD] + iter.foreach { pair => + map.setMerge(pair._1, pair._2, mergeFunc) + } + (map.keySet, map._values, map.keySet.getBitSet) + } +} + +/** + * An abstract map from vertex id to vertex attribute. [[VertexPartition]] is the corresponding + * concrete implementation. [[VertexPartitionBaseOps]] provides a variety of operations for + * VertexPartitionBase and subclasses that provide implicit evidence of membership in the + * `VertexPartitionBaseOpsConstructor` typeclass (for example, + * [[VertexPartition.VertexPartitionOpsConstructor]]). + */ +private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] { + + def index: VertexIdToIndexMap + def values: Array[VD] + def mask: BitSet + + val capacity: Int = index.capacity + + def size: Int = mask.cardinality() + + /** Return the vertex attribute for the given vertex ID. */ + def apply(vid: VertexId): VD = values(index.getPos(vid)) + + def isDefined(vid: VertexId): Boolean = { + val pos = index.getPos(vid) + pos >= 0 && mask.get(pos) + } + + def iterator: Iterator[(VertexId, VD)] = + mask.iterator.map(ind => (index.getValue(ind), values(ind))) +} + +/** + * A typeclass for subclasses of `VertexPartitionBase` representing the ability to wrap them in a + * `VertexPartitionBaseOps`. + */ +private[graphx] trait VertexPartitionBaseOpsConstructor[T[X] <: VertexPartitionBase[X]] { + def toOps[VD: ClassTag](partition: T[VD]): VertexPartitionBaseOps[VD, T] +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala new file mode 100644 index 0000000000000..21ff615feca6c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.language.higherKinds +import scala.language.implicitConversions +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.util.collection.BitSet + +import org.apache.spark.graphx._ +import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap + +/** + * An class containing additional operations for subclasses of VertexPartitionBase that provide + * implicit evidence of membership in the `VertexPartitionBaseOpsConstructor` typeclass (for + * example, [[VertexPartition.VertexPartitionOpsConstructor]]). + */ +private[graphx] abstract class VertexPartitionBaseOps + [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor] + (self: Self[VD]) + extends Logging { + + def withIndex(index: VertexIdToIndexMap): Self[VD] + def withValues[VD2: ClassTag](values: Array[VD2]): Self[VD2] + def withMask(mask: BitSet): Self[VD] + + /** + * Pass each vertex attribute along with the vertex id through a map + * function and retain the original RDD's partitioning and index. + * + * @tparam VD2 the type returned by the map function + * + * @param f the function applied to each vertex id and vertex + * attribute in the RDD + * + * @return a new VertexPartition with values obtained by applying `f` to + * each of the entries in the original VertexRDD. The resulting + * VertexPartition retains the same index. + */ + def map[VD2: ClassTag](f: (VertexId, VD) => VD2): Self[VD2] = { + // Construct a view of the map transformation + val newValues = new Array[VD2](self.capacity) + var i = self.mask.nextSetBit(0) + while (i >= 0) { + newValues(i) = f(self.index.getValue(i), self.values(i)) + i = self.mask.nextSetBit(i + 1) + } + this.withValues(newValues) + } + + /** + * Restrict the vertex set to the set of vertices satisfying the given predicate. + * + * @param pred the user defined predicate + * + * @note The vertex set preserves the original index structure which means that the returned + * RDD can be easily joined with the original vertex-set. Furthermore, the filter only + * modifies the bitmap index and so no new values are allocated. + */ + def filter(pred: (VertexId, VD) => Boolean): Self[VD] = { + // Allocate the array to store the results into + val newMask = new BitSet(self.capacity) + // Iterate over the active bits in the old mask and evaluate the predicate + var i = self.mask.nextSetBit(0) + while (i >= 0) { + if (pred(self.index.getValue(i), self.values(i))) { + newMask.set(i) + } + i = self.mask.nextSetBit(i + 1) + } + this.withMask(newMask) + } + + /** + * Hides vertices that are the same between this and other. For vertices that are different, keeps + * the values from `other`. The indices of `this` and `other` must be the same. + */ + def diff(other: Self[VD]): Self[VD] = { + if (self.index != other.index) { + logWarning("Diffing two VertexPartitions with different indexes is slow.") + diff(createUsingIndex(other.iterator)) + } else { + val newMask = self.mask & other.mask + var i = newMask.nextSetBit(0) + while (i >= 0) { + if (self.values(i) == other.values(i)) { + newMask.unset(i) + } + i = newMask.nextSetBit(i + 1) + } + this.withValues(other.values).withMask(newMask) + } + } + + /** Left outer join another VertexPartition. */ + def leftJoin[VD2: ClassTag, VD3: ClassTag] + (other: Self[VD2]) + (f: (VertexId, VD, Option[VD2]) => VD3): Self[VD3] = { + if (self.index != other.index) { + logWarning("Joining two VertexPartitions with different indexes is slow.") + leftJoin(createUsingIndex(other.iterator))(f) + } else { + val newValues = new Array[VD3](self.capacity) + + var i = self.mask.nextSetBit(0) + while (i >= 0) { + val otherV: Option[VD2] = if (other.mask.get(i)) Some(other.values(i)) else None + newValues(i) = f(self.index.getValue(i), self.values(i), otherV) + i = self.mask.nextSetBit(i + 1) + } + this.withValues(newValues) + } + } + + /** Left outer join another iterator of messages. */ + def leftJoin[VD2: ClassTag, VD3: ClassTag] + (other: Iterator[(VertexId, VD2)]) + (f: (VertexId, VD, Option[VD2]) => VD3): Self[VD3] = { + leftJoin(createUsingIndex(other))(f) + } + + /** Inner join another VertexPartition. */ + def innerJoin[U: ClassTag, VD2: ClassTag] + (other: Self[U]) + (f: (VertexId, VD, U) => VD2): Self[VD2] = { + if (self.index != other.index) { + logWarning("Joining two VertexPartitions with different indexes is slow.") + innerJoin(createUsingIndex(other.iterator))(f) + } else { + val newMask = self.mask & other.mask + val newValues = new Array[VD2](self.capacity) + var i = newMask.nextSetBit(0) + while (i >= 0) { + newValues(i) = f(self.index.getValue(i), self.values(i), other.values(i)) + i = newMask.nextSetBit(i + 1) + } + this.withValues(newValues).withMask(newMask) + } + } + + /** + * Inner join an iterator of messages. + */ + def innerJoin[U: ClassTag, VD2: ClassTag] + (iter: Iterator[Product2[VertexId, U]]) + (f: (VertexId, VD, U) => VD2): Self[VD2] = { + innerJoin(createUsingIndex(iter))(f) + } + + /** + * Similar effect as aggregateUsingIndex((a, b) => a) + */ + def createUsingIndex[VD2: ClassTag](iter: Iterator[Product2[VertexId, VD2]]) + : Self[VD2] = { + val newMask = new BitSet(self.capacity) + val newValues = new Array[VD2](self.capacity) + iter.foreach { pair => + val pos = self.index.getPos(pair._1) + if (pos >= 0) { + newMask.set(pos) + newValues(pos) = pair._2 + } + } + this.withValues(newValues).withMask(newMask) + } + + /** + * Similar to innerJoin, but vertices from the left side that don't appear in iter will remain in + * the partition, hidden by the bitmask. + */ + def innerJoinKeepLeft(iter: Iterator[Product2[VertexId, VD]]): Self[VD] = { + val newMask = new BitSet(self.capacity) + val newValues = new Array[VD](self.capacity) + System.arraycopy(self.values, 0, newValues, 0, newValues.length) + iter.foreach { pair => + val pos = self.index.getPos(pair._1) + if (pos >= 0) { + newMask.set(pos) + newValues(pos) = pair._2 + } + } + this.withValues(newValues).withMask(newMask) + } + + def aggregateUsingIndex[VD2: ClassTag]( + iter: Iterator[Product2[VertexId, VD2]], + reduceFunc: (VD2, VD2) => VD2): Self[VD2] = { + val newMask = new BitSet(self.capacity) + val newValues = new Array[VD2](self.capacity) + iter.foreach { product => + val vid = product._1 + val vdata = product._2 + val pos = self.index.getPos(vid) + if (pos >= 0) { + if (newMask.get(pos)) { + newValues(pos) = reduceFunc(newValues(pos), vdata) + } else { // otherwise just store the new value + newMask.set(pos) + newValues(pos) = vdata + } + } + } + this.withValues(newValues).withMask(newMask) + } + + /** + * Construct a new VertexPartition whose index contains only the vertices in the mask. + */ + def reindex(): Self[VD] = { + val hashMap = new PrimitiveKeyOpenHashMap[VertexId, VD] + val arbitraryMerge = (a: VD, b: VD) => a + for ((k, v) <- self.iterator) { + hashMap.setMerge(k, v, arbitraryMerge) + } + this.withIndex(hashMap.keySet).withValues(hashMap._values).withMask(hashMap.keySet.getBitSet) + } + + /** + * Converts a vertex partition (in particular, one of type `Self`) into a + * `VertexPartitionBaseOps`. Within this class, this allows chaining the methods defined above, + * because these methods return a `Self` and this implicit conversion re-wraps that in a + * `VertexPartitionBaseOps`. This relies on the context bound on `Self`. + */ + private implicit def toOps[VD2: ClassTag]( + partition: Self[VD2]): VertexPartitionBaseOps[VD2, Self] = { + implicitly[VertexPartitionBaseOpsConstructor[Self]].toOps(partition) + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala index d901d4fe225fe..069e042ed94a3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala @@ -55,6 +55,7 @@ object Analytics extends Logging { val conf = new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + .set("spark.locality.wait", "100000") taskType match { case "pagerank" => @@ -62,12 +63,14 @@ object Analytics extends Logging { var outFname = "" var numEPart = 4 var partitionStrategy: Option[PartitionStrategy] = None + var numIterOpt: Option[Int] = None options.foreach{ case ("tol", v) => tol = v.toFloat case ("output", v) => outFname = v case ("numEPart", v) => numEPart = v.toInt case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v)) + case ("numIter", v) => numIterOpt = Some(v.toInt) case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) } @@ -84,7 +87,10 @@ object Analytics extends Logging { println("GRAPHX: Number of vertices " + graph.vertices.count) println("GRAPHX: Number of edges " + graph.edges.count) - val pr = graph.pageRank(tol).vertices.cache() + val pr = (numIterOpt match { + case Some(numIter) => PageRank.run(graph, numIter) + case None => PageRank.runUntilConvergence(graph, tol) + }).vertices.cache() println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_ + _)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 32b5fe4813594..7b9bac5d9c8ea 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -110,7 +110,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val p = 100 val verts = 1 to n val graph = Graph.fromEdgeTuples(sc.parallelize(verts.flatMap(x => - verts.filter(y => y % x == 0).map(y => (x: VertexId, y: VertexId))), p), 0) + verts.withFilter(y => y % x == 0).map(y => (x: VertexId, y: VertexId))), p), 0) assert(graph.edges.partitions.length === p) val partitionedGraph = graph.partitionBy(EdgePartition2D) assert(graph.edges.partitions.length === p) @@ -120,7 +120,13 @@ class GraphSuite extends FunSuite with LocalSparkContext { val part = iter.next()._2 Iterator((part.srcIds ++ part.dstIds).toSet) }.collect - assert(verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) + if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) { + val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound) + val failure = verts.maxBy(id => partitionSets.count(_.contains(id))) + fail(("Replication bound test failed for %d/%d vertices. " + + "Example: vertex %d replicated to %d (> %f) partitions.").format( + numFailures, n, failure, partitionSets.count(_.contains(failure)), bound)) + } // This should not be true for the default hash partitioning val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index e135d1d7ad6a3..d2e0c01bc35ef 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -26,10 +26,16 @@ import org.apache.spark.graphx._ class EdgePartitionSuite extends FunSuite { + def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { + val builder = new EdgePartitionBuilder[A, Int] + for ((src, dst, attr) <- xs) { builder.add(src: VertexId, dst: VertexId, attr) } + builder.toEdgePartition + } + test("reverse") { val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0)) val reversedEdges = List(Edge(0, 2, 0), Edge(1, 0, 0), Edge(2, 1, 0)) - val builder = new EdgePartitionBuilder[Int] + val builder = new EdgePartitionBuilder[Int, Nothing] for (e <- edges) { builder.add(e.srcId, e.dstId, e.attr) } @@ -40,7 +46,7 @@ class EdgePartitionSuite extends FunSuite { test("map") { val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0)) - val builder = new EdgePartitionBuilder[Int] + val builder = new EdgePartitionBuilder[Int, Nothing] for (e <- edges) { builder.add(e.srcId, e.dstId, e.attr) } @@ -49,11 +55,22 @@ class EdgePartitionSuite extends FunSuite { edges.map(e => e.copy(attr = e.srcId + e.dstId))) } + test("filter") { + val edges = List(Edge(0, 1, 0), Edge(0, 2, 0), Edge(2, 0, 0)) + val builder = new EdgePartitionBuilder[Int, Int] + for (e <- edges) { + builder.add(e.srcId, e.dstId, e.attr) + } + val edgePartition = builder.toEdgePartition + val filtered = edgePartition.filter(et => et.srcId == 0, (vid, attr) => vid == 0 || vid == 1) + assert(filtered.tripletIterator().toList.map(et => (et.srcId, et.dstId)) === List((0L, 1L))) + } + test("groupEdges") { val edges = List( Edge(0, 1, 1), Edge(1, 2, 2), Edge(2, 0, 4), Edge(0, 1, 8), Edge(1, 2, 16), Edge(2, 0, 32)) val groupedEdges = List(Edge(0, 1, 9), Edge(1, 2, 18), Edge(2, 0, 36)) - val builder = new EdgePartitionBuilder[Int] + val builder = new EdgePartitionBuilder[Int, Nothing] for (e <- edges) { builder.add(e.srcId, e.dstId, e.attr) } @@ -61,11 +78,19 @@ class EdgePartitionSuite extends FunSuite { assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges) } + test("upgradeIterator") { + val edges = List((0, 1, 0), (1, 0, 0)) + val verts = List((0L, 1), (1L, 2)) + val part = makeEdgePartition(edges).updateVertices(verts.iterator) + assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList === + part.tripletIterator().toList.map(_.toTuple)) + } + test("indexIterator") { val edgesFrom0 = List(Edge(0, 1, 0)) val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0)) val sortedEdges = edgesFrom0 ++ edgesFrom1 - val builder = new EdgePartitionBuilder[Int] + val builder = new EdgePartitionBuilder[Int, Nothing] for (e <- Random.shuffle(sortedEdges)) { builder.add(e.srcId, e.dstId, e.attr) } @@ -77,11 +102,6 @@ class EdgePartitionSuite extends FunSuite { } test("innerJoin") { - def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A] = { - val builder = new EdgePartitionBuilder[A] - for ((src, dst, attr) <- xs) { builder.add(src: VertexId, dst: VertexId, attr) } - builder.toEdgePartition - } val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0)) val a = makeEdgePartition(aList) @@ -90,4 +110,14 @@ class EdgePartitionSuite extends FunSuite { assert(a.innerJoin(b) { (src, dst, a, b) => a }.iterator.map(_.copy()).toList === List(Edge(0, 1, 0), Edge(1, 0, 0), Edge(5, 5, 0))) } + + test("isActive, numActives, replaceActives") { + val ep = new EdgePartitionBuilder[Nothing, Nothing].toEdgePartition + .withActiveSet(Iterator(0L, 2L, 0L)) + assert(ep.isActive(0)) + assert(!ep.isActive(1)) + assert(ep.isActive(2)) + assert(!ep.isActive(-1)) + assert(ep.numActives == Some(2)) + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala index 9cbb2d2acdc2d..49b2704390fea 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala @@ -26,17 +26,11 @@ import org.apache.spark.graphx._ class EdgeTripletIteratorSuite extends FunSuite { test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int] + val builder = new EdgePartitionBuilder[Int, Int] builder.add(1, 2, 0) builder.add(1, 3, 0) builder.add(1, 4, 0) - val vidmap = new VertexIdToIndexMap - vidmap.add(1) - vidmap.add(2) - vidmap.add(3) - vidmap.add(4) - val vs = Array.fill(vidmap.capacity)(0) - val iter = new EdgeTripletIterator[Int, Int](vidmap, vs, builder.toEdgePartition) + val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) val result = iter.toList.map(et => (et.srcId, et.dstId)) assert(result === Seq((1, 2), (1, 3), (1, 4))) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index a048d13fd12b8..8bf1384d514c1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -30,17 +30,6 @@ class VertexPartitionSuite extends FunSuite { assert(!vp.isDefined(-1)) } - test("isActive, numActives, replaceActives") { - val vp = VertexPartition(Iterator((0L, 1), (1L, 1))) - .filter { (vid, attr) => vid == 0 } - .replaceActives(Iterator(0, 2, 0)) - assert(vp.isActive(0)) - assert(!vp.isActive(1)) - assert(vp.isActive(2)) - assert(!vp.isActive(-1)) - assert(vp.numActives == Some(2)) - } - test("map") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).map { (vid, attr) => 2 } assert(vp(0) === 2) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index efdb38e907d14..fafc9b36a77d3 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -76,6 +76,8 @@ object MimaBuild { excludeSparkClass("util.XORShiftRandom") ++ excludeSparkClass("graphx.EdgeRDD") ++ excludeSparkClass("graphx.VertexRDD") ++ + excludeSparkClass("graphx.impl.GraphImpl") ++ + excludeSparkClass("graphx.impl.RoutingTable") ++ excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ excludeSparkClass("mllib.optimization.SquaredGradient") ++ excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ From 2b7bd29eb6ee5baf739eec143044ecfc296b9b1f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 10 May 2014 20:50:40 -0700 Subject: [PATCH 26/33] SPARK-1789. Multiple versions of Netty dependencies cause FlumeStreamSuite failure TL;DR is there is a bit of JAR hell trouble with Netty, that can be mostly resolved and will resolve a test failure. I hit the error described at http://apache-spark-user-list.1001560.n3.nabble.com/SparkContext-startup-time-out-td1753.html while running FlumeStreamingSuite, and have for a short while (is it just me?) velvia notes: "I have found a workaround. If you add akka 2.2.4 to your dependencies, then everything works, probably because akka 2.2.4 brings in newer version of Jetty." There are at least 3 versions of Netty in play in the build: - the new Flume 1.4.0 dependency brings in io.netty:netty:3.4.0.Final, and that is the immediate problem - the custom version of akka 2.2.3 depends on io.netty:netty:3.6.6. - but, Spark Core directly uses io.netty:netty-all:4.0.17.Final The POMs try to exclude other versions of netty, but are excluding org.jboss.netty:netty, when in fact older versions of io.netty:netty (not netty-all) are also an issue. The org.jboss.netty:netty excludes are largely unnecessary. I replaced many of them with io.netty:netty exclusions until everything agreed on io.netty:netty-all:4.0.17.Final. But this didn't work, since Akka 2.2.3 doesn't work with Netty 4.x. Down-grading to 3.6.6.Final across the board made some Spark code not compile. If the build *keeps* io.netty:netty:3.6.6.Final as well, everything seems to work. Part of the reason seems to be that Netty 3.x used the old `org.jboss.netty` packages. This is less than ideal, but is no worse than the current situation. So this PR resolves the issue and improves the JAR hell, even if it leaves the existing theoretical Netty 3-vs-4 conflict: - Remove org.jboss.netty excludes where possible, for clarity; they're not needed except with Hadoop artifacts - Add io.netty:netty excludes where needed -- except, let akka keep its io.netty:netty - Change a bit of test code that actually depended on Netty 3.x, to use 4.x equivalent - Update SBT build accordingly A better change would be to update Akka far enough such that it agrees on Netty 4.x, but I don't know if that's feasible. Author: Sean Owen Closes #723 from srowen/SPARK-1789 and squashes the following commits: 43661b7 [Sean Owen] Update and add Netty excludes to prevent some JAR conflicts that cause test issues --- .../org/apache/spark/LocalSparkContext.scala | 3 +- examples/pom.xml | 4 +++ external/flume/pom.xml | 2 +- external/mqtt/pom.xml | 6 ---- external/twitter/pom.xml | 6 ---- external/zeromq/pom.xml | 6 ---- pom.xml | 32 ----------------- project/SparkBuild.scala | 35 ++++++++++--------- 8 files changed, 24 insertions(+), 70 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 4b972f88a9542..53e367a61715b 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -17,8 +17,7 @@ package org.apache.spark -import org.jboss.netty.logging.InternalLoggerFactory -import org.jboss.netty.logging.Slf4JLoggerFactory +import _root_.io.netty.util.internal.logging.{Slf4JLoggerFactory, InternalLoggerFactory} import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite diff --git a/examples/pom.xml b/examples/pom.xml index e1fc149d87f17..874bcd7916f35 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -106,6 +106,10 @@ org.jboss.netty netty + + io.netty + netty + commons-logging commons-logging diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 03d3b2394f510..6aec215687fe0 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -50,7 +50,7 @@ 1.4.0 - org.jboss.netty + io.netty netty diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 9aa1c1a9f5b80..7b2dc5ba1d7f9 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -53,12 +53,6 @@ ${akka.group} akka-zeromq_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - org.scalatest diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index a443459594710..5766d3a0d44ec 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -48,12 +48,6 @@ org.twitter4j twitter4j-stream 3.0.3 - - - org.jboss.netty - netty - - org.scalatest diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a40e55876e640..4ed4196bd8662 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -48,12 +48,6 @@ ${akka.group} akka-zeromq_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - org.scalatest diff --git a/pom.xml b/pom.xml index e0bff60a54cde..c4e1c6be52a1b 100644 --- a/pom.xml +++ b/pom.xml @@ -324,45 +324,21 @@ ${akka.group} akka-actor_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - ${akka.group} akka-remote_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - ${akka.group} akka-slf4j_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - ${akka.group} akka-testkit_${scala.binary.version} ${akka.version} - - - org.jboss.netty - netty - - colt @@ -513,10 +489,6 @@ avro ${avro.version} - - org.jboss.netty - netty - io.netty netty @@ -551,10 +523,6 @@ avro-mapred ${avro.version} - - org.jboss.netty - netty - io.netty netty diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index af882b3ea7beb..a12c61853e410 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -303,7 +303,8 @@ object SparkBuild extends Build { val parquetVersion = "1.4.3" val slf4jVersion = "1.7.5" - val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeJBossNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeIONetty = ExclusionRule(organization = "io.netty") val excludeEclipseJetty = ExclusionRule(organization = "org.eclipse.jetty") val excludeAsm = ExclusionRule(organization = "org.ow2.asm") val excludeOldAsm = ExclusionRule(organization = "asm") @@ -337,8 +338,8 @@ object SparkBuild extends Build { "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 "com.ning" % "compress-lzf" % "1.0.0", "org.xerial.snappy" % "snappy-java" % "1.0.5", - "org.spark-project.akka" %% "akka-remote" % akkaVersion excludeAll(excludeNetty), - "org.spark-project.akka" %% "akka-slf4j" % akkaVersion excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-remote" % akkaVersion, + "org.spark-project.akka" %% "akka-slf4j" % akkaVersion, "org.spark-project.akka" %% "akka-testkit" % akkaVersion % "test", "org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap), "colt" % "colt" % "1.2.0", @@ -346,8 +347,8 @@ object SparkBuild extends Build { "commons-net" % "commons-net" % "2.2", "net.java.dev.jets3t" % "jets3t" % jets3tVersion excludeAll(excludeCommonsLogging), "org.apache.derby" % "derby" % "10.4.2.0" % "test", - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm), - "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeNetty), + "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm), + "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeJBossNetty), "com.codahale.metrics" % "metrics-core" % codahaleMetricsVersion, "com.codahale.metrics" % "metrics-jvm" % codahaleMetricsVersion, "com.codahale.metrics" % "metrics-json" % codahaleMetricsVersion, @@ -421,7 +422,7 @@ object SparkBuild extends Build { v => "spark-examples-" + v + "-hadoop" + hadoopVersion + ".jar" }, libraryDependencies ++= Seq( "com.twitter" %% "algebird-core" % "0.1.11", - "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeJruby), + "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeIONetty, excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeJruby), "org.apache.cassandra" % "cassandra-all" % "1.2.6" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") @@ -429,7 +430,7 @@ object SparkBuild extends Build { exclude("io.netty", "netty") exclude("jline","jline") exclude("org.apache.cassandra.deps", "avro") - excludeAll(excludeSLF4J), + excludeAll(excludeSLF4J, excludeIONetty), "com.github.scopt" %% "scopt" % "3.2.0" ) ) ++ assemblySettings ++ extraAssemblySettings @@ -561,11 +562,11 @@ object SparkBuild extends Build { def yarnEnabledSettings = Seq( libraryDependencies ++= Seq( // Exclude rule required for all ? - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), - "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), - "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), - "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), - "org.apache.hadoop" % "hadoop-yarn-server-web-proxy" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm) + "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-server-web-proxy" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm) ) ) @@ -593,7 +594,7 @@ object SparkBuild extends Build { name := "spark-streaming-twitter", previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"), libraryDependencies ++= Seq( - "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty) + "org.twitter4j" % "twitter4j-stream" % "3.0.3" ) ) @@ -601,12 +602,12 @@ object SparkBuild extends Build { name := "spark-streaming-kafka", previousArtifact := sparkPreviousArtifact("spark-streaming-kafka"), libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), + "com.github.sgroschupf" % "zkclient" % "0.1", "org.apache.kafka" %% "kafka" % "0.8.0" exclude("com.sun.jdmk", "jmxtools") exclude("com.sun.jmx", "jmxri") exclude("net.sf.jopt-simple", "jopt-simple") - excludeAll(excludeNetty, excludeSLF4J) + excludeAll(excludeSLF4J) ) ) @@ -614,7 +615,7 @@ object SparkBuild extends Build { name := "spark-streaming-flume", previousArtifact := sparkPreviousArtifact("spark-streaming-flume"), libraryDependencies ++= Seq( - "org.apache.flume" % "flume-ng-sdk" % "1.4.0" % "compile" excludeAll(excludeNetty, excludeThrift) + "org.apache.flume" % "flume-ng-sdk" % "1.4.0" % "compile" excludeAll(excludeIONetty, excludeThrift) ) ) @@ -622,7 +623,7 @@ object SparkBuild extends Build { name := "spark-streaming-zeromq", previousArtifact := sparkPreviousArtifact("spark-streaming-zeromq"), libraryDependencies ++= Seq( - "org.spark-project.akka" %% "akka-zeromq" % akkaVersion excludeAll(excludeNetty) + "org.spark-project.akka" %% "akka-zeromq" % akkaVersion ) ) From 83e0424d87022e7a967088365931a08aa06ffd9f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 10 May 2014 20:58:02 -0700 Subject: [PATCH 27/33] [SPARK-1774] Respect SparkSubmit --jars on YARN (client) SparkSubmit ignores `--jars` for YARN client. This is a bug. This PR also automatically adds the application jar to `spark.jar`. Previously, when running as yarn-client, you must specify the jar additionally through `--files` (because `--jars` didn't work). Now you don't have to explicitly specify it through either. Tested on a YARN cluster. Author: Andrew Or Closes #710 from andrewor14/yarn-jars and squashes the following commits: 35d1928 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-jars c27bf6c [Andrew Or] For yarn-cluster and python, do not add primaryResource to spark.jar c92c5bf [Andrew Or] Minor cleanups 269f9f3 [Andrew Or] Fix format 013d840 [Andrew Or] Fix tests 1407474 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-jars 3bb75e8 [Andrew Or] Allow SparkSubmit --jars to take effect in yarn-client mode --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 39 ++++--- .../spark/deploy/SparkSubmitSuite.scala | 110 ++++++++++++------ .../spark/deploy/yarn/ClientArguments.scala | 4 +- 4 files changed, 102 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c639b3e15ded5..71bab295442fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -917,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging { if (SparkHadoopUtil.get.isYarnMode() && (master == "yarn-standalone" || master == "yarn-cluster")) { // In order for this to work in yarn-cluster mode the user must specify the - // --addjars option to the client to upload the file into the distributed cache + // --addJars option to the client to upload the file into the distributed cache // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 16de6f7cdb100..c6d3cbd2e728b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -67,8 +67,7 @@ object SparkSubmit { private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) /** - * @return - * a tuple containing the arguments for the child, a list of classpath + * @return a tuple containing the arguments for the child, a list of classpath * entries for the child, a list of system propertes, a list of env vars * and the main class for the child */ @@ -115,13 +114,16 @@ object SparkSubmit { val sysProps = new HashMap[String, String]() var childMainClass = "" + val isPython = args.isPython + val isYarnCluster = clusterManager == YARN && deployOnCluster + if (clusterManager == MESOS && deployOnCluster) { printErrorAndExit("Cannot currently run driver on the cluster in Mesos") } // If we're running a Python app, set the Java class to run to be our PythonRunner, add // Python files to deployment list, and pass the main file and Python path to PythonRunner - if (args.isPython) { + if (isPython) { if (deployOnCluster) { printErrorAndExit("Cannot currently run Python driver programs on cluster") } @@ -161,6 +163,7 @@ object SparkSubmit { val options = List[OptionAssigner]( OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"), + OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true, @@ -168,7 +171,8 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true, sysProp = "spark.driver.extraLibraryPath"), OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"), - OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.driverMemory, STANDALONE, true, clOption = "--memory"), + OptionAssigner(args.driverCores, STANDALONE, true, clOption = "--cores"), OptionAssigner(args.queue, YARN, true, clOption = "--queue"), OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"), @@ -176,20 +180,18 @@ object SparkSubmit { OptionAssigner(args.executorMemory, YARN, true, clOption = "--executor-memory"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, false, sysProp = "spark.executor.memory"), - OptionAssigner(args.driverMemory, STANDALONE, true, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, true, clOption = "--cores"), OptionAssigner(args.executorCores, YARN, true, clOption = "--executor-cores"), OptionAssigner(args.executorCores, YARN, false, sysProp = "spark.executor.cores"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, false, sysProp = "spark.cores.max"), OptionAssigner(args.files, YARN, false, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.files, YARN, true, clOption = "--files"), + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), OptionAssigner(args.archives, YARN, false, sysProp = "spark.yarn.dist.archives"), OptionAssigner(args.archives, YARN, true, clOption = "--archives"), OptionAssigner(args.jars, YARN, true, clOption = "--addJars"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.jars") + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, false, sysProp = "spark.jars") ) // For client mode make any added jars immediately visible on the classpath @@ -212,9 +214,10 @@ object SparkSubmit { } } - // For standalone mode, add the application jar automatically so the user doesn't have to - // call sc.addJar. TODO: Standalone mode in the cluster - if (clusterManager == STANDALONE) { + // Add the application jar automatically so the user doesn't have to call sc.addJar + // For YARN cluster mode, the jar is already distributed on each node as "app.jar" + // For python files, the primary resource is already distributed as a regular file + if (!isYarnCluster && !isPython) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq()) if (args.primaryResource != RESERVED_JAR_NAME) { jars = jars ++ Seq(args.primaryResource) @@ -222,11 +225,11 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } + // Standalone cluster specific configurations if (deployOnCluster && clusterManager == STANDALONE) { if (args.supervise) { childArgs += "--supervise" } - childMainClass = "org.apache.spark.deploy.Client" childArgs += "launch" childArgs += (args.master, args.primaryResource, args.mainClass) @@ -243,6 +246,7 @@ object SparkSubmit { } } + // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { if (!sysProps.contains(k)) sysProps(k) = v } @@ -250,9 +254,12 @@ object SparkSubmit { (childArgs, childClasspath, sysProps, childMainClass) } - private def launch(childArgs: ArrayBuffer[String], childClasspath: ArrayBuffer[String], - sysProps: Map[String, String], childMainClass: String, verbose: Boolean = false) - { + private def launch( + childArgs: ArrayBuffer[String], + childClasspath: ArrayBuffer[String], + sysProps: Map[String, String], + childMainClass: String, + verbose: Boolean = false) { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index c9edb03cdeb0f..6c0deede53784 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -87,25 +87,41 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("handles arguments with --key=val") { - val clArgs = Seq("--jars=one.jar,two.jar,three.jar", "--name=myApp") + val clArgs = Seq( + "--jars=one.jar,two.jar,three.jar", + "--name=myApp") val appArgs = new SparkSubmitArguments(clArgs) appArgs.jars should be ("one.jar,two.jar,three.jar") appArgs.name should be ("myApp") } test("handles arguments to user program") { - val clArgs = Seq("--name", "myApp", "--class", "Foo", "userjar.jar", "some", "--weird", "args") + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "userjar.jar", + "some", + "--weird", "args") val appArgs = new SparkSubmitArguments(clArgs) appArgs.childArgs should be (Seq("some", "--weird", "args")) } test("handles YARN cluster mode") { - val clArgs = Seq("--deploy-mode", "cluster", - "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", - "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", - "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "thejar.jar", "arg1", "arg2") + val clArgs = Seq( + "--deploy-mode", "cluster", + "--master", "yarn", + "--executor-memory", "5g", + "--executor-cores", "5", + "--class", "org.SomeClass", + "--jars", "one.jar,two.jar,three.jar", + "--driver-memory", "4g", + "--queue", "thequeue", + "--files", "file1.txt,file2.txt", + "--archives", "archive1.txt,archive2.txt", + "--num-executors", "6", + "--name", "beauty", + "thejar.jar", + "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) val childArgsStr = childArgs.mkString(" ") @@ -127,12 +143,21 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("handles YARN client mode") { - val clArgs = Seq("--deploy-mode", "client", - "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", - "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", - "--driver-memory", "4g", "--queue", "thequeue", "--files", "file1.txt,file2.txt", - "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "thejar.jar", "arg1", "arg2") + val clArgs = Seq( + "--deploy-mode", "client", + "--master", "yarn", + "--executor-memory", "5g", + "--executor-cores", "5", + "--class", "org.SomeClass", + "--jars", "one.jar,two.jar,three.jar", + "--driver-memory", "4g", + "--queue", "thequeue", + "--files", "file1.txt,file2.txt", + "--archives", "archive1.txt,archive2.txt", + "--num-executors", "6", + "--name", "trill", + "thejar.jar", + "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") @@ -142,6 +167,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { classpath should contain ("two.jar") classpath should contain ("three.jar") sysProps("spark.app.name") should be ("trill") + sysProps("spark.jars") should be ("one.jar,two.jar,three.jar,thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.executor.cores") should be ("5") sysProps("spark.yarn.queue") should be ("thequeue") @@ -152,9 +178,15 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("handles standalone cluster mode") { - val clArgs = Seq("--deploy-mode", "cluster", - "--master", "spark://h:p", "--class", "org.SomeClass", - "--supervise", "--driver-memory", "4g", "--driver-cores", "5", "thejar.jar", "arg1", "arg2") + val clArgs = Seq( + "--deploy-mode", "cluster", + "--master", "spark://h:p", + "--class", "org.SomeClass", + "--supervise", + "--driver-memory", "4g", + "--driver-cores", "5", + "thejar.jar", + "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) val childArgsStr = childArgs.mkString(" ") @@ -166,9 +198,15 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("handles standalone client mode") { - val clArgs = Seq("--deploy-mode", "client", - "--master", "spark://h:p", "--executor-memory", "5g", "--total-executor-cores", "5", - "--class", "org.SomeClass", "--driver-memory", "4g", "thejar.jar", "arg1", "arg2") + val clArgs = Seq( + "--deploy-mode", "client", + "--master", "spark://h:p", + "--executor-memory", "5g", + "--total-executor-cores", "5", + "--class", "org.SomeClass", + "--driver-memory", "4g", + "thejar.jar", + "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") @@ -179,9 +217,15 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("handles mesos client mode") { - val clArgs = Seq("--deploy-mode", "client", - "--master", "mesos://h:p", "--executor-memory", "5g", "--total-executor-cores", "5", - "--class", "org.SomeClass", "--driver-memory", "4g", "thejar.jar", "arg1", "arg2") + val clArgs = Seq( + "--deploy-mode", "client", + "--master", "mesos://h:p", + "--executor-memory", "5g", + "--total-executor-cores", "5", + "--class", "org.SomeClass", + "--driver-memory", "4g", + "thejar.jar", + "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") @@ -192,15 +236,17 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { } test("launch simple application with spark-submit") { - runSparkSubmit( - Seq( - "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), - "--name", "testApp", - "--master", "local", - "unUsed.jar")) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + unusedJar.toString) + runSparkSubmit(args) } test("spark submit includes jars passed in through --jar") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") @@ -209,7 +255,7 @@ class SparkSubmitSuite extends FunSuite with ShouldMatchers { "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, - "unused.jar") + unusedJar.toString) runSparkSubmit(args) } @@ -227,7 +273,7 @@ object JarCreationTest { def main(args: Array[String]) { val conf = new SparkConf() val sc = new SparkContext(conf) - val result = sc.makeRDD(1 to 100, 10).mapPartitions{ x => + val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => var foundClasses = false try { Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) @@ -248,7 +294,6 @@ object SimpleApplicationTest { def main(args: Array[String]) { val conf = new SparkConf() val sc = new SparkContext(conf) - val configs = Seq("spark.master", "spark.app.name") for (config <- configs) { val masterValue = conf.get(config) @@ -266,6 +311,5 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } - } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 3e4c739e34fe9..b2c413b6d267c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.SparkConf -import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo} +import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.util.IntParam import org.apache.spark.util.MemoryParam @@ -40,9 +40,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { var amMemory: Int = 512 // MB var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" - // TODO var inputFormatInfo: List[InputFormatInfo] = null - // TODO(harvey) var priority = 0 parseArgs(args.toList) From 70bcdef48a051028598d380d41dfce1c9bfb2b9b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sat, 10 May 2014 21:08:04 -0700 Subject: [PATCH 28/33] Enabled incremental build that comes with sbt 0.13.2 More info at. https://github.com/sbt/sbt/issues/1010 Author: Prashant Sharma Closes #525 from ScrapCodes/sbt-inc-opt and squashes the following commits: ba8fa42 [Prashant Sharma] Enabled incremental build that comes with sbt 0.13.2 --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a12c61853e410..12791e490ae2b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -176,7 +176,7 @@ object SparkBuild extends Build { retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), - + incOptions := incOptions.value.withNameHashing(true), // Fork new JVMs for tests and set Java options for those fork := true, javaOptions in Test += "-Dspark.home=" + sparkHome, From 6bee01dd04ef73c6b829110ebcdd622d521ea8ff Mon Sep 17 00:00:00 2001 From: witgo Date: Sun, 11 May 2014 14:34:27 -0700 Subject: [PATCH 29/33] remove outdated runtime Information scala home Author: witgo Closes #728 from witgo/scala_home and squashes the following commits: cdfd8be [witgo] Merge branch 'master' of https://github.com/apache/spark into scala_home fac094a [witgo] remove outdated runtime Information scala home --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 806e77d98fc5f..19d507c0cf860 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -281,8 +281,7 @@ object SparkEnv extends Logging { val jvmInformation = Seq( ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), ("Java Home", Properties.javaHome), - ("Scala Version", Properties.versionString), - ("Scala Home", Properties.scalaHome) + ("Scala Version", Properties.versionString) ).sorted // Spark properties From 7d9cc9214bd06495f6838e355331dd2b5f1f7407 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 11 May 2014 17:11:55 -0700 Subject: [PATCH 30/33] SPARK-1770: Load balance elements when repartitioning. This patch adds better balancing when performing a repartition of an RDD. Previously the elements in the RDD were hash partitioned, meaning if the RDD was skewed certain partitions would end up being very large. This commit adds load balancing of elements across the repartitioned RDD splits. The load balancing is not perfect: a given output partition can have up to N more elements than the average if there are N input partitions. However, some randomization is used to minimize the probabiliy that this happens. Author: Patrick Wendell Closes #727 from pwendell/load-balance and squashes the following commits: f9da752 [Patrick Wendell] Response to Matei's feedback acfa46a [Patrick Wendell] SPARK-1770: Load balance elements when repartitioning. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 15 +++++++-- .../scala/org/apache/spark/rdd/RDDSuite.scala | 33 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a1ca612cc9a09..aa03e9276fb34 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -328,11 +328,22 @@ abstract class RDD[T: ClassTag]( def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null) : RDD[T] = { if (shuffle) { + /** Distributes elements evenly across output partitions, starting from a random partition. */ + def distributePartition(index: Int, items: Iterator[T]): Iterator[(Int, T)] = { + var position = (new Random(index)).nextInt(numPartitions) + items.map { t => + // Note that the hash code of the key will just be the key itself. The HashPartitioner + // will mod it with the number of total partitions. + position = position + 1 + (position, t) + } + } + // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), + new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), - numPartitions).keys + numPartitions).values } else { new CoalescedRDD(this, numPartitions) } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8da9a0da700e0..e686068f7a99a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -202,6 +202,39 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(repartitioned2.collect().toSet === (1 to 1000).toSet) } + test("repartitioned RDDs perform load balancing") { + // Coalesce partitions + val input = Array.fill(1000)(1) + val initialPartitions = 10 + val data = sc.parallelize(input, initialPartitions) + + val repartitioned1 = data.repartition(2) + assert(repartitioned1.partitions.size == 2) + val partitions1 = repartitioned1.glom().collect() + // some noise in balancing is allowed due to randomization + assert(math.abs(partitions1(0).length - 500) < initialPartitions) + assert(math.abs(partitions1(1).length - 500) < initialPartitions) + assert(repartitioned1.collect() === input) + + def testSplitPartitions(input: Seq[Int], initialPartitions: Int, finalPartitions: Int) { + val data = sc.parallelize(input, initialPartitions) + val repartitioned = data.repartition(finalPartitions) + assert(repartitioned.partitions.size === finalPartitions) + val partitions = repartitioned.glom().collect() + // assert all elements are present + assert(repartitioned.collect().sortWith(_ > _).toSeq === input.toSeq.sortWith(_ > _).toSeq) + // assert no bucket is overloaded + for (partition <- partitions) { + val avg = input.size / finalPartitions + val maxPossible = avg + initialPartitions + assert(partition.length <= maxPossible) + } + } + + testSplitPartitions(Array.fill(100)(1), 10, 20) + testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100) + } + test("coalesced RDDs") { val data = sc.parallelize(1 to 10, 10) From 05c9aa9eb1b7f13cd40bbca23e6bc7e1d20e91cd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 11 May 2014 18:17:34 -0700 Subject: [PATCH 31/33] SPARK-1652: Set driver memory correctly in spark-submit. The previous check didn't account for the fact that the default deploy mode is "client" unless otherwise specified. Also, this sets the more narrowly defined SPARK_DRIVER_MEMORY instead of setting SPARK_MEM. Author: Patrick Wendell Closes #730 from pwendell/spark-submit and squashes the following commits: 430b98f [Patrick Wendell] Feedback from Aaron e788edf [Patrick Wendell] Changes based on Aaron's feedback f508146 [Patrick Wendell] SPARK-1652: Set driver memory correctly in spark-submit. --- bin/spark-submit | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bin/spark-submit b/bin/spark-submit index 49bc26252cadf..63903b17a2902 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -35,8 +35,10 @@ while (($#)); do shift done -if [ ! -z $DRIVER_MEMORY ] && [ ! -z $DEPLOY_MODE ] && [ $DEPLOY_MODE = "client" ]; then - export SPARK_MEM=$DRIVER_MEMORY +DEPLOY_MODE=${DEPLOY_MODE:-"client"} + +if [ -n "$DRIVER_MEMORY" ] && [ $DEPLOY_MODE == "client" ]; then + export SPARK_DRIVER_MEMORY=$DRIVER_MEMORY fi $SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" From f938a155b2a9c126b292d5403aca31de83d5105a Mon Sep 17 00:00:00 2001 From: "Joseph E. Gonzalez" Date: Sun, 11 May 2014 18:33:46 -0700 Subject: [PATCH 32/33] Fix error in 2d Graph Partitioner Their was a minor bug in which negative partition ids could be generated when constructing a 2D partitioning of a graph. This could lead to an inefficient 2D partition for large vertex id values. Author: Joseph E. Gonzalez Closes #709 from jegonzal/fix_2d_partitioning and squashes the following commits: 937c562 [Joseph E. Gonzalez] fixing bug in 2d partitioning algorithm where negative partition ids could be generated. --- .../scala/org/apache/spark/graphx/PartitionStrategy.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 0470d74cf9efe..1526ccef06fd4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -78,8 +78,8 @@ object PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: VertexId = 1125899906842597L - val col: PartitionID = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt - val row: PartitionID = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt + val col: PartitionID = (math.abs(src * mixingPrime) % ceilSqrtNumParts).toInt + val row: PartitionID = (math.abs(dst * mixingPrime) % ceilSqrtNumParts).toInt (col * ceilSqrtNumParts + row) % numParts } } From a6b02fb7486356493474c7f42bb714c9cce215ca Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sun, 11 May 2014 19:20:42 -0700 Subject: [PATCH 33/33] SPARK-1786: Edge Partition Serialization This appears to address the issue with edge partition serialization. The solution appears to be just registering the `PrimitiveKeyOpenHashMap`. However I noticed that we appear to have forked that code in GraphX but retained the same name (which is confusing). I also renamed our local copy to `GraphXPrimitiveKeyOpenHashMap`. We should consider dropping that and using the one in Spark if possible. Author: Ankur Dave Author: Joseph E. Gonzalez Closes #724 from jegonzal/edge_partition_serialization and squashes the following commits: b0a525a [Ankur Dave] Disable reference tracking to fix serialization test bb7f548 [Ankur Dave] Add failing test for EdgePartition Kryo serialization 67dac22 [Joseph E. Gonzalez] Making EdgePartition serializable. --- .../spark/graphx/GraphKryoRegistrator.scala | 9 ++++++--- .../spark/graphx/impl/EdgePartition.scala | 14 +++++++------- .../graphx/impl/EdgePartitionBuilder.scala | 4 ++-- .../graphx/impl/EdgeTripletIterator.scala | 2 +- .../graphx/impl/RoutingTablePartition.scala | 4 ++-- .../graphx/impl/ShippableVertexPartition.scala | 2 +- .../spark/graphx/impl/VertexPartition.scala | 2 +- .../graphx/impl/VertexPartitionBase.scala | 6 +++--- .../graphx/impl/VertexPartitionBaseOps.scala | 4 ++-- ...ala => GraphXPrimitiveKeyOpenHashMap.scala} | 2 +- .../spark/graphx/impl/EdgePartitionSuite.scala | 18 ++++++++++++++++++ 11 files changed, 44 insertions(+), 23 deletions(-) rename graphx/src/main/scala/org/apache/spark/graphx/util/collection/{PrimitiveKeyOpenHashMap.scala => GraphXPrimitiveKeyOpenHashMap.scala} (98%) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index d295d0127ac72..f97f329c0e832 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -24,6 +24,9 @@ import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.BitSet import org.apache.spark.graphx.impl._ +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.OpenHashSet + /** * Registers GraphX classes with Kryo for improved performance. @@ -43,8 +46,8 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[PartitionStrategy]) kryo.register(classOf[BoundedPriorityQueue[Object]]) kryo.register(classOf[EdgeDirection]) - - // This avoids a large number of hash table lookups. - kryo.setReferences(false) + kryo.register(classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]]) + kryo.register(classOf[OpenHashSet[Int]]) + kryo.register(classOf[OpenHashSet[Long]]) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 871e81f8d245c..a5c9cd1f8b4e6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -20,7 +20,7 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap /** * A collection of edges stored in columnar format, along with any vertex attributes referenced. The @@ -42,12 +42,12 @@ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - @transient val srcIds: Array[VertexId], - @transient val dstIds: Array[VertexId], - @transient val data: Array[ED], - @transient val index: PrimitiveKeyOpenHashMap[VertexId, Int], - @transient val vertices: VertexPartition[VD], - @transient val activeSet: Option[VertexSet] = None + val srcIds: Array[VertexId] = null, + val dstIds: Array[VertexId] = null, + val data: Array[ED] = null, + val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, + val vertices: VertexPartition[VD] = null, + val activeSet: Option[VertexSet] = None ) extends Serializable { /** Return a new `EdgePartition` with the specified edge data. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index ecb49bef42e45..4520beb991515 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -23,7 +23,7 @@ import scala.util.Sorting import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( @@ -41,7 +41,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla val srcIds = new Array[VertexId](edgeArray.size) val dstIds = new Array[VertexId](edgeArray.size) val data = new Array[ED](edgeArray.size) - val index = new PrimitiveKeyOpenHashMap[VertexId, Int] + val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala index ebb0b9418d65d..56f79a7097fce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap /** * The Iterator type returned when constructing edge triplets. This could be an anonymous class in diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 927e32ad0f448..d02e9238adba5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.util.collection.{BitSet, PrimitiveVector} import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap /** * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that @@ -69,7 +69,7 @@ object RoutingTablePartition { : Iterator[RoutingTableMessage] = { // Determine which positions each vertex id appears in using a map where the low 2 bits // represent src and dst - val map = new PrimitiveKeyOpenHashMap[VertexId, Byte] + val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte] edgePartition.srcIds.iterator.foreach { srcId => map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index f4e221d4e05ae..dca54b8a7da86 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.util.collection.{BitSet, PrimitiveVector} import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap /** Stores vertex attributes to ship to an edge partition. */ private[graphx] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index f1d174720a1ba..55c7a19d1bdab 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.util.collection.BitSet import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] object VertexPartition { /** Construct a `VertexPartition` from the given vertices. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala index 8d9e0204d27f2..34939b24440aa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.util.collection.BitSet import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] object VertexPartitionBase { /** @@ -32,7 +32,7 @@ private[graphx] object VertexPartitionBase { */ def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)]) : (VertexIdToIndexMap, Array[VD], BitSet) = { - val map = new PrimitiveKeyOpenHashMap[VertexId, VD] + val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] iter.foreach { pair => map(pair._1) = pair._2 } @@ -45,7 +45,7 @@ private[graphx] object VertexPartitionBase { */ def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)], mergeFunc: (VD, VD) => VD) : (VertexIdToIndexMap, Array[VD], BitSet) = { - val map = new PrimitiveKeyOpenHashMap[VertexId, VD] + val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] iter.foreach { pair => map.setMerge(pair._1, pair._2, mergeFunc) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 21ff615feca6c..a4f769b294010 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging import org.apache.spark.util.collection.BitSet import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap /** * An class containing additional operations for subclasses of VertexPartitionBase that provide @@ -224,7 +224,7 @@ private[graphx] abstract class VertexPartitionBaseOps * Construct a new VertexPartition whose index contains only the vertices in the mask. */ def reindex(): Self[VD] = { - val hashMap = new PrimitiveKeyOpenHashMap[VertexId, VD] + val hashMap = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD] val arbitraryMerge = (a: VD, b: VD) => a for ((k, v) <- self.iterator) { hashMap.setMerge(k, v, arbitraryMerge) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala similarity index 98% rename from graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index 7b02e2ed1a9cb..57b01b6f2e1fb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -29,7 +29,7 @@ import scala.reflect._ * Under the hood, it uses our OpenHashSet implementation. */ private[graphx] -class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, +class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, @specialized(Long, Int, Double) V: ClassTag]( val keySet: OpenHashSet[K], var _values: Array[V]) extends Iterable[(K, V)] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index d2e0c01bc35ef..28fd112f2b124 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -22,6 +22,9 @@ import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkConf +import org.apache.spark.serializer.KryoSerializer + import org.apache.spark.graphx._ class EdgePartitionSuite extends FunSuite { @@ -120,4 +123,19 @@ class EdgePartitionSuite extends FunSuite { assert(!ep.isActive(-1)) assert(ep.numActives == Some(2)) } + + test("Kryo serialization") { + val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) + val a: EdgePartition[Int, Int] = makeEdgePartition(aList) + val conf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + val s = new KryoSerializer(conf).newInstance() + val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) + assert(aSer.srcIds.toList === a.srcIds.toList) + assert(aSer.dstIds.toList === a.dstIds.toList) + assert(aSer.data.toList === a.data.toList) + assert(aSer.index != null) + assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + } }