-The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark
-how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object
+The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark
+how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object
that contains information about your application.
{% highlight python %}
@@ -912,7 +912,7 @@ The following table lists some of the common transformations supported by Spark.
RDD API doc
([Scala](api/scala/index.html#org.apache.spark.rdd.RDD),
[Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html),
- [Python](api/python/pyspark.rdd.RDD-class.html))
+ [Python](api/python/pyspark.html#pyspark.RDD))
and pair RDD functions doc
([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions),
[Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html))
@@ -1025,7 +1025,7 @@ The following table lists some of the common actions supported by Spark. Refer t
RDD API doc
([Scala](api/scala/index.html#org.apache.spark.rdd.RDD),
[Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html),
- [Python](api/python/pyspark.rdd.RDD-class.html))
+ [Python](api/python/pyspark.html#pyspark.RDD))
and pair RDD functions doc
([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions),
[Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html))
@@ -1105,7 +1105,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj
These levels are set by passing a
`StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel),
[Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html),
-[Python](api/python/pyspark.storagelevel.StorageLevel-class.html))
+[Python](api/python/pyspark.html#pyspark.StorageLevel))
to `persist()`. The `cache()` method is a shorthand for using the default storage level,
which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of
storage levels is:
@@ -1374,7 +1374,7 @@ scala> accum.value
{% endhighlight %}
While this code used the built-in support for accumulators of type Int, programmers can also
-create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html).
+create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam).
The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data
type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class
representing mathematical vectors, we could write:
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 6a9d304501dc0..c984639bd34cf 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -224,11 +224,9 @@ See the [configuration page](configuration.html) for information on Spark config
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 2cbb4c967eb81..6a333fdb562a7 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -56,7 +56,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
The entry point into all relational functionality in Spark is the
-[`SQLContext`](api/python/pyspark.sql.SQLContext-class.html) class, or one
+[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one
of its decedents. To create a basic `SQLContext`, all you need is a SparkContext.
{% highlight python %}
@@ -509,8 +509,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt")
// The schema is encoded in a string
val schemaString = "name age"
-// Import Spark SQL data types and Row.
-import org.apache.spark.sql._
+// Import Row.
+import org.apache.spark.sql.Row;
+
+// Import Spark SQL data types
+import org.apache.spark.sql.types.{StructType,StructField,StringType};
// Generate the schema based on the string of schema
val schema =
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 57b074778f2b0..3ecbf2308cd44 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -133,10 +133,10 @@ The master URL passed to Spark can be in one of the following formats:
Or, for a Mesos cluster using ZooKeeper, use
mesos://zk://...
.
yarn-client | Connect to a YARN cluster in
-client mode. The cluster location will be found based on the HADOOP_CONF_DIR variable.
+client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
|
yarn-cluster | Connect to a YARN cluster in
-cluster mode. The cluster location will be found based on HADOOP_CONF_DIR.
+cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
|
diff --git a/examples/pom.xml b/examples/pom.xml
index 994071d94d0ad..7e93f0eec0b91 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../pom.xml
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 91a0a860d6c71..1f4ca4fbe7778 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -175,7 +175,8 @@ object MovieLensALS {
}
/** Compute RMSE (Root Mean Squared Error). */
- def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = {
+ def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean)
+ : Double = {
def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
index 91c9772744f18..9f22d40c15f3f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
@@ -116,7 +116,7 @@ object PowerIterationClusteringExample {
sc.stop()
}
- def generateCircle(radius: Double, n: Int) = {
+ def generateCircle(radius: Double, n: Int): Seq[(Double, Double)] = {
Seq.tabulate(n) { i =>
val theta = 2.0 * math.Pi * i / n
(radius * math.cos(theta), radius * math.sin(theta))
@@ -147,7 +147,7 @@ object PowerIterationClusteringExample {
/**
* Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel
*/
- def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = {
+ def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = {
val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma)
val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0)
val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 96c2787e35cd0..67907bbfb6d1b 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 172d447b77cda..8df7edbdcad33 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 5109b8ed87524..0b79f47647f6b 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 369856187a244..f695cff410a18 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index fc53c23abda85..3cd960d1fd1d4 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -25,16 +25,15 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
-import com.google.common.io.Files
import kafka.serializer.StringDecoder
import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
+import org.apache.spark.util.Utils
class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
@@ -60,7 +59,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
)
ssc = new StreamingContext(sparkConf, Milliseconds(500))
- tempDirectory = Files.createTempDir()
+ tempDirectory = Utils.createTempDir()
ssc.checkpoint(tempDirectory.getAbsolutePath)
}
@@ -68,10 +67,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
if (ssc != null) {
ssc.stop()
}
- if (tempDirectory != null && tempDirectory.exists()) {
- FileUtils.deleteDirectory(tempDirectory)
- tempDirectory = null
- }
+ Utils.deleteRecursively(tempDirectory)
tearDownKafka()
}
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index a344f000c5002..98f95a9a64fa0 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index e95853f005ce2..8b6a8959ac4cf 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 9b3475d7c3dc2..a50d378b34335 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index bc2f8be10c9ce..4351a8a12fe21 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 7e49a71907336..25847a1b33d9c 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index 6eb29af03f833..e14bbae4a9b6e 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index c0d534e185d7f..d38a3aa8256b7 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../pom.xml
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index dc8b4789c4b61..86f611d55aa8a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -113,7 +113,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* Collect the neighbor vertex attributes for each vertex.
*
* @note This function could be highly inefficient on power-law
- * graphs where high degree vertices may force a large ammount of
+ * graphs where high degree vertices may force a large amount of
* information to be collected to a single location.
*
* @param edgeDirection the direction along which to collect
@@ -187,7 +187,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
/**
* Join the vertices with an RDD and then apply a function from the
- * the vertex and RDD entry to a new vertex value. The input table
+ * vertex and RDD entry to a new vertex value. The input table
* should contain at most one entry for each vertex. If no entry is
* provided the map function is skipped and the old value is used.
*
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 5e55620147df8..01b013ff716fc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -78,8 +78,8 @@ object Pregel extends Logging {
*
* @param graph the input graph.
*
- * @param initialMsg the message each vertex will receive at the on
- * the first iteration
+ * @param initialMsg the message each vertex will receive at the first
+ * iteration
*
* @param maxIterations the maximum number of iterations to run for
*
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index e139959c3f5c1..570440ba4441f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -25,8 +25,8 @@ import org.apache.spark.graphx._
/**
* PageRank algorithm implementation. There are two implementations of PageRank implemented.
*
- * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number
- * of iterations:
+ * The first implementation uses the standalone [[Graph]] interface and runs PageRank
+ * for a fixed number of iterations:
* {{{
* var PR = Array.fill(n)( 1.0 )
* val oldPR = Array.fill(n)( 1.0 )
@@ -38,7 +38,7 @@ import org.apache.spark.graphx._
* }
* }}}
*
- * The second implementation uses the standalone [[Graph]] interface and runs PageRank until
+ * The second implementation uses the [[Pregel]] interface and runs PageRank until
* convergence:
*
* {{{
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index b61d9f0fbe5e4..8d15150458d26 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.graphx
import org.scalatest.FunSuite
-import com.google.common.io.Files
-
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
class GraphSuite extends FunSuite with LocalSparkContext {
@@ -369,8 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
test("checkpoint") {
- val checkpointDir = Files.createTempDir()
- checkpointDir.deleteOnExit()
+ val checkpointDir = Utils.createTempDir()
withSpark { sc =>
sc.setCheckpointDir(checkpointDir.getAbsolutePath)
val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)}
diff --git a/launcher/pom.xml b/launcher/pom.xml
index ccbd9d0419a98..0fe2814135d88 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../pom.xml
diff --git a/make-distribution.sh b/make-distribution.sh
index 9ed1abfe8c598..8162fe94c1af0 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
DISTDIR="$SPARK_HOME/dist"
SPARK_TACHYON=false
-TACHYON_VERSION="0.5.0"
+TACHYON_VERSION="0.6.1"
TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
diff --git a/mllib/pom.xml b/mllib/pom.xml
index a76704a8c2c59..4c183543e3fa8 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5bbcd2e080e07..c4a36103303a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
- * :: DeveloperAPI ::
+ * :: DeveloperApi ::
*
* Derives the output schema from the input schema and parameters.
* The schema describes the columns and types of the data.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 6131ba8832691..fc4e12773c46d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -41,7 +41,7 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
def getNumFeatures: Int = get(numFeatures)
/** @group setParam */
- def setNumFeatures(value: Int) = set(numFeatures, value)
+ def setNumFeatures(value: Int): this.type = set(numFeatures, value)
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 1a70322b4cace..5d660d1e151a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
/** @group getParam */
def getOutputCol: String = get(outputCol)
}
+
+private[ml] trait HasCheckpointInterval extends Params {
+ /**
+ * param for checkpoint interval
+ * @group param
+ */
+ val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+
+ /** @group getParam */
+ def getCheckpointInterval: Int = get(checkpointInterval)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index e3515ee81af3d..514b4ef98dc5b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.recommendation
import java.{util => ju}
+import java.io.IOException
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.netlib.util.intW
import org.apache.spark.{Logging, Partitioner}
@@ -46,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol {
+ with HasPredictionCol with HasCheckpointInterval {
/**
* Param for rank of the matrix factorization.
@@ -164,6 +166,7 @@ class ALSModel private[ml] (
itemFactors: RDD[(Int, Array[Float])])
extends Model[ALSModel] with ALSParams {
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
@@ -262,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -274,6 +280,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
+ setCheckpointInterval(10)
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
@@ -285,7 +292,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha), nonnegative = map(nonnegative))
+ alpha = map(alpha), nonnegative = map(nonnegative),
+ checkpointInterval = map(checkpointInterval))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
@@ -494,6 +502,7 @@ object ALS extends Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -521,6 +530,18 @@ object ALS extends Logging {
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+ var previousCheckpointFile: Option[String] = None
+ val shouldCheckpoint: Int => Boolean = (iter) =>
+ sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+ val deletePreviousCheckpointFile: () => Unit = () =>
+ previousCheckpointFile.foreach { file =>
+ try {
+ FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
+ } catch {
+ case e: IOException =>
+ logWarning(s"Cannot delete checkpoint file $file:", e)
+ }
+ }
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
@@ -528,19 +549,30 @@ object ALS extends Logging {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
- if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
- itemFactors.checkpoint()
- }
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+ // TODO: Generalize PeriodicGraphCheckpointer and use it here.
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
+ }
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
+ if (shouldCheckpoint(iter)) {
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint()
+ itemFactors.count() // checkpoint item factors and cut lineage
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index cbd87ea8aeb37..15ca2547d56a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -345,9 +345,13 @@ private[python] class PythonMLLibAPI extends Serializable {
def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
predict(SerDe.asTupleRDD(userAndProducts.rdd))
- def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getUserFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
- def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getProductFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
}
@@ -909,7 +913,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseVector
private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
val bytes = new Array[Byte](8 * vector.size)
val bb = ByteBuffer.wrap(bytes)
@@ -941,7 +945,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseMatrix
private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
val bytes = new Array[Byte](8 * m.values.size)
val order = ByteOrder.nativeOrder()
@@ -973,7 +977,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
val n = v.indices.size
val indiceBytes = new Array[Byte](4 * n)
@@ -1015,7 +1019,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for LabeledPoint
private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
saveObjects(out, pickler, point.label, point.features)
}
@@ -1031,7 +1035,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for Rating
private[python] class RatingPickler extends BasePickler[Rating] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val rating: Rating = obj.asInstanceOf[Rating]
saveObjects(out, pickler, rating.user, rating.product, rating.rating)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 2ebc7fa5d4234..d60e82c410979 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,6 +17,10 @@
package org.apache.spark.mllib.classification
+import java.lang.{Iterable => JIterable}
+
+import scala.collection.JavaConverters._
+
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] (
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
+ /** A Java-friendly constructor that takes three Iterable parameters. */
+ private[mllib] def this(
+ labels: JIterable[Double],
+ pi: JIterable[Double],
+ theta: JIterable[JIterable[Double]]) =
+ this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
+
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -83,10 +94,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
private object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Hard-code class name string in case it changes in the future */
- def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
@@ -174,7 +185,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
- def run(data: RDD[LabeledPoint]) = {
+ def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8956189ff1158..3b6790cce47c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -32,7 +32,7 @@ private[classification] object GLMClassificationModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index e41f941fd2c2c..0f8d6a399682d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -536,5 +536,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable
def this(array: Array[Double]) = this(Vectors.dense(array))
/** Converts the vector to a dense vector. */
- def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
index ea10bde5fa252..a8378a76d20ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -96,30 +96,30 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns precision for a given label (category)
* @param label the label.
*/
- def precision(label: Double) = {
+ def precision(label: Double): Double = {
val tp = tpPerClass(label)
val fp = fpPerClass.getOrElse(label, 0L)
- if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp)
}
/**
* Returns recall for a given label (category)
* @param label the label.
*/
- def recall(label: Double) = {
+ def recall(label: Double): Double = {
val tp = tpPerClass(label)
val fn = fnPerClass.getOrElse(label, 0L)
- if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn)
}
/**
* Returns f1-measure for a given label (category)
* @param label the label.
*/
- def f1Measure(label: Double) = {
+ def f1Measure(label: Double): Double = {
val p = precision(label)
val r = recall(label)
- if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ if((p + r) == 0) 0.0 else 2 * p * r / (p + r)
}
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
@@ -130,7 +130,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based precision
* (equals to micro-averaged document-based precision)
*/
- lazy val microPrecision = {
+ lazy val microPrecision: Double = {
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
sumTp.toDouble / (sumTp + sumFp)
}
@@ -139,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based recall
* (equals to micro-averaged document-based recall)
*/
- lazy val microRecall = {
+ lazy val microRecall: Double = {
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
sumTp.toDouble / (sumTp + sumFn)
}
@@ -148,7 +148,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based f1-measure
* (equals to micro-averaged document-based f1-measure)
*/
- lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+ lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
/**
* Returns the sequence of labels in ascending order
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0e4a4d0085895..849f44295f089 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
/**
* Trait for a local matrix.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
sealed trait Matrix extends Serializable {
/** Number of rows. */
@@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}
+@DeveloperApi
+private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
+ // set as not nullable, except values since in the future, support for binary matrices might
+ // be added for which values are not needed.
+ // the sparse matrix needs colPtrs and rowIndices, which are set as
+ // null, while building the dense matrix.
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("numRows", IntegerType, nullable = false),
+ StructField("numCols", IntegerType, nullable = false),
+ StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
+ StructField("isTransposed", BooleanType, nullable = false)
+ ))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(7)
+ obj match {
+ case sm: SparseMatrix =>
+ row.setByte(0, 0)
+ row.setInt(1, sm.numRows)
+ row.setInt(2, sm.numCols)
+ row.update(3, sm.colPtrs.toSeq)
+ row.update(4, sm.rowIndices.toSeq)
+ row.update(5, sm.values.toSeq)
+ row.setBoolean(6, sm.isTransposed)
+
+ case dm: DenseMatrix =>
+ row.setByte(0, 1)
+ row.setInt(1, dm.numRows)
+ row.setInt(2, dm.numCols)
+ row.setNullAt(3)
+ row.setNullAt(4)
+ row.update(5, dm.values.toSeq)
+ row.setBoolean(6, dm.isTransposed)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Matrix = {
+ datum match {
+ // TODO: something wrong with UDT serialization, should never happen.
+ case m: Matrix => m
+ case row: Row =>
+ require(row.length == 7,
+ s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
+ val tpe = row.getByte(0)
+ val numRows = row.getInt(1)
+ val numCols = row.getInt(2)
+ val values = row.getAs[Iterable[Double]](5).toArray
+ val isTransposed = row.getBoolean(6)
+ tpe match {
+ case 0 =>
+ val colPtrs = row.getAs[Iterable[Int]](3).toArray
+ val rowIndices = row.getAs[Iterable[Int]](4).toArray
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
+ case 1 =>
+ new DenseMatrix(numRows, numCols, values, isTransposed)
+ }
+ }
+ }
+
+ override def userClass: Class[Matrix] = classOf[Matrix]
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case v: MatrixUDT => true
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = 1994
+
+ private[spark] override def asNullable: MatrixUDT = this
+}
+
/**
* Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
* @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
* row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class DenseMatrix(
val numRows: Int,
val numCols: Int,
@@ -146,12 +235,16 @@ class DenseMatrix(
def this(numRows: Int, numCols: Int, values: Array[Double]) =
this(numRows, numCols, values, false)
- override def equals(o: Any) = o match {
+ override def equals(o: Any): Boolean = o match {
case m: DenseMatrix =>
m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
case _ => false
}
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray)
+ }
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BDM[Double](numRows, numCols, values)
@@ -173,7 +266,7 @@ class DenseMatrix(
values(index(i, j)) = v
}
- override def copy = new DenseMatrix(numRows, numCols, values.clone())
+ override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f))
@@ -356,6 +449,7 @@ object DenseMatrix {
* Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
* and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class SparseMatrix(
val numRows: Int,
val numCols: Int,
@@ -431,7 +525,9 @@ class SparseMatrix(
}
}
- override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ override def copy: SparseMatrix = {
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ }
private[mllib] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e9d25dcb7e778..2cda9b252ee06 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -183,6 +183,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}
}
+ override def hashCode: Int = 7919
+
private[spark] override def asNullable: VectorUDT = this
}
@@ -478,7 +480,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
- override def apply(i: Int) = values(i)
+ override def apply(i: Int): Double = values(i)
override def copy: DenseVector = {
new DenseVector(values.clone())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 1d253963130f1..3323ae7b1fba0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -49,7 +49,7 @@ private[mllib] class GridPartitioner(
private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt
private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt
- override val numPartitions = rowPartitions * colPartitions
+ override val numPartitions: Int = rowPartitions * colPartitions
/**
* Returns the index of the partition the input coordinate belongs to.
@@ -85,6 +85,14 @@ private[mllib] class GridPartitioner(
false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ rows: java.lang.Integer,
+ cols: java.lang.Integer,
+ rowsPerPart: java.lang.Integer,
+ colsPerPart: java.lang.Integer)
+ }
}
private[mllib] object GridPartitioner {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 405bae62ee8b6..9349ecaa13f56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -56,7 +56,7 @@ class UniformGenerator extends RandomDataGenerator[Double] {
random.nextDouble()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): UniformGenerator = new UniformGenerator()
}
@@ -75,7 +75,7 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
random.nextGaussian()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala
new file mode 100644
index 0000000000000..9213fd3f595c3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * Machine learning specific Pair RDD functions.
+ */
+@DeveloperApi
+class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Serializable {
+ /**
+ * Returns the top k (largest) elements for each key from this RDD as defined by the specified
+ * implicit Ordering[T].
+ * If the number of elements for a certain key is less than k, all of them will be returned.
+ *
+ * @param num k, the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an RDD that contains the top k values for each key
+ */
+ def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = {
+ self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))(
+ seqOp = (queue, item) => {
+ queue += item
+ queue
+ },
+ combOp = (queue1, queue2) => {
+ queue1 ++= queue2
+ queue1
+ }
+ ).mapValues(_.toArray.sorted(ord.reverse))
+ }
+}
+
+@DeveloperApi
+object MLPairRDDFunctions {
+ /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */
+ implicit def fromPairRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): MLPairRDDFunctions[K, V] =
+ new MLPairRDDFunctions[K, V](rdd)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index caacab943030b..dddefe1944e9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -82,6 +82,9 @@ class ALS private (
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+ /** checkpoint interval */
+ private var checkpointInterval: Int = 10
+
/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
@@ -182,6 +185,19 @@ class ALS private (
this
}
+ /**
+ * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with
+ * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps
+ * with eliminating temporary shuffle files on disk, which can be important when there are many
+ * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]],
+ * this setting is ignored.
+ */
+ @DeveloperApi
+ def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -212,6 +228,7 @@ class ALS private (
nonnegative = nonnegative,
intermediateRDDStorageLevel = intermediateRDDStorageLevel,
finalRDDStorageLevel = StorageLevel.NONE,
+ checkpointInterval = checkpointInterval,
seed = seed)
val userFactors = floatUserFactors
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index bd7e340ca2d8e..b55944f74f623 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -32,7 +32,7 @@ private[regression] object GLMRegressionModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 8d5c36da32bdb..ada227c200a79 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -83,10 +83,13 @@ class Strategy (
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- def isMulticlassClassification =
+ def isMulticlassClassification: Boolean = {
algo == Classification && numClasses > 2
- def isMulticlassWithCategoricalFeatures
- = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
+
+ def isMulticlassWithCategoricalFeatures: Boolean = {
+ isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index b7950e00786ab..5ac10f3fd32dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -71,7 +71,7 @@ object Entropy extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c946db9c0d1c8..19d318203c344 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -67,7 +67,7 @@ object Gini extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index df9eafa5da16a..7104a7fa4dd4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -58,7 +58,7 @@ object Variance extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d1bde15e6b150..793dd664c5d5a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean absolute error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- math.abs(err)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = label - prediction
+ math.abs(err)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 55213e695638c..51b1aed167b66 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -50,20 +50,10 @@ object LogLoss extends Loss {
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean log loss of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { case point =>
- val prediction = model.predict(point.features)
- val margin = 2.0 * point.label * prediction
- // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
- 2.0 * MLUtils.log1pExp(-margin)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val margin = 2.0 * label * prediction
+ // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
+ 2.0 * MLUtils.log1pExp(-margin)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index e1169d9f66ea4..357869ff6b333 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -47,6 +47,18 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Measure of model error on data
*/
- def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map(point => computeError(model.predict(point.features), point.label)).mean()
+ }
+
+ /**
+ * Method to calculate loss when the predictions are already known.
+ * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
+ * predicted values from previously fit trees.
+ * @param prediction Predicted label.
+ * @param label True label.
+ * @return Measure of model error on datapoint.
+ */
+ def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 50ecaa2f86f35..b990707ca4525 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -47,18 +47,9 @@ object SquaredError extends Loss {
2.0 * (model.predict(point.features) - point.label)
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean squared error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- err * err
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = prediction - label
+ err * err
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 8a57ebc387d01..c9bafd60fba4d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -120,10 +120,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
private[tree] object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel"
case class PredictData(predict: Double, prob: Double) {
def toPredict: Predict = new Predict(predict, prob)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 80990aa9a603f..f209fdafd3653 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,23 +38,32 @@ class InformationGainStats(
val leftPredict: Predict,
val rightPredict: Predict) extends Serializable {
- override def toString = {
+ override def toString: String = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
- override def equals(o: Any) =
- o match {
- case other: InformationGainStats => {
- gain == other.gain &&
- impurity == other.impurity &&
- leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity &&
- leftPredict == other.leftPredict &&
- rightPredict == other.rightPredict
- }
- case _ => false
- }
+ override def equals(o: Any): Boolean = o match {
+ case other: InformationGainStats =>
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
+
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ gain: java.lang.Double,
+ impurity: java.lang.Double,
+ leftImpurity: java.lang.Double,
+ rightImpurity: java.lang.Double,
+ leftPredict,
+ rightPredict)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index d961081d185e9..4f72bb8014cc0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -50,8 +50,10 @@ class Node (
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) extends Serializable with Logging {
- override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ override def toString: String = {
+ "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ }
/**
* build the left node and right nodes if not leaf
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index ad4c0dbbfb3e5..25990af7c6cf7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -29,7 +29,7 @@ class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
- override def toString = {
+ override def toString: String = {
"predict = %f, prob = %f".format(predict, prob)
}
@@ -39,4 +39,8 @@ class Predict(
case _ => false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index b7a85f58544a3..fb35e70a8d077 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -38,9 +38,10 @@ case class Split(
featureType: FeatureType,
categories: List[Double]) {
- override def toString =
+ override def toString: String = {
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
", categories = " + categories
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 30a8f7ca301af..1950254b2aa6d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
@@ -79,7 +81,7 @@ object RandomForestModel extends Loader[RandomForestModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel"
}
}
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
}
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param loss evaluation metric.
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ loss: Loss): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+
+ var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
+ val pred = treeWeights(0) * trees(0).predict(i.features)
+ val error = loss.computeError(pred, i.label)
+ (pred, error)
+ }
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ // Avoid the model being copied across numIterations.
+ val broadcastTrees = sc.broadcast(trees)
+ val broadcastWeights = sc.broadcast(treeWeights)
+
+ (1 until numIterations).map { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = broadcastWeights.value(nTree)
+ iter.map {
+ case (point, (pred, error)) => {
+ val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ broadcastWeights.unpersist()
+ evaluationArray
+ }
+
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -130,7 +184,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
}
}
@@ -257,7 +311,7 @@ private[tree] object TreeEnsembleModel extends Logging {
import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
case class Metadata(
algo: String,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bb86bafc0eb0a..0bb06e9e8ac9c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.recommendation
+import java.io.File
import java.util.Random
import scala.collection.mutable
@@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.util.Utils
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
private var sqlContext: SQLContext = _
+ private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
+ tempDir = Utils.createTempDir()
+ sc.setCheckpointDir(tempDir.getAbsolutePath)
sqlContext = new SQLContext(sc)
}
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
test("LocalIndexEncoder") {
val random = new Random
for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
@@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}.count()
}
}
+
+ test("als with large number of iterations") {
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
+ ALS.train(
+ ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index c098b5458fe6b..96f677db3f377 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}
+
+ test("MatrixUDT") {
+ val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
+ val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
+ val dm3 = new DenseMatrix(0, 0, Array())
+ val sm1 = dm1.toSparse
+ val sm2 = dm2.toSparse
+ val sm3 = dm3.toSparse
+ val mUDT = new MatrixUDT()
+ Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
+ mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000000..1ac7c12c4e8e6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
+
+class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+ test("topByKey") {
+ val topMap = sc.parallelize(Array((1, 1), (1, 2), (3, 2), (3, 7), (3, 5), (5, 1), (5, 3)), 2)
+ .topByKey(2)
+ .collectAsMap()
+
+ assert(topMap.size === 3)
+ assert(topMap(1) === Array(2, 1))
+ assert(topMap(3) === Array(7, 5))
+ assert(topMap(5) === Array(3, 1))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index b437aeaaf0547..55b0bac7d49fe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
- assert(gbtValidate.numTrees !== numIterations)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
}
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 74437f37c47e4..7b51845206f4a 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index a2bcca26d8344..7dc7c65825e34 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index cea7a20c223e2..1e2e9c80af6cc 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/pom.xml b/pom.xml
index 6fc56a86d44ac..23bb16130b504 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent_2.10
-
1.3.0-SNAPSHOT
+
1.4.0-SNAPSHOT
pom
Spark Project Parent POM
http://spark.apache.org/
@@ -120,7 +120,7 @@
shaded-protobuf
1.7.10
1.2.17
-
1.0.4
+
2.2.0
2.4.1
${hadoop.version}
0.98.7-hadoop1
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index f0cbf4e57b8c5..dde92949fa175 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -91,7 +91,7 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.2.0"
+ val previousSparkVersion = "1.3.0"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index a6b07fa7cddec..328d59485a731 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -16,6 +16,7 @@
*/
import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.core.ProblemFilters._
/**
* Additional excludes for checking of Spark's binary compatibility.
@@ -33,6 +34,19 @@ import com.typesafe.tools.mima.core._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.4") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("ml"),
+ // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
+ // These are needed if checking against the sbt build, since they are part of
+ // the maven-generated artifacts in 1.3.
+ excludePackage("org.spark-project.jetty"),
+ MimaBuild.excludeSparkPackage("unused"),
+ ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional")
+ )
+
case v if v.startsWith("1.3") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e4765173709e8..6766f3ebb8894 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,9 +21,10 @@
from numpy import array
from pyspark import RDD
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
@@ -99,6 +100,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
1
>>> lrm.predict(SparseVector(2, {0: 1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LogisticRegressionModel.load(sc, path)
+ >>> sameModel.predict(array([0.0, 1.0]))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: 1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(LogisticRegressionModel, self).__init__(weights, intercept)
@@ -124,6 +137,22 @@ def predict(self, x):
else:
return 1 if prob > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = LogisticRegressionModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class LogisticRegressionWithSGD(object):
@@ -243,6 +272,18 @@ class SVMModel(LinearBinaryClassificationModel):
1
>>> svm.predict(SparseVector(2, {0: -1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> svm.save(sc, path)
+ >>> sameModel = SVMModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: -1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(SVMModel, self).__init__(weights, intercept)
@@ -263,6 +304,22 @@ def predict(self, x):
else:
return 1 if margin > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = SVMModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class SVMWithSGD(object):
@@ -303,7 +360,8 @@ def train(rdd, i):
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
-class NaiveBayesModel(object):
+@inherit_doc
+class NaiveBayesModel(Saveable, Loader):
"""
Model for Naive Bayes classifiers.
@@ -334,6 +392,16 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = NaiveBayesModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self, labels, pi, theta):
@@ -348,6 +416,23 @@ def predict(self, x):
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
+ def save(self, sc, path):
+ java_labels = _py2java(sc, self.labels.tolist())
+ java_pi = _py2java(sc, self.pi.tolist())
+ java_theta = _py2java(sc, self.theta.tolist())
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel(
+ java_labels, java_pi, java_theta)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
+ sc._jsc.sc(), path)
+ py_labels = _java2py(sc, java_model.labels())
+ py_pi = _java2py(sc, java_model.pi())
+ py_theta = _java2py(sc, java_model.theta())
+ return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
+
class NaiveBayes(object):
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c21ad578793f..414a0ada80787 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,8 +18,9 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc, inherit_doc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.util import Saveable, Loader
__all__ = ['LabeledPoint', 'LinearModel',
'LinearRegressionModel', 'LinearRegressionWithSGD',
@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LinearRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LinearRegressionModel(weights, intercept)
+ return model
# train_func should take two parameters, namely data and initial_weights, and
@@ -135,7 +163,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
- initial_weights = initial_weights or [0.0] * len(data.first().features)
+ if initial_weights is None:
+ initial_weights = [0.0] * len(data.first().features)
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)
@@ -199,6 +228,20 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LassoModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -211,6 +254,19 @@ class LassoModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LassoModel(weights, intercept)
+ return model
class LassoWithSGD(object):
@@ -246,6 +302,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = RidgeRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -258,6 +328,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = RidgeRegressionModel(weights, intercept)
+ return model
class RidgeRegressionWithSGD(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5328d99b69684..155019638f806 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -323,6 +323,13 @@ def test_regression(self):
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)
+ try:
+ LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ except ValueError:
+ self.fail()
+
class StatTests(PySparkTestCase):
# SPARK-4023
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index e877c720ac77a..c5c3468eb95e9 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -20,7 +20,6 @@
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
-from pyspark.mllib.regression import LabeledPoint
class MLUtils(object):
@@ -50,6 +49,7 @@ def _parse_libsvm_line(line, multiclass=None):
@staticmethod
def _convert_labeled_point_to_libsvm(p):
"""Converts a LabeledPoint to a string in LIBSVM format."""
+ from pyspark.mllib.regression import LabeledPoint
assert isinstance(p, LabeledPoint)
items = [str(p.label)]
v = _convert_to_vector(p.features)
@@ -92,6 +92,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
@@ -110,6 +111,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
>>> print examples[2]
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
"""
+ from pyspark.mllib.regression import LabeledPoint
if multiclass is not None:
warnings.warn("deprecated", DeprecationWarning)
@@ -130,6 +132,7 @@ def saveAsLibSVMFile(data, dir):
>>> from tempfile import NamedTemporaryFile
>>> from fileinput import input
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> from glob import glob
>>> from pyspark.mllib.util import MLUtils
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
@@ -156,6 +159,7 @@ def loadLabeledPoints(sc, path, minPartitions=None):
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
>>> tempFile = NamedTemporaryFile(delete=True)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bf17f513c0bc3..c337a43c8a7fc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -346,6 +346,12 @@ def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD.
+ :param withReplacement: can elements be sampled multiple times (replaced when sampled out)
+ :param fraction: expected size of the sample as a fraction of this RDD's size
+ without replacement: probability that each element is chosen; fraction must be [0, 1]
+ with replacement: expected number of times each element is chosen; fraction must be >= 0
+ :param seed: seed for the random number generator
+
>>> rdd = sc.parallelize(range(100), 4)
>>> rdd.sample(False, 0.1, 81).count()
10
diff --git a/repl/pom.xml b/repl/pom.xml
index 295f88ea3ecf9..edfa1c7f2c29c 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../pom.xml
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index fbef5b25ba688..14f5e9ed4f25e 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -21,11 +21,9 @@ import java.io._
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.tools.nsc.interpreter.SparkILoop
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.SparkContext
@@ -196,8 +194,7 @@ class ReplSuite extends FunSuite {
}
test("interacting with files") {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 8ad026dbdf8ff..3dea2ee76542f 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 54ab13ca352d2..ea7d44a3723d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* This is currently included mostly for illustrative purposes. Users wanting more complete support
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
-class SqlParser extends AbstractSparkSQLParser {
+class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
def parseExpression(input: String): Expression = {
// Initialize the Keywords.
@@ -61,11 +61,8 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CAST = Keyword("CAST")
protected val COALESCE = Keyword("COALESCE")
protected val COUNT = Keyword("COUNT")
- protected val DATE = Keyword("DATE")
- protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
- protected val DOUBLE = Keyword("DOUBLE")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
@@ -78,7 +75,6 @@ class SqlParser extends AbstractSparkSQLParser {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
- protected val INT = Keyword("INT")
protected val INSERT = Keyword("INSERT")
protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
@@ -105,13 +101,11 @@ class SqlParser extends AbstractSparkSQLParser {
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val SQRT = Keyword("SQRT")
- protected val STRING = Keyword("STRING")
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
protected val THEN = Keyword("THEN")
- protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
protected val UNION = Keyword("UNION")
protected val UPPER = Keyword("UPPER")
@@ -315,7 +309,9 @@ class SqlParser extends AbstractSparkSQLParser {
)
protected lazy val cast: Parser[Expression] =
- CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }
+ CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
+ case exp ~ t => Cast(exp, t)
+ }
protected lazy val literal: Parser[Literal] =
( numericLiteral
@@ -387,19 +383,4 @@ class SqlParser extends AbstractSparkSQLParser {
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
}
-
- protected lazy val dataType: Parser[DataType] =
- ( STRING ^^^ StringType
- | TIMESTAMP ^^^ TimestampType
- | DOUBLE ^^^ DoubleType
- | fixedDecimalType
- | DECIMAL ^^^ DecimalType.Unlimited
- | DATE ^^^ DateType
- | INT ^^^ IntegerType
- )
-
- protected lazy val fixedDecimalType: Parser[DataType] =
- (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
- case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index a9ba0be596349..adaeab0b5c027 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.Star
protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a match {
@@ -115,7 +114,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
- override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
+ override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"
override def isEmpty: Boolean = baseSet.isEmpty
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 735b7488fdcbd..5297d1e31246c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -346,13 +346,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
SplitEvaluation(
- Cast(Sum(partialSum.toAttribute), dataType),
+ Cast(CombineSum(partialSum.toAttribute), dataType),
partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
- Sum(partialSum.toAttribute),
+ CombineSum(partialSum.toAttribute),
partialSum :: Nil)
}
}
@@ -360,6 +360,30 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def newInstance() = new SumFunction(child, this)
}
+/**
+ * Sum should satisfy 3 cases:
+ * 1) sum of all null values = zero
+ * 2) sum for table column with no data = null
+ * 3) sum of column with null and not null values = sum of not null values
+ * Require separate CombineSum Expression and function as it has to distinguish "No data" case
+ * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
+ * Combining PartitionLevel InputData
+ * <-- null
+ * Zero <-- Zero <-- null
+ *
+ * <-- null <-- no data
+ * null <-- null <-- no data
+ */
+case class CombineSum(child: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = child :: Nil
+ override def nullable = true
+ override def dataType = child.dataType
+ override def toString = s"CombineSum($child)"
+ override def newInstance() = new CombineSumFunction(child, this)
+}
+
case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -565,7 +589,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
private val sum = MutableLiteral(null, calcType)
- private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
+ private val addFunction =
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
@@ -580,6 +605,43 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
}
+case class CombineSumFunction(expr: Expression, base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
+
+ private val sum = MutableLiteral(null, calcType)
+
+ private val addFunction =
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
+
+ override def update(input: Row): Unit = {
+ val result = expr.eval(input)
+ // partial sum result can be null only when no input rows present
+ if(result != null) {
+ sum.update(addFunction, input)
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(sum, dataType).eval(null)
+ case _ => sum.eval(null)
+ }
+ }
+}
+
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
index 80c7dfd376c96..528e38a50a740 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.rules
-import org.apache.spark.sql.catalyst.util
+import org.apache.spark.util.Utils
/**
* A collection of generators that build custom bytecode at runtime for performing the evaluation
@@ -52,7 +52,7 @@ package object codegen {
@DeveloperApi
object DumpByteCode {
import scala.sys.process._
- val dumpDirectory = util.getTempFilePath("sparkSqlByteCode")
+ val dumpDirectory = Utils.createTempDir()
dumpDirectory.mkdir()
def apply(obj: Any): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index faa366771824b..f03d6f71a9fae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -146,6 +146,27 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
result
}
+ override def equals(o: Any): Boolean = o match {
+ case other: Row =>
+ if (values.length != other.length) {
+ return false
+ }
+
+ var i = 0
+ while (i < values.length) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (apply(i) != other.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+
+ case _ => false
+ }
+
def copy() = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 1e7b449d75b80..384fe53a68362 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -289,6 +289,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
case object NoRelation extends LeafNode {
override def output = Nil
+
+ /**
+ * Computes [[Statistics]] for this plan. The default implementation assumes the output
+ * cardinality is the product of of all child plan's cardinality, i.e. applies in the case
+ * of cartesian joins.
+ *
+ * [[LeafNode]]s must override this.
+ */
+ override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index d8da45ae70c4b..feed50f9a2a2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -19,20 +19,9 @@ package org.apache.spark.sql.catalyst
import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}
-import org.apache.spark.util.{Utils => SparkUtils}
+import org.apache.spark.util.Utils
package object util {
- /**
- * Returns a path to a temporary file that probably does not exist.
- * Note, there is always the race condition that someone created this
- * file since the last time we checked. Thus, this shouldn't be used
- * for anything security conscious.
- */
- def getTempFilePath(prefix: String, suffix: String = ""): File = {
- val tempFile = File.createTempFile(prefix, suffix)
- tempFile.delete()
- tempFile
- }
def fileToString(file: File, encoding: String = "UTF-8") = {
val inStream = new FileInputStream(file)
@@ -56,7 +45,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
- classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
+ classLoader: ClassLoader = Utils.getSparkClassLoader) = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
new file mode 100644
index 0000000000000..89278f7dbc806
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.language.implicitConversions
+import scala.util.matching.Regex
+import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+
+import org.apache.spark.sql.catalyst.SqlLexical
+
+/**
+ * This is a data type parser that can be used to parse string representations of data types
+ * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser.
+ */
+private[sql] trait DataTypeParser extends StandardTokenParsers {
+
+ // This is used to create a parser from a regex. We are using regexes for data type strings
+ // since these strings can be also used as column names or field names.
+ import lexical.Identifier
+ implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
+ s"identifier matching regex ${regex}",
+ { case Identifier(str) if regex.unapplySeq(str).isDefined => str }
+ )
+
+ protected lazy val primitiveType: Parser[DataType] =
+ "(?i)string".r ^^^ StringType |
+ "(?i)float".r ^^^ FloatType |
+ "(?i)int".r ^^^ IntegerType |
+ "(?i)tinyint".r ^^^ ByteType |
+ "(?i)smallint".r ^^^ ShortType |
+ "(?i)double".r ^^^ DoubleType |
+ "(?i)bigint".r ^^^ LongType |
+ "(?i)binary".r ^^^ BinaryType |
+ "(?i)boolean".r ^^^ BooleanType |
+ fixedDecimalType |
+ "(?i)decimal".r ^^^ DecimalType.Unlimited |
+ "(?i)date".r ^^^ DateType |
+ "(?i)timestamp".r ^^^ TimestampType |
+ varchar
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale =>
+ DecimalType(precision.toInt, scale.toInt)
+ }
+
+ protected lazy val varchar: Parser[DataType] =
+ "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType
+
+ protected lazy val arrayType: Parser[DataType] =
+ "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ {
+ case tpe => ArrayType(tpe)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ident ~ ":" ~ dataType ^^ {
+ case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
+ case fields => new StructType(fields.toArray)
+ }) |
+ ("(?i)struct".r ~ "<>" ^^^ StructType(Nil))
+
+ protected lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
+
+ def toDataType(dataTypeString: String): DataType = synchronized {
+ phrase(dataType)(new lexical.Scanner(dataTypeString)) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString))
+ }
+ }
+
+ private def failMessage(dataTypeString: String): String = {
+ s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " +
+ "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " +
+ "Please note that backtick itself is not supported in a field name."
+ }
+}
+
+private[sql] object DataTypeParser {
+ lazy val dataTypeParser = new DataTypeParser {
+ override val lexical = new SqlLexical
+ }
+
+ def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
+}
+
+/** The exception thrown from the [[DataTypeParser]]. */
+protected[sql] class DataTypeException(message: String) extends Exception(message)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 21cc6cea4bf54..994c5202c15dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case d: Decimal =>
compare(d) == 0
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index bf39603d13bd5..d973144de3468 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -307,7 +307,7 @@ protected[sql] object NativeType {
protected[sql] trait PrimitiveType extends DataType {
- override def isPrimitive = true
+ override def isPrimitive: Boolean = true
}
@@ -442,7 +442,7 @@ class TimestampType private() extends NativeType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = new Ordering[JvmType] {
- def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
+ def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y)
}
/**
@@ -542,7 +542,7 @@ class LongType private() extends IntegralType {
*/
override def defaultSize: Int = 8
- override def simpleString = "bigint"
+ override def simpleString: String = "bigint"
private[spark] override def asNullable: LongType = this
}
@@ -572,7 +572,7 @@ class IntegerType private() extends IntegralType {
*/
override def defaultSize: Int = 4
- override def simpleString = "int"
+ override def simpleString: String = "int"
private[spark] override def asNullable: IntegerType = this
}
@@ -602,7 +602,7 @@ class ShortType private() extends IntegralType {
*/
override def defaultSize: Int = 2
- override def simpleString = "smallint"
+ override def simpleString: String = "smallint"
private[spark] override def asNullable: ShortType = this
}
@@ -632,7 +632,7 @@ class ByteType private() extends IntegralType {
*/
override def defaultSize: Int = 1
- override def simpleString = "tinyint"
+ override def simpleString: String = "tinyint"
private[spark] override def asNullable: ByteType = this
}
@@ -696,7 +696,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
*/
override def defaultSize: Int = 4096
- override def simpleString = precisionInfo match {
+ override def simpleString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}
@@ -836,7 +836,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
*/
override def defaultSize: Int = 100 * elementType.defaultSize
- override def simpleString = s"array<${elementType.simpleString}>"
+ override def simpleString: String = s"array<${elementType.simpleString}>"
private[spark] override def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
@@ -1065,7 +1065,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
- override def simpleString = {
+ override def simpleString: String = {
val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}")
s"struct<${fieldTypes.mkString(",")}>"
}
@@ -1142,7 +1142,7 @@ case class MapType(
*/
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
- override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
+ override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
private[spark] override def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
new file mode 100644
index 0000000000000..1ba21b64603ac
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
@@ -0,0 +1,116 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+class DataTypeParserSuite extends FunSuite {
+
+ def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
+ test(s"parse ${dataTypeString.replace("\n", "")}") {
+ assert(DataTypeParser(dataTypeString) === expectedDataType)
+ }
+ }
+
+ def unsupported(dataTypeString: String): Unit = {
+ test(s"$dataTypeString is not supported") {
+ intercept[DataTypeException](DataTypeParser(dataTypeString))
+ }
+ }
+
+ checkDataType("int", IntegerType)
+ checkDataType("BooLean", BooleanType)
+ checkDataType("tinYint", ByteType)
+ checkDataType("smallINT", ShortType)
+ checkDataType("INT", IntegerType)
+ checkDataType("bigint", LongType)
+ checkDataType("float", FloatType)
+ checkDataType("dOUBle", DoubleType)
+ checkDataType("decimal(10, 5)", DecimalType(10, 5))
+ checkDataType("decimal", DecimalType.Unlimited)
+ checkDataType("DATE", DateType)
+ checkDataType("timestamp", TimestampType)
+ checkDataType("string", StringType)
+ checkDataType("varchAr(20)", StringType)
+ checkDataType("BINARY", BinaryType)
+
+ checkDataType("array
", ArrayType(DoubleType, true))
+ checkDataType("Array