diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 937261de00e3a..1bd2d87e00796 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -32,10 +32,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"
attachPage(new StagePage(this))
attachPage(new PoolPage(this))
- def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ def isFairScheduler: Boolean = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
- def handleKillRequest(request: HttpServletRequest) = {
- if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) {
+ def handleKillRequest(request: HttpServletRequest): Unit = {
+ if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index dbf1ceeda1878..711a3697bda15 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -94,11 +94,11 @@ private[jobs] object UIData {
var taskData = new HashMap[Long, TaskUIData]
var executorSummary = new HashMap[String, ExecutorSummary]
- def hasInput = inputBytes > 0
- def hasOutput = outputBytes > 0
- def hasShuffleRead = shuffleReadTotalBytes > 0
- def hasShuffleWrite = shuffleWriteBytes > 0
- def hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0
+ def hasInput: Boolean = inputBytes > 0
+ def hasOutput: Boolean = outputBytes > 0
+ def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0
+ def hasShuffleWrite: Boolean = shuffleWriteBytes > 0
+ def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index a81291d505583..045bd784990d1 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -40,10 +40,10 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag
class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener {
private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing
- def storageStatusList = storageStatusListener.storageStatusList
+ def storageStatusList: Seq[StorageStatus] = storageStatusListener.storageStatusList
/** Filter RDD info to include only those with cached partitions */
- def rddInfoList = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq
+ def rddInfoList: Seq[RDDInfo] = _rddInfoMap.values.filter(_.numCachedPartitions > 0).toSeq
/** Update the storage info of the RDDs whose blocks are among the given updated blocks */
private def updateRDDInfo(updatedBlocks: Seq[(BlockId, BlockStatus)]): Unit = {
@@ -56,19 +56,19 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar
* Assumes the storage status list is fully up-to-date. This implies the corresponding
* StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener.
*/
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
val metrics = taskEnd.taskMetrics
if (metrics != null && metrics.updatedBlocks.isDefined) {
updateRDDInfo(metrics.updatedBlocks.get)
}
}
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
val rddInfos = stageSubmitted.stageInfo.rddInfos
rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
}
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
// Remove all partitions that are no longer cached in current completed stage
val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet
_rddInfoMap.retain { case (id, info) =>
@@ -76,7 +76,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar
}
}
- override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized {
+ override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
_rddInfoMap.remove(unpersistRDD.rddId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index 390310243ee0a..9044aaeef2d48 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -27,8 +27,8 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat
// scalastyle:on
private[this] var completed = false
- def next() = sub.next()
- def hasNext = {
+ def next(): A = sub.next()
+ def hasNext: Boolean = {
val r = sub.hasNext
if (!r && !completed) {
completed = true
@@ -37,13 +37,13 @@ abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterat
r
}
- def completion()
+ def completion(): Unit
}
private[spark] object CompletionIterator {
- def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
+ def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A, I] = {
new CompletionIterator[A,I](sub) {
- def completion() = completionFunction
+ def completion(): Unit = completionFunction
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
index a465298c8c5ab..9aea8efa38c7a 100644
--- a/core/src/main/scala/org/apache/spark/util/Distribution.scala
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -57,7 +57,7 @@ private[spark] class Distribution(val data: Array[Double], val startIdx: Int, va
out.println
}
- def statCounter = StatCounter(data.slice(startIdx, endIdx))
+ def statCounter: StatCounter = StatCounter(data.slice(startIdx, endIdx))
/**
* print a summary of this distribution to the given PrintStream.
diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
index cf89c1782fd67..1718554061985 100644
--- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala
+++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
@@ -39,31 +39,27 @@ private[spark] class ManualClock(private var time: Long) extends Clock {
/**
* @param timeToSet new time (in milliseconds) that the clock should represent
*/
- def setTime(timeToSet: Long) =
- synchronized {
- time = timeToSet
- notifyAll()
- }
+ def setTime(timeToSet: Long): Unit = synchronized {
+ time = timeToSet
+ notifyAll()
+ }
/**
* @param timeToAdd time (in milliseconds) to add to the clock's time
*/
- def advance(timeToAdd: Long) =
- synchronized {
- time += timeToAdd
- notifyAll()
- }
+ def advance(timeToAdd: Long): Unit = synchronized {
+ time += timeToAdd
+ notifyAll()
+ }
/**
* @param targetTime block until the clock time is set or advanced to at least this time
* @return current time reported by the clock when waiting finishes
*/
- def waitTillTime(targetTime: Long): Long =
- synchronized {
- while (time < targetTime) {
- wait(100)
- }
- getTimeMillis()
+ def waitTillTime(targetTime: Long): Long = synchronized {
+ while (time < targetTime) {
+ wait(100)
}
-
+ getTimeMillis()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index ac40f19ed6799..375ed430bde45 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -67,14 +67,15 @@ private[spark] object MetadataCleanerType extends Enumeration {
type MetadataCleanerType = Value
- def systemProperty(which: MetadataCleanerType.MetadataCleanerType) =
- "spark.cleaner.ttl." + which.toString
+ def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = {
+ "spark.cleaner.ttl." + which.toString
+ }
}
// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the
// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
private[spark] object MetadataCleaner {
- def getDelaySeconds(conf: SparkConf) = {
+ def getDelaySeconds(conf: SparkConf): Int = {
conf.getInt("spark.cleaner.ttl", -1)
}
diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
index 74fa77b68de0b..dad888548ed10 100644
--- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala
+++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
@@ -43,7 +43,7 @@ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef
this
}
- override def toString = "(" + _1 + "," + _2 + ")"
+ override def toString: String = "(" + _1 + "," + _2 + ")"
override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]]
}
diff --git a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
index 6d8d9e8da3678..73d126ff6254e 100644
--- a/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala
@@ -22,7 +22,7 @@ package org.apache.spark.util
*/
private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) {
- override def findClass(name: String) = {
+ override def findClass(name: String): Class[_] = {
super.findClass(name)
}
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
index 770ff9d5ad6ae..a06b6f84ef11b 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
@@ -27,7 +27,7 @@ import java.nio.channels.Channels
*/
private[spark]
class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable {
- def value = buffer
+ def value: ByteBuffer = buffer
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
index d80eed455c427..8586da1996cf3 100644
--- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala
+++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
@@ -141,8 +141,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
object StatCounter {
/** Build a StatCounter from a list of values. */
- def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+ def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)
/** Build a StatCounter from a list of values passed as variable-length arguments. */
- def apply(values: Double*) = new StatCounter(values)
+ def apply(values: Double*): StatCounter = new StatCounter(values)
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
index f5be5856c2109..310c0c109416c 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
@@ -82,7 +82,7 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo
this
}
- override def update(key: A, value: B) = this += ((key, value))
+ override def update(key: A, value: B): Unit = this += ((key, value))
override def apply(key: A): B = internalMap.apply(key)
@@ -92,14 +92,14 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo
override def size: Int = internalMap.size
- override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
+ override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f)
def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
def toMap: Map[A, B] = iterator.toMap
/** Remove old key-value pairs with timestamps earlier than `threshTime`. */
- def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
+ def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime)
/** Remove entries with values that are no longer strongly reachable. */
def clearNullValues() {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 91aa70870ab20..d9a671687aad0 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -85,7 +85,7 @@ private[spark] object Utils extends Logging {
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
- override def resolveClass(desc: ObjectStreamClass) =
+ override def resolveClass(desc: ObjectStreamClass): Class[_] =
Class.forName(desc.getName, false, loader)
}
ois.readObject.asInstanceOf[T]
@@ -106,11 +106,10 @@ private[spark] object Utils extends Logging {
/** Serialize via nested stream using specific serializer */
def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(
- f: SerializationStream => Unit) = {
+ f: SerializationStream => Unit): Unit = {
val osWrapper = ser.serializeStream(new OutputStream {
- def write(b: Int) = os.write(b)
-
- override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len)
+ override def write(b: Int): Unit = os.write(b)
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
})
try {
f(osWrapper)
@@ -121,10 +120,9 @@ private[spark] object Utils extends Logging {
/** Deserialize via nested stream using specific serializer */
def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(
- f: DeserializationStream => Unit) = {
+ f: DeserializationStream => Unit): Unit = {
val isWrapper = ser.deserializeStream(new InputStream {
- def read(): Int = is.read()
-
+ override def read(): Int = is.read()
override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len)
})
try {
@@ -137,7 +135,7 @@ private[spark] object Utils extends Logging {
/**
* Get the ClassLoader which loaded Spark.
*/
- def getSparkClassLoader = getClass.getClassLoader
+ def getSparkClassLoader: ClassLoader = getClass.getClassLoader
/**
* Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
@@ -146,7 +144,7 @@ private[spark] object Utils extends Logging {
* This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
* active loader when setting up ClassLoader delegation chains.
*/
- def getContextOrSparkClassLoader =
+ def getContextOrSparkClassLoader: ClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
/** Determines whether the provided class is loadable in the current thread. */
@@ -155,12 +153,14 @@ private[spark] object Utils extends Logging {
}
/** Preferred alternative to Class.forName(className) */
- def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader)
+ def classForName(className: String): Class[_] = {
+ Class.forName(className, true, getContextOrSparkClassLoader)
+ }
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
- def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
+ def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
@@ -288,7 +288,7 @@ private[spark] object Utils extends Logging {
} catch { case e: SecurityException => dir = null; }
}
- dir
+ dir.getCanonicalFile
}
/**
@@ -1557,7 +1557,7 @@ private[spark] object Utils extends Logging {
/** Return the class name of the given object, removing all dollar signs */
- def getFormattedClassName(obj: AnyRef) = {
+ def getFormattedClassName(obj: AnyRef): String = {
obj.getClass.getSimpleName.replace("$", "")
}
@@ -1570,7 +1570,7 @@ private[spark] object Utils extends Logging {
}
/** Return an empty JSON object */
- def emptyJson = JObject(List[JField]())
+ def emptyJson: JsonAST.JObject = JObject(List[JField]())
/**
* Return a Hadoop FileSystem with the scheme encoded in the given path.
@@ -1618,7 +1618,7 @@ private[spark] object Utils extends Logging {
/**
* Indicates whether Spark is currently running unit tests.
*/
- def isTesting = {
+ def isTesting: Boolean = {
sys.env.contains("SPARK_TESTING") || sys.props.contains("spark.testing")
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index af1f64649f354..f79e8e0491ea1 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -156,10 +156,10 @@ class BitSet(numBits: Int) extends Serializable {
/**
* Get an iterator over the set bits.
*/
- def iterator = new Iterator[Int] {
+ def iterator: Iterator[Int] = new Iterator[Int] {
var ind = nextSetBit(0)
override def hasNext: Boolean = ind >= 0
- override def next() = {
+ override def next(): Int = {
val tmp = ind
ind = nextSetBit(ind + 1)
tmp
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 8a0f5a602de12..9ff4744593d4d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -159,7 +159,7 @@ class ExternalAppendOnlyMap[K, V, C](
val batchSizes = new ArrayBuffer[Long]
// Flush the disk writer's contents to disk, and update relevant variables
- def flush() = {
+ def flush(): Unit = {
val w = writer
writer = null
w.commitAndClose()
@@ -355,7 +355,7 @@ class ExternalAppendOnlyMap[K, V, C](
val pairs: ArrayBuffer[(K, C)])
extends Comparable[StreamBuffer] {
- def isEmpty = pairs.length == 0
+ def isEmpty: Boolean = pairs.length == 0
// Invalid if there are no more pairs in this stream
def minKeyHash: Int = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index d69f2d9048055..3262e670c2030 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -283,7 +283,7 @@ private[spark] class ExternalSorter[K, V, C](
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is closed at the end of this process, and cannot be reused.
- def flush() = {
+ def flush(): Unit = {
val w = writer
writer = null
w.commitAndClose()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index b8de4ff9aa494..c52591b352340 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -109,7 +109,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
}
}
- override def iterator = new Iterator[(K, V)] {
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
var pos = -1
var nextPair: (K, V) = computeNextPair()
@@ -132,9 +132,9 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
}
}
- def hasNext = nextPair != null
+ def hasNext: Boolean = nextPair != null
- def next() = {
+ def next(): (K, V) = {
val pair = nextPair
nextPair = computeNextPair()
pair
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 4e363b74f4bef..c80057f95e0b2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -85,7 +85,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
protected var _bitset = new BitSet(_capacity)
- def getBitSet = _bitset
+ def getBitSet: BitSet = _bitset
// Init of the array in constructor (instead of in declaration) to work around a Scala compiler
// specialization bug that would generate two arrays (one for Object and one for specialized T).
@@ -183,7 +183,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
/** Return the value at the specified position. */
def getValue(pos: Int): T = _data(pos)
- def iterator = new Iterator[T] {
+ def iterator: Iterator[T] = new Iterator[T] {
var pos = nextPos(0)
override def hasNext: Boolean = pos != INVALID_POS
override def next(): T = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 2e1ef06cbc4e1..61e22642761f0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -46,7 +46,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
private var _oldValues: Array[V] = null
- override def size = _keySet.size
+ override def size: Int = _keySet.size
/** Get the value for a given key */
def apply(k: K): V = {
@@ -87,7 +87,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
}
}
- override def iterator = new Iterator[(K, V)] {
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
var pos = 0
var nextPair: (K, V) = computeNextPair()
@@ -103,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
}
}
- def hasNext = nextPair != null
+ def hasNext: Boolean = nextPair != null
- def next() = {
+ def next(): (K, V) = {
val pair = nextPair
nextPair = computeNextPair()
pair
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala
index c5268c0fae0ef..bdbca00a00622 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala
@@ -32,7 +32,7 @@ private[spark] object Utils {
*/
def takeOrdered[T](input: Iterator[T], num: Int)(implicit ord: Ordering[T]): Iterator[T] = {
val ordering = new GuavaOrdering[T] {
- override def compare(l: T, r: T) = ord.compare(l, r)
+ override def compare(l: T, r: T): Int = ord.compare(l, r)
}
collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
index 1d5467060623c..14b6ba4af489a 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
@@ -121,7 +121,7 @@ private[spark] object FileAppender extends Logging {
val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT)
val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT)
- def createTimeBasedAppender() = {
+ def createTimeBasedAppender(): FileAppender = {
val validatedParams: Option[(Long, String)] = rollingInterval match {
case "daily" =>
logInfo(s"Rolling executor logs enabled for $file with daily rolling")
@@ -149,7 +149,7 @@ private[spark] object FileAppender extends Logging {
}
}
- def createSizeBasedAppender() = {
+ def createSizeBasedAppender(): FileAppender = {
rollingSizeBytes match {
case IntParam(bytes) =>
logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes")
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 76e7a2760bcd1..786b97ad7b9ec 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -105,7 +105,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
private val rng: Random = new XORShiftRandom
- override def setSeed(seed: Long) = rng.setSeed(seed)
+ override def setSeed(seed: Long): Unit = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
if (ub - lb <= 0.0) {
@@ -131,7 +131,7 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
def cloneComplement(): BernoulliCellSampler[T] =
new BernoulliCellSampler[T](lb, ub, !complement)
- override def clone = new BernoulliCellSampler[T](lb, ub, complement)
+ override def clone: BernoulliCellSampler[T] = new BernoulliCellSampler[T](lb, ub, complement)
}
@@ -153,7 +153,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
private val rng: Random = RandomSampler.newDefaultRNG
- override def setSeed(seed: Long) = rng.setSeed(seed)
+ override def setSeed(seed: Long): Unit = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
if (fraction <= 0.0) {
@@ -167,7 +167,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
}
}
- override def clone = new BernoulliSampler[T](fraction)
+ override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction)
}
@@ -209,7 +209,7 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T]
}
}
- override def clone = new PoissonSampler[T](fraction)
+ override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
}
@@ -228,15 +228,18 @@ class GapSamplingIterator[T: ClassTag](
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
- case `arrayClass` => ((n: Int) => { data = data.drop(n) })
- case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
- case _ => ((n: Int) => {
+ case `arrayClass` =>
+ (n: Int) => { data = data.drop(n) }
+ case `arrayBufferClass` =>
+ (n: Int) => { data = data.drop(n) }
+ case _ =>
+ (n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
- })
+ }
}
}
@@ -244,21 +247,21 @@ class GapSamplingIterator[T: ClassTag](
override def next(): T = {
val r = data.next()
- advance
+ advance()
r
}
private val lnq = math.log1p(-f)
/** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
- private def advance: Unit = {
+ private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / lnq).toInt
iterDrop(k)
}
/** advance to first sample as part of object construction. */
- advance
+ advance()
// Attempting to invoke this closer to the top with other object initialization
// was causing it to break in strange ways, so I'm invoking it last, which seems to
// work reliably.
@@ -279,15 +282,18 @@ class GapSamplingReplacementIterator[T: ClassTag](
val arrayClass = Array.empty[T].iterator.getClass
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
data.getClass match {
- case `arrayClass` => ((n: Int) => { data = data.drop(n) })
- case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
- case _ => ((n: Int) => {
+ case `arrayClass` =>
+ (n: Int) => { data = data.drop(n) }
+ case `arrayBufferClass` =>
+ (n: Int) => { data = data.drop(n) }
+ case _ =>
+ (n: Int) => {
var j = 0
while (j < n && data.hasNext) {
data.next()
j += 1
}
- })
+ }
}
}
@@ -300,7 +306,7 @@ class GapSamplingReplacementIterator[T: ClassTag](
override def next(): T = {
val r = v
rep -= 1
- if (rep <= 0) advance
+ if (rep <= 0) advance()
r
}
@@ -309,7 +315,7 @@ class GapSamplingReplacementIterator[T: ClassTag](
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
* q is the probabililty of Poisson(0; f)
*/
- private def advance: Unit = {
+ private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
val k = (math.log(u) / (-f)).toInt
iterDrop(k)
@@ -343,7 +349,7 @@ class GapSamplingReplacementIterator[T: ClassTag](
}
/** advance to first sample as part of object construction. */
- advance
+ advance()
// Attempting to invoke this closer to the top with other object initialization
// was causing it to break in strange ways, so I'm invoking it last, which seems to
// work reliably.
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
index 2ae308dacf1ae..9e29bf9d61f17 100644
--- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -311,7 +311,7 @@ private[random] class AcceptanceResult(var numItems: Long = 0L, var numAccepted:
var acceptBound: Double = Double.NaN // upper bound for accepting item instantly
var waitListBound: Double = Double.NaN // upper bound for adding item to waitlist
- def areBoundsEmpty = acceptBound.isNaN || waitListBound.isNaN
+ def areBoundsEmpty: Boolean = acceptBound.isNaN || waitListBound.isNaN
def merge(other: Option[AcceptanceResult]): Unit = {
if (other.isDefined) {
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 467b890fb4bb9..c4a7b4441c85c 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -83,7 +83,7 @@ private[spark] object XORShiftRandom {
* @return Map of execution times for {@link java.util.Random java.util.Random}
* and XORShift
*/
- def benchmark(numIters: Int) = {
+ def benchmark(numIters: Int): Map[String, Long] = {
val seed = 1L
val million = 1e6.toInt
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8ec54360ca42a..d4b5bb519157c 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -24,11 +24,12 @@
import java.util.*;
import java.util.concurrent.*;
-import org.apache.spark.input.PortableDataStream;
+import scala.collection.JavaConversions;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
@@ -51,8 +52,11 @@
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.StatCounter;
@@ -726,8 +730,8 @@ public void javaDoubleRDDHistoGram() {
Tuple2 results = rdd.histogram(2);
double[] expected_buckets = {1.0, 2.5, 4.0};
long[] expected_counts = {2, 2};
- Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
- Assert.assertArrayEquals(expected_counts, results._2);
+ Assert.assertArrayEquals(expected_buckets, results._1(), 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2());
// Test with provided buckets
long[] histogram = rdd.histogram(expected_buckets);
Assert.assertArrayEquals(expected_counts, histogram);
@@ -1424,6 +1428,49 @@ public void checkpointAndRestore() {
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}
+ @Test
+ public void combineByKey() {
+ JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6));
+ Function keyFunction = new Function() {
+ @Override
+ public Integer call(Integer v1) throws Exception {
+ return v1 % 3;
+ }
+ };
+ Function createCombinerFunction = new Function() {
+ @Override
+ public Integer call(Integer v1) throws Exception {
+ return v1;
+ }
+ };
+
+ Function2 mergeValueFunction = new Function2() {
+ @Override
+ public Integer call(Integer v1, Integer v2) throws Exception {
+ return v1 + v2;
+ }
+ };
+
+ JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction)
+ .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction);
+ Map results = combinedRDD.collectAsMap();
+ ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7);
+ Assert.assertEquals(expected, results);
+
+ Partitioner defaultPartitioner = Partitioner.defaultPartitioner(
+ combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList()));
+ combinedRDD = originalRDD.keyBy(keyFunction)
+ .combineByKey(
+ createCombinerFunction,
+ mergeValueFunction,
+ mergeValueFunction,
+ defaultPartitioner,
+ false,
+ new KryoSerializer(new SparkConf()));
+ results = combinedRDD.collectAsMap();
+ Assert.assertEquals(expected, results);
+ }
+
@SuppressWarnings("unchecked")
@Test
public void mapOnPairRDD() {
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 3b10b3a042317..32abc65385267 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -33,8 +33,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
override def beforeEach() {
super.beforeEach()
- checkpointDir = File.createTempFile("temp", "")
- checkpointDir.deleteOnExit()
+ checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
checkpointDir.delete()
sc = new SparkContext("local", "test")
sc.setCheckpointDir(checkpointDir.toString)
diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
index 43fbd3ff3f756..62cb7649c0284 100644
--- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala
@@ -21,6 +21,8 @@ import java.io.File
import org.scalatest.FunSuite
+import org.apache.spark.util.Utils
+
class SecurityManagerSuite extends FunSuite {
test("set security with conf") {
@@ -160,8 +162,7 @@ class SecurityManagerSuite extends FunSuite {
}
test("ssl off setup") {
- val file = File.createTempFile("SSLOptionsSuite", "conf")
- file.deleteOnExit()
+ val file = File.createTempFile("SSLOptionsSuite", "conf", Utils.createTempDir())
System.setProperty("spark.ssl.configFile", file.getAbsolutePath)
val conf = new SparkConf()
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index b8e3e83b5a47b..b07c4d93db4e6 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -79,13 +79,14 @@ class SparkContextSuite extends FunSuite with LocalSparkContext {
val byteArray2 = converter.convert(bytesWritable)
assert(byteArray2.length === 0)
}
-
+
test("addFile works") {
- val file1 = File.createTempFile("someprefix1", "somesuffix1")
+ val dir = Utils.createTempDir()
+
+ val file1 = File.createTempFile("someprefix1", "somesuffix1", dir)
val absolutePath1 = file1.getAbsolutePath
- val pluto = Utils.createTempDir()
- val file2 = File.createTempFile("someprefix2", "somesuffix2", pluto)
+ val file2 = File.createTempFile("someprefix2", "somesuffix2", dir)
val relativePath = file2.getParent + "/../" + file2.getParentFile.getName + "/" + file2.getName
val absolutePath2 = file2.getAbsolutePath
@@ -129,7 +130,7 @@ class SparkContextSuite extends FunSuite with LocalSparkContext {
sc.stop()
}
}
-
+
test("addFile recursive works") {
val pluto = Utils.createTempDir()
val neptune = Utils.createTempDir(pluto.getAbsolutePath)
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 46d745c4ecbfa..4561e5b8e9663 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -402,8 +402,10 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
val archives = "file:/archive1,archive2" // spark.yarn.dist.archives
val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles
+ val tmpDir = Utils.createTempDir()
+
// Test jars and files
- val f1 = File.createTempFile("test-submit-jars-files", "")
+ val f1 = File.createTempFile("test-submit-jars-files", "", tmpDir)
val writer1 = new PrintWriter(f1)
writer1.println("spark.jars " + jars)
writer1.println("spark.files " + files)
@@ -420,7 +422,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
sysProps("spark.files") should be(Utils.resolveURIs(files))
// Test files and archives (Yarn)
- val f2 = File.createTempFile("test-submit-files-archives", "")
+ val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir)
val writer2 = new PrintWriter(f2)
writer2.println("spark.yarn.dist.files " + files)
writer2.println("spark.yarn.dist.archives " + archives)
@@ -437,7 +439,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties
sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives))
// Test python files
- val f3 = File.createTempFile("test-submit-python-files", "")
+ val f3 = File.createTempFile("test-submit-python-files", "", tmpDir)
val writer3 = new PrintWriter(f3)
writer3.println("spark.submit.pyFiles " + pyFiles)
writer3.close()
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 1a9a0e857e546..aea76c1adcc09 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -22,7 +22,6 @@ import java.io.File
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat}
-import org.apache.spark._
import org.scalatest.FunSuite
import scala.collection.Map
@@ -30,6 +29,9 @@ import scala.language.postfixOps
import scala.sys.process._
import scala.util.Try
+import org.apache.spark._
+import org.apache.spark.util.Utils
+
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
@@ -141,7 +143,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
// make sure symlinks were created
assert(pipedLs.length > 0)
// clean up top level tasks directory
- new File("tasks").delete()
+ Utils.deleteRecursively(new File("tasks"))
} else {
assert(true)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala
new file mode 100644
index 0000000000000..3fa0115e68259
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtilsSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import org.mockito.Mockito._
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.{SparkConf, SparkContext}
+
+class MemoryUtilsSuite extends FunSuite with MockitoSugar {
+ test("MesosMemoryUtils should always override memoryOverhead when it's set") {
+ val sparkConf = new SparkConf
+
+ val sc = mock[SparkContext]
+ when(sc.conf).thenReturn(sparkConf)
+
+ // 384 > sc.executorMemory * 0.1 => 512 + 384 = 896
+ when(sc.executorMemory).thenReturn(512)
+ assert(MemoryUtils.calculateTotalMemory(sc) === 896)
+
+ // 384 < sc.executorMemory * 0.1 => 4096 + (4096 * 0.1) = 4505.6
+ when(sc.executorMemory).thenReturn(4096)
+ assert(MemoryUtils.calculateTotalMemory(sc) === 4505)
+
+ // set memoryOverhead
+ sparkConf.set("spark.mesos.executor.memoryOverhead", "100")
+ assert(MemoryUtils.calculateTotalMemory(sc) === 4196)
+ sparkConf.set("spark.mesos.executor.memoryOverhead", "400")
+ assert(MemoryUtils.calculateTotalMemory(sc) === 4496)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
similarity index 98%
rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
index afbaa9ade811f..f1a4380d349b3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.mesos
+package org.apache.spark.scheduler.cluster.mesos
import java.nio.ByteBuffer
import java.util
@@ -24,21 +24,20 @@ import java.util.Collections
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.mesos.SchedulerDriver
-import org.apache.mesos.Protos._
import org.apache.mesos.Protos.Value.Scalar
-import org.mockito.Mockito._
+import org.apache.mesos.Protos._
+import org.apache.mesos.SchedulerDriver
import org.mockito.Matchers._
+import org.mockito.Mockito._
import org.mockito.{ArgumentCaptor, Matchers}
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext}
import org.apache.spark.executor.MesosExecutorBackend
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded,
TaskDescription, TaskSchedulerImpl, WorkerOffer}
-import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.scheduler.cluster.mesos.{MesosSchedulerBackend, MemoryUtils}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext}
class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with MockitoSugar {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala
similarity index 92%
rename from core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala
index 86a42a7398e4d..eebcba40f8a1c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosTaskLaunchDataSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala
@@ -15,14 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.mesos
+package org.apache.spark.scheduler.cluster.mesos
import java.nio.ByteBuffer
import org.scalatest.FunSuite
-import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData
-
class MesosTaskLaunchDataSuite extends FunSuite {
test("serialize and deserialize data must be same") {
val serializedTask = ByteBuffer.allocate(40)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index c21c92b63ad13..78bbc4ec2c620 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -16,16 +16,18 @@
*/
package org.apache.spark.storage
-import org.scalatest.FunSuite
import java.io.File
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
class BlockObjectWriterSuite extends FunSuite {
test("verify write metrics") {
- val file = new File("somefile")
- file.deleteOnExit()
+ val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
@@ -47,8 +49,7 @@ class BlockObjectWriterSuite extends FunSuite {
}
test("verify write metrics on revert") {
- val file = new File("somefile")
- file.deleteOnExit()
+ val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
@@ -71,8 +72,7 @@ class BlockObjectWriterSuite extends FunSuite {
}
test("Reopening a closed block writer") {
- val file = new File("somefile")
- file.deleteOnExit()
+ val file = new File(Utils.createTempDir(), "somefile")
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index 4dc5b6103db74..43b6a405cb68c 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolic
class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
- val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile
+ val testFile = new File(Utils.createTempDir(), "FileAppenderSuite-test").getAbsoluteFile
before {
cleanup()
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index b91428efadfd0..5d93086082189 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -122,7 +122,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties {
test("reading offset bytes of a file") {
val tmpDir2 = Utils.createTempDir()
- tmpDir2.deleteOnExit()
val f1Path = tmpDir2 + "/f1"
val f1 = new FileOutputStream(f1Path)
f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8))
@@ -151,7 +150,6 @@ class UtilsSuite extends FunSuite with ResetSystemProperties {
test("reading offset bytes across multiple files") {
val tmpDir = Utils.createTempDir()
- tmpDir.deleteOnExit()
val files = (1 to 3).map(i => new File(tmpDir, i.toString))
Files.write("0123456789", files(0), UTF_8)
Files.write("abcdefghij", files(1), UTF_8)
@@ -357,7 +355,8 @@ class UtilsSuite extends FunSuite with ResetSystemProperties {
}
test("loading properties from file") {
- val outFile = File.createTempFile("test-load-spark-properties", "test")
+ val tmpDir = Utils.createTempDir()
+ val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir)
try {
System.setProperty("spark.test.fileNameLoadB", "2")
Files.write("spark.test.fileNameLoadA true\n" +
@@ -370,7 +369,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties {
assert(sparkConf.getBoolean("spark.test.fileNameLoadA", false) === true)
assert(sparkConf.getInt("spark.test.fileNameLoadB", 1) === 2)
} finally {
- outFile.delete()
+ Utils.deleteRecursively(tmpDir)
}
}
diff --git a/dev/lint-python b/dev/lint-python
index 772f856154ae0..fded654893a7c 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -19,43 +19,53 @@
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")"
-PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt"
+PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/"
+PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt"
cd "$SPARK_ROOT_DIR"
+# compileall: https://docs.python.org/2/library/compileall.html
+python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH"
+compile_status="${PIPESTATUS[0]}"
+
# Get pep8 at runtime so that we don't rely on it being installed on the build server.
#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162
#+ TODOs:
-#+ - Dynamically determine latest release version of pep8 and use that.
-#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?))
+#+ - Download pep8 from PyPI. It's more "official".
PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py"
-PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py"
-PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/"
+PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.6.2/pep8.py"
+# if [ ! -e "$PEP8_SCRIPT_PATH" ]; then
curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH"
-curl_status=$?
+curl_status="$?"
-if [ $curl_status -ne 0 ]; then
+if [ "$curl_status" -ne 0 ]; then
echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"."
- exit $curl_status
+ exit "$curl_status"
fi
-
+# fi
# There is no need to write this output to a file
#+ first, but we do so so that the check status can
#+ be output before the report, like with the
#+ scalastyle and RAT checks.
-python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH"
-pep8_status=${PIPESTATUS[0]} #$?
+python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH"
+pep8_status="${PIPESTATUS[0]}"
+
+if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then
+ lint_status=0
+else
+ lint_status=1
+fi
-if [ $pep8_status -ne 0 ]; then
- echo "PEP 8 checks failed."
- cat "$PEP8_REPORT_PATH"
+if [ "$lint_status" -ne 0 ]; then
+ echo "Python lint checks failed."
+ cat "$PYTHON_LINT_REPORT_PATH"
else
- echo "PEP 8 checks passed."
+ echo "Python lint checks passed."
fi
-rm "$PEP8_REPORT_PATH"
rm "$PEP8_SCRIPT_PATH"
+rm "$PYTHON_LINT_REPORT_PATH"
-exit $pep8_status
+exit "$lint_status"
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 6a849e4f77207..5f4000e83925c 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -49,6 +49,21 @@ SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout
+# Array to capture all tests to run on the pull request. These tests are held under the
+#+ dev/tests/ directory.
+#
+# To write a PR test:
+#+ * the file must reside within the dev/tests directory
+#+ * be an executable bash script
+#+ * accept two arguments on the command line, the first being the Github PR long commit
+#+ hash and the second the Github SHA1 hash
+#+ * and, lastly, return string output to be included in the pr message output that will
+#+ be posted to Github
+PR_TESTS=(
+ "pr_merge_ability"
+ "pr_public_classes"
+)
+
function post_message () {
local message=$1
local data="{\"body\": \"$message\"}"
@@ -131,48 +146,22 @@ function send_archived_logs () {
fi
}
-
-# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
-#+ and not anything else added to master since the PR was branched.
-
-# check PR merge-ability and check for new public classes
-{
- if [ "$sha1" == "$ghprbActualCommit" ]; then
- merge_note=" * This patch **does not merge cleanly**."
- else
- merge_note=" * This patch merges cleanly."
+# Environment variable to capture PR test output
+pr_message=""
+
+# Run pull request tests
+for t in "${PR_TESTS[@]}"; do
+ this_test="${FWDIR}/dev/tests/${t}.sh"
+ # Ensure the test is a file and is executable
+ if [ -x "$this_test" ]; then
+ echo "ghprb: $ghprbActualCommit sha1: $sha1"
+ this_mssg="`bash \"${this_test}\" \"${ghprbActualCommit}\" \"${sha1}\" 2>/dev/null`"
+ # Check if this is the merge test as we submit that note *before* and *after*
+ # the tests run
+ [ "$t" == "pr_merge_ability" ] && merge_note="${this_mssg}"
+ pr_message="${pr_message}\n${this_mssg}"
fi
-
- source_files=$(
- git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
- | grep -v -e "\/test" `# ignore files in test directories` \
- | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
- | tr "\n" " "
- )
- new_public_classes=$(
- git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \
- | grep "^\+" `# filter in only added lines` \
- | sed -r -e "s/^\+//g" `# remove the leading +` \
- | grep -e "trait " -e "class " `# filter in lines with these key words` \
- | grep -e "{" -e "(" `# filter in lines with these key words, too` \
- | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
- | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
- | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
- | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
- | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
- | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
- | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
- | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
- | tr -d "\n" `# remove actual LF characters`
- )
-
- if [ -z "$new_public_classes" ]; then
- public_classes_note=" * This patch adds no public classes."
- else
- public_classes_note=" * This patch adds the following public classes _(experimental)_:"
- public_classes_note="${public_classes_note}\n${new_public_classes}"
- fi
-}
+done
# post start message
{
@@ -181,7 +170,6 @@ function send_archived_logs () {
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
start_message="${start_message}\n${merge_note}"
- # start_message="${start_message}\n${public_classes_note}"
post_message "$start_message"
}
@@ -234,8 +222,7 @@ function send_archived_logs () {
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
result_message="${result_message}\n${test_result_note}"
- result_message="${result_message}\n${merge_note}"
- result_message="${result_message}\n${public_classes_note}"
+ result_message="${result_message}\n${pr_message}"
post_message "$result_message"
}
diff --git a/dev/tests/pr_merge_ability.sh b/dev/tests/pr_merge_ability.sh
new file mode 100755
index 0000000000000..d9a347fe24a8c
--- /dev/null
+++ b/dev/tests/pr_merge_ability.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# This script follows the base format for testing pull requests against
+# another branch and returning results to be published. More details can be
+# found at dev/run-tests-jenkins.
+#
+# Arg1: The Github Pull Request Actual Commit
+#+ known as `ghprbActualCommit` in `run-tests-jenkins`
+# Arg2: The SHA1 hash
+#+ known as `sha1` in `run-tests-jenkins`
+#
+
+ghprbActualCommit="$1"
+sha1="$2"
+
+# check PR merge-ability
+if [ "${sha1}" == "${ghprbActualCommit}" ]; then
+ echo " * This patch **does not merge cleanly**."
+else
+ echo " * This patch merges cleanly."
+fi
diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh
new file mode 100755
index 0000000000000..927295b88c963
--- /dev/null
+++ b/dev/tests/pr_public_classes.sh
@@ -0,0 +1,65 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# This script follows the base format for testing pull requests against
+# another branch and returning results to be published. More details can be
+# found at dev/run-tests-jenkins.
+#
+# Arg1: The Github Pull Request Actual Commit
+#+ known as `ghprbActualCommit` in `run-tests-jenkins`
+# Arg2: The SHA1 hash
+#+ known as `sha1` in `run-tests-jenkins`
+#
+
+# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
+#+ and not anything else added to master since the PR was branched.
+
+ghprbActualCommit="$1"
+sha1="$2"
+
+source_files=$(
+ git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
+ | grep -v -e "\/test" `# ignore files in test directories` \
+ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
+ | tr "\n" " "
+)
+new_public_classes=$(
+ git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \
+ | grep "^\+" `# filter in only added lines` \
+ | sed -r -e "s/^\+//g" `# remove the leading +` \
+ | grep -e "trait " -e "class " `# filter in lines with these key words` \
+ | grep -e "{" -e "(" `# filter in lines with these key words, too` \
+ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
+ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
+ | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
+ | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
+ | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
+ | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
+ | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
+ | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
+ | tr -d "\n" `# remove actual LF characters`
+)
+
+if [ -z "$new_public_classes" ]; then
+ echo " * This patch adds no public classes."
+else
+ public_classes_note=" * This patch adds the following public classes _(experimental)_:"
+ echo "${public_classes_note}\n${new_public_classes}"
+fi
diff --git a/docs/_config.yml b/docs/_config.yml
index 0652927a8ce9b..b22b627f09007 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -14,8 +14,8 @@ include:
# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.3.0-SNAPSHOT
-SPARK_VERSION_SHORT: 1.3.0
+SPARK_VERSION: 1.4.0-SNAPSHOT
+SPARK_VERSION_SHORT: 1.4.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
MESOS_VERSION: 0.21.0
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 8c9a1e1262d8f..7f60f82b966fe 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -5,7 +5,7 @@ title: Running Spark on EC2
The `spark-ec2` script, located in Spark's `ec2` directory, allows you
to launch, manage and shut down Spark clusters on Amazon EC2. It automatically
-sets up Spark, Shark and HDFS on the cluster for you. This guide describes
+sets up Spark and HDFS on the cluster for you. This guide describes
how to use `spark-ec2` to launch clusters, how to run jobs on them, and how
to shut them down. It assumes you've already signed up for an EC2 account
on the [Amazon Web Services site](http://aws.amazon.com/).
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index 5295e351dd711..963e88a3e1d8f 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -14,8 +14,7 @@ runs an independent set of executor processes. The cluster managers that Spark r
facilities for [scheduling across applications](#scheduling-across-applications). Second,
_within_ each Spark application, multiple "jobs" (Spark actions) may be running concurrently
if they were submitted by different threads. This is common if your application is serving requests
-over the network; for example, the [Shark](http://shark.cs.berkeley.edu) server works this way. Spark
-includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext.
+over the network. Spark includes a [fair scheduler](#scheduling-within-an-application) to schedule resources within each SparkContext.
# Scheduling Across Applications
@@ -52,8 +51,7 @@ an application to gain back cores on one node when it has work to do. To use thi
Note that none of the modes currently provide memory sharing across applications. If you would like to share
data this way, we recommend running a single server application that can serve multiple requests by querying
-the same RDDs. For example, the [Shark](http://shark.cs.berkeley.edu) JDBC server works this way for SQL
-queries. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will
+the same RDDs. In future releases, in-memory storage systems such as [Tachyon](http://tachyon-project.org) will
provide another approach to share RDDs.
## Dynamic Resource Allocation
diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md
index fe6c1bf7bfd99..4f2a2f71048f7 100644
--- a/docs/mllib-data-types.md
+++ b/docs/mllib-data-types.md
@@ -78,13 +78,13 @@ MLlib recognizes the following types as dense vectors:
and the following as sparse vectors:
-* MLlib's [`SparseVector`](api/python/pyspark.mllib.linalg.SparseVector-class.html).
+* MLlib's [`SparseVector`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.SparseVector).
* SciPy's
[`csc_matrix`](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix)
with a single column
We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented
-in [`Vectors`](api/python/pyspark.mllib.linalg.Vectors-class.html) to create sparse vectors.
+in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors.
{% highlight python %}
import numpy as np
@@ -151,7 +151,7 @@ LabeledPoint neg = new LabeledPoint(1.0, Vectors.sparse(3, new int[] {0, 2}, new
A labeled point is represented by
-[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html).
+[`LabeledPoint`](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint).
{% highlight python %}
from pyspark.mllib.linalg import SparseVector
@@ -211,7 +211,7 @@ JavaRDD examples =
-[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.util.MLUtils-class.html) reads training
+[`MLUtils.loadLibSVMFile`](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) reads training
examples stored in LIBSVM format.
{% highlight python %}
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index cbfb682609af3..7521fb14a7bd6 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -464,8 +464,8 @@ first one being the training dataset and the second being the validation dataset
The training is stopped when the improvement in the validation error is not more than a certain tolerance
(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
-and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
-iterations.
+and the user is advised to set a large enough negative tolerance and examine the validation curve using `evaluateEachIteration`
+(which gives the error or loss per iteration) to tune the number of iterations.
### Examples
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 55b8f2ce6c364..a83472f5be52e 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -106,11 +106,11 @@ NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath");
-[NaiveBayes](api/python/pyspark.mllib.classification.NaiveBayes-class.html) implements multinomial
+[NaiveBayes](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayes) implements multinomial
naive Bayes. It takes an RDD of
-[LabeledPoint](api/python/pyspark.mllib.regression.LabeledPoint-class.html) and an optionally
+[LabeledPoint](api/python/pyspark.mllib.html#pyspark.mllib.regression.LabeledPoint) and an optionally
smoothing parameter `lambda` as input, and output a
-[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be
+[NaiveBayesModel](api/python/pyspark.mllib.html#pyspark.mllib.classification.NaiveBayesModel), which can be
used for evaluation and prediction.
Note that the Python API does not yet support model save/load but will in the future.
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index ca8c29218f52d..887eae7f4f07b 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -81,8 +81,8 @@ System.out.println(summary.numNonzeros()); // number of nonzeros in each column
-[`colStats()`](api/python/pyspark.mllib.stat.Statistics-class.html#colStats) returns an instance of
-[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.stat.MultivariateStatisticalSummary-class.html),
+[`colStats()`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics.colStats) returns an instance of
+[`MultivariateStatisticalSummary`](api/python/pyspark.mllib.html#pyspark.mllib.stat.MultivariateStatisticalSummary),
which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the
total count.
@@ -169,7 +169,7 @@ Matrix correlMatrix = Statistics.corr(data.rdd(), "pearson");
-[`Statistics`](api/python/pyspark.mllib.stat.Statistics-class.html) provides methods to
+[`Statistics`](api/python/pyspark.mllib.html#pyspark.mllib.stat.Statistics) provides methods to
calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or
an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively.
@@ -258,7 +258,7 @@ JavaPairRDD exactSample = data.sampleByKeyExact(false, fractions);
{% endhighlight %}
-[`sampleByKey()`](api/python/pyspark.rdd.RDD-class.html#sampleByKey) allows users to
+[`sampleByKey()`](api/python/pyspark.html#pyspark.RDD.sampleByKey) allows users to
sample approximately $\lceil f_k \cdot n_k \rceil \, \forall k \in K$ items, where $f_k$ is the
desired fraction for key $k$, $n_k$ is the number of key-value pairs for key $k$, and $K$ is the
set of keys.
@@ -476,7 +476,7 @@ JavaDoubleRDD v = u.map(
-[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory
+[`RandomRDDs`](api/python/pyspark.mllib.html#pyspark.mllib.random.RandomRDDs) provides factory
methods to generate random double RDDs or vector RDDs.
The following example generates a random double RDD, whose values follows the standard normal
distribution `N(0, 1)`, and then map it to `N(1, 4)`.
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index eda3a95426182..5fe832b6fa100 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -142,8 +142,8 @@ JavaSparkContext sc = new JavaSparkContext(conf);
-The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.context.SparkContext-class.html) object, which tells Spark
-how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.conf.SparkConf-class.html) object
+The first thing a Spark program must do is to create a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object, which tells Spark
+how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/python/pyspark.html#pyspark.SparkConf) object
that contains information about your application.
{% highlight python %}
@@ -912,7 +912,7 @@ The following table lists some of the common transformations supported by Spark.
RDD API doc
([Scala](api/scala/index.html#org.apache.spark.rdd.RDD),
[Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html),
- [Python](api/python/pyspark.rdd.RDD-class.html))
+ [Python](api/python/pyspark.html#pyspark.RDD))
and pair RDD functions doc
([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions),
[Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html))
@@ -1025,7 +1025,7 @@ The following table lists some of the common actions supported by Spark. Refer t
RDD API doc
([Scala](api/scala/index.html#org.apache.spark.rdd.RDD),
[Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html),
- [Python](api/python/pyspark.rdd.RDD-class.html))
+ [Python](api/python/pyspark.html#pyspark.RDD))
and pair RDD functions doc
([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions),
[Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html))
@@ -1105,7 +1105,7 @@ replicate it across nodes, or store it off-heap in [Tachyon](http://tachyon-proj
These levels are set by passing a
`StorageLevel` object ([Scala](api/scala/index.html#org.apache.spark.storage.StorageLevel),
[Java](api/java/index.html?org/apache/spark/storage/StorageLevel.html),
-[Python](api/python/pyspark.storagelevel.StorageLevel-class.html))
+[Python](api/python/pyspark.html#pyspark.StorageLevel))
to `persist()`. The `cache()` method is a shorthand for using the default storage level,
which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The full set of
storage levels is:
@@ -1374,7 +1374,7 @@ scala> accum.value
{% endhighlight %}
While this code used the built-in support for accumulators of type Int, programmers can also
-create their own types by subclassing [AccumulatorParam](api/python/pyspark.accumulators.AccumulatorParam-class.html).
+create their own types by subclassing [AccumulatorParam](api/python/pyspark.html#pyspark.AccumulatorParam).
The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data
type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class
representing mathematical vectors, we could write:
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 6a9d304501dc0..c984639bd34cf 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -224,11 +224,9 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.executor.memoryOverhead
executor memory * 0.10, with minimum of 384
- This value is an additive for spark.executor.memory, specified in MB,
- which is used to calculate the total Mesos task memory. A value of 384
- implies a 384MB overhead. Additionally, there is a hard-coded 10% minimum
- overhead. The final overhead will be the larger of either
- `spark.mesos.executor.memoryOverhead` or 10% of `spark.executor.memory`.
+ The amount of additional memory, specified in MB, to be allocated per executor. By default,
+ the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set,
+ the final overhead will be this value.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 2cbb4c967eb81..6a333fdb562a7 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -56,7 +56,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
The entry point into all relational functionality in Spark is the
-[`SQLContext`](api/python/pyspark.sql.SQLContext-class.html) class, or one
+[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one
of its decedents. To create a basic `SQLContext`, all you need is a SparkContext.
{% highlight python %}
@@ -509,8 +509,11 @@ val people = sc.textFile("examples/src/main/resources/people.txt")
// The schema is encoded in a string
val schemaString = "name age"
-// Import Spark SQL data types and Row.
-import org.apache.spark.sql._
+// Import Row.
+import org.apache.spark.sql.Row;
+
+// Import Spark SQL data types
+import org.apache.spark.sql.types.{StructType,StructField,StringType};
// Generate the schema based on the string of schema
val schema =
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 57b074778f2b0..3ecbf2308cd44 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -133,10 +133,10 @@ The master URL passed to Spark can be in one of the following formats:
Or, for a Mesos cluster using ZooKeeper, use mesos://zk://....
yarn-client
Connect to a YARN cluster in
-client mode. The cluster location will be found based on the HADOOP_CONF_DIR variable.
+client mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
yarn-cluster
Connect to a YARN cluster in
-cluster mode. The cluster location will be found based on HADOOP_CONF_DIR.
+cluster mode. The cluster location will be found based on the HADOOP_CONF_DIR or YARN_CONF_DIR variable.
diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
index 0857657152ec7..4f3e8da809f7f 100644
--- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
+++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
@@ -25,7 +25,6 @@ export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}"
export SPARK_LOCAL_DIRS="{{spark_local_dirs}}"
export MODULES="{{modules}}"
export SPARK_VERSION="{{spark_version}}"
-export SHARK_VERSION="{{shark_version}}"
export TACHYON_VERSION="{{tachyon_version}}"
export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}"
export SWAP_MB="{{swap}}"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index f848874b0c775..c467cd08ed742 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -1159,8 +1159,8 @@ def real_main():
if EC2_INSTANCE_TYPES[opts.instance_type] != \
EC2_INSTANCE_TYPES[opts.master_instance_type]:
print >> stderr, \
- "Error: spark-ec2 currently does not support having a master and slaves with " + \
- "different AMI virtualization types."
+ "Error: spark-ec2 currently does not support having a master and slaves " + \
+ "with different AMI virtualization types."
print >> stderr, "master instance virtualization type: {t}".format(
t=EC2_INSTANCE_TYPES[opts.master_instance_type])
print >> stderr, "slave instance virtualization type: {t}".format(
diff --git a/examples/pom.xml b/examples/pom.xml
index 994071d94d0ad..7e93f0eec0b91 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../pom.xml
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 91a0a860d6c71..1f4ca4fbe7778 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -175,7 +175,8 @@ object MovieLensALS {
}
/** Compute RMSE (Root Mean Squared Error). */
- def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean) = {
+ def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean)
+ : Double = {
def mapPredictedRating(r: Double) = if (implicitPrefs) math.max(math.min(r, 1.0), 0.0) else r
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
index 91c9772744f18..9f22d40c15f3f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
@@ -116,7 +116,7 @@ object PowerIterationClusteringExample {
sc.stop()
}
- def generateCircle(radius: Double, n: Int) = {
+ def generateCircle(radius: Double, n: Int): Seq[(Double, Double)] = {
Seq.tabulate(n) { i =>
val theta = 2.0 * math.Pi * i / n
(radius * math.cos(theta), radius * math.sin(theta))
@@ -147,7 +147,7 @@ object PowerIterationClusteringExample {
/**
* Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel
*/
- def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = {
+ def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double): Double = {
val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma)
val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0)
val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 96c2787e35cd0..67907bbfb6d1b 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 172d447b77cda..8df7edbdcad33 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 5109b8ed87524..0b79f47647f6b 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 369856187a244..f695cff410a18 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index fc53c23abda85..3cd960d1fd1d4 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -25,16 +25,15 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
-import com.google.common.io.Files
import kafka.serializer.StringDecoder
import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
+import org.apache.spark.util.Utils
class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
@@ -60,7 +59,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
)
ssc = new StreamingContext(sparkConf, Milliseconds(500))
- tempDirectory = Files.createTempDir()
+ tempDirectory = Utils.createTempDir()
ssc.checkpoint(tempDirectory.getAbsolutePath)
}
@@ -68,10 +67,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
if (ssc != null) {
ssc.stop()
}
- if (tempDirectory != null && tempDirectory.exists()) {
- FileUtils.deleteDirectory(tempDirectory)
- tempDirectory = null
- }
+ Utils.deleteRecursively(tempDirectory)
tearDownKafka()
}
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index a344f000c5002..98f95a9a64fa0 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index e95853f005ce2..8b6a8959ac4cf 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 9b3475d7c3dc2..a50d378b34335 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index bc2f8be10c9ce..4351a8a12fe21 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 7e49a71907336..25847a1b33d9c 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index 6eb29af03f833..e14bbae4a9b6e 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index c0d534e185d7f..d38a3aa8256b7 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../pom.xml
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index dc8b4789c4b61..86f611d55aa8a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -113,7 +113,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* Collect the neighbor vertex attributes for each vertex.
*
* @note This function could be highly inefficient on power-law
- * graphs where high degree vertices may force a large ammount of
+ * graphs where high degree vertices may force a large amount of
* information to be collected to a single location.
*
* @param edgeDirection the direction along which to collect
@@ -187,7 +187,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
/**
* Join the vertices with an RDD and then apply a function from the
- * the vertex and RDD entry to a new vertex value. The input table
+ * vertex and RDD entry to a new vertex value. The input table
* should contain at most one entry for each vertex. If no entry is
* provided the map function is skipped and the old value is used.
*
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 5e55620147df8..01b013ff716fc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -78,8 +78,8 @@ object Pregel extends Logging {
*
* @param graph the input graph.
*
- * @param initialMsg the message each vertex will receive at the on
- * the first iteration
+ * @param initialMsg the message each vertex will receive at the first
+ * iteration
*
* @param maxIterations the maximum number of iterations to run for
*
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index e139959c3f5c1..570440ba4441f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -25,8 +25,8 @@ import org.apache.spark.graphx._
/**
* PageRank algorithm implementation. There are two implementations of PageRank implemented.
*
- * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number
- * of iterations:
+ * The first implementation uses the standalone [[Graph]] interface and runs PageRank
+ * for a fixed number of iterations:
* {{{
* var PR = Array.fill(n)( 1.0 )
* val oldPR = Array.fill(n)( 1.0 )
@@ -38,7 +38,7 @@ import org.apache.spark.graphx._
* }
* }}}
*
- * The second implementation uses the standalone [[Graph]] interface and runs PageRank until
+ * The second implementation uses the [[Pregel]] interface and runs PageRank until
* convergence:
*
* {{{
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index b61d9f0fbe5e4..8d15150458d26 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.graphx
import org.scalatest.FunSuite
-import com.google.common.io.Files
-
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
class GraphSuite extends FunSuite with LocalSparkContext {
@@ -369,8 +368,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
test("checkpoint") {
- val checkpointDir = Files.createTempDir()
- checkpointDir.deleteOnExit()
+ val checkpointDir = Utils.createTempDir()
withSpark { sc =>
sc.setCheckpointDir(checkpointDir.getAbsolutePath)
val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)}
diff --git a/launcher/pom.xml b/launcher/pom.xml
index ccbd9d0419a98..0fe2814135d88 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../pom.xml
diff --git a/mllib/pom.xml b/mllib/pom.xml
index a76704a8c2c59..4c183543e3fa8 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5bbcd2e080e07..c4a36103303a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
- * :: DeveloperAPI ::
+ * :: DeveloperApi ::
*
* Derives the output schema from the input schema and parameters.
* The schema describes the columns and types of the data.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 6131ba8832691..fc4e12773c46d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -41,7 +41,7 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
def getNumFeatures: Int = get(numFeatures)
/** @group setParam */
- def setNumFeatures(value: Int) = set(numFeatures, value)
+ def setNumFeatures(value: Int): this.type = set(numFeatures, value)
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 1a70322b4cace..5d660d1e151a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -138,3 +138,14 @@ private[ml] trait HasOutputCol extends Params {
/** @group getParam */
def getOutputCol: String = get(outputCol)
}
+
+private[ml] trait HasCheckpointInterval extends Params {
+ /**
+ * param for checkpoint interval
+ * @group param
+ */
+ val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+
+ /** @group getParam */
+ def getCheckpointInterval: Int = get(checkpointInterval)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index e3515ee81af3d..514b4ef98dc5b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.recommendation
import java.{util => ju}
+import java.io.IOException
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -26,6 +27,7 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.netlib.util.intW
import org.apache.spark.{Logging, Partitioner}
@@ -46,7 +48,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol {
+ with HasPredictionCol with HasCheckpointInterval {
/**
* Param for rank of the matrix factorization.
@@ -164,6 +166,7 @@ class ALSModel private[ml] (
itemFactors: RDD[(Int, Array[Float])])
extends Model[ALSModel] with ALSParams {
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
@@ -262,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
+ /** @group setParam */
+ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -274,6 +280,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
+ setCheckpointInterval(10)
override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
@@ -285,7 +292,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha), nonnegative = map(nonnegative))
+ alpha = map(alpha), nonnegative = map(nonnegative),
+ checkpointInterval = map(checkpointInterval))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
@@ -494,6 +502,7 @@ object ALS extends Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
@@ -521,6 +530,18 @@ object ALS extends Logging {
val seedGen = new XORShiftRandom(seed)
var userFactors = initialize(userInBlocks, rank, seedGen.nextLong())
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
+ var previousCheckpointFile: Option[String] = None
+ val shouldCheckpoint: Int => Boolean = (iter) =>
+ sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+ val deletePreviousCheckpointFile: () => Unit = () =>
+ previousCheckpointFile.foreach { file =>
+ try {
+ FileSystem.get(sc.hadoopConfiguration).delete(new Path(file), true)
+ } catch {
+ case e: IOException =>
+ logWarning(s"Cannot delete checkpoint file $file:", e)
+ }
+ }
if (implicitPrefs) {
for (iter <- 1 to maxIter) {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
@@ -528,19 +549,30 @@ object ALS extends Logging {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
- if (sc.checkpointDir.isDefined && (iter % 3 == 0)) {
- itemFactors.checkpoint()
- }
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
+ // TODO: Generalize PeriodicGraphCheckpointer and use it here.
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint() // itemFactors gets materialized in computeFactors.
+ }
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
+ if (shouldCheckpoint(iter)) {
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
+ if (shouldCheckpoint(iter)) {
+ itemFactors.checkpoint()
+ itemFactors.count() // checkpoint item factors and cut lineage
+ deletePreviousCheckpointFile()
+ previousCheckpointFile = itemFactors.getCheckpointFile
+ }
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index cbd87ea8aeb37..15ca2547d56a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -345,9 +345,13 @@ private[python] class PythonMLLibAPI extends Serializable {
def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
predict(SerDe.asTupleRDD(userAndProducts.rdd))
- def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getUserFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
- def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getProductFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
}
@@ -909,7 +913,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseVector
private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
val bytes = new Array[Byte](8 * vector.size)
val bb = ByteBuffer.wrap(bytes)
@@ -941,7 +945,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseMatrix
private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
val bytes = new Array[Byte](8 * m.values.size)
val order = ByteOrder.nativeOrder()
@@ -973,7 +977,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
val n = v.indices.size
val indiceBytes = new Array[Byte](4 * n)
@@ -1015,7 +1019,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for LabeledPoint
private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
saveObjects(out, pickler, point.label, point.features)
}
@@ -1031,7 +1035,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for Rating
private[python] class RatingPickler extends BasePickler[Rating] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val rating: Rating = obj.asInstanceOf[Rating]
saveObjects(out, pickler, rating.user, rating.product, rating.rating)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index b787667b018e6..e7c3599ff619c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -163,6 +163,10 @@ class LogisticRegressionModel (
}
override protected def formatVersion: String = "1.0"
+
+ override def toString: String = {
+ s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}"
+ }
}
object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 2ebc7fa5d4234..d60e82c410979 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,6 +17,10 @@
package org.apache.spark.mllib.classification
+import java.lang.{Iterable => JIterable}
+
+import scala.collection.JavaConverters._
+
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -41,6 +45,13 @@ class NaiveBayesModel private[mllib] (
val pi: Array[Double],
val theta: Array[Array[Double]]) extends ClassificationModel with Serializable with Saveable {
+ /** A Java-friendly constructor that takes three Iterable parameters. */
+ private[mllib] def this(
+ labels: JIterable[Double],
+ pi: JIterable[Double],
+ theta: JIterable[JIterable[Double]]) =
+ this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
+
private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM[Double](theta.length, theta(0).length)
@@ -83,10 +94,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
private object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Hard-code class name string in case it changes in the future */
- def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
@@ -174,7 +185,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
- def run(data: RDD[LabeledPoint]) = {
+ def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index cfc7f868a02f0..52fb62dcff1b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -86,6 +86,10 @@ class SVMModel (
}
override protected def formatVersion: String = "1.0"
+
+ override def toString: String = {
+ s"${super.toString}, numClasses = 2, threshold = ${threshold.get}"
+ }
}
object SVMModel extends Loader[SVMModel] {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8956189ff1158..3b6790cce47c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -32,7 +32,7 @@ private[classification] object GLMClassificationModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index e41f941fd2c2c..0f8d6a399682d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -536,5 +536,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable
def this(array: Array[Double]) = this(Vectors.dense(array))
/** Converts the vector to a dense vector. */
- def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
index ea10bde5fa252..a8378a76d20ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -96,30 +96,30 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns precision for a given label (category)
* @param label the label.
*/
- def precision(label: Double) = {
+ def precision(label: Double): Double = {
val tp = tpPerClass(label)
val fp = fpPerClass.getOrElse(label, 0L)
- if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp)
}
/**
* Returns recall for a given label (category)
* @param label the label.
*/
- def recall(label: Double) = {
+ def recall(label: Double): Double = {
val tp = tpPerClass(label)
val fn = fnPerClass.getOrElse(label, 0L)
- if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn)
}
/**
* Returns f1-measure for a given label (category)
* @param label the label.
*/
- def f1Measure(label: Double) = {
+ def f1Measure(label: Double): Double = {
val p = precision(label)
val r = recall(label)
- if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ if((p + r) == 0) 0.0 else 2 * p * r / (p + r)
}
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
@@ -130,7 +130,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based precision
* (equals to micro-averaged document-based precision)
*/
- lazy val microPrecision = {
+ lazy val microPrecision: Double = {
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
sumTp.toDouble / (sumTp + sumFp)
}
@@ -139,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based recall
* (equals to micro-averaged document-based recall)
*/
- lazy val microRecall = {
+ lazy val microRecall: Double = {
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
sumTp.toDouble / (sumTp + sumFn)
}
@@ -148,7 +148,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based f1-measure
* (equals to micro-averaged document-based f1-measure)
*/
- lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+ lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
/**
* Returns the sequence of labels in ascending order
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0e4a4d0085895..d1a174063caba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
/**
* Trait for a local matrix.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
sealed trait Matrix extends Serializable {
/** Number of rows. */
@@ -102,6 +108,90 @@ sealed trait Matrix extends Serializable {
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}
+@DeveloperApi
+private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
+ // set as not nullable, except values since in the future, support for binary matrices might
+ // be added for which values are not needed.
+ // the sparse matrix needs colPtrs and rowIndices, which are set as
+ // null, while building the dense matrix.
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("numRows", IntegerType, nullable = false),
+ StructField("numCols", IntegerType, nullable = false),
+ StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
+ StructField("isTransposed", BooleanType, nullable = false)
+ ))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(7)
+ obj match {
+ case sm: SparseMatrix =>
+ row.setByte(0, 0)
+ row.setInt(1, sm.numRows)
+ row.setInt(2, sm.numCols)
+ row.update(3, sm.colPtrs.toSeq)
+ row.update(4, sm.rowIndices.toSeq)
+ row.update(5, sm.values.toSeq)
+ row.setBoolean(6, sm.isTransposed)
+
+ case dm: DenseMatrix =>
+ row.setByte(0, 1)
+ row.setInt(1, dm.numRows)
+ row.setInt(2, dm.numCols)
+ row.setNullAt(3)
+ row.setNullAt(4)
+ row.update(5, dm.values.toSeq)
+ row.setBoolean(6, dm.isTransposed)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Matrix = {
+ datum match {
+ // TODO: something wrong with UDT serialization, should never happen.
+ case m: Matrix => m
+ case row: Row =>
+ require(row.length == 7,
+ s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
+ val tpe = row.getByte(0)
+ val numRows = row.getInt(1)
+ val numCols = row.getInt(2)
+ val values = row.getAs[Iterable[Double]](5).toArray
+ val isTransposed = row.getBoolean(6)
+ tpe match {
+ case 0 =>
+ val colPtrs = row.getAs[Iterable[Int]](3).toArray
+ val rowIndices = row.getAs[Iterable[Int]](4).toArray
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
+ case 1 =>
+ new DenseMatrix(numRows, numCols, values, isTransposed)
+ }
+ }
+ }
+
+ override def userClass: Class[Matrix] = classOf[Matrix]
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case v: MatrixUDT => true
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = 1994
+
+ override def typeName: String = "matrix"
+
+ private[spark] override def asNullable: MatrixUDT = this
+}
+
/**
* Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +209,7 @@ sealed trait Matrix extends Serializable {
* @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
* row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class DenseMatrix(
val numRows: Int,
val numCols: Int,
@@ -146,12 +237,16 @@ class DenseMatrix(
def this(numRows: Int, numCols: Int, values: Array[Double]) =
this(numRows, numCols, values, false)
- override def equals(o: Any) = o match {
+ override def equals(o: Any): Boolean = o match {
case m: DenseMatrix =>
m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
case _ => false
}
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray)
+ }
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BDM[Double](numRows, numCols, values)
@@ -173,7 +268,7 @@ class DenseMatrix(
values(index(i, j)) = v
}
- override def copy = new DenseMatrix(numRows, numCols, values.clone())
+ override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f))
@@ -356,6 +451,7 @@ object DenseMatrix {
* Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
* and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class SparseMatrix(
val numRows: Int,
val numCols: Int,
@@ -431,7 +527,9 @@ class SparseMatrix(
}
}
- override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ override def copy: SparseMatrix = {
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ }
private[mllib] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e9d25dcb7e778..328dbe2ce11fa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -183,6 +183,10 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}
}
+ override def hashCode: Int = 7919
+
+ override def typeName: String = "vector"
+
private[spark] override def asNullable: VectorUDT = this
}
@@ -478,7 +482,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
- override def apply(i: Int) = values(i)
+ override def apply(i: Int): Double = values(i)
override def copy: DenseVector = {
new DenseVector(values.clone())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 1d253963130f1..3323ae7b1fba0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -49,7 +49,7 @@ private[mllib] class GridPartitioner(
private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt
private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt
- override val numPartitions = rowPartitions * colPartitions
+ override val numPartitions: Int = rowPartitions * colPartitions
/**
* Returns the index of the partition the input coordinate belongs to.
@@ -85,6 +85,14 @@ private[mllib] class GridPartitioner(
false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ rows: java.lang.Integer,
+ cols: java.lang.Integer,
+ rowsPerPart: java.lang.Integer,
+ colsPerPart: java.lang.Integer)
+ }
}
private[mllib] object GridPartitioner {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 405bae62ee8b6..9349ecaa13f56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -56,7 +56,7 @@ class UniformGenerator extends RandomDataGenerator[Double] {
random.nextDouble()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): UniformGenerator = new UniformGenerator()
}
@@ -75,7 +75,7 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
random.nextGaussian()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala
new file mode 100644
index 0000000000000..9213fd3f595c3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * Machine learning specific Pair RDD functions.
+ */
+@DeveloperApi
+class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Serializable {
+ /**
+ * Returns the top k (largest) elements for each key from this RDD as defined by the specified
+ * implicit Ordering[T].
+ * If the number of elements for a certain key is less than k, all of them will be returned.
+ *
+ * @param num k, the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an RDD that contains the top k values for each key
+ */
+ def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = {
+ self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))(
+ seqOp = (queue, item) => {
+ queue += item
+ queue
+ },
+ combOp = (queue1, queue2) => {
+ queue1 ++= queue2
+ queue1
+ }
+ ).mapValues(_.toArray.sorted(ord.reverse))
+ }
+}
+
+@DeveloperApi
+object MLPairRDDFunctions {
+ /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */
+ implicit def fromPairRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): MLPairRDDFunctions[K, V] =
+ new MLPairRDDFunctions[K, V](rdd)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index caacab943030b..dddefe1944e9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -82,6 +82,9 @@ class ALS private (
private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
private var finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+ /** checkpoint interval */
+ private var checkpointInterval: Int = 10
+
/**
* Set the number of blocks for both user blocks and product blocks to parallelize the computation
* into; pass -1 for an auto-configured number of blocks. Default: -1.
@@ -182,6 +185,19 @@ class ALS private (
this
}
+ /**
+ * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with
+ * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps
+ * with eliminating temporary shuffle files on disk, which can be important when there are many
+ * ALS iterations. If the checkpoint directory is not set in [[org.apache.spark.SparkContext]],
+ * this setting is ignored.
+ */
+ @DeveloperApi
+ def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -212,6 +228,7 @@ class ALS private (
nonnegative = nonnegative,
intermediateRDDStorageLevel = intermediateRDDStorageLevel,
finalRDDStorageLevel = StorageLevel.NONE,
+ checkpointInterval = checkpointInterval,
seed = seed)
val userFactors = floatUserFactors
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index b262bec904525..45b9ebb4cc0d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -76,7 +76,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
predictPoint(testData, weights, intercept)
}
- override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept)
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ s"${this.getClass.getName}: intercept = ${intercept}, numFeatures = ${weights.size}"
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index bd7e340ca2d8e..b55944f74f623 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -32,7 +32,7 @@ private[regression] object GLMRegressionModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 8d5c36da32bdb..ada227c200a79 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -83,10 +83,13 @@ class Strategy (
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- def isMulticlassClassification =
+ def isMulticlassClassification: Boolean = {
algo == Classification && numClasses > 2
- def isMulticlassWithCategoricalFeatures
- = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
+
+ def isMulticlassWithCategoricalFeatures: Boolean = {
+ isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index b7950e00786ab..5ac10f3fd32dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -71,7 +71,7 @@ object Entropy extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c946db9c0d1c8..19d318203c344 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -67,7 +67,7 @@ object Gini extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index df9eafa5da16a..7104a7fa4dd4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -58,7 +58,7 @@ object Variance extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index d1bde15e6b150..793dd664c5d5a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -47,18 +47,9 @@ object AbsoluteError extends Loss {
if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean absolute error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- math.abs(err)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = label - prediction
+ math.abs(err)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 55213e695638c..51b1aed167b66 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -50,20 +50,10 @@ object LogLoss extends Loss {
- 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean log loss of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { case point =>
- val prediction = model.predict(point.features)
- val margin = 2.0 * point.label * prediction
- // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
- 2.0 * MLUtils.log1pExp(-margin)
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val margin = 2.0 * label * prediction
+ // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
+ 2.0 * MLUtils.log1pExp(-margin)
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index e1169d9f66ea4..357869ff6b333 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -47,6 +47,18 @@ trait Loss extends Serializable {
* @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return Measure of model error on data
*/
- def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map(point => computeError(model.predict(point.features), point.label)).mean()
+ }
+
+ /**
+ * Method to calculate loss when the predictions are already known.
+ * Note: This method is used in the method evaluateEachIteration to avoid recomputing the
+ * predicted values from previously fit trees.
+ * @param prediction Predicted label.
+ * @param label True label.
+ * @return Measure of model error on datapoint.
+ */
+ def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 50ecaa2f86f35..b990707ca4525 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -47,18 +47,9 @@ object SquaredError extends Loss {
2.0 * (model.predict(point.features) - point.label)
}
- /**
- * Method to calculate loss of the base learner for the gradient boosting calculation.
- * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
- * purposes.
- * @param model Ensemble model
- * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @return Mean squared error of model on data
- */
- override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = model.predict(y.features) - y.label
- err * err
- }.mean()
+ override def computeError(prediction: Double, label: Double): Double = {
+ val err = prediction - label
+ err * err
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 8a57ebc387d01..c9bafd60fba4d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -120,10 +120,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
private[tree] object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel"
case class PredictData(predict: Double, prob: Double) {
def toPredict: Predict = new Predict(predict, prob)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 80990aa9a603f..f209fdafd3653 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,23 +38,32 @@ class InformationGainStats(
val leftPredict: Predict,
val rightPredict: Predict) extends Serializable {
- override def toString = {
+ override def toString: String = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
- override def equals(o: Any) =
- o match {
- case other: InformationGainStats => {
- gain == other.gain &&
- impurity == other.impurity &&
- leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity &&
- leftPredict == other.leftPredict &&
- rightPredict == other.rightPredict
- }
- case _ => false
- }
+ override def equals(o: Any): Boolean = o match {
+ case other: InformationGainStats =>
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
+
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ gain: java.lang.Double,
+ impurity: java.lang.Double,
+ leftImpurity: java.lang.Double,
+ rightImpurity: java.lang.Double,
+ leftPredict,
+ rightPredict)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index d961081d185e9..4f72bb8014cc0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -50,8 +50,10 @@ class Node (
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) extends Serializable with Logging {
- override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ override def toString: String = {
+ "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ }
/**
* build the left node and right nodes if not leaf
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index ad4c0dbbfb3e5..25990af7c6cf7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -29,7 +29,7 @@ class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
- override def toString = {
+ override def toString: String = {
"predict = %f, prob = %f".format(predict, prob)
}
@@ -39,4 +39,8 @@ class Predict(
case _ => false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index b7a85f58544a3..fb35e70a8d077 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -38,9 +38,10 @@ case class Split(
featureType: FeatureType,
categories: List[Double]) {
- override def toString =
+ override def toString: String = {
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
", categories = " + categories
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 30a8f7ca301af..1950254b2aa6d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -28,9 +28,11 @@ import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.tree.loss.Loss
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
@@ -79,7 +81,7 @@ object RandomForestModel extends Loader[RandomForestModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel"
}
}
@@ -108,6 +110,58 @@ class GradientBoostedTreesModel(
}
override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param loss evaluation metric.
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ loss: Loss): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+
+ var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
+ val pred = treeWeights(0) * trees(0).predict(i.features)
+ val error = loss.computeError(pred, i.label)
+ (pred, error)
+ }
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ // Avoid the model being copied across numIterations.
+ val broadcastTrees = sc.broadcast(trees)
+ val broadcastWeights = sc.broadcast(treeWeights)
+
+ (1 until numIterations).map { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = broadcastWeights.value(nTree)
+ iter.map {
+ case (point, (pred, error)) => {
+ val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ broadcastWeights.unpersist()
+ evaluationArray
+ }
+
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -130,7 +184,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
}
}
@@ -257,7 +311,7 @@ private[tree] object TreeEnsembleModel extends Logging {
import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
case class Metadata(
algo: String,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index bb86bafc0eb0a..0bb06e9e8ac9c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.recommendation
+import java.io.File
import java.util.Random
import scala.collection.mutable
@@ -32,16 +33,25 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.util.Utils
class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
private var sqlContext: SQLContext = _
+ private var tempDir: File = _
override def beforeAll(): Unit = {
super.beforeAll()
+ tempDir = Utils.createTempDir()
+ sc.setCheckpointDir(tempDir.getAbsolutePath)
sqlContext = new SQLContext(sc)
}
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
test("LocalIndexEncoder") {
val random = new Random
for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
@@ -485,4 +495,11 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}.count()
}
}
+
+ test("als with large number of iterations") {
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
+ ALS.train(
+ ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index c098b5458fe6b..0d2cec58e2c03 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -424,4 +424,19 @@ class MatricesSuite extends FunSuite {
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}
+
+ test("MatrixUDT") {
+ val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
+ val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
+ val dm3 = new DenseMatrix(0, 0, Array())
+ val sm1 = dm1.toSparse
+ val sm2 = dm2.toSparse
+ val sm3 = dm3.toSparse
+ val mUDT = new MatrixUDT()
+ Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
+ mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
+ }
+ assert(mUDT.typeName == "matrix")
+ assert(mUDT.simpleString == "matrix")
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 5def899cea117..2839c4c289b2d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -187,6 +187,8 @@ class VectorsSuite extends FunSuite {
for (v <- Seq(dv0, dv1, sv0, sv1)) {
assert(v === udt.deserialize(udt.serialize(v)))
}
+ assert(udt.typeName == "vector")
+ assert(udt.simpleString == "vector")
}
test("fromBreeze") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000000..1ac7c12c4e8e6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.rdd.MLPairRDDFunctions._
+
+class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
+ test("topByKey") {
+ val topMap = sc.parallelize(Array((1, 1), (1, 2), (3, 2), (3, 7), (3, 5), (5, 1), (5, 3)), 2)
+ .topByKey(2)
+ .collectAsMap()
+
+ assert(topMap.size === 3)
+ assert(topMap(1) === Array(2, 1))
+ assert(topMap(3) === Array(7, 5))
+ assert(topMap(5) === Array(3, 1))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index b437aeaaf0547..55b0bac7d49fe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -175,10 +175,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
val gbtValidate = new GradientBoostedTrees(boostingStrategy)
.runWithValidation(trainRdd, validateRdd)
- assert(gbtValidate.numTrees !== numIterations)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
- val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val (errorWithoutValidation, errorWithValidation) = {
if (algo == Classification) {
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -188,6 +189,17 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
}
}
assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
}
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 74437f37c47e4..7b51845206f4a 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml
index a2bcca26d8344..7dc7c65825e34 100644
--- a/network/shuffle/pom.xml
+++ b/network/shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index cea7a20c223e2..1e2e9c80af6cc 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/pom.xml b/pom.xml
index 6fc56a86d44ac..23bb16130b504 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOTpomSpark Project Parent POMhttp://spark.apache.org/
@@ -120,7 +120,7 @@
shaded-protobuf1.7.101.2.17
- 1.0.4
+ 2.2.02.4.1${hadoop.version}0.98.7-hadoop1
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index f0cbf4e57b8c5..dde92949fa175 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -91,7 +91,7 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.2.0"
+ val previousSparkVersion = "1.3.0"
val fullId = "spark-" + projectRef.project + "_2.10"
mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index a6b07fa7cddec..56f5dbe53fad4 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -16,6 +16,7 @@
*/
import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.core.ProblemFilters._
/**
* Additional excludes for checking of Spark's binary compatibility.
@@ -33,6 +34,25 @@ import com.typesafe.tools.mima.core._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.4") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("ml"),
+ // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"),
+ // These are needed if checking against the sbt build, since they are part of
+ // the maven-generated artifacts in 1.3.
+ excludePackage("org.spark-project.jetty"),
+ MimaBuild.excludeSparkPackage("unused"),
+ ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.rdd.JdbcRDD.compute"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast")
+ )
+
case v if v.startsWith("1.3") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e4765173709e8..6766f3ebb8894 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,9 +21,10 @@
from numpy import array
from pyspark import RDD
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
@@ -99,6 +100,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
1
>>> lrm.predict(SparseVector(2, {0: 1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LogisticRegressionModel.load(sc, path)
+ >>> sameModel.predict(array([0.0, 1.0]))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: 1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(LogisticRegressionModel, self).__init__(weights, intercept)
@@ -124,6 +137,22 @@ def predict(self, x):
else:
return 1 if prob > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = LogisticRegressionModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class LogisticRegressionWithSGD(object):
@@ -243,6 +272,18 @@ class SVMModel(LinearBinaryClassificationModel):
1
>>> svm.predict(SparseVector(2, {0: -1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> svm.save(sc, path)
+ >>> sameModel = SVMModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: -1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(SVMModel, self).__init__(weights, intercept)
@@ -263,6 +304,22 @@ def predict(self, x):
else:
return 1 if margin > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = SVMModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class SVMWithSGD(object):
@@ -303,7 +360,8 @@ def train(rdd, i):
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
-class NaiveBayesModel(object):
+@inherit_doc
+class NaiveBayesModel(Saveable, Loader):
"""
Model for Naive Bayes classifiers.
@@ -334,6 +392,16 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = NaiveBayesModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self, labels, pi, theta):
@@ -348,6 +416,23 @@ def predict(self, x):
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
+ def save(self, sc, path):
+ java_labels = _py2java(sc, self.labels.tolist())
+ java_pi = _py2java(sc, self.pi.tolist())
+ java_theta = _py2java(sc, self.theta.tolist())
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel(
+ java_labels, java_pi, java_theta)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
+ sc._jsc.sc(), path)
+ py_labels = _java2py(sc, java_model.labels())
+ py_pi = _java2py(sc, java_model.pi())
+ py_theta = _java2py(sc, java_model.theta())
+ return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
+
class NaiveBayes(object):
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c21ad578793f..414a0ada80787 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,8 +18,9 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc, inherit_doc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.util import Saveable, Loader
__all__ = ['LabeledPoint', 'LinearModel',
'LinearRegressionModel', 'LinearRegressionWithSGD',
@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LinearRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LinearRegressionModel(weights, intercept)
+ return model
# train_func should take two parameters, namely data and initial_weights, and
@@ -135,7 +163,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
- initial_weights = initial_weights or [0.0] * len(data.first().features)
+ if initial_weights is None:
+ initial_weights = [0.0] * len(data.first().features)
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)
@@ -199,6 +228,20 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LassoModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -211,6 +254,19 @@ class LassoModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LassoModel(weights, intercept)
+ return model
class LassoWithSGD(object):
@@ -246,6 +302,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = RidgeRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -258,6 +328,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = RidgeRegressionModel(weights, intercept)
+ return model
class RidgeRegressionWithSGD(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5328d99b69684..155019638f806 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -323,6 +323,13 @@ def test_regression(self):
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)
+ try:
+ LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
+ except ValueError:
+ self.fail()
+
class StatTests(PySparkTestCase):
# SPARK-4023
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index e877c720ac77a..c5c3468eb95e9 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -20,7 +20,6 @@
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
-from pyspark.mllib.regression import LabeledPoint
class MLUtils(object):
@@ -50,6 +49,7 @@ def _parse_libsvm_line(line, multiclass=None):
@staticmethod
def _convert_labeled_point_to_libsvm(p):
"""Converts a LabeledPoint to a string in LIBSVM format."""
+ from pyspark.mllib.regression import LabeledPoint
assert isinstance(p, LabeledPoint)
items = [str(p.label)]
v = _convert_to_vector(p.features)
@@ -92,6 +92,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
@@ -110,6 +111,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None
>>> print examples[2]
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
"""
+ from pyspark.mllib.regression import LabeledPoint
if multiclass is not None:
warnings.warn("deprecated", DeprecationWarning)
@@ -130,6 +132,7 @@ def saveAsLibSVMFile(data, dir):
>>> from tempfile import NamedTemporaryFile
>>> from fileinput import input
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> from glob import glob
>>> from pyspark.mllib.util import MLUtils
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
@@ -156,6 +159,7 @@ def loadLabeledPoints(sc, path, minPartitions=None):
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
>>> tempFile = NamedTemporaryFile(delete=True)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bf17f513c0bc3..c337a43c8a7fc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -346,6 +346,12 @@ def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD.
+ :param withReplacement: can elements be sampled multiple times (replaced when sampled out)
+ :param fraction: expected size of the sample as a fraction of this RDD's size
+ without replacement: probability that each element is chosen; fraction must be [0, 1]
+ with replacement: expected number of times each element is chosen; fraction must be >= 0
+ :param seed: seed for the random number generator
+
>>> rdd = sc.parallelize(range(100), 4)
>>> rdd.sample(False, 0.1, 81).count()
10
diff --git a/repl/pom.xml b/repl/pom.xml
index 295f88ea3ecf9..edfa1c7f2c29c 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../pom.xml
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index fbef5b25ba688..14f5e9ed4f25e 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -21,11 +21,9 @@ import java.io._
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.tools.nsc.interpreter.SparkILoop
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.SparkContext
@@ -196,8 +194,7 @@ class ReplSuite extends FunSuite {
}
test("interacting with files") {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 0ff521706c71a..459a5035d4984 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -137,9 +137,9 @@
-
+
-
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 8ad026dbdf8ff..3dea2ee76542f 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.10
- 1.3.0-SNAPSHOT
+ 1.4.0-SNAPSHOT../../pom.xml
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 54ab13ca352d2..ea7d44a3723d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* This is currently included mostly for illustrative purposes. Users wanting more complete support
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
-class SqlParser extends AbstractSparkSQLParser {
+class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
def parseExpression(input: String): Expression = {
// Initialize the Keywords.
@@ -61,11 +61,8 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CAST = Keyword("CAST")
protected val COALESCE = Keyword("COALESCE")
protected val COUNT = Keyword("COUNT")
- protected val DATE = Keyword("DATE")
- protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
- protected val DOUBLE = Keyword("DOUBLE")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
@@ -78,7 +75,6 @@ class SqlParser extends AbstractSparkSQLParser {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
- protected val INT = Keyword("INT")
protected val INSERT = Keyword("INSERT")
protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
@@ -105,13 +101,11 @@ class SqlParser extends AbstractSparkSQLParser {
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val SQRT = Keyword("SQRT")
- protected val STRING = Keyword("STRING")
protected val SUBSTR = Keyword("SUBSTR")
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
protected val THEN = Keyword("THEN")
- protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
protected val UNION = Keyword("UNION")
protected val UPPER = Keyword("UPPER")
@@ -315,7 +309,9 @@ class SqlParser extends AbstractSparkSQLParser {
)
protected lazy val cast: Parser[Expression] =
- CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }
+ CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
+ case exp ~ t => Cast(exp, t)
+ }
protected lazy val literal: Parser[Literal] =
( numericLiteral
@@ -387,19 +383,4 @@ class SqlParser extends AbstractSparkSQLParser {
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
}
-
- protected lazy val dataType: Parser[DataType] =
- ( STRING ^^^ StringType
- | TIMESTAMP ^^^ TimestampType
- | DOUBLE ^^^ DoubleType
- | fixedDecimalType
- | DECIMAL ^^^ DecimalType.Unlimited
- | DATE ^^^ DateType
- | INT ^^^ IntegerType
- )
-
- protected lazy val fixedDecimalType: Parser[DataType] =
- (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
- case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 4e8fc892f3eea..425e1e41cbf21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -63,7 +63,7 @@ class CheckAnalysis {
s"filter expression '${f.condition.prettyString}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
- case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
+ case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case e: Attribute if !groupingExprs.contains(e) =>
@@ -85,14 +85,18 @@ class CheckAnalysis {
cleaned.foreach(checkValidAggregateExpression)
- case o if o.children.nonEmpty &&
- !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
- val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
- val input = o.inputSet.map(_.prettyString).mkString(",")
+ case _ => // Fallbacks to the following checks
+ }
+
+ operator match {
+ case o if o.children.nonEmpty && o.missingInput.nonEmpty =>
+ val missingAttributes = o.missingInput.mkString(",")
+ val input = o.inputSet.mkString(",")
- failAnalysis(s"resolved attributes $missingAttributes missing from $input")
+ failAnalysis(
+ s"resolved attribute(s) $missingAttributes missing from $input " +
+ s"in operator ${operator.simpleString}")
- // Catch all
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index a9ba0be596349..adaeab0b5c027 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.analysis.Star
protected class AttributeEquals(val a: Attribute) {
override def hashCode() = a match {
@@ -115,7 +114,7 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
- override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
+ override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"
override def isEmpty: Boolean = baseSet.isEmpty
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 735b7488fdcbd..5297d1e31246c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -346,13 +346,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
SplitEvaluation(
- Cast(Sum(partialSum.toAttribute), dataType),
+ Cast(CombineSum(partialSum.toAttribute), dataType),
partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
- Sum(partialSum.toAttribute),
+ CombineSum(partialSum.toAttribute),
partialSum :: Nil)
}
}
@@ -360,6 +360,30 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def newInstance() = new SumFunction(child, this)
}
+/**
+ * Sum should satisfy 3 cases:
+ * 1) sum of all null values = zero
+ * 2) sum for table column with no data = null
+ * 3) sum of column with null and not null values = sum of not null values
+ * Require separate CombineSum Expression and function as it has to distinguish "No data" case
+ * versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
+ * Combining PartitionLevel InputData
+ * <-- null
+ * Zero <-- Zero <-- null
+ *
+ * <-- null <-- no data
+ * null <-- null <-- no data
+ */
+case class CombineSum(child: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = child :: Nil
+ override def nullable = true
+ override def dataType = child.dataType
+ override def toString = s"CombineSum($child)"
+ override def newInstance() = new CombineSumFunction(child, this)
+}
+
case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -565,7 +589,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
private val sum = MutableLiteral(null, calcType)
- private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
+ private val addFunction =
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
@@ -580,6 +605,43 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
}
}
+case class CombineSumFunction(expr: Expression, base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ private val calcType =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ private val zero = Cast(Literal(0), calcType)
+
+ private val sum = MutableLiteral(null, calcType)
+
+ private val addFunction =
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
+
+ override def update(input: Row): Unit = {
+ val result = expr.eval(input)
+ // partial sum result can be null only when no input rows present
+ if(result != null) {
+ sum.update(addFunction, input)
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(sum, dataType).eval(null)
+ case _ => sum.eval(null)
+ }
+ }
+}
+
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
index 80c7dfd376c96..528e38a50a740 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.rules
-import org.apache.spark.sql.catalyst.util
+import org.apache.spark.util.Utils
/**
* A collection of generators that build custom bytecode at runtime for performing the evaluation
@@ -52,7 +52,7 @@ package object codegen {
@DeveloperApi
object DumpByteCode {
import scala.sys.process._
- val dumpDirectory = util.getTempFilePath("sparkSqlByteCode")
+ val dumpDirectory = Utils.createTempDir()
dumpDirectory.mkdir()
def apply(obj: Any): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index faa366771824b..f03d6f71a9fae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -146,6 +146,27 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
result
}
+ override def equals(o: Any): Boolean = o match {
+ case other: Row =>
+ if (values.length != other.length) {
+ return false
+ }
+
+ var i = 0
+ while (i < values.length) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (apply(i) != other.apply(i)) {
+ return false
+ }
+ i += 1
+ }
+ true
+
+ case _ => false
+ }
+
def copy() = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 17a88e07de15f..48191f31198f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
+import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
@@ -47,8 +47,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
* Attributes that are referenced by expressions but not provided by this nodes children.
* Subclasses should override this method if they produce attributes internally as it is used by
* assertions designed to prevent the construction of invalid plans.
+ *
+ * Note that virtual columns should be excluded. Currently, we only support the grouping ID
+ * virtual column.
*/
- def missingInput: AttributeSet = references -- inputSet
+ def missingInput: AttributeSet =
+ (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName)
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 1e7b449d75b80..384fe53a68362 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -289,6 +289,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
case object NoRelation extends LeafNode {
override def output = Nil
+
+ /**
+ * Computes [[Statistics]] for this plan. The default implementation assumes the output
+ * cardinality is the product of of all child plan's cardinality, i.e. applies in the case
+ * of cartesian joins.
+ *
+ * [[LeafNode]]s must override this.
+ */
+ override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index d8da45ae70c4b..feed50f9a2a2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -19,20 +19,9 @@ package org.apache.spark.sql.catalyst
import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}
-import org.apache.spark.util.{Utils => SparkUtils}
+import org.apache.spark.util.Utils
package object util {
- /**
- * Returns a path to a temporary file that probably does not exist.
- * Note, there is always the race condition that someone created this
- * file since the last time we checked. Thus, this shouldn't be used
- * for anything security conscious.
- */
- def getTempFilePath(prefix: String, suffix: String = ""): File = {
- val tempFile = File.createTempFile(prefix, suffix)
- tempFile.delete()
- tempFile
- }
def fileToString(file: File, encoding: String = "UTF-8") = {
val inStream = new FileInputStream(file)
@@ -56,7 +45,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
- classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
+ classLoader: ClassLoader = Utils.getSparkClassLoader) = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
new file mode 100644
index 0000000000000..89278f7dbc806
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.language.implicitConversions
+import scala.util.matching.Regex
+import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+
+import org.apache.spark.sql.catalyst.SqlLexical
+
+/**
+ * This is a data type parser that can be used to parse string representations of data types
+ * provided in SQL queries. This parser is mixed in with DDLParser and SqlParser.
+ */
+private[sql] trait DataTypeParser extends StandardTokenParsers {
+
+ // This is used to create a parser from a regex. We are using regexes for data type strings
+ // since these strings can be also used as column names or field names.
+ import lexical.Identifier
+ implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
+ s"identifier matching regex ${regex}",
+ { case Identifier(str) if regex.unapplySeq(str).isDefined => str }
+ )
+
+ protected lazy val primitiveType: Parser[DataType] =
+ "(?i)string".r ^^^ StringType |
+ "(?i)float".r ^^^ FloatType |
+ "(?i)int".r ^^^ IntegerType |
+ "(?i)tinyint".r ^^^ ByteType |
+ "(?i)smallint".r ^^^ ShortType |
+ "(?i)double".r ^^^ DoubleType |
+ "(?i)bigint".r ^^^ LongType |
+ "(?i)binary".r ^^^ BinaryType |
+ "(?i)boolean".r ^^^ BooleanType |
+ fixedDecimalType |
+ "(?i)decimal".r ^^^ DecimalType.Unlimited |
+ "(?i)date".r ^^^ DateType |
+ "(?i)timestamp".r ^^^ TimestampType |
+ varchar
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("(?i)decimal".r ~> "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale =>
+ DecimalType(precision.toInt, scale.toInt)
+ }
+
+ protected lazy val varchar: Parser[DataType] =
+ "(?i)varchar".r ~> "(" ~> (numericLit <~ ")") ^^^ StringType
+
+ protected lazy val arrayType: Parser[DataType] =
+ "(?i)array".r ~> "<" ~> dataType <~ ">" ^^ {
+ case tpe => ArrayType(tpe)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "(?i)map".r ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ident ~ ":" ~ dataType ^^ {
+ case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ ("(?i)struct".r ~> "<" ~> repsep(structField, ",") <~ ">" ^^ {
+ case fields => new StructType(fields.toArray)
+ }) |
+ ("(?i)struct".r ~ "<>" ^^^ StructType(Nil))
+
+ protected lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
+
+ def toDataType(dataTypeString: String): DataType = synchronized {
+ phrase(dataType)(new lexical.Scanner(dataTypeString)) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => throw new DataTypeException(failMessage(dataTypeString))
+ }
+ }
+
+ private def failMessage(dataTypeString: String): String = {
+ s"Unsupported dataType: $dataTypeString. If you have a struct and a field name of it has " +
+ "any special characters, please use backticks (`) to quote that field name, e.g. `x+y`. " +
+ "Please note that backtick itself is not supported in a field name."
+ }
+}
+
+private[sql] object DataTypeParser {
+ lazy val dataTypeParser = new DataTypeParser {
+ override val lexical = new SqlLexical
+ }
+
+ def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
+}
+
+/** The exception thrown from the [[DataTypeParser]]. */
+protected[sql] class DataTypeException(message: String) extends Exception(message)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 21cc6cea4bf54..994c5202c15dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -246,7 +246,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
- override def equals(other: Any) = other match {
+ override def equals(other: Any): Boolean = other match {
case d: Decimal =>
compare(d) == 0
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index bf39603d13bd5..d973144de3468 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -307,7 +307,7 @@ protected[sql] object NativeType {
protected[sql] trait PrimitiveType extends DataType {
- override def isPrimitive = true
+ override def isPrimitive: Boolean = true
}
@@ -442,7 +442,7 @@ class TimestampType private() extends NativeType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = new Ordering[JvmType] {
- def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
+ def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y)
}
/**
@@ -542,7 +542,7 @@ class LongType private() extends IntegralType {
*/
override def defaultSize: Int = 8
- override def simpleString = "bigint"
+ override def simpleString: String = "bigint"
private[spark] override def asNullable: LongType = this
}
@@ -572,7 +572,7 @@ class IntegerType private() extends IntegralType {
*/
override def defaultSize: Int = 4
- override def simpleString = "int"
+ override def simpleString: String = "int"
private[spark] override def asNullable: IntegerType = this
}
@@ -602,7 +602,7 @@ class ShortType private() extends IntegralType {
*/
override def defaultSize: Int = 2
- override def simpleString = "smallint"
+ override def simpleString: String = "smallint"
private[spark] override def asNullable: ShortType = this
}
@@ -632,7 +632,7 @@ class ByteType private() extends IntegralType {
*/
override def defaultSize: Int = 1
- override def simpleString = "tinyint"
+ override def simpleString: String = "tinyint"
private[spark] override def asNullable: ByteType = this
}
@@ -696,7 +696,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
*/
override def defaultSize: Int = 4096
- override def simpleString = precisionInfo match {
+ override def simpleString: String = precisionInfo match {
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
case None => "decimal(10,0)"
}
@@ -836,7 +836,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
*/
override def defaultSize: Int = 100 * elementType.defaultSize
- override def simpleString = s"array<${elementType.simpleString}>"
+ override def simpleString: String = s"array<${elementType.simpleString}>"
private[spark] override def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
@@ -1065,7 +1065,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
*/
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
- override def simpleString = {
+ override def simpleString: String = {
val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}")
s"struct<${fieldTypes.mkString(",")}>"
}
@@ -1142,7 +1142,7 @@ case class MapType(
*/
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
- override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
+ override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
private[spark] override def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index c1dd5aa913ddc..359aec4a7b5ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -199,4 +199,22 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
+
+ test("SPARK-6452 regression test") {
+ // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+ val plan =
+ Aggregate(
+ Nil,
+ Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", StringType)(exprId = ExprId(2))))
+
+ assert(plan.resolved)
+
+ val message = intercept[AnalysisException] {
+ caseSensitiveAnalyze(plan)
+ }.getMessage
+
+ assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
new file mode 100644
index 0000000000000..1ba21b64603ac
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
@@ -0,0 +1,116 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.types
+
+import org.scalatest.FunSuite
+
+class DataTypeParserSuite extends FunSuite {
+
+ def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
+ test(s"parse ${dataTypeString.replace("\n", "")}") {
+ assert(DataTypeParser(dataTypeString) === expectedDataType)
+ }
+ }
+
+ def unsupported(dataTypeString: String): Unit = {
+ test(s"$dataTypeString is not supported") {
+ intercept[DataTypeException](DataTypeParser(dataTypeString))
+ }
+ }
+
+ checkDataType("int", IntegerType)
+ checkDataType("BooLean", BooleanType)
+ checkDataType("tinYint", ByteType)
+ checkDataType("smallINT", ShortType)
+ checkDataType("INT", IntegerType)
+ checkDataType("bigint", LongType)
+ checkDataType("float", FloatType)
+ checkDataType("dOUBle", DoubleType)
+ checkDataType("decimal(10, 5)", DecimalType(10, 5))
+ checkDataType("decimal", DecimalType.Unlimited)
+ checkDataType("DATE", DateType)
+ checkDataType("timestamp", TimestampType)
+ checkDataType("string", StringType)
+ checkDataType("varchAr(20)", StringType)
+ checkDataType("BINARY", BinaryType)
+
+ checkDataType("array", ArrayType(DoubleType, true))
+ checkDataType("Array