From 0afe5cb65a195d2f14e8dfcefdbec5dac023651f Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 15 Aug 2014 11:35:08 -0700 Subject: [PATCH 01/54] SPARK-3028. sparkEventToJson should support SparkListenerExecutorMetrics... ...Update Author: Sandy Ryza Closes #1961 from sryza/sandy-spark-3028 and squashes the following commits: dccdff5 [Sandy Ryza] Fix compile error f883ded [Sandy Ryza] SPARK-3028. sparkEventToJson should support SparkListenerExecutorMetricsUpdate --- .../org/apache/spark/scheduler/EventLoggingListener.scala | 2 ++ core/src/main/scala/org/apache/spark/util/JsonProtocol.scala | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 406147f167bf3..7378ce923f0ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -127,6 +127,8 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + // No-op because logging every update would be overkill + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } /** * Stop logging events. diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6f8eb1ee12634..1e18ec688c40d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -72,8 +72,9 @@ private[spark] object JsonProtocol { case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - // Not used, but keeps compiler happy + // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing + case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } From c7032290a3f0f5545aa4f0a9a144c62571344dc8 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 15 Aug 2014 14:50:10 -0700 Subject: [PATCH 02/54] [SPARK-3022] [SPARK-3041] [mllib] Call findBins once per level + unordered feature bug fix DecisionTree improvements: (1) TreePoint representation to avoid binning multiple times (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features (3) Timing for DecisionTree internals Details: (1) TreePoint representation to avoid binning multiple times [https://issues.apache.org/jira/browse/SPARK-3022] Added private[tree] TreePoint class for representing binned feature values. The input RDD of LabeledPoint is converted to the TreePoint representation initially and then cached. This avoids the previous problem of re-computing bins multiple times. (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features [https://issues.apache.org/jira/browse/SPARK-3041] isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories. * exhibited for unordered features (multi-class classification with categorical features of low arity) * Fix: Index bins correctly for unordered categorical features. (3) Timing for DecisionTree internals Added tree/impl/TimeTracker.scala class which is private[tree] for now, for timing key parts of DT code. Prints timing info via logDebug. CC: mengxr manishamde chouqin Very similar update, with one bug fix. Many apologies for the conflicting update, but I hope that a few more optimizations I have on the way (which depend on this update) will prove valuable to you: SPARK-3042 and SPARK-3043 Author: Joseph K. Bradley Closes #1950 from jkbradley/dt-opt1 and squashes the following commits: 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals --- .../spark/mllib/tree/DecisionTree.scala | 289 ++++++++---------- .../mllib/tree/configuration/Strategy.scala | 43 ++- .../spark/mllib/tree/impl/TimeTracker.scala | 73 +++++ .../spark/mllib/tree/impl/TreePoint.scala | 201 ++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 50 +-- 5 files changed, 449 insertions(+), 207 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bb50f07be5d7b..2a3107a13e916 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,22 +17,24 @@ package org.apache.spark.mllib.tree -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -53,16 +55,27 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. + timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + .persist(StorageLevel.MEMORY_AND_DISK) + // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -76,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = retaggedInput.take(1)(0).features.size + val numFeatures = treeInput.take(1)(0).binnedFeatures.size // Calculate level for single group construction @@ -96,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) + timer.stop("init") + /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. @@ -113,15 +128,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + timer.start("findBestSplits") + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + timer.stop("findBestSplits") for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + timer.start("extractNodeInfo") // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.stop("extractNodeInfo") + timer.start("extractInfoForLowerLevels") // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) @@ -144,6 +165,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + new DecisionTreeModel(topNode, strategy.algo) } @@ -406,7 +432,7 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -415,44 +441,45 @@ object DecisionTree extends Serializable with Logging { * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + filters, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) } } /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -465,13 +492,14 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -507,7 +535,7 @@ object DecisionTree extends Serializable with Logging { logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size + val numFeatures = input.first().binnedFeatures.size logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits @@ -542,33 +570,43 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false } // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val features = labeledPoint.features + parentFilters.foreach { filter => val featureIndex = filter.split.feature - val threshold = filter.split.threshold val comparison = filter.comparison - val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) if (isFeatureContinuous) { + val binId = treePoint.binnedFeatures(featureIndex) + val bin = bins(featureIndex)(binId) + val featureValue = bin.highSplit.threshold + val threshold = filter.split.threshold comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + case -1 => if (featureValue > threshold) return false + case 1 => if (featureValue <= threshold) return false } } else { - val containsFeature = categories.contains(feature) + val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + val featureValue = if (isUnorderedFeature) { + treePoint.binnedFeatures(featureIndex) + } else { + val binId = treePoint.binnedFeatures(featureIndex) + bins(featureIndex)(binId).category + } + val containsFeature = filter.split.categories.contains(featureValue) comparison match { case -1 => if (!containsFeature) return false case 1 => if (containsFeature) return false } - } } @@ -576,103 +614,6 @@ object DecisionTree extends Serializable with Logging { true } - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 - } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex - } - } - /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: @@ -689,17 +630,17 @@ object DecisionTree extends Serializable with Logging { * bin index for this labeledPoint * (or InvalidBinIndex if labeledPoint is not handled by this node) */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(treePoint: TreePoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. - arr(0) = labeledPoint.label + arr(0) = treePoint.label // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) + val sampleValid = isSampleValid(parentFilters, treePoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. @@ -707,19 +648,7 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } + arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex) featureIndex += 1 } } @@ -728,7 +657,8 @@ object DecisionTree extends Serializable with Logging { arr } - // Find feature bins for all nodes at a level. + // Find feature bins for all nodes at a level. + timer.start("aggregation") val binMappedRDD = input.map(x => findBinsForLevel(x)) /** @@ -830,6 +760,8 @@ object DecisionTree extends Serializable with Logging { } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * @@ -853,7 +785,6 @@ object DecisionTree extends Serializable with Logging { val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes // actual class label val label = arr(0) // Iterate over all features. @@ -912,7 +843,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + agg(aggIndex + 2) = agg(aggIndex + 2) + label * label featureIndex += 1 } } @@ -977,6 +908,7 @@ object DecisionTree extends Serializable with Logging { val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } + timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** @@ -1031,10 +963,17 @@ object DecisionTree extends Serializable with Logging { def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } - if (result._1 < 0) 0 else result._1 + result._1 } val predict = indexOfLargestArrayElement(leftRightCounts) @@ -1057,6 +996,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) @@ -1280,15 +1220,41 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) + splitIndex += 1 } + featureIndex += 1 } gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { + // Ordered features + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1307,7 +1273,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1317,22 +1283,8 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1383,6 +1335,7 @@ object DecisionTree extends Serializable with Logging { } // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 @@ -1395,6 +1348,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.stop("chooseSplits") + bestSplits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f31a503608b22..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,6 +85,10 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ private[tree] def assertValid(): Unit = { algo match { case Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..d215d68c4279e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..ccac1031fd9d9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[tree] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param strategy DecisionTree training info, used for dataset metadata. + * @param bins Bins for features, of size (numFeatures, numBins). + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + strategy: Strategy, + bins: Array[Array[Bin]]): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, + strategy.categoricalFeaturesInfo) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + isMulticlassClassification: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + val featureInfo = categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, + bins, categoricalFeaturesInfo) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isUnorderedFeature, bins, categoricalFeaturesInfo) + } + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find bin for one (labeledPoint, feature). + * + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } else if (lowThreshold >= feature) { + right = mid - 1 + } else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new RuntimeException("No bin was found for continuous feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new RuntimeException("No bin was found for categorical feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..a5c49a38dc08f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} +import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy) + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } def validateRegressor( @@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { err * err }.sum val mse = squaredError / input.length - assert(mse <= requiredMSE) + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } test("split and bin calculation") { @@ -427,7 +429,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -454,7 +457,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -499,7 +503,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -544,7 +550,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -567,7 +574,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,7 +604,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,7 +613,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -630,7 +639,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -689,7 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -714,7 +725,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -738,7 +750,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -757,7 +770,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) From cc3648774e9a744850107bb187f2828d447e0a48 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 Aug 2014 17:04:15 -0700 Subject: [PATCH 03/54] [SPARK-3046] use executor's class loader as the default serializer classloader The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization. https://issues.apache.org/jira/browse/SPARK-3046 https://issues.apache.org/jira/browse/SPARK-2878 Author: Reynold Xin Closes #1972 from rxin/kryoBug and squashes the following commits: c1c7bf0 [Reynold Xin] Made change to JavaSerializer. 7204c33 [Reynold Xin] Added imports back. d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader. --- .../org/apache/spark/executor/Executor.scala | 3 + .../spark/serializer/JavaSerializer.scala | 9 ++- .../spark/serializer/KryoSerializer.scala | 9 ++- .../apache/spark/serializer/Serializer.scala | 17 +++++ .../KryoSerializerDistributedSuite.scala | 71 +++++++++++++++++++ .../serializer/KryoSerializerSuite.scala | 23 +++++- 6 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eac1f2326a29d..fb3f7bd54bbfa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -99,6 +99,9 @@ private[spark] class Executor( private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 34bc3124097bb..af33a2f2ca3e1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,7 +63,9 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { +private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) + extends SerializerInstance { + def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) - def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + override def newInstance(): SerializerInstance = { + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) + new JavaSerializerInstance(counterReset, classLoader) + } override def writeExternal(out: ObjectOutput) { out.writeInt(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 85944eabcfefc..99682220b4ab5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf) val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() kryo.setRegistrationRequired(registrationRequired) - val classLoader = Thread.currentThread.getContextClassLoader + + val oldClassLoader = Thread.currentThread.getContextClassLoader + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf) try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] + + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) } catch { case e: Exception => throw new SparkException(s"Failed to invoke $regCls", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2f5cea469c61..e674438c8176c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} */ @DeveloperApi trait Serializer { + + /** + * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should + * make sure it is using this when set. + */ + @volatile protected var defaultClassLoader: Option[ClassLoader] = None + + /** + * Sets a class loader for the serializer to use in deserialization. + * + * @return this Serializer object + */ + def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + defaultClassLoader = Some(classLoader) + this + } + def newInstance(): SerializerInstance } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala new file mode 100644 index 0000000000000..11e8c9c4cb37f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.Kryo +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoDistributedTest._ + +class KryoSerializerDistributedSuite extends FunSuite { + + test("kryo objects are serialised consistently in different processes") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + conf.set("spark.task.maxFailures", "1") + + val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) + conf.setJars(List(jar.getPath)) + + val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val original = Thread.currentThread.getContextClassLoader + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + SparkEnv.get.serializer.setDefaultClassLoader(loader) + + val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache() + + // Randomly mix the keys so that the join below will require a shuffle with each partition + // sending data to multiple other partitions. + val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)} + + // Join the two RDDs, and force evaluation + assert(shuffledRDD.join(cachedRDD).collect().size == 1) + + LocalSparkContext.stop(sc) + } +} + +object KryoDistributedTest { + class MyCustomClass + + class AppJarRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + val classLoader = Thread.currentThread.getContextClassLoader + k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + } + } + + object AppJarRegistrator { + val customClassName = "KryoSerializerDistributedSuiteCustomClass" + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 3bf9efebb39d2..a579fd50bd9e4 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ class KryoSerializerSuite extends FunSuite with SharedSparkContext { @@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) } + + test("default class loader can be set by a different thread") { + val ser = new KryoSerializer(new SparkConf) + + // First serialize the object + val serInstance = ser.newInstance() + val bytes = serInstance.serialize(new ClassLoaderTestingObject) + + // Deserialize the object to make sure normal deserialization works + serInstance.deserialize[ClassLoaderTestingObject](bytes) + + // Set a special, broken ClassLoader and make sure we get an exception on deserialization + ser.setDefaultClassLoader(new ClassLoader() { + override def loadClass(name: String) = throw new UnsupportedOperationException + }) + intercept[UnsupportedOperationException] { + ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) + } + } } +class ClassLoaderTestingObject + class KryoSerializerResizableOutputSuite extends FunSuite { import org.apache.spark.SparkConf import org.apache.spark.SparkContext From 5d25c0b74f6397d78164b96afb8b8cbb1b15cfbd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 15 Aug 2014 21:04:29 -0700 Subject: [PATCH 04/54] [SPARK-3078][MLLIB] Make LRWithLBFGS API consistent with others Should ask users to set parameters through the optimizer. dbtsai Author: Xiangrui Meng Closes #1973 from mengxr/lr-lbfgs and squashes the following commits: e3efbb1 [Xiangrui Meng] fix tests 21b3579 [Xiangrui Meng] fix method name 641eea4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into lr-lbfgs 456ab7c [Xiangrui Meng] update LRWithLBFGS --- .../examples/mllib/BinaryClassification.scala | 8 ++-- .../classification/LogisticRegression.scala | 40 +++---------------- .../spark/mllib/optimization/LBFGS.scala | 9 +++++ .../LogisticRegressionSuite.scala | 5 ++- .../spark/mllib/optimization/LBFGSSuite.scala | 24 +++++------ 5 files changed, 33 insertions(+), 53 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 56b02b65d8724..a6f78d2441db1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD} +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} @@ -66,7 +66,8 @@ object BinaryClassification { .text("number of iterations") .action((x, c) => c.copy(numIterations = x)) opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") + .text("initial step size (ignored by logistic regression), " + + s"default: ${defaultParams.stepSize}") .action((x, c) => c.copy(stepSize = x)) opt[String]("algorithm") .text(s"algorithm (${Algorithm.values.mkString(",")}), " + @@ -125,10 +126,9 @@ object BinaryClassification { val model = params.algorithm match { case LR => - val algorithm = new LogisticRegressionWithSGD() + val algorithm = new LogisticRegressionWithLBFGS() algorithm.optimizer .setNumIterations(params.numIterations) - .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) algorithm.run(training).clearThreshold() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 6790c86f651b4..486bdbfa9cb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -73,6 +73,8 @@ class LogisticRegressionModel ( /** * Train a classification model for Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, @@ -191,51 +193,19 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Logistic Regression using Limited-memory BFGS. + * Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1} */ -class LogisticRegressionWithLBFGS private ( - private var convergenceTol: Double, - private var maxNumIterations: Int, - private var regParam: Double) +class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { - /** - * Construct a LogisticRegression object with default parameters - */ - def this() = this(1E-4, 100, 0.0) - this.setFeatureScaling(true) - private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() - // Have to return new LBFGS object every time since users can reset the parameters anytime. - override def optimizer = new LBFGS(gradient, updater) - .setNumCorrections(10) - .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) - .setRegParam(regParam) + override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(DataValidators.binaryLabelValidator) - /** - * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. - * Smaller value will lead to higher accuracy with the cost of more iterations. - */ - def setConvergenceTol(convergenceTol: Double): this.type = { - this.convergenceTol = convergenceTol - this - } - - /** - * Set the maximal number of iterations for L-BFGS. Default 100. - */ - def setNumIterations(numIterations: Int): this.type = { - this.maxNumIterations = numIterations - this - } - override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 033fe44f34f3c..d16d0daf08565 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the maximal number of iterations for L-BFGS. Default 100. + * @deprecated use [[LBFGS#setNumIterations]] instead */ + @deprecated("use setNumIterations instead", "1.1.0") def setMaxNumIterations(iters: Int): this.type = { + this.setNumIterations(iters) + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(iters: Int): this.type = { this.maxNumIterations = iters this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index bc05b2046878f..862178694a50e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -272,8 +272,9 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = - (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + lr.optimizer.setNumIterations(2) + val model = lr.run(points) val predictions = model.predict(points.map(_.features)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 5f4c24115ac80..ccba004baa007 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -55,7 +55,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (_, loss) = LBFGS.runLBFGS( dataRDD, @@ -63,7 +63,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { simpleUpdater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -99,7 +99,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, @@ -107,7 +107,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { /** * For the first run, we set the convergenceTol to 0.0, so that the algorithm will - * run up to the maxNumIterations which is 8 here. + * run up to the numIterations which is 8 here. */ val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0) - val maxNumIterations = 8 + val numIterations = 8 var convergenceTol = 0.0 val (_, lossLBFGS1) = LBFGS.runLBFGS( @@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) .setNumCorrections(numCorrections) .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) + .setNumIterations(numIterations) .setRegParam(regParam) val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) @@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater) .setNumCorrections(1) .setConvergenceTol(1e-12) - .setMaxNumIterations(1) + .setNumIterations(1) .setRegParam(1.0) val random = new Random(0) // If we serialize data directly in the task closure, the size of the serialized task would be From 2e069ca6560bf7ab07bd019f9530b42f4fe45014 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 15 Aug 2014 21:07:55 -0700 Subject: [PATCH 05/54] [SPARK-3001][MLLIB] Improve Spearman's correlation The current implementation requires sorting individual columns, which could be done with a global sort. result on a 32-node cluster: m | n | prev | this ---|---|-------|----- 1000000 | 50 | 55s | 9s 10000000 | 50 | 97s | 76s 1000000 | 100 | 119s | 15s Author: Xiangrui Meng Closes #1917 from mengxr/spearman and squashes the following commits: 4d5d262 [Xiangrui Meng] remove unused import 85c48de [Xiangrui Meng] minor updates a048d0c [Xiangrui Meng] remove cache and set a limit to cachedIds b98bb18 [Xiangrui Meng] add comments 0846e07 [Xiangrui Meng] first version --- .../correlation/SpearmanCorrelation.scala | 120 ++++++------------ 1 file changed, 42 insertions(+), 78 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 9bd0c2cd05de4..4a6c677f06d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.stat.correlation import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} -import org.apache.spark.rdd.{CoGroupedRDD, RDD} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} +import org.apache.spark.rdd.RDD /** * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix @@ -43,87 +43,51 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { /** * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the * correlation between column i and j. - * - * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into - * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. */ override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { - val indexed = X.zipWithUniqueId() - - val numCols = X.first.size - if (numCols > 50) { - logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" - + " than 50 columns.") - } - val ranks = new Array[RDD[(Long, Double)]](numCols) - - // Note: we use a for loop here instead of a while loop with a single index variable - // to avoid race condition caused by closure serialization - for (k <- 0 until numCols) { - val column = indexed.map { case (vector, index) => (vector(k), index) } - ranks(k) = getRanks(column) + // ((columnIndex, value), rowUid) + val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) => + vec.toArray.view.zipWithIndex.map { case (v, j) => + ((j, v), uid) + } } - - val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) - PearsonCorrelation.computeCorrelationMatrix(ranksMat) - } - - /** - * Compute the ranks for elements in the input RDD, using the average method for ties. - * - * With the average method, elements with the same value receive the same rank that's computed - * by taking the average of their positions in the sorted list. - * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] - * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for - * ranks in the standard definition for Spearman's correlation. This does not affect the final - * results and is slightly more performant. - * - * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) - * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is - * copied from the input RDD. - */ - private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { - // Get elements' positions in the sorted list for computing average rank for duplicate values - val sorted = indexed.sortByKey().zipWithIndex() - - val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => - // add an extra element to signify the end of the list so that flatMap can flush the last - // batch of duplicates - val end = -1L - val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) - val firstEntry = padded.next() - var lastVal = firstEntry._1._1 - var firstRank = firstEntry._2.toDouble - val idBuffer = ArrayBuffer(firstEntry._1._2) - padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != end) { - idBuffer += id - Iterator.empty - } else { - val entries = if (idBuffer.size == 1) { - Iterator((idBuffer(0), firstRank)) - } else { - val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 - idBuffer.map(id => (id, averageRank)) - } - lastVal = v - firstRank = rank - idBuffer.clear() - idBuffer += id - entries + // global sort by (columnIndex, value) + val sorted = colBased.sortByKey() + // assign global ranks (using average ranks for tied values) + val globalRanks = sorted.zipWithIndex().mapPartitions { iter => + var preCol = -1 + var preVal = Double.NaN + var startRank = -1.0 + var cachedUids = ArrayBuffer.empty[Long] + val flush: () => Iterable[(Long, (Int, Double))] = () => { + val averageRank = startRank + (cachedUids.size - 1) / 2.0 + val output = cachedUids.map { uid => + (uid, (preCol, averageRank)) } + cachedUids.clear() + output } + iter.flatMap { case (((j, v), uid), rank) => + // If we see a new value or cachedUids is too big, we flush ids with their average rank. + if (j != preCol || v != preVal || cachedUids.size >= 10000000) { + val output = flush() + preCol = j + preVal = v + startRank = rank + cachedUids += uid + output + } else { + cachedUids += uid + Iterator.empty + } + } ++ flush() } - ranks - } - - private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { - val partitioner = new HashPartitioner(input.partitions.size) - val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) - cogrouped.map { - case (_, values: Array[Iterable[_]]) => - val doubles = values.asInstanceOf[Array[Iterable[Double]]] - new DenseVector(doubles.flatten.toArray) + // Replace values in the input matrix by their ranks compared with values in the same column. + // Note that shifting all ranks in a column by a constant value doesn't affect result. + val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) => + // sort by column index and then convert values to a vector + Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray) } + PearsonCorrelation.computeCorrelationMatrix(groupedRanks) } } From c9da466edb83e45a159ccc17c68856a511b9e8b7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 15 Aug 2014 22:55:32 -0700 Subject: [PATCH 06/54] [SPARK-3015] Block on cleaning tasks to prevent Akka timeouts More detail on the issue is described in [SPARK-3015](https://issues.apache.org/jira/browse/SPARK-3015), but the TLDR is if we send too many blocking Akka messages that are dependent on each other in quick successions, then we end up causing a few of these messages to time out and ultimately kill the executors. As of #1498, we broadcast each RDD whether or not it is persisted. This means if we create many RDDs (each of which becomes a broadcast) and the driver performs a GC that cleans up all of these broadcast blocks, then we end up sending many `RemoveBroadcast` messages in parallel and trigger the chain of blocking messages at high frequencies. We do not know of the Akka-level root cause yet, so this is intended to be a temporary solution until we identify the real issue. I have done some preliminary testing of enabling blocking and observed that the queue length remains quite low (< 1000) even under very intensive workloads. In the long run, we should do something more sophisticated to allow a limited degree of parallelism through batching clean up tasks or processing them in a sliding window. In the longer run, we should clean up the whole `BlockManager*` message passing interface to avoid unnecessarily awaiting on futures created from Akka asks. tdas pwendell mengxr Author: Andrew Or Closes #1931 from andrewor14/reference-blocking and squashes the following commits: d0f7195 [Andrew Or] Merge branch 'master' of github.com:apache/spark into reference-blocking ce9daf5 [Andrew Or] Remove logic for logging queue length 111192a [Andrew Or] Add missing space in log message (minor) a183b83 [Andrew Or] Switch order of code blocks (minor) 9fd1fe6 [Andrew Or] Remove outdated log 104b366 [Andrew Or] Use the actual reference queue length 0b7e768 [Andrew Or] Block on cleaning tasks by default + log error on queue full --- .../main/scala/org/apache/spark/ContextCleaner.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..3848734d6f639 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -66,10 +66,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Whether the cleaning thread will block on cleanup tasks. - * This is set to true only for tests. + * + * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary + * workaround for the issue, which is ultimately caused by the way the BlockManager actors + * issue inter-dependent blocking Akka messages to each other at high frequencies. This happens, + * for instance, when the driver performs a GC and cleans up all broadcast blocks that are no + * longer in scope. */ private val blockOnCleanupTasks = sc.conf.getBoolean( - "spark.cleaner.referenceTracking.blocking", false) + "spark.cleaner.referenceTracking.blocking", true) @volatile private var stopped = false @@ -174,9 +179,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - - // Used for testing. These methods explicitly blocks until cleanup is completed - // to ensure that more reliable testing. } private object ContextCleaner { From a83c7723bf7a90dc6cd5dde98a179303b7542020 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 15 Aug 2014 23:12:34 -0700 Subject: [PATCH 07/54] [SPARK-3045] Make Serializer interface Java friendly Author: Reynold Xin Closes #1948 from rxin/kryo and squashes the following commits: a3a80d8 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer classloader 3d13277 [Reynold Xin] Reverted that in TestJavaSerializerImpl too. 196f3dc [Reynold Xin] Ok one more commit to revert the classloader change. c49b50c [Reynold Xin] Removed JavaSerializer change. afbf37d [Reynold Xin] Moved the test case also. a2e693e [Reynold Xin] Removed the Kryo bug fix from this pull request. c81bd6c [Reynold Xin] Use defaultClassLoader when executing user specified custom registrator. 68f261e [Reynold Xin] Added license check excludes. 0c28179 [Reynold Xin] [SPARK-3045] Make Serializer interface Java friendly [SPARK-3046] Set executor's class loader as the default serializer class loader --- .../spark/serializer/JavaSerializer.scala | 15 +-- .../spark/serializer/KryoSerializer.scala | 32 +++---- .../apache/spark/serializer/Serializer.scala | 25 ++--- .../apache/spark/serializer/package-info.java | 2 +- .../serializer/TestJavaSerializerImpl.java | 95 +++++++++++++++++++ .../KryoSerializerResizableOutputSuite.scala | 52 ++++++++++ .../serializer/KryoSerializerSuite.scala | 34 +------ project/MimaExcludes.scala | 11 +++ 8 files changed, 193 insertions(+), 73 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java create mode 100644 core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index af33a2f2ca3e1..554a33ce7f1a6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,10 +63,11 @@ extends DeserializationStream { def close() { objIn.close() } } + private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -74,23 +75,23 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] + in.readObject() } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] + in.readObject() } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s, counterReset) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 99682220b4ab5..87ef9bb0b43c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -91,7 +91,7 @@ class KryoSerializer(conf: SparkConf) Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) } catch { - case e: Exception => + case e: Exception => throw new SparkException(s"Failed to invoke $regCls", e) } finally { Thread.currentThread.setContextClassLoader(oldClassLoader) @@ -106,7 +106,7 @@ class KryoSerializer(conf: SparkConf) kryo } - def newInstance(): SerializerInstance = { + override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } } @@ -115,20 +115,20 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - def flush() { output.flush() } - def close() { output.close() } + override def flush() { output.flush() } + override def close() { output.close() } } private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) + private val input = new KryoInput(inStream) - def readObject[T: ClassTag](): T = { + override def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -138,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } } - def close() { + override def close() { // Kryo's Input automatically closes the input stream it is using. input.close() } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() + private val kryo = ks.newKryo() // Make these lazy vals to avoid creating a buffer unless we use them - lazy val output = ks.newKryoOutput() - lazy val input = new KryoInput() + private lazy val output = ks.newKryoOutput() + private lazy val input = new KryoInput() - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) @@ -171,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ obj } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(kryo, s) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index e674438c8176c..a9144cdd97b8c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi -trait Serializer { +abstract class Serializer { /** * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should @@ -61,10 +61,12 @@ trait Serializer { this } + /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance } +@DeveloperApi object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer @@ -81,7 +83,7 @@ object Serializer { * An instance of a serializer, for use by one thread at a time. */ @DeveloperApi -trait SerializerInstance { +abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer def deserialize[T: ClassTag](bytes: ByteBuffer): T @@ -91,21 +93,6 @@ trait SerializerInstance { def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new ByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.wrap(stream.toByteArray) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } } /** @@ -113,7 +100,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -trait SerializationStream { +abstract class SerializationStream { def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit @@ -132,7 +119,7 @@ trait SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -trait DeserializationStream { +abstract class DeserializationStream { def readObject[T: ClassTag](): T def close(): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/package-info.java b/core/src/main/scala/org/apache/spark/serializer/package-info.java index 4c0b73ab36a00..207c6e02e4293 100644 --- a/core/src/main/scala/org/apache/spark/serializer/package-info.java +++ b/core/src/main/scala/org/apache/spark/serializer/package-info.java @@ -18,4 +18,4 @@ /** * Pluggable serializers for RDD and shuffle data. */ -package org.apache.spark.serializer; \ No newline at end of file +package org.apache.spark.serializer; diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java new file mode 100644 index 0000000000000..3d50ab4fabe42 --- /dev/null +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.Option; +import scala.reflect.ClassTag; + + +/** + * A simple Serializer implementation to make sure the API is Java-friendly. + */ +class TestJavaSerializerImpl extends Serializer { + + @Override + public SerializerInstance newInstance() { + return null; + } + + static class SerializerInstanceImpl extends SerializerInstance { + @Override + public ByteBuffer serialize(T t, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag evidence$1) { + return null; + } + + @Override + public SerializationStream serializeStream(OutputStream s) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + } + + static class SerializationStreamImpl extends SerializationStream { + + @Override + public SerializationStream writeObject(T t, ClassTag evidence$1) { + return null; + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + } + + static class DeserializationStreamImpl extends DeserializationStream { + + @Override + public T readObject(ClassTag evidence$1) { + return null; + } + + @Override + public void close() { + + } + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala new file mode 100644 index 0000000000000..967c9e9899c9d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkException + + +class KryoSerializerResizableOutputSuite extends FunSuite { + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect()) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect() === x) + LocalSparkContext.stop(sc) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index a579fd50bd9e4..e1e35b688d581 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ + class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) @@ -207,7 +208,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } - + test("kryo with nonexistent custom registrator should fail") { import org.apache.spark.{SparkConf, SparkException} @@ -238,39 +239,12 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } } -class ClassLoaderTestingObject - -class KryoSerializerResizableOutputSuite extends FunSuite { - import org.apache.spark.SparkConf - import org.apache.spark.SparkContext - import org.apache.spark.LocalSparkContext - import org.apache.spark.SparkException - - // trial and error showed this will not serialize with 1mb buffer - val x = (1 to 400000).toArray - test("kryo without resizable output buffer should fail on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect) - LocalSparkContext.stop(sc) - } +class ClassLoaderTestingObject - test("kryo with resizable output buffer should succeed on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect === x) - LocalSparkContext.stop(sc) - } -} object KryoTest { + case class CaseClass(i: Int, s: String) {} class ClassWithNoArgConstructor { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1e3c760b845de..bbe68b29d2d8e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -61,6 +61,17 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.storage.MemoryStore.Entry") ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ Seq( // Renamed putValues -> putArray + putIterator ProblemFilters.exclude[MissingMethodProblem]( From 20fcf3d0b72f3707dc1ed95d453f570fabdefd16 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 16 Aug 2014 00:04:55 -0700 Subject: [PATCH 08/54] [SPARK-2977] Ensure ShuffleManager is created before ShuffleBlockManager This is intended to fix SPARK-2977. Before, there was an implicit ordering dependency where we needed to know the ShuffleManager implementation before creating the ShuffleBlockManager. This patch makes that dependency explicit by adding ShuffleManager to a bunch of constructors. I think it's a little odd for BlockManager to take a ShuffleManager only to pass it to ShuffleBlockManager without using it itself; there's an opportunity to clean this up later if we sever the circular dependencies between BlockManager and other components and pass those components to BlockManager's constructor. Author: Josh Rosen Closes #1976 from JoshRosen/SPARK-2977 and squashes the following commits: a9cd1e1 [Josh Rosen] [SPARK-2977] Ensure ShuffleManager is created before ShuffleBlockManager. --- .../scala/org/apache/spark/SparkEnv.scala | 22 +++++++++---------- .../apache/spark/storage/BlockManager.scala | 11 ++++++---- .../spark/storage/ShuffleBlockManager.scala | 7 +++--- .../apache/spark/storage/ThreadingTest.scala | 3 ++- .../spark/storage/BlockManagerSuite.scala | 12 +++++----- .../spark/storage/DiskBlockManagerSuite.scala | 8 +++++-- 6 files changed, 37 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 22d8d1cb1ddcf..fc36e37c53f5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -210,12 +210,22 @@ object SparkEnv extends Logging { "MapOutputTracker", new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) + + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker) + serializer, conf, securityManager, mapOutputTracker, shuffleManager) val connectionManager = blockManager.connectionManager @@ -250,16 +260,6 @@ object SparkEnv extends Logging { "." } - // Let the user specify short names for shuffle managers - val shortShuffleMgrNames = Map( - "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") - val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) - val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - - val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e8bbd298c631a..e4c3d58905e7f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,6 +33,7 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -57,11 +58,12 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) extends Logging { private val port = conf.getInt("spark.blockManager.port", 0) - val shuffleBlockManager = new ShuffleBlockManager(this) + val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) val connectionManager = @@ -142,9 +144,10 @@ private[spark] class BlockManager( serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) = { + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker) + conf, securityManager, mapOutputTracker, shuffleManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 3565719b54545..b8f5d3a5b02aa 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} @@ -62,7 +63,8 @@ private[spark] trait ShuffleWriterGroup { */ // TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] -class ShuffleBlockManager(blockManager: BlockManager) extends Logging { +class ShuffleBlockManager(blockManager: BlockManager, + shuffleManager: ShuffleManager) extends Logging { def conf = blockManager.conf // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -71,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { conf.getBoolean("spark.shuffle.consolidateFiles", false) // Are we using sort-based shuffle? - val sortBasedShuffle = - conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager] private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 75c2e09a6bbb8..aa83ea90ee9ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.ArrayBlockingQueue import akka.actor._ +import org.apache.spark.shuffle.hash.HashShuffleManager import util.Random import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -101,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 94bb2c445d2e9..20bac66105a69 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit import akka.actor._ import akka.pattern.ask import akka.util.Timeout +import org.apache.spark.shuffle.hash.HashShuffleManager import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any @@ -61,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -71,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager( - name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, + mapOutputTracker, shuffleManager) } before { @@ -791,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) // The put should fail since a1 is not serializable. class UnserializableClass @@ -1007,7 +1009,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) @@ -1054,7 +1056,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index b8299e2ea187f..777579bc570db 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.shuffle.hash.HashShuffleManager + import scala.collection.mutable import scala.language.reflectiveCalls @@ -42,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // so we coerce consolidation if not already enabled. testConf.set("spark.shuffle.consolidateFiles", "true") - val shuffleBlockManager = new ShuffleBlockManager(null) { + private val shuffleManager = new HashShuffleManager(testConf.clone) + + val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) { override def conf = testConf.clone var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -148,7 +152,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), confCopy) val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null) + securityManager, null, shuffleManager) try { From b4a05928e95c0f6973fd21e60ff9c108f226e38c Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 16 Aug 2014 11:26:51 -0700 Subject: [PATCH 09/54] [SQL] Using safe floating-point numbers in doctest Test code in `sql.py` tries to compare two floating-point numbers directly, and cased [build failure(s)](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/18365/consoleFull). [Doctest documentation](https://docs.python.org/3/library/doctest.html#warnings) recommends using numbers in the form of `I/2**J` to avoid the precision issue. Author: Cheng Lian Closes #1925 from liancheng/fix-pysql-fp-test and squashes the following commits: 0fbf584 [Cheng Lian] Removed unnecessary `...' from inferSchema doctest e8059d4 [Cheng Lian] Using safe floating-point numbers in doctest --- python/pyspark/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 95086a2258222..d4ca0cc8f336e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1093,8 +1093,8 @@ def applySchema(self, rdd, schema): >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), From 4bdfaa16fce399bd97c98858151246b3b02f350f Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sat, 16 Aug 2014 12:35:59 -0700 Subject: [PATCH 10/54] [SPARK-3076] [Jenkins] catch & report test timeouts * Remove unused code to get jq * Set timeout on tests and report gracefully on them Author: Nicholas Chammas Closes #1974 from nchammas/master and squashes the following commits: d1f1b6b [Nicholas Chammas] set timeout to realistic number 8b1ea41 [Nicholas Chammas] fix formatting 279526e [Nicholas Chammas] [SPARK-3076] catch & report test timeouts --- dev/run-tests-jenkins | 48 ++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 721f09be5be6d..31506e28e05af 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,27 +26,17 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" -function get_jq () { - # Get jq so we can parse some JSON, man. - # Essential if we want to do anything with the GitHub API responses. - local JQ_EXECUTABLE_URL="http://stedolan.github.io/jq/download/linux64/jq" - - echo "Fetching jq from ${JQ_EXECUTABLE_URL}" - - curl --silent --output "$FWDIR/dev/jq" "$JQ_EXECUTABLE_URL" - local curl_status=$? - - if [ $curl_status -ne 0 ]; then - echo "Failed to get jq." >&2 - return $curl_status - fi - - chmod u+x "$FWDIR/dev/jq" -} - COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" +COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" +# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( +SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" + +# NOTE: Jenkins will kill the whole build after 120 minutes. +# Tests are a large part of that, but not all of it. +TESTS_TIMEOUT="120m" + function post_message () { local message=$1 local data="{\"body\": \"$message\"}" @@ -96,10 +86,6 @@ function post_message () { fi } -COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" -# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( -short_commit_hash=${ghprbActualCommit:0:7} - # check PR merge-ability and check for new public classes { if [ "$sha1" == "$ghprbActualCommit" ]; then @@ -138,7 +124,7 @@ short_commit_hash=${ghprbActualCommit:0:7} { start_message="\ [QA tests have started](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${short_commit_hash}\`](${COMMIT_URL})." + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." start_message="${start_message}\n${merge_note}" # start_message="${start_message}\n${public_classes_note}" @@ -148,13 +134,19 @@ short_commit_hash=${ghprbActualCommit:0:7} # run tests { - ./dev/run-tests + timeout "${TESTS_TIMEOUT}" ./dev/run-tests test_result="$?" - if [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes** unit tests." + if [ "$test_result" -eq "124" ]; then + fail_message="**Tests timed out** after a configured wait of \`${TESTS_TIMEOUT}\`." + post_message "$fail_message" + exit $test_result else - test_result_note=" * This patch **fails** unit tests." + if [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes** unit tests." + else + test_result_note=" * This patch **fails** unit tests." + fi fi } @@ -162,7 +154,7 @@ short_commit_hash=${ghprbActualCommit:0:7} { result_message="\ [QA tests have finished](${BUILD_URL}consoleFull) for \ - PR $ghprbPullId at commit [\`${short_commit_hash}\`](${COMMIT_URL})." + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" result_message="${result_message}\n${merge_note}" From 76fa0eaf515fd6771cdd69422b1259485debcae5 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 16 Aug 2014 14:15:58 -0700 Subject: [PATCH 11/54] [SPARK-2677] BasicBlockFetchIterator#next can wait forever Author: Kousuke Saruta Closes #1632 from sarutak/SPARK-2677 and squashes the following commits: cddbc7b [Kousuke Saruta] Removed Exception throwing when ConnectionManager#handleMessage receives ack for non-referenced message d3bd2a8 [Kousuke Saruta] Modified configuration.md for spark.core.connection.ack.timeout e85f88b [Kousuke Saruta] Removed useless synchronized blocks 7ed48be [Kousuke Saruta] Modified ConnectionManager to use ackTimeoutMonitor ConnectionManager-wide 9b620a6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 0dd9ad3 [Kousuke Saruta] Modified typo in ConnectionManagerSuite.scala 7cbb8ca [Kousuke Saruta] Modified to match with scalastyle 8a73974 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 ade279a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 0174d6a [Kousuke Saruta] Modified ConnectionManager.scala to handle the case remote Executor cannot ack a454239 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2677 9b7b7c1 [Kousuke Saruta] (WIP) Modifying ConnectionManager.scala --- .../spark/network/ConnectionManager.scala | 45 ++++++++++++++----- .../network/ConnectionManagerSuite.scala | 44 +++++++++++++++++- docs/configuration.md | 9 ++++ 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 95f96b8463a01..37d69a9ec4ce4 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -22,6 +22,7 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ +import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} @@ -61,17 +62,17 @@ private[spark] class ConnectionManager( var ackMessage: Option[Message] = None def markDone(ackMessage: Option[Message]) { - this.synchronized { - this.ackMessage = ackMessage - completionHandler(this) - } + this.ackMessage = ackMessage + completionHandler(this) } } private val selector = SelectorProvider.provider.openSelector() + private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -652,19 +653,27 @@ private[spark] class ConnectionManager( } } if (bufferMessage.hasAckId()) { - val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status + status.markDone(Some(message)) } case None => { - throw new Exception("Could not find reference for received ack message " + - message.id) + /** + * We can fall down on this code because of following 2 cases + * + * (1) Invalid ack sent due to buggy code. + * + * (2) Late-arriving ack for a SendMessageStatus + * To avoid unwilling late-arriving ack + * caused by long pause like GC, you can set + * larger value than default to spark.core.connection.ack.wait.timeout + */ + logWarning(s"Could not find reference for received ack Message ${message.id}") } } } - sentMessageStatus.markDone(Some(message)) } else { var ackMessage : Option[Message] = None try { @@ -836,9 +845,23 @@ private[spark] class ConnectionManager( def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) : Future[Message] = { val promise = Promise[Message]() + + val timeoutTask = new TimerTask { + override def run(): Unit = { + messageStatuses.synchronized { + messageStatuses.remove(message.id).foreach ( s => { + promise.failure( + new IOException(s"sendMessageReliably failed because ack " + + "was not received within ${ackTimeout} sec")) + }) + } + } + } + val status = new MessageStatus(message, connectionManagerId, s => { + timeoutTask.cancel() s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd + case None => // Indicates a failure where we either never sent or never got ACK'd promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) case Some(ackMessage) => if (ackMessage.hasError) { @@ -852,6 +875,8 @@ private[spark] class ConnectionManager( messageStatuses.synchronized { messageStatuses += ((message.id, status)) } + + ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 846537df003df..e2f4d4c57cdb5 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -19,14 +19,19 @@ package org.apache.spark.network import java.io.IOException import java.nio._ +import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import scala.concurrent.TimeoutException import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.Try +import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. @@ -255,5 +260,42 @@ class ConnectionManagerSuite extends FunSuite { } + test("sendMessageReliably timeout") { + val clientConf = new SparkConf + clientConf.set("spark.authenticate", "false") + val ackTimeout = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + + val clientSecurityManager = new SecurityManager(clientConf) + val manager = new ConnectionManager(0, clientConf, clientSecurityManager) + + val serverConf = new SparkConf + serverConf.set("spark.authenticate", "false") + val serverSecurityManager = new SecurityManager(serverConf) + val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // sleep 60 sec > ack timeout for simulating server slow down or hang up + Thread.sleep(ackTimeout * 3 * 1000) + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + // Future should throw IOException in 30 sec. + // Otherwise TimeoutExcepton is thrown from Await.result. + // We expect TimeoutException is not thrown. + intercept[IOException] { + Await.result(future, (ackTimeout * 2) second) + } + + manager.stop() + managerServer.stop() + } + } diff --git a/docs/configuration.md b/docs/configuration.md index c408c468dcd94..981170d8b49b7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -884,6 +884,15 @@ Apart from these, the following properties are also available, and may be useful out and giving up. + + spark.core.connection.ack.wait.timeout + 60 + + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + spark.ui.filters None From 7e70708a99949549adde00cb6246a9582bbc4929 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 15:13:34 -0700 Subject: [PATCH 12/54] [SPARK-3048][MLLIB] add LabeledPoint.parse and remove loadStreamingLabeledPoints Move `parse()` from `LabeledPointParser` to `LabeledPoint` and make it public. This breaks binary compatibility only when a user uses synthesized methods like `tupled` and `curried`, which is rare. `LabeledPoint.parse` is more consistent with `Vectors.parse`, which is why `LabeledPointParser` is not preferred. freeman-lab tdas Author: Xiangrui Meng Closes #1952 from mengxr/labelparser and squashes the following commits: c818fb2 [Xiangrui Meng] merge master ce20e6f [Xiangrui Meng] update mima excludes b386b8d [Xiangrui Meng] fix tests 2436b3d [Xiangrui Meng] add parse() to LabeledPoint --- .../mllib/StreamingLinearRegression.scala | 7 +++---- .../spark/mllib/regression/LabeledPoint.scala | 2 +- .../StreamingLinearRegressionWithSGD.scala | 2 +- .../org/apache/spark/mllib/util/MLUtils.scala | 17 ++--------------- .../mllib/regression/LabeledPointSuite.scala | 4 ++-- .../StreamingLinearRegressionSuite.scala | 6 +++--- project/MimaExcludes.scala | 5 +++++ 7 files changed, 17 insertions(+), 26 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1fd37edfa7427..0e992fa9967bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,8 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -56,8 +55,8 @@ object StreamingLinearRegression { val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) - val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 62a03af4a9964..17c753c56681f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -36,7 +36,7 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ -private[mllib] object LabeledPointParser { +object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 8851097050318..1d11fde24712c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector /** * Train or predict a linear regression model on streaming data. Training uses diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index f4cce86a65ba7..ca35100aa99c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliSampler -import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext @@ -185,7 +185,7 @@ object MLUtils { * @return labeled points stored as an RDD[LabeledPoint] */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = - sc.textFile(path, minPartitions).map(LabeledPointParser.parse) + sc.textFile(path, minPartitions).map(LabeledPoint.parse) /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of @@ -194,19 +194,6 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) - /** - * Loads streaming labeled points from a stream of text files - * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. - * See `StreamingContext.textFileStream` for more details on how to - * generate a stream from files - * - * @param ssc Streaming context - * @param dir Directory path in any Hadoop-supported file system URI - * @return Labeled points stored as a DStream[LabeledPoint] - */ - def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = - ssc.textFileStream(dir).map(LabeledPointParser.parse) - /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d9308aaba6ee1..110c44a7193fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -28,12 +28,12 @@ class LabeledPointSuite extends FunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) points.foreach { p => - assert(p === LabeledPointParser.parse(p.toString)) + assert(p === LabeledPoint.parse(p.toString)) } } test("parse labeled points with v0.9 format") { - val point = LabeledPointParser.parse("1.0,1.0 0.0 -2.0") + val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index ed21f84472c9a..45e25eecf508e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils @@ -55,7 +55,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val numBatches = 10 val batchDuration = Milliseconds(1000) val ssc = new StreamingContext(sc, batchDuration) - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) @@ -97,7 +97,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val batchDuration = Milliseconds(2000) val ssc = new StreamingContext(sc, batchDuration) val numBatches = 5 - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bbe68b29d2d8e..300589394b96f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -129,6 +129,11 @@ object MimaExcludes { Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ Seq ( // Scala 2.11 compatibility fix ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) From ac6411c6e75906997c78de23dfdbc8d225b87cfd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 15:14:43 -0700 Subject: [PATCH 13/54] [SPARK-3081][MLLIB] rename RandomRDDGenerators to RandomRDDs `RandomRDDGenerators` means factory for `RandomRDDGenerator`. However, its methods return RDDs but not RDDGenerators. So a more proper (and shorter) name would be `RandomRDDs`. dorx brkyvz Author: Xiangrui Meng Closes #1979 from mengxr/randomrdds and squashes the following commits: b161a2d [Xiangrui Meng] rename RandomRDDGenerators to RandomRDDs --- .../mllib/api/python/PythonMLLibAPI.scala | 2 +- ...omRDDGenerators.scala => RandomRDDs.scala} | 6 ++--- ...atorsSuite.scala => RandomRDDsSuite.scala} | 16 ++++++------ python/pyspark/mllib/random.py | 25 +++++++++---------- 4 files changed, 24 insertions(+), 25 deletions(-) rename mllib/src/main/scala/org/apache/spark/mllib/random/{RandomRDDGenerators.scala => RandomRDDs.scala} (99%) rename mllib/src/test/scala/org/apache/spark/mllib/random/{RandomRDDGeneratorsSuite.scala => RandomRDDsSuite.scala} (88%) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 18dc087856785..4343124f102a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} +import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala rename to mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index b0a0593223910..36270369526cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.random +import scala.reflect.ClassTag + import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector @@ -24,14 +26,12 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental -object RandomRDDGenerators { +object RandomRDDs { /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 96e0bc63b0fa4..c50b78bcbcc61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, @@ -113,18 +113,18 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed) + val uniform = RandomRDDs.uniformRDD(sc, size, numPartitions, seed) testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed) + val normal = RandomRDDs.normalRDD(sc, size, numPartitions, seed) testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed) + val poisson = RandomRDDs.poissonRDD(sc, poissonMean, size, numPartitions, seed) testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1) } // mock distribution to check that partitions have unique seeds - val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) + val random = RandomRDDs.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) assert(random.collect.size === random.collect.distinct.size) } @@ -135,13 +135,13 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed) + val uniform = RandomRDDs.uniformVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed) + val normal = RandomRDDs.normalVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) + val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) } } diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index eb496688b6eef..3f3b19053d32e 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,8 +25,7 @@ from pyspark.serializers import NoOpSerializer -class RandomRDDGenerators: - +class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -40,17 +39,17 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} - >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 >>> max(x) <= 1.0 and min(x) >= 0.0 True - >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() 4 - >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() >>> parts == sc.defaultParallelism True """ @@ -66,10 +65,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from standard normal to some other normal N(mean, sigma), use - C{RandomRDDGenerators.normal(sc, n, p, seed)\ + C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} - >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -89,7 +88,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): distribution with the input mean. >>> mean = 100.0 - >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -110,12 +109,12 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the uniform distribution on [0.0 1.0]. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) >>> mat.max() <= 1.0 and mat.min() >= 0.0 True - >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ jrdd = sc._jvm.PythonMLLibAPI() \ @@ -130,7 +129,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the standard normal distribution. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -151,7 +150,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) From 379e7585c356f20bf8b4878ecba9401e2195da12 Mon Sep 17 00:00:00 2001 From: iAmGhost Date: Sat, 16 Aug 2014 16:48:38 -0700 Subject: [PATCH 14/54] [SPARK-3035] Wrong example with SparkContext.addFile https://issues.apache.org/jira/browse/SPARK-3035 fix for wrong document. Author: iAmGhost Closes #1942 from iAmGhost/master and squashes the following commits: 487528a [iAmGhost] [SPARK-3035] Wrong example with SparkContext.addFile fix for wrong document. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4001ecab5ea00..6c049238819a7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -613,7 +613,7 @@ def addFile(self, path): >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: ... fileVal = int(testFile.readline()) - ... return [x * 100 for x in iterator] + ... return [x * fileVal for x in iterator] >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() [100, 200, 300, 400] """ From 2fc8aca086a2679b854038b7e2c488f19039ecbd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 16 Aug 2014 16:59:34 -0700 Subject: [PATCH 15/54] [SPARK-1065] [PySpark] improve supporting for large broadcast Passing large object by py4j is very slow (cost much memory), so pass broadcast objects via files (similar to parallelize()). Add an option to keep object in driver (it's False by default) to save memory in driver. Author: Davies Liu Closes #1912 from davies/broadcast and squashes the following commits: e06df4a [Davies Liu] load broadcast from disk in driver automatically db3f232 [Davies Liu] fix serialization of accumulator 631a827 [Davies Liu] Merge branch 'master' into broadcast c7baa8c [Davies Liu] compress serrialized broadcast and command 9a7161f [Davies Liu] fix doc tests e93cf4b [Davies Liu] address comments: add test 6226189 [Davies Liu] improve large broadcast --- .../apache/spark/api/python/PythonRDD.scala | 8 ++++ python/pyspark/broadcast.py | 37 ++++++++++++++----- python/pyspark/context.py | 20 ++++++---- python/pyspark/rdd.py | 5 ++- python/pyspark/serializers.py | 17 +++++++++ python/pyspark/tests.py | 7 ++++ python/pyspark/worker.py | 8 ++-- 7 files changed, 81 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 9f5c5bd30f0c9..10210a2927dcc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f3e64989ed564..675a2fcd2ff4e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,18 +21,16 @@ >>> b = sc.broadcast([1, 2, 3, 4, 5]) >>> b.value [1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b.unpersist() >>> large_broadcast = sc.broadcast(list(range(10000))) """ +import os + +from pyspark.serializers import CompressedSerializer, PickleSerializer + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -52,17 +50,38 @@ class Broadcast(object): Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, + pickle_registry=None, path=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.value = value self.bid = bid + if path is None: + self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry + self.path = path + + def unpersist(self, blocking=False): + self._jbroadcast.unpersist(blocking) + os.unlink(self.path) def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + + def __getattr__(self, item): + if item == 'value' and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + value = ser.load_stream(open(self.path)).next() + self.value = value + return value + + raise AttributeError(item) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6c049238819a7..a90870ed3a353 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -566,13 +566,19 @@ def broadcast(self, value): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. + object for reading it in distributed functions. The variable will + be sent to each cluster only once. + + :keep: Keep the `value` in driver or not. """ - pickleSer = PickleSerializer() - pickled = pickleSer.dumps(value) - jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + ser = CompressedSerializer(PickleSerializer()) + # pass large object by py4j is very slow and need much memory + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + ser.dump_stream([value], tempFile) + tempFile.close() + jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) + return Broadcast(jbroadcast.id(), None, jbroadcast, + self._pickled_broadcast_vars, tempFile.name) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3934bdda0a466..240381e5bae12 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -36,7 +36,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long + PickleSerializer, pack_long, CompressedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -1810,7 +1810,8 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CompressedSerializer(CloudPickleSerializer()) + pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df90cafb245bf..74870c0edcf99 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,7 @@ import sys import types import collections +import zlib from pyspark import cloudpickle @@ -403,6 +404,22 @@ def loads(self, obj): raise ValueError("invalid sevialization type: %s" % _type) +class CompressedSerializer(FramedSerializer): + """ + compress the serialized data + """ + + def __init__(self, serializer): + FramedSerializer.__init__(self) + self.serializer = serializer + + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) + + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) + + class UTF8Deserializer(Serializer): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 22b51110ed671..f1fece998cd54 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self): theDoes = self.sc.parallelize([jon, jane]) self.assertEquals([jon, jane], theDoes.collect()) + def test_large_broadcast(self): + N = 100000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 270MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEquals(N, m) + class TestIO(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2770f63059853..77a9c4a0e0677 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,7 +30,8 @@ from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + CompressedSerializer pickleSer = PickleSerializer() @@ -65,12 +66,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) + ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = pickleSer._read_with_length(infile) + value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = pickleSer._read_with_length(infile) + command = ser._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) From bc95fe08dff62a0abea314ab4ab9275c8f119598 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 16 Aug 2014 20:05:55 -0700 Subject: [PATCH 16/54] In the stop method of ConnectionManager to cancel the ackTimeoutMonitor cc JoshRosen sarutak Author: GuoQiang Li Closes #1989 from witgo/cancel_ackTimeoutMonitor and squashes the following commits: 4a700fa [GuoQiang Li] In the stop method of ConnectionManager to cancel the ackTimeoutMonitor --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 37d69a9ec4ce4..e77d762bdf221 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -886,6 +886,7 @@ private[spark] class ConnectionManager( } def stop() { + ackTimeoutMonitor.cancel() selectorThread.interrupt() selectorThread.join() selector.close() From fbad72288d8b6e641b00417a544cae6e8bfef2d7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 16 Aug 2014 21:16:27 -0700 Subject: [PATCH 17/54] [SPARK-3077][MLLIB] fix some chisq-test - promote nullHypothesis field in ChiSqTestResult to TestResult. Every test should have a null hypothesis - correct null hypothesis statement for independence test - p-value: 0.01 -> 0.1 Author: Xiangrui Meng Closes #1982 from mengxr/fix-chisq and squashes the following commits: 5f0de02 [Xiangrui Meng] make ChiSqTestResult constructor package private bc74ea1 [Xiangrui Meng] update chisq-test --- .../spark/mllib/stat/test/ChiSqTest.scala | 2 +- .../spark/mllib/stat/test/TestResult.scala | 28 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 8f6752737402e..215de95db5113 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -56,7 +56,7 @@ private[stat] object ChiSqTest extends Logging { object NullHypothesis extends Enumeration { type NullHypothesis = Value val goodnessOfFit = Value("observed follows the same distribution as expected.") - val independence = Value("observations in each column are statistically independent.") + val independence = Value("the occurrence of the outcomes is statistically independent.") } // Method identification based on input methodName string diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 2f278621335e1..4784f9e947908 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -44,6 +44,11 @@ trait TestResult[DF] { */ def statistic: Double + /** + * Null hypothesis of the test. + */ + def nullHypothesis: String + /** * String explaining the hypothesis test result. * Specific classes implementing this trait should override this method to output test-specific @@ -53,13 +58,13 @@ trait TestResult[DF] { // String explaining what the p-value indicates. val pValueExplain = if (pValue <= 0.01) { - "Very strong presumption against null hypothesis." + s"Very strong presumption against null hypothesis: $nullHypothesis." } else if (0.01 < pValue && pValue <= 0.05) { - "Strong presumption against null hypothesis." - } else if (0.05 < pValue && pValue <= 0.01) { - "Low presumption against null hypothesis." + s"Strong presumption against null hypothesis: $nullHypothesis." + } else if (0.05 < pValue && pValue <= 0.1) { + s"Low presumption against null hypothesis: $nullHypothesis." } else { - "No presumption against null hypothesis." + s"No presumption against null hypothesis: $nullHypothesis." } s"degrees of freedom = ${degreesOfFreedom.toString} \n" + @@ -70,19 +75,18 @@ trait TestResult[DF] { /** * :: Experimental :: - * Object containing the test results for the chi squared hypothesis test. + * Object containing the test results for the chi-squared hypothesis test. */ @Experimental -class ChiSqTestResult(override val pValue: Double, +class ChiSqTestResult private[stat] (override val pValue: Double, override val degreesOfFreedom: Int, override val statistic: Double, val method: String, - val nullHypothesis: String) extends TestResult[Int] { + override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { - "Chi squared test summary: \n" + - s"method: $method \n" + - s"null hypothesis: $nullHypothesis \n" + - super.toString + "Chi squared test summary:\n" + + s"method: $method\n" + + super.toString } } From 73ab7f141c205df277c6ac19252e590d6806c41f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 16 Aug 2014 23:53:14 -0700 Subject: [PATCH 18/54] [SPARK-3042] [mllib] DecisionTree Filter top-down instead of bottom-up DecisionTree needs to match each example to a node at each iteration. It currently does this with a set of filters very inefficiently: For each example, it examines each node at the current level and traces up to the root to see if that example should be handled by that node. Fix: Filter top-down using the partly built tree itself. Major changes: * Eliminated Filter class, findBinsForLevel() method. * Set up node parent links in main loop over levels in train(). * Added predictNodeIndex() for filtering top-down. * Added DTMetadata class Other changes: * Pre-compute set of unorderedFeatures. Notes for following expected PR based on [https://issues.apache.org/jira/browse/SPARK-3043]: * The unorderedFeatures set will next be stored in a metadata structure to simplify function calls (to store other items such as the data in strategy). I've done initial tests indicating that this speeds things up, but am only now running large-scale ones. CC: mengxr manishamde chouqin Any comments are welcome---thanks! Author: Joseph K. Bradley Closes #1975 from jkbradley/dt-opt2 and squashes the following commits: a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals --- .../spark/mllib/tree/DecisionTree.scala | 878 ++++++++---------- .../tree/impl/DecisionTreeMetadata.scala | 101 ++ .../spark/mllib/tree/impl/TreePoint.scala | 30 +- .../apache/spark/mllib/tree/model/Bin.scala | 18 +- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../spark/mllib/tree/model/Filter.scala | 28 - .../apache/spark/mllib/tree/model/Node.scala | 16 +- .../apache/spark/mllib/tree/model/Split.scala | 5 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 167 ++-- 9 files changed, 615 insertions(+), 630 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2a3107a13e916..6b9a8f72c244e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -62,43 +62,38 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) val numBins = bins(0).length timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) + val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 - // Initialize an array to hold filters applied to points for each node. - val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list. - filters(0) = List() + val maxNumNodes = (2 << maxDepth) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = treeInput.take(1)(0).binnedFeatures.size // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -114,9 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - * the sample is only used for the split calculation at the node if the sampled would have - * still survived the filters of the parent nodes. + * Each data sample is handled by a particular node at that level (or it reaches a leaf + * beforehand and is not used in later levels. */ var level = 0 @@ -130,22 +124,37 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") + val levelNodeIndexOffset = (1 << level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + val nodeIndex = levelNodeIndexOffset + index + val isLeftChild = level != 0 && nodeIndex % 2 == 1 + val parentNodeIndex = if (isLeftChild) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") - // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) timer.stop("extractNodeInfo") - timer.start("extractInfoForLowerLevels") + if (level != 0) { + // Set parent. + if (isLeftChild) { + nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) + } else { + nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) + } + } // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) + timer.start("extractInfoForLowerLevels") + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } - require(math.pow(2, level) == splitsStatsForLevel.length) + require((1 << level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -183,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = math.pow(2, level).toInt - 1 + index + val nodeIndex = (1 << level) - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -198,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double], - filters: Array[List[Filter]]): Unit = { - // 0 corresponds to the left child node and 1 corresponds to the right child node. - var i = 0 - while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth) { - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } - } - i += 1 + parentImpurities: Array[Double]): Unit = { + + if (level >= maxDepth) { + return } + + val leftNodeIndex = (2 << level) - 1 + 2 * index + val leftImpurity = nodeSplitStats._2.leftImpurity + logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) + parentImpurities(leftNodeIndex) = leftImpurity + + val rightNodeIndex = leftNodeIndex + 1 + val rightImpurity = nodeSplitStats._2.rightImpurity + logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) + parentImpurities(rightNodeIndex) = rightImpurity } } @@ -434,10 +433,8 @@ object DecisionTree extends Serializable with Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. @@ -446,9 +443,9 @@ object DecisionTree extends Serializable with Logging { protected[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, @@ -459,34 +456,32 @@ object DecisionTree extends Serializable with Logging { // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt + val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, timer, numGroups, groupIndex) + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, + nodes, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) + findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features - * @param bins possible bins for all features + * @param bins possible bins for all features, indexed as (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. @@ -494,9 +489,9 @@ object DecisionTree extends Serializable with Logging { private def findBestSplitsPerGroup( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], timer: TimeTracker, @@ -515,7 +510,7 @@ object DecisionTree extends Serializable with Logging { * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, * is ordered (read ordering for categorical variables in the findSplitsBins method), * we exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. @@ -531,160 +526,124 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = math.pow(2, level).toInt / numGroups + val numNodes = (1 << level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().binnedFeatures.size + val numFeatures = metadata.numFeatures logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification + val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ - def findParentFilters(nodeIndex: Int): List[Filter] = { - if (level == 0) { - List[Filter]() - } else { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - filters(nodeFilterIndex) - } - } - /** - * Find whether the sample is valid input for the current node, i.e., whether it passes through - * all the filters for the current node. + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. */ - def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { - // leaf - if ((level > 0) && (parentFilters.length == 0)) { - return false - } - - // Apply each filter and check sample validity. Return false when invalid condition found. - parentFilters.foreach { filter => - val featureIndex = filter.split.feature - val comparison = filter.comparison - val isFeatureContinuous = filter.split.featureType == Continuous - if (isFeatureContinuous) { - val binId = treePoint.binnedFeatures(featureIndex) - val bin = bins(featureIndex)(binId) - val featureValue = bin.highSplit.threshold - val threshold = filter.split.threshold - comparison match { - case -1 => if (featureValue > threshold) return false - case 1 => if (featureValue <= threshold) return false + def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold } - } else { - val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - val featureValue = if (isUnorderedFeature) { - treePoint.binnedFeatures(featureIndex) + case Categorical => { + val featureValue = if (metadata.isUnordered(featureIndex)) { + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + node.id * 2 + 1 // left } else { - val binId = treePoint.binnedFeatures(featureIndex) - bins(featureIndex)(binId).category + node.id * 2 + 2 // right } - val containsFeature = filter.split.categories.contains(featureValue) - comparison match { - case -1 => if (!containsFeature) return false - case 1 => if (containsFeature) return false + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures) } } } + } - // Return true when the sample is valid for all filters. - true + def nodeIndexToLevel(idx: Int): Int = { + if (idx == 0) { + 0 + } else { + math.floor(math.log(idx) / math.log(2)).toInt + } } + // Used for treePointToNodeIndex + val levelOffset = (1 << level) - 1 + /** - * Finds bins for all nodes (and all features) at a given level. - * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and the categorical feature value in multiclass classification. - * Invalid sample is denoted by noting bin for feature 1 as -1. - * - * For unordered features, the "bin index" returned is actually the feature value (category). - * - * @return Array of size 1 + numFeatures * numNodes, where - * arr(0) = label for labeledPoint, and - * arr(1 + numFeatures * nodeIndex + featureIndex) = - * bin index for this labeledPoint - * (or InvalidBinIndex if labeledPoint is not handled by this node) + * Find the node index for the given example. + * Nodes are indexed from 0 at the start of this (level, group). + * If the example does not reach this level, returns a value < 0. */ - def findBinsForLevel(treePoint: TreePoint): Array[Double] = { - // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) - // First element of the array is the label of the instance. - arr(0) = treePoint.label - // Iterate over nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, treePoint) - val shift = 1 + numFeatures * nodeIndex - if (!sampleValid) { - // Mark one bin as -1 is sufficient. - arr(shift) = InvalidBinIndex - } else { - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def treePointToNodeIndex(treePoint: TreePoint): Int = { + if (level == 0) { + 0 + } else { + val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) + // Get index for this (level, group). + globalNodeIndex - levelOffset - groupShift } - arr } - // Find feature bins for all nodes at a level. - timer.start("aggregation") - val binMappedRDD = input.map(x => findBinsForLevel(x)) - /** * Increment aggregate in location for (node, feature, bin, label). * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def updateBinForOrderedFeature( - arr: Array[Double], + treePoint: TreePoint, agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int): Unit = { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * arr(arrIndex).toInt + - label.toInt + numClasses * treePoint.binnedFeatures(featureIndex) + + treePoint.label.toInt agg(aggIndex) += 1 } @@ -693,8 +652,8 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * - * @param arr arr(0) = label. - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. @@ -703,21 +662,18 @@ object DecisionTree extends Serializable with Logging { def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, - arr: Array[Double], - label: Double, + treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { - // Find the bin index for this feature. - val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - val featureValue = arr(arrIndex).toInt + val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - label.toInt + treePoint.label.toInt // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + val numCategoricalBins = (1 << featureCategories - 1) - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val aggIndex = aggShift + binIndex * numClasses @@ -733,30 +689,21 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def binaryOrNotCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + featureIndex += 1 } } @@ -765,49 +712,28 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). - * For ordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. - * For unordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). * @param agg Array storing aggregate calculation. * For ordered features, this is of size: * numClasses * numBins * numFeatures * numNodes. * For unordered features, this is of size: * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - featureIndex += 1 - } + def multiclassWithCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + } else { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) } - nodeIndex += 1 + featureIndex += 1 } } @@ -818,36 +744,25 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size: 3 * numBins * numFeatures * numNodes - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label * label - featureIndex += 1 - } - } - nodeIndex += 1 + def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Update count, sum, and sum^2 for one bin. + val binIndex = treePoint.binnedFeatures(featureIndex) + val aggIndex = + 3 * numBins * numFeatures * nodeIndex + + 3 * numBins * featureIndex + + 3 * binIndex + agg(aggIndex) += 1 + agg(aggIndex + 1) += label + agg(aggIndex + 2) += label * label + featureIndex += 1 } } @@ -866,26 +781,30 @@ object DecisionTree extends Serializable with Logging { * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. * Size for regression: * 3 * numBins * numFeatures * numNodes. - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(arr, agg) + def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + val nodeIndex = treePointToNodeIndex(treePoint) + // If the example does not reach this level, then nodeIndex < 0. + // If the example reaches this level but is handled in a different group, + // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). + if (nodeIndex >= 0 && nodeIndex < numNodes) { + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) } else { - binaryOrNotCategoricalBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) } - case Regression => regressionBinSeqOp(arr, agg) + } else { + regressionBinSeqOp(agg, treePoint, nodeIndex) + } } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) + val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -905,144 +824,134 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregates. + timer.start("aggregation") val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], + leftNodeAgg: Array[Double], + rightNodeAgg: Array[Double], topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) - val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - val leftTotalCount = leftCounts.sum - val rightTotalCount = rightCounts.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } + if (metadata.isClassification) { + val leftTotalCount = leftNodeAgg.sum + val rightTotalCount = rightNodeAgg.sum - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) + classIndex += 1 + } + metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } + } - // Sum of count for each label - val leftRightCounts: Array[Double] = - leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => - leftCount + rightCount - } + val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 + // Sum of count for each label + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => + leftCount + rightCount } - val predict = indexOfLargestArrayElement(leftRightCounts) - val prob = leftRightCounts(predict) / totalCount - - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(leftCounts, leftTotalCount) + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(rightCounts, rightTotalCount) + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } + result._1 + } - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(leftNodeAgg, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(rightNodeAgg, rightTotalCount) + } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount - case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } + } else { + // Regression - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + val leftSumSquares = leftNodeAgg(2) + + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + val rightSumSquares = rightNodeAgg(2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + metadata.impurity.calculate(count, sum, sumSquares) } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, + Double.MinValue, leftSum / leftCount) + } - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } @@ -1065,6 +974,19 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + /** + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1169,45 +1091,32 @@ object DecisionTree extends Serializable with Logging { } } - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 + if (metadata.isClassification) { + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } - (leftNodeAgg, rightNodeAgg) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } else { + // Regression + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) } } @@ -1225,8 +1134,9 @@ object DecisionTree extends Serializable with Logging { val numSplitsForFeature = getNumSplitsForFeature(featureIndex) var splitIndex = 0 while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) splitIndex += 1 } featureIndex += 1 @@ -1238,18 +1148,14 @@ object DecisionTree extends Serializable with Logging { * Get the number of splits for a feature. */ def getNumSplitsForFeature(featureIndex: Int): Int = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { + if (metadata.isContinuous(featureIndex)) { numBins - 1 } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + if (metadata.isUnordered(featureIndex)) { + (1 << featureCategories - 1) - 1 } else { - // Ordered features featureCategories } } @@ -1308,29 +1214,29 @@ object DecisionTree extends Serializable with Logging { * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } + } else { + // Regression + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode } } @@ -1340,7 +1246,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = (1 << level) - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) @@ -1358,20 +1264,15 @@ object DecisionTree extends Serializable with Logging { * * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode( - numFeatures: Int, - numBins: Int, - numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, - algo: Algo): Int = { - algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - 2 * numClasses * numBins * numFeatures - } else { - numClasses * numBins * numFeatures - } - case Regression => 3 * numBins * numFeatures + private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + if (metadata.isClassification) { + if (metadata.isMulticlassWithCategoricalFeatures) { + 2 * metadata.numClasses * numBins * metadata.numFeatures + } else { + metadata.numClasses * numBins * metadata.numFeatures + } + } else { + 3 * numBins * metadata.numFeatures } } @@ -1390,16 +1291,15 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return A tuple of (splits,bins). + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numBins - 1). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] @@ -1407,19 +1307,18 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.size - val maxBins = strategy.maxBins + val maxBins = metadata.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) /* * Ensure numBins is always greater than the categories. For multiclass classification, @@ -1431,13 +1330,12 @@ object DecisionTree extends Serializable with Logging { * by the number of training examples. * TODO: Allow this case, where we simply will know nothing about some categories. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + if (metadata.featureArity.size > 0) { + val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } - // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1451,7 +1349,7 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { + metadata.quantileStrategy match { case Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1462,7 +1360,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1475,18 +1373,14 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = metadata.featureArity(featureIndex) // Use different bin/split calculation strategy for categorical features in multiclass // classification that satisfy the space constraint. - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { + if (metadata.isUnordered(featureIndex)) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + while (index < (1 << featureCategories - 1) - 1) { val categories: List[Double] = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) @@ -1516,7 +1410,7 @@ object DecisionTree extends Serializable with Logging { * centroidForCategories is a mapping: category (for the given feature) --> centroid */ val centroidForCategories = { - if (isMulticlassClassification) { + if (isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -1524,7 +1418,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) + .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1576,7 +1470,7 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) @@ -1590,7 +1484,7 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (splits,bins) + (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000000..d9eda354dc986 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + */ +private[tree] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + +} + +private[tree] object DecisionTreeMetadata { + + def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + + val numFeatures = input.take(1)(0).features.size + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClassesForClassification + case Regression => 0 + } + + val maxBins = math.min(strategy.maxBins, numExamples).toInt + val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + + val unorderedFeatures = new mutable.HashSet[Int]() + if (numClasses > 2) { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + if (k - 1 < log2MaxBinsp1) { + // Note: The above check is equivalent to checking: + // numUnorderedBins = (1 << k - 1) - 1 < maxBins + unorderedFeatures.add(f) + } else { + // TODO: Allow this case, where we simply will know nothing about some categories? + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + } else { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, + strategy.impurity, strategy.quantileCalculationStrategy) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index ccac1031fd9d9..170e43e222083 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.impl import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.model.Bin import org.apache.spark.rdd.RDD @@ -48,50 +47,35 @@ private[tree] object TreePoint { * Convert an input dataset into its TreePoint representation, * binning feature values in preparation for DecisionTree training. * @param input Input dataset. - * @param strategy DecisionTree training info, used for dataset metadata. * @param bins Bins for features, of size (numFeatures, numBins). + * @param metadata Learning and dataset metadata * @return TreePoint dataset representation */ def convertToTreeRDD( input: RDD[LabeledPoint], - strategy: Strategy, - bins: Array[Array[Bin]]): RDD[TreePoint] = { + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { input.map { x => - TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, - strategy.categoricalFeaturesInfo) + TreePoint.labeledPointToTreePoint(x, bins, metadata) } } /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). - * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, - isMulticlassClassification: Boolean, bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + metadata: DecisionTreeMetadata): TreePoint = { val numFeatures = labeledPoint.features.size val numBins = bins(0).size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, - bins, categoricalFeaturesInfo) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isUnorderedFeature, bins, categoricalFeaturesInfo) - } + arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), + metadata.isUnordered(featureIndex), bins, metadata.featureArity) featureIndex += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index c89c1e371a40e..af35d88f713e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. For a continuous - * feature, a bin is determined by a low and a high "split". For a categorical feature, - * the a bin is determined using a single label value (category). + * Used for "binning" the features bins for faster best split calculation. + * + * For a continuous feature, the bin is determined by a low and a high split, + * where an example with featureValue falls into the bin s.t. + * lowSplit.threshold < featureValue <= highSplit.threshold. + * + * For ordered categorical features, there is a 1-1-1 correspondence between + * bins, splits, and feature values. The bin is determined by category/feature value. + * However, the bins are not necessarily ordered by feature value; + * they are ordered using impurity. + * For unordered categorical features, there is a 1-1 correspondence between bins, splits, + * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for binary classification + * @param category categorical label value accepted in the bin for ordered features */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3d3406b5d5f22..0594fd0749d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Vector): Double = { - topNode.predictIfLeaf(features) + topNode.predict(features) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala deleted file mode 100644 index 2deaf4ae8dcab..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -/** - * Filter specifying a split and type of comparison to be applied on features - * @param split split specifying the feature index, type and threshold - * @param comparison integer specifying <,=,> - */ -private[tree] case class Filter(split: Split, comparison: Int) { - // Comparison -1,0,1 signifies <.=,> - override def toString = " split = " + split + "comparison = " + comparison -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 944f11c2c2e4f..0eee6262781c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -69,24 +69,24 @@ class Node ( /** * predict value if node is not leaf - * @param feature feature value + * @param features feature value * @return predicted value */ - def predictIfLeaf(feature: Vector) : Double = { + def predict(features: Vector) : Double = { if (isLeaf) { predict } else{ if (split.get.featureType == Continuous) { - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (features(split.get.feature) <= split.get.threshold) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } else { - if (split.get.categories.contains(feature(split.get.feature))) { - leftNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(features(split.get.feature))) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index d7ffd386c05ee..50fb48b40de3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * :: DeveloperApi :: * Split applied to a feature * @param feature feature index - * @param threshold threshold for continuous feature + * @param threshold Threshold for continuous feature. + * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous - * @param categories accepted values for categorical variables + * @param categories Split left if categorical feature value is in this set, else right. */ @DeveloperApi case class Split( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a5c49a38dc08f..2f36fd907772c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} -import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -82,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -162,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Check splits. @@ -279,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -373,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -428,10 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -456,10 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -495,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -503,9 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -518,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -526,9 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -542,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -550,9 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -566,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -574,9 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -590,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -598,14 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val nodes: Array[Node] = new Array[Node](7) + nodes(0) = modelOneNode.topNode + nodes(0).leftNode = None + nodes(0).rightNode = None + val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -613,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, - filters, splits, bins, 0) + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, + nodes, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -629,19 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -657,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -688,20 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -716,18 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -741,18 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -765,14 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 From 318e28b503f22a89c23b7b3624e5fcf689fb92a2 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 17 Aug 2014 17:06:55 -0700 Subject: [PATCH 19/54] SPARK-2881. Upgrade snappy-java to 1.1.1.3. This upgrades snappy-java which fixes the issue reported in SPARK-2881. This is the master branch equivalent to #1994 which provides a different work-around for the 1.1 branch. Author: Patrick Wendell Closes #1995 from pwendell/snappy-1.1 and squashes the following commits: 0c7c4c2 [Patrick Wendell] SPARK-2881. Upgrade snappy-java to 1.1.1.3. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 920912353fe9c..ef12c8f1a5c49 100644 --- a/pom.xml +++ b/pom.xml @@ -316,7 +316,7 @@ org.xerial.snappy snappy-java - 1.0.5 + 1.1.1.3 net.jpountz.lz4 From 5ecb08ea063166564178885b7515abef0d76eecb Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 18:10:45 -0700 Subject: [PATCH 20/54] Revert "[SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled" Revert #1891 due to issues with hadoop 1 compatibility. Author: Michael Armbrust Closes #2007 from marmbrus/revert1891 and squashes the following commits: 68706c0 [Michael Armbrust] Revert "[SPARK-2970] [SQL] spark-sql script ends with IOException when EventLogging is enabled" --- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index c16a7d3661c66..b092f42372171 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,8 +26,6 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} @@ -118,17 +116,13 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - /** - * This should be executed before shutdown hook of - * FileSystem to avoid race condition of FileSystem operation - */ - ShutdownHookManager.get.addShutdownHook( + Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) + ) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { From bfa09b01d7eddc572cd22ca2e418a735b4ccc826 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 19:00:38 -0700 Subject: [PATCH 21/54] [SQL] Improve debug logging and toStrings. Author: Michael Armbrust Closes #2004 from marmbrus/codgenDebugging and squashes the following commits: b7a7e41 [Michael Armbrust] Improve debug logging and toStrings. --- .../expressions/codegen/CodeGenerator.scala | 21 +++++++++++++++++-- .../catalyst/expressions/nullFunctions.scala | 2 ++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5b398695bf560..de2d67ce82ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -78,7 +78,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .build( new CacheLoader[InType, OutType]() { override def load(in: InType): OutType = globalLock.synchronized { - create(in) + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result } }) @@ -413,7 +418,19 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children } - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + // Only inject debugging code if debugging is turned on. + val debugCode = + if (log.isDebugEnabled) { + val localLogger = log + val localLoggerTree = reify { localLogger } + q""" + $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + """ :: Nil + } else { + Nil + } + + EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) } protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index ce6d99c911ab3..e88c5d4fa178a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -60,6 +60,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def eval(input: Row): Any = { child.eval(input) == null } + + override def toString = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { From 99243288b049f4a4fb4ba0505ea2310be5eb4bd2 Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Sun, 17 Aug 2014 19:33:15 -0700 Subject: [PATCH 22/54] [SPARK-1981] updated streaming-kinesis.md fixed markup, separated out sections more-clearly, more thorough explanations Author: Chris Fregly Closes #1757 from cfregly/master and squashes the following commits: 9b1c71a [Chris Fregly] better explained why spark checkpoints are disabled in the example (due to no stateful operations being used) 0f37061 [Chris Fregly] SPARK-1981: (Kinesis streaming support) updated streaming-kinesis.md 862df67 [Chris Fregly] Merge remote-tracking branch 'upstream/master' 8e1ae2e [Chris Fregly] Merge remote-tracking branch 'upstream/master' 4774581 [Chris Fregly] updated docs, renamed retry to retryRandom to be more clear, removed retries around store() method 0393795 [Chris Fregly] moved Kinesis examples out of examples/ and back into extras/kinesis-asl 691a6be [Chris Fregly] fixed tests and formatting, fixed a bug with JavaKinesisWordCount during union of streams 0e1c67b [Chris Fregly] Merge remote-tracking branch 'upstream/master' 74e5c7c [Chris Fregly] updated per TD's feedback. simplified examples, updated docs e33cbeb [Chris Fregly] Merge remote-tracking branch 'upstream/master' bf614e9 [Chris Fregly] per matei's feedback: moved the kinesis examples into the examples/ dir d17ca6d [Chris Fregly] per TD's feedback: updated docs, simplified the KinesisUtils api 912640c [Chris Fregly] changed the foundKinesis class to be a publically-avail class db3eefd [Chris Fregly] Merge remote-tracking branch 'upstream/master' 21de67f [Chris Fregly] Merge remote-tracking branch 'upstream/master' 6c39561 [Chris Fregly] parameterized the versions of the aws java sdk and kinesis client 338997e [Chris Fregly] improve build docs for kinesis 828f8ae [Chris Fregly] more cleanup e7c8978 [Chris Fregly] Merge remote-tracking branch 'upstream/master' cd68c0d [Chris Fregly] fixed typos and backward compatibility d18e680 [Chris Fregly] Merge remote-tracking branch 'upstream/master' b3b0ff1 [Chris Fregly] [SPARK-1981] Add AWS Kinesis streaming support --- docs/streaming-kinesis.md | 97 ++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md index 801c905c88df8..16ad3222105a2 100644 --- a/docs/streaming-kinesis.md +++ b/docs/streaming-kinesis.md @@ -3,56 +3,57 @@ layout: global title: Spark Streaming Kinesis Receiver --- -### Kinesis -Build notes: -
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • -
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • -
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • +## Kinesis +###Design +
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • +
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • +
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • +
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • +
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • -Kinesis examples notes: -
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • -
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • -
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • +### Build +
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • +
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • +
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • +
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -Deployment and runtime notes: -
  • A single KinesisReceiver can process many shards of a stream.
  • -
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream.
  • -
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • -
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    - 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    - 2) Java System Properties - aws.accessKeyId and aws.secretKey
    - 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    - 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    -
  • -
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    - http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, -retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • -
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). -Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, -it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • +###Example +
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • +
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • +
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • +
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) +
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • +
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • +
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • +
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • -Failure recovery notes: -
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    - 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    - 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    - 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +###Deployment and Runtime +
  • A Kinesis application name must be unique for a given account and region.
  • +
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • +
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • +
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • +
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • +
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    +1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    +2) Java System Properties - aws.accessKeyId and aws.secretKey
    +3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    +4) Instance profile credentials - delivered through the Amazon EC2 metadata service
  • -
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • -
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) -or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data -depending on the checkpoint frequency.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • + +###Fault-Tolerance +
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • +
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • +
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • Record processing should be idempotent when possible.
  • -
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • -
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • +
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • +
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file From 95470a03ae85d7d37d75f73435425a0e22918bc9 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Sun, 17 Aug 2014 19:50:31 -0700 Subject: [PATCH 23/54] [HOTFIX][STREAMING] Allow the JVM/Netty to decide which port to bind to in Flume Polling Tests. Author: Hari Shreedharan Closes #1820 from harishreedharan/use-free-ports and squashes the following commits: b939067 [Hari Shreedharan] Remove unused import. 67856a8 [Hari Shreedharan] Remove findFreePort. 0ea51d1 [Hari Shreedharan] Make some changes to getPort to use map on the serverOpt. 1fb0283 [Hari Shreedharan] Merge branch 'master' of https://github.com/apache/spark into use-free-ports b351651 [Hari Shreedharan] Allow Netty to choose port, and query it to decide the port to bind to. Leaving findFreePort as is, if other tests want to use it at some point. e6c9620 [Hari Shreedharan] Making sure the second sink uses the correct port. 11c340d [Hari Shreedharan] Add info about race condition to scaladoc. e89d135 [Hari Shreedharan] Adding Scaladoc. 6013bb0 [Hari Shreedharan] [STREAMING] Find free ports to use before attempting to create Flume Sink in Flume Polling Suite --- .../streaming/flume/sink/SparkSink.scala | 8 +++ .../flume/FlumePollingStreamSuite.scala | 55 +++++++++---------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 7b735133e3d14..948af5947f5e1 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -131,6 +131,14 @@ class SparkSink extends AbstractSink with Logging with Configurable { blockingLatch.await() Status.BACKOFF } + + private[flume] def getPort(): Int = { + serverOpt + .map(_.getPort) + .getOrElse( + throw new RuntimeException("Server was not started!") + ) + } } /** diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index a69baa16981a1..8a85b0f987e42 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -22,6 +22,8 @@ import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} import java.util.Random +import org.apache.spark.TestUtils + import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -39,9 +41,6 @@ import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { - val random = new Random() - /** Return a port in the ephemeral range. */ - def getTestPort = random.nextInt(16382) + 49152 val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch @@ -77,17 +76,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePolling(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -98,10 +86,19 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() ssc.start() writeAndVerify(Seq(channel), ssc, outputBuffer) @@ -111,18 +108,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePollingMultipleHost(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -136,17 +121,29 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() val sink2 = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + ssc.start() writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) assertChannelIsEmpty(channel) From c77f40668fbb5b8bca9a9b25c039895cb7a4a80c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 17 Aug 2014 20:53:18 -0700 Subject: [PATCH 24/54] [SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for number of distinct values There is a bug determining the column index. dorx Author: Xiangrui Meng Closes #1997 from mengxr/chisq-index and squashes the following commits: 8fc2ab2 [Xiangrui Meng] fix col indexing bug and add a check for number of distinct values --- .../apache/spark/mllib/stat/Statistics.scala | 2 +- .../spark/mllib/stat/test/ChiSqTest.scala | 37 +++++++++++++++---- .../mllib/stat/HypothesisTestSuite.scala | 37 ++++++++++++++----- 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf1028fbc725..3cf4e807b4cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -155,7 +155,7 @@ object Statistics { * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which - * the chi-squared statistic is computed. + * the chi-squared statistic is computed. All label and feature values must be categorical. * * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. * Real-valued features will be treated as categorical for each distinct value. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 215de95db5113..0089419c2c5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test import breeze.linalg.{DenseMatrix => BDM} import cern.jet.stat.Probability.chiSquareComplemented -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import scala.collection.mutable + /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted @@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null - // At most 100 columns at a time - val batchSize = 100 + // at most 1000 columns at a time + val batchSize = 1000 var batch = 0 while (batch * batchSize < numCols) { // The following block of code can be cleaned up and made public as // chiSquared(data: RDD[(V1, V2)]) val startCol = batch * batchSize val endCol = startCol + math.min(batchSize, numCols - startCol) - val pairCounts = data.flatMap { p => - // assume dense vectors - p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => - (col, feature, p.label) + val pairCounts = data.mapPartitions { iter => + val distinctLabels = mutable.HashSet.empty[Double] + val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] = + Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*) + var i = 1 + iter.flatMap { case LabeledPoint(label, features) => + if (i % 1000 == 0) { + if (distinctLabels.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct label values.") + } + allDistinctFeatures.foreach { case (col, distinctFeatures) => + if (distinctFeatures.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct values in column $col.") + } + } + } + i += 1 + distinctLabels += label + features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + allDistinctFeatures(col) += feature + (col, feature, label) + } } }.countByValue() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 5bd0521298c14..6de3840b3f198 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.mllib.stat +import java.util.Random + import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest @@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { // labels: 1.0 (2 / 6), 0.0 (4 / 6) // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) - val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), - new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), - new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), - new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) for (numParts <- List(2, 4, 6, 8)) { val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) val feature1 = chi(0) @@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { } // Test that the right number of results is returned - val numCols = 321 - val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), - new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val numCols = 1001 + val sparseData = Array( + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) assert(chi.size === numCols) + assert(chi(1000) != null) // SPARK-3087 + + // Detect continous features or labels + val random = new Random(11L) + val continuousLabel = + Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) + } + val continuousFeature = + Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) + } } } From 5173f3c40f6b64f224f11364e038953826013895 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 17 Aug 2014 22:29:58 -0700 Subject: [PATCH 25/54] SPARK-2884: Create binary builds in parallel with release script. --- dev/create-release/create-release.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 1867cf4ec46ca..28f26d2368254 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -117,12 +117,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +wait # Copy data echo "Copying release tarballs" From df652ea02a3e42d987419308ef14874300347373 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 17 Aug 2014 22:39:06 -0700 Subject: [PATCH 26/54] SPARK-2900. aggregate inputBytes per stage Author: Sandy Ryza Closes #1826 from sryza/sandy-spark-2900 and squashes the following commits: 43f9091 [Sandy Ryza] SPARK-2900 --- .../org/apache/spark/ui/jobs/JobProgressListener.scala | 6 ++++++ .../apache/spark/ui/jobs/JobProgressListenerSuite.scala | 9 ++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a3e9566832d06..74cd637d88155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -200,6 +200,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleReadBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val inputBytesDelta = + (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) + stageData.inputBytes += inputBytesDelta + execSummary.inputBytes += inputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index f5ba31c309277..147ec0bc52e39 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -150,6 +150,9 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.executorRunTime = base + 4 taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.bytesRead = base + 7 taskMetrics } @@ -182,6 +185,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 205) assert(stage0Data.memoryBytesSpilled == 112) assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.inputBytes == 114) + assert(stage1Data.inputBytes == 207) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get @@ -208,6 +213,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 610) assert(stage0Data.memoryBytesSpilled == 412) assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.inputBytes == 414) + assert(stage1Data.inputBytes == 614) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get From 3c8fa505900ac158d57de36f6b0fd6da05f8893b Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 17 Aug 2014 23:29:44 -0700 Subject: [PATCH 27/54] [SPARK-3097][MLlib] Word2Vec performance improvement mengxr Please review the code. Adding weights in reduceByKey soon. Only output model entry for words appeared in the partition before merging and use reduceByKey to combine model. In general, this implementation is 30s or so faster than implementation using big array. Author: Liquan Pei Closes #1932 from Ishiihara/Word2Vec-improve2 and squashes the following commits: d5377a9 [Liquan Pei] use syn0Global and syn1Global to represent model cad2011 [Liquan Pei] bug fix for synModify array out of bound 083aa66 [Liquan Pei] update synGlobal in place and reduce synOut size 9075e1c [Liquan Pei] combine syn0Global and syn1Global to synGlobal aa2ab36 [Liquan Pei] use reduceByKey to combine models --- .../apache/spark/mllib/feature/Word2Vec.scala | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index ecd49ea2ff533..d2ae62b482aff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging { var syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * vectorSize + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { @@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) + val syn0Local = model._1 + val syn1Local = model._2 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + var index = 0 + while(index < vocabSize) { + if (syn0Modify(index) != 0) { + synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + if (syn1Modify(index) != 0) { + synOut.update(index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + index += 1 + } + Iterator(synOut) } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect() + var i = 0 + while (i < synAgg.length) { + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + i += 1 + } } newSentences.unpersist() From eef779b8d631de971d440051cae21040f4de558f Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 17 Aug 2014 23:30:47 -0700 Subject: [PATCH 28/54] [SPARK-2842][MLlib]Word2Vec documentation mengxr Documentation for Word2Vec Author: Liquan Pei Closes #2003 from Ishiihara/Word2Vec-doc and squashes the following commits: 4ff11d4 [Liquan Pei] minor fix 8d7458f [Liquan Pei] code reformat 6df0dcb [Liquan Pei] add Word2Vec documentation --- docs/mllib-feature-extraction.md | 63 +++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 21453cb9cd8c9..4b3cb715c58c7 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -9,4 +9,65 @@ displayTitle: MLlib - Feature Extraction ## Word2Vec -## TFIDF +Word2Vec computes distributed vector representation of words. The main advantage of the distributed +representations is that similar words are close in the vector space, which makes generalization to +novel patterns easier and model estimation more robust. Distributed vector representation is +showed to be useful in many natural language processing applications such as named entity +recognition, disambiguation, parsing, tagging and machine translation. + +### Model + +In our implementation of Word2Vec, we used skip-gram model. The training objective of skip-gram is +to learn word vector representations that are good at predicting its context in the same sentence. +Mathematically, given a sequence of training words `$w_1, w_2, \dots, w_T$`, the objective of the +skip-gram model is to maximize the average log-likelihood +`\[ +\frac{1}{T} \sum_{t = 1}^{T}\sum_{j=-k}^{j=k} \log p(w_{t+j} | w_t) +\]` +where $k$ is the size of the training window. + +In the skip-gram model, every word $w$ is associated with two vectors $u_w$ and $v_w$ which are +vector representations of $w$ as word and context respectively. The probability of correctly +predicting word $w_i$ given word $w_j$ is determined by the softmax model, which is +`\[ +p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top}v_{w_j})} +\]` +where $V$ is the vocabulary size. + +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, +we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to +$O(\log(V))$ + +### Example + +The example below demonstrates how to load a text file, parse it as an RDD of `Seq[String]`, +construct a `Word2Vec` instance and then fit a `Word2VecModel` with the input data. Finally, +we display the top 40 synonyms of the specified word. To run the example, first download +the [text8](http://mattmahoney.net/dc/text8.zip) data and extract it to your preferred directory. +Here we assume the extracted file is `text8` and in same directory as you run the spark shell. + +
    +
    +{% highlight scala %} +import org.apache.spark._ +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.Word2Vec + +val input = sc.textFile("text8").map(line => line.split(" ").toSeq) + +val word2vec = new Word2Vec() + +val model = word2vec.fit(input) + +val synonyms = model.findSynonyms("china", 40) + +for((synonym, cosineSimilarity) <- synonyms) { + println(s"$synonym $cosineSimilarity") +} +{% endhighlight %} +
    +
    + +## TFIDF \ No newline at end of file From 9306b8c6c8c412b9d0d5cffb6bd7a87784f0f6bf Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Mon, 18 Aug 2014 01:15:45 -0700 Subject: [PATCH 29/54] [MLlib] Remove transform(dataset: RDD[String]) from Word2Vec public API mengxr Remove transform(dataset: RDD[String]) from public API. Author: Liquan Pei Closes #2010 from Ishiihara/Word2Vec-api and squashes the following commits: 17b1031 [Liquan Pei] remove transform(dataset: RDD[String]) from public API --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d2ae62b482aff..1dcaa2cd2e630 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -434,15 +434,6 @@ class Word2VecModel private[mllib] ( } } - /** - * Transforms an RDD to its vector representation - * @param dataset a an RDD of words - * @return RDD of vector representation - */ - def transform(dataset: RDD[String]): RDD[Vector] = { - dataset.map(word => transform(word)) - } - /** * Find synonyms of a word * @param word a word From c0cbbdeaf4f2033be03d32e3ea0288812b4edbf6 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Mon, 18 Aug 2014 09:34:36 -0700 Subject: [PATCH 30/54] SPARK-3093 : masterLock in Worker is no longer need there's no need to use masterLock in Worker now since all communications are within Akka actor Author: CrazyJvm Closes #2008 from CrazyJvm/no-need-master-lock and squashes the following commits: dd39e20 [CrazyJvm] fix format 58e7fa5 [CrazyJvm] there's no need to use masterLock now since all communications are within Akka actor --- .../apache/spark/deploy/worker/Worker.scala | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 80fde7e4b2624..81400af22c0bf 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -72,7 +72,6 @@ private[spark] class Worker( val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) val testing: Boolean = sys.props.contains("spark.testing") - val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null var activeMasterUrl: String = "" @@ -145,18 +144,16 @@ private[spark] class Worker( } def changeMaster(url: String, uiUrl: String) { - masterLock.synchronized { - activeMasterUrl = url - activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = activeMasterUrl match { - case Master.sparkUrlRegex(_host, _port) => - Address("akka.tcp", Master.systemName, _host, _port.toInt) - case x => - throw new SparkException("Invalid spark URL: " + x) - } - connected = true + activeMasterUrl = url + activeMasterWebUiUrl = uiUrl + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(_host, _port) => + Address("akka.tcp", Master.systemName, _host, _port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) } + connected = true } def tryRegisterAllMasters() { @@ -199,9 +196,7 @@ private[spark] class Worker( } case SendHeartbeat => - masterLock.synchronized { - if (connected) { master ! Heartbeat(workerId) } - } + if (connected) { master ! Heartbeat(workerId) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor @@ -244,9 +239,7 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) - } + master ! ExecutorStateChanged(appId, execId, manager.state, None, None) } catch { case e: Exception => { logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) @@ -254,17 +247,13 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) - } + master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) } } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - } + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -330,9 +319,7 @@ private[spark] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - masterLock.synchronized { - master ! DriverStateChanged(driverId, state, exception) - } + master ! DriverStateChanged(driverId, state, exception) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem From f45efbb8aaa65bc46d65e77e93076fbc29f4455d Mon Sep 17 00:00:00 2001 From: Chandan Kumar Date: Mon, 18 Aug 2014 09:52:25 -0700 Subject: [PATCH 31/54] [SPARK-2862] histogram method fails on some choices of bucketCount Author: Chandan Kumar Closes #1787 from nrchandan/spark-2862 and squashes the following commits: a76bbf6 [Chandan Kumar] [SPARK-2862] Fix for a broken test case and add new test cases 4211eea [Chandan Kumar] [SPARK-2862] Add Scala bug id 13854f1 [Chandan Kumar] [SPARK-2862] Use shorthand range notation to avoid Scala bug --- .../apache/spark/rdd/DoubleRDDFunctions.scala | 15 ++++++++---- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index f233544d128f5..e0494ee39657c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -95,7 +95,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { - // Compute the minimum and the maxium + // Scala's built-in range has issues. See #SI-8782 + def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { + val span = max - min + Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max + } + // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => @@ -107,9 +112,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment = (max-min)/bucketCount.toDouble - val range = if (increment != 0) { - Range.Double.inclusive(min, max, increment) + val range = if (min != max) { + // Range.Double.inclusive(min, max, increment) + // The above code doesn't always work. See Scala bug #SI-8782. + // https://issues.scala-lang.org/browse/SI-8782 + customRange(min, max, bucketCount) } else { List(min, min) } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index a822bd18bfdbd..f89bdb6e07dea 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -245,6 +245,29 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsForLargerDatasets") { + // Verify the case of slighly larger datasets + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(8) + val expectedHistogramResults = + Array(12, 12, 11, 12, 12, 11, 12, 12) + val expectedHistogramBuckets = + Array(6.0, 17.625, 29.25, 40.875, 52.5, 64.125, 75.75, 87.375, 99.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithIrrationalBucketEdges") { + // Verify the case of buckets with irrational edges. See #SPARK-2862. + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(9) + val expectedHistogramResults = + Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets(0) === 6.0) + assert(histogramBuckets(9) === 99.0) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity From 7ae28d1247e4756219016206c51fec1656e3917b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 18 Aug 2014 10:00:46 -0700 Subject: [PATCH 32/54] SPARK-3096: Include parquet hive serde by default in build A small change - we should just add this dependency. It doesn't have any recursive deps and it's needed for reading have parquet tables. Author: Patrick Wendell Closes #2009 from pwendell/parquet and squashes the following commits: e411f9f [Patrick Wendell] SPARk-309: Include parquet hive serde by default in build --- sql/hive/pom.xml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 93d00f7c37c9b..30ff277e67c88 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -36,6 +36,11 @@ + + com.twitter + parquet-hive-bundle + 1.5.0 + org.apache.spark spark-core_${scala.binary.version} From 6a13dca12fac06f3af892ffcc8922cc84f91b786 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 10:05:52 -0700 Subject: [PATCH 33/54] [SPARK-3084] [SQL] Collect broadcasted tables in parallel in joins BroadcastHashJoin has a broadcastFuture variable that tries to collect the broadcasted table in a separate thread, but this doesn't help because it's a lazy val that only gets initialized when you attempt to build the RDD. Thus queries that broadcast multiple tables would collect and broadcast them sequentially. I changed this to a val to let it start collecting right when the operator is created. Author: Matei Zaharia Closes #1990 from mateiz/spark-3084 and squashes the following commits: f468766 [Matei Zaharia] [SPARK-3084] Collect broadcasted tables in parallel in joins --- .../src/main/scala/org/apache/spark/sql/execution/joins.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index c86811e838bd8..481bb8c05e71b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -424,7 +424,7 @@ case class BroadcastHashJoin( UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient - lazy val broadcastFuture = future { + val broadcastFuture = future { sparkContext.broadcast(buildPlan.executeCollect()) } From 4bf3de71074053af94f077c99e9c65a1962739e1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 10:45:24 -0700 Subject: [PATCH 34/54] [SPARK-3085] [SQL] Use compact data structures in SQL joins This reuses the CompactBuffer from Spark Core to save memory and pointer dereferences. I also tried AppendOnlyMap instead of java.util.HashMap but unfortunately that slows things down because it seems to do more equals() calls and the equals on GenericRow, and especially JoinedRow, is pretty expensive. Author: Matei Zaharia Closes #1993 from mateiz/spark-3085 and squashes the following commits: 188221e [Matei Zaharia] Remove unneeded import 5f903ee [Matei Zaharia] [SPARK-3085] [SQL] Use compact data structures in SQL joins --- .../apache/spark/sql/execution/joins.scala | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 481bb8c05e71b..b08f9aacc1fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution import java.util.{HashMap => JavaHashMap} -import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.collection.CompactBuffer @DeveloperApi sealed abstract class BuildSide @@ -67,7 +66,7 @@ trait HashJoin { def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() var currentRow: Row = null // Create a mapping of buildKeys -> rows @@ -77,7 +76,7 @@ trait HashJoin { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() + val newMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -89,7 +88,7 @@ trait HashJoin { new Iterator[Row] { private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -140,7 +139,7 @@ trait HashJoin { /** * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using + * Performs a hash based outer join for two child relations by shuffling the data using * the join keys. This operator requires loading the associated partition in both side into memory. */ @DeveloperApi @@ -179,26 +178,26 @@ case class HashOuterJoin( @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. + // iterator for performance purpose. private[this] def leftOuterIterator( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - leftIter.iterator.flatMap { l => + leftIter.iterator.flatMap { l => joinedRow.withLeft(l) var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right joinedRow.withRight(rightNullRow).copy @@ -210,20 +209,20 @@ case class HashOuterJoin( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - rightIter.iterator.flatMap { r => + rightIter.iterator.flatMap { r => joinedRow.withRight(r) var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. joinedRow.withLeft(leftNullRow).copy @@ -236,7 +235,7 @@ case class HashOuterJoin( val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) if (!key.anyNull) { @@ -246,8 +245,8 @@ case class HashOuterJoin( leftIter.iterator.flatMap[Row] { l => joinedRow.withLeft(l) var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { @@ -260,7 +259,7 @@ case class HashOuterJoin( // 2. For those unmatched records in left, append additional records with empty right. // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all + // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. joinedRow.withRight(rightNullRow).copy @@ -268,8 +267,8 @@ case class HashOuterJoin( } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. case (r, idx) if (!rightMatchedSet.contains(idx)) => { joinedRow(leftNullRow, r).copy } @@ -284,15 +283,15 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) if (existingMatchList == null) { - existingMatchList = new ArrayBuffer[Row]() + existingMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, existingMatchList) } @@ -311,20 +310,20 @@ case class HashOuterJoin( val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) import scala.collection.JavaConversions._ - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") @@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new ArrayBuffer[Row] + val matchedRows = new CompactBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin( val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[Row] = { - val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + val buf: CompactBuffer[Row] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) case _ => } } i += 1 } - arrBuf.toSeq + buf.toSeq } // TODO: Breaks lineage. From 6bca8898a1aa4ca7161492229bac1748b3da2ad7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 18 Aug 2014 10:52:20 -0700 Subject: [PATCH 35/54] SPARK-3025 [SQL]: Allow JDBC clients to set a fair scheduler pool This definitely needs review as I am not familiar with this part of Spark. I tested this locally and it did seem to work. Author: Patrick Wendell Closes #1937 from pwendell/scheduler and squashes the following commits: b858e33 [Patrick Wendell] SPARK-3025: Allow JDBC clients to set a fair scheduler pool --- docs/sql-programming-guide.md | 5 ++++ .../scala/org/apache/spark/sql/SQLConf.scala | 3 +++ .../server/SparkSQLOperationManager.scala | 27 ++++++++++++++----- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd6543945c385..34accade36ea9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -605,6 +605,11 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script comes with Hive. +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + + SET spark.sql.thriftserver.scheduler.pool=accounting; + ### Migration Guide for Shark Users #### Reducer number diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 90de11182e605..56face2992bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -33,6 +33,9 @@ private[spark] object SQLConf { val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + // This is only used for the thriftserver + val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9338e8121b0fe..699a1103f3248 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.hive.thriftserver.server -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.math.{random, round} - import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map} +import scala.math.{random, round} + import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession - import org.apache.spark.Logging +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -43,6 +43,9 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + // TODO: Currenlty this will grow infinitely, even as sessions expire + val sessionToActivePool = Map[HiveSession, String]() + override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, @@ -165,8 +168,18 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } iter = { val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = From 9eb74c7d2cbe127dd4c32bf1a8318497b2fb55b6 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 18 Aug 2014 11:00:10 -0700 Subject: [PATCH 36/54] [SPARK-3091] [SQL] Add support for caching metadata on Parquet files For larger Parquet files, reading the file footers (which is done in parallel on up to 5 threads) and HDFS block locations (which is serial) can take multiple seconds. We can add an option to cache this data within FilteringParquetInputFormat. Unfortunately ParquetInputFormat only caches footers within each instance of ParquetInputFormat, not across them. Note: this PR leaves this turned off by default for 1.1, but I believe it's safe to turn it on after. The keys in the hash maps are FileStatus objects that include a modification time, so this will work fine if files are modified. The location cache could become invalid if files have moved within HDFS, but that's rare so I just made it invalidate entries every 15 minutes. Author: Matei Zaharia Closes #2005 from mateiz/parquet-cache and squashes the following commits: dae8efe [Matei Zaharia] Bug fix c71e9ed [Matei Zaharia] Handle empty statuses directly 22072b0 [Matei Zaharia] Use Guava caches and add a config option for caching metadata 8fb56ce [Matei Zaharia] Cache file block locations too 453bd21 [Matei Zaharia] Bug fix 4094df6 [Matei Zaharia] First attempt at caching Parquet footers --- .../scala/org/apache/spark/sql/SQLConf.scala | 1 + .../sql/parquet/ParquetTableOperations.scala | 84 ++++++++++++++++--- 2 files changed, 72 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 56face2992bcf..4f2adb006fbc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -32,6 +32,7 @@ private[spark] object SQLConf { val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 759a2a586b926..c6dca10f6ad7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,22 +17,23 @@ package org.apache.spark.sql.parquet -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - import java.io.IOException import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.{Date, List => JList} +import java.util.concurrent.{Callable, TimeUnit} +import java.util.{ArrayList, Collections, Date, List => JList} +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try + +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData @@ -41,7 +42,7 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -96,6 +97,11 @@ case class ParquetTableScan( ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) } + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.set( + SQLConf.PARQUET_CACHE_METADATA, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + sc.newAPIHadoopRDD( conf, classOf[FilteringParquetRowInputFormat], @@ -323,10 +329,40 @@ private[parquet] class FilteringParquetRowInputFormat } override def getFooters(jobContext: JobContext): JList[Footer] = { + import FilteringParquetRowInputFormat.footerCache + if (footers eq null) { + val conf = ContextUtil.getConfiguration(jobContext) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + if (statuses.isEmpty) { + footers = Collections.emptyList[Footer] + } else if (!cacheMetadata) { + // Read the footers from HDFS + footers = getFooters(conf, statuses) + } else { + // Read only the footers that are not in the footerCache + val foundFooters = footerCache.getAllPresent(statuses) + val toFetch = new ArrayList[FileStatus] + for (s <- statuses) { + if (!foundFooters.containsKey(s)) { + toFetch.add(s) + } + } + val newFooters = new mutable.HashMap[FileStatus, Footer] + if (toFetch.size > 0) { + val fetched = getFooters(conf, toFetch) + for ((status, i) <- toFetch.zipWithIndex) { + newFooters(status) = fetched.get(i) + } + footerCache.putAll(newFooters) + } + footers = new ArrayList[Footer](statuses.size) + for (status <- statuses) { + footers.add(newFooters.getOrElse(status, foundFooters.get(status))) + } + } } footers @@ -339,6 +375,10 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { + import FilteringParquetRowInputFormat.blockLocationCache + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -366,16 +406,23 @@ private[parquet] class FilteringParquetRowInputFormat for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile - val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } splits.addAll( generateSplits.invoke( null, blocks, - fileBlockLocations, - fileStatus, + blockLocations, + status, parquetMetaData.getFileMetaData, readContext.getRequestedSchema.toString, readContext.getReadSupportMetadata, @@ -387,6 +434,17 @@ private[parquet] class FilteringParquetRowInputFormat } } +private[parquet] object FilteringParquetRowInputFormat { + private val footerCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .build[FileStatus, Footer]() + + private val blockLocationCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move + .build[FileStatus, Array[BlockLocation]]() +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) From 3abd0c1cda09bb575adc99847a619bc84af37fd0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 13:17:10 -0700 Subject: [PATCH 37/54] [SPARK-2406][SQL] Initial support for using ParquetTableScan to read HiveMetaStore tables. This PR adds an experimental flag `spark.sql.hive.convertMetastoreParquet` that when true causes the planner to detects tables that use Hive's Parquet SerDe and instead plans them using Spark SQL's native `ParquetTableScan`. Author: Michael Armbrust Author: Yin Huai Closes #1819 from marmbrus/parquetMetastore and squashes the following commits: 1620079 [Michael Armbrust] Revert "remove hive parquet bundle" cc30430 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into parquetMetastore 4f3d54f [Michael Armbrust] fix style 41ebc5f [Michael Armbrust] remove hive parquet bundle a43e0da [Michael Armbrust] Merge remote-tracking branch 'origin/master' into parquetMetastore 4c4dc19 [Michael Armbrust] Fix bug with tree splicing. ebb267e [Michael Armbrust] include parquet hive to tests pass (Remove this later). c0d9b72 [Michael Armbrust] Avoid creating a HadoopRDD per partition. Add dirty hacks to retrieve partition values from the InputSplit. 8cdc93c [Michael Armbrust] Merge pull request #8 from yhuai/parquetMetastore a0baec7 [Yin Huai] Partitioning columns can be resolved. 1161338 [Michael Armbrust] Add a test to make sure conversion is actually happening 212d5cd [Michael Armbrust] Initial support for using ParquetTableScan to read HiveMetaStore tables. --- project/SparkBuild.scala | 1 - .../spark/sql/execution/basicOperators.scala | 12 ++ .../spark/sql/parquet/ParquetRelation.scala | 8 +- .../sql/parquet/ParquetTableOperations.scala | 74 ++++++-- .../apache/spark/sql/hive/HiveContext.scala | 9 + .../spark/sql/hive/HiveStrategies.scala | 119 +++++++++++- .../sql/hive/parquet/FakeParquetSerDe.scala | 56 ++++++ .../sql/parquet/ParquetMetastoreSuite.scala | 171 ++++++++++++++++++ 8 files changed, 427 insertions(+), 23 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 63a285b81a60c..49d52aefca17a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -228,7 +228,6 @@ object SQL { object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", // Multiple queries rely on the TestHive singleton. See comments there for more details. parallelExecution in Test := false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0027f3cf1fc79..f9dfa3c92f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -303,3 +303,15 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } + +/** + * :: DeveloperApi :: + * A plan node that does nothing but lie about the output of its child. Used to spice a + * (hopefully structurally equivalent) tree from a different optimization sequence into an already + * resolved tree. + */ +@DeveloperApi +case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { + def children = child :: Nil + def execute() = child.execute() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 053b2a154389c..1713ae6fb5d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} private[sql] case class ParquetRelation( path: String, @transient conf: Option[Configuration], - @transient sqlContext: SQLContext) + @transient sqlContext: SQLContext, + partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { self: Product => @@ -61,12 +62,13 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = + partitioningAttributes ++ ParquetTypesConverter.readSchemaFromFile( - new Path(path), + new Path(path.split(",").head), conf, sqlContext.isParquetBinaryAsString) - override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index c6dca10f6ad7c..f6cfab736d98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData @@ -42,6 +43,7 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} @@ -60,11 +62,18 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map { a => - relation.output - .find(o => o.exprId == a.exprId) - .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) - } + val normalOutput = + attributes + .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) + .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + + val partOutput = + attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + + def output = partOutput ++ normalOutput + + assert(normalOutput.size + partOutput.size == attributes.size, + s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext @@ -72,16 +81,19 @@ case class ParquetTableScan( ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) val conf: Configuration = ContextUtil.getConfiguration(job) - val qualifiedPath = { - val path = new Path(relation.path) - path.getFileSystem(conf).makeQualified(path) + + relation.path.split(",").foreach { curPath => + val qualifiedPath = { + val path = new Path(curPath) + path.getFileSystem(conf).makeQualified(path) + } + NewFileInputFormat.addInputPath(job, qualifiedPath) } - NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) + ParquetTypesConverter.convertToString(normalOutput)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -102,13 +114,41 @@ case class ParquetTableScan( SQLConf.PARQUET_CACHE_METADATA, sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) - sc.newAPIHadoopRDD( - conf, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row]) - .map(_._2) - .filter(_ != null) // Parquet's record filters may produce null values + val baseRDD = + new org.apache.spark.rdd.NewHadoopRDD( + sc, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + conf) + + if (partOutput.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => + val partValue = "([^=]+)=([^=]+)".r + val partValues = + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + .getPath + .toString + .split("/") + .flatMap { + case partValue(key, value) => Some(key -> value) + case _ => None + }.toMap + + val partitionRowValues = + partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + + def hasNext = iter.hasNext + + def next() = joinedRow.withRight(iter.next()._2) + } + } + } else { + baseRDD.map(_._2) + }.filter(_ != null) // Parquet's record filters may produce null values } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index a8da676ffa0e0..ff32c7c90a0d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -79,6 +79,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Change the default SQL dialect to HiveQL override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + /** + * When true, enables an experimental feature where metastore tables that use the parquet SerDe + * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive + * SerDe. + */ + private[spark] def convertMetastoreParquet: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -326,6 +334,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { TakeOrdered, ParquetOperations, InMemoryScans, + ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 5fcc1bd4b9adf..389ace726d205 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.SQLContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} + +import scala.collection.JavaConversions._ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -32,6 +38,115 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext + /** + * :: Experimental :: + * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet + * table scan operator. + * + * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring + * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. + * + * Other issues: + * - Much of this logic assumes case insensitive resolution. + */ + @Experimental + object ParquetConversion extends Strategy { + implicit class LogicalPlanHacks(s: SchemaRDD) { + def lowerCase = + new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + + def addPartitioningAttributes(attrs: Seq[Attribute]) = + new SchemaRDD( + s.sqlContext, + s.logicalPlan transform { + case p: ParquetRelation => p.copy(partitioningAttributes = attrs) + }) + } + + implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { + def fakeOutput(newOutput: Seq[Attribute]) = + OutputFaker( + originalPlan.output.map(a => + newOutput.find(a.name.toLowerCase == _.name.toLowerCase) + .getOrElse( + sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), + originalPlan) + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) + if relation.tableDesc.getSerdeClassName.contains("Parquet") && + hiveContext.convertMetastoreParquet => + + // Filter out all predicates that only deal with partition keys + val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.exprId).subsetOf(partitionKeyIds) + } + + // We are going to throw the predicates and projection back at the whole optimization + // sequence so lets unresolve all the attributes, allowing them to be rebound to the + // matching parquet attributes. + val unresolvedOtherPredicates = otherPredicates.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }).reduceOption(And).getOrElse(Literal(true)) + + val unresolvedProjection = projectList.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }) + + if (relation.hiveQlTable.isPartitioned) { + val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) + // Translate the predicate so that it automatically casts the input values to the correct + // data types during evaluation + val castedPredicate = rawPredicate transform { + case a: AttributeReference => + val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) + val key = relation.partitionKeys(idx) + Cast(BoundReference(idx, StringType, nullable = true), key.dataType) + } + + val inputData = new GenericMutableRow(relation.partitionKeys.size) + val pruningCondition = + if(codegenEnabled) { + GeneratePredicate(castedPredicate) + } else { + InterpretedPredicate(castedPredicate) + } + + val partitions = relation.hiveQlPartitions.filter { part => + val partitionValues = part.getValues + var i = 0 + while (i < partitionValues.size()) { + inputData(i) = partitionValues(i) + i += 1 + } + pruningCondition(inputData) + } + + hiveContext + .parquetFile(partitions.map(_.getLocation).mkString(",")) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)):: Nil + } else { + hiveContext + .parquetFile(relation.hiveQlTable.getDataLocation.getPath) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } + case _ => Nil + } + } + object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala new file mode 100644 index 0000000000000..544abfc32423c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.parquet + +import java.util.Properties + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category +import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.io.Writable + +/** + * A placeholder that allows SparkSQL users to create metastore tables that are stored as + * parquet files. It is only intended to pass the checks that the serde is valid and exists + * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan + * when "spark.sql.hive.convertMetastoreParquet" is set to true. + */ +@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + + "placeholder in the Hive MetaStore") +class FakeParquetSerDe extends SerDe { + override def getObjectInspector: ObjectInspector = new ObjectInspector { + override def getCategory: Category = Category.PRIMITIVE + + override def getTypeName: String = "string" + } + + override def deserialize(p1: Writable): AnyRef = throwError + + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getSerializedClass: Class[_ <: Writable] = throwError + + override def getSerDeStats: SerDeStats = throwError + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError + + private def throwError = + sys.error( + "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala new file mode 100644 index 0000000000000..0723be7298e15 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -0,0 +1,171 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.File + +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +case class ParquetData(intField: Int, stringField: String) + +/** + * Tests for our SerDe -> Native parquet scan conversion. + */ +class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") + } + + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + """) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + test("project the partitioning column") { + checkAnswer( + sql("SELECT p, count(*) FROM partitioned_parquet group by p"), + (1, 10) :: + (2, 10) :: + (3, 10) :: + (4, 10) :: + (5, 10) :: + (6, 10) :: + (7, 10) :: + (8, 10) :: + (9, 10) :: + (10, 10) :: Nil + ) + } + + test("project partitioning and non-partitioning columns") { + checkAnswer( + sql("SELECT stringField, p, count(intField) " + + "FROM partitioned_parquet GROUP BY p, stringField"), + ("part-1", 1, 10) :: + ("part-2", 2, 10) :: + ("part-3", 3, 10) :: + ("part-4", 4, 10) :: + ("part-5", 5, 10) :: + ("part-6", 6, 10) :: + ("part-7", 7, 10) :: + ("part-8", 8, 10) :: + ("part-9", 9, 10) :: + ("part-10", 10, 10) :: Nil + ) + } + + test("simple count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet"), + 100) + } + + test("pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p = 1"), + 10) + } + + test("multi-partition pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p IN (1,2,3)"), + 30) + } + + test("non-partition predicates") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE intField IN (1,2,3)"), + 30) + } + + test("sum") { + checkAnswer( + sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"), + 1 + 2 + 3 + ) + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + 10 + ) + } + + test("conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + }.nonEmpty) + } +} From 66ade00f91a9343ac9277c5a7c09314087a4831e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 18 Aug 2014 13:25:30 -0700 Subject: [PATCH 38/54] [SPARK-2169] Don't copy appName / basePath everywhere. Instead of keeping copies in all pages, just reference the values kept in the base SparkUI instance (by making them available via getters). Author: Marcelo Vanzin Closes #1252 from vanzin/SPARK-2169 and squashes the following commits: 4412fc6 [Marcelo Vanzin] Simplify UIUtils.headerSparkPage signature. 4e5d35a [Marcelo Vanzin] [SPARK-2169] Don't copy appName / basePath everywhere. --- .../apache/spark/deploy/master/Master.scala | 2 +- .../scala/org/apache/spark/ui/SparkUI.scala | 9 +++++++++ .../scala/org/apache/spark/ui/UIUtils.scala | 12 +++++------- .../scala/org/apache/spark/ui/WebUI.scala | 3 +++ .../apache/spark/ui/env/EnvironmentPage.scala | 4 +--- .../apache/spark/ui/env/EnvironmentTab.scala | 4 +--- .../apache/spark/ui/exec/ExecutorsPage.scala | 5 +---- .../apache/spark/ui/exec/ExecutorsTab.scala | 6 ++---- .../spark/ui/jobs/JobProgressPage.scala | 4 +--- .../apache/spark/ui/jobs/JobProgressTab.scala | 7 +++---- .../org/apache/spark/ui/jobs/PoolPage.scala | 5 +---- .../org/apache/spark/ui/jobs/PoolTable.scala | 7 +++---- .../org/apache/spark/ui/jobs/StagePage.scala | 8 ++------ .../org/apache/spark/ui/jobs/StageTable.scala | 19 ++++++++++--------- .../org/apache/spark/ui/storage/RDDPage.scala | 8 ++------ .../apache/spark/ui/storage/StoragePage.scala | 6 ++---- .../apache/spark/ui/storage/StorageTab.scala | 4 +--- .../spark/streaming/ui/StreamingPage.scala | 3 +-- .../spark/streaming/ui/StreamingTab.scala | 6 ++---- 19 files changed, 51 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cfa2c028a807b..5017273e87c07 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -697,7 +697,7 @@ private[spark] class Master( appIdToUI(app.id) = ui webUi.attachSparkUI(ui) // Application UI is successfully rebuilt, so link the Master UI to it - app.desc.appUiUrl = ui.basePath + app.desc.appUiUrl = ui.getBasePath true } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6c788a37dc70b..cccd59d122a92 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -76,6 +76,8 @@ private[spark] class SparkUI( } } + def getAppName = appName + /** Set the app name for this UI. */ def setAppName(name: String) { appName = name @@ -100,6 +102,13 @@ private[spark] class SparkUI( private[spark] def appUIAddress = s"http://$appUIHostPort" } +private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) + extends WebUITab(parent, prefix) { + + def appName: String = parent.getAppName + +} + private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 715cc2f4df8dd..bee6dad3387e5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -163,17 +163,15 @@ private[spark] object UIUtils extends Logging { /** Returns a spark page with correctly formatted headers */ def headerSparkPage( - content: => Seq[Node], - basePath: String, - appName: String, title: String, - tabs: Seq[WebUITab], - activeTab: WebUITab, + content: => Seq[Node], + activeTab: SparkUITab, refreshInterval: Option[Int] = None): Seq[Node] = { - val header = tabs.map { tab => + val appName = activeTab.appName + val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5f52f95088007..5d88ca403a674 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -50,6 +50,7 @@ private[spark] abstract class WebUI( protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) + def getBasePath: String = basePath def getTabs: Seq[WebUITab] = tabs.toSeq def getHandlers: Seq[ServletContextHandler] = handlers.toSeq def getSecurityManager: SecurityManager = securityManager @@ -135,6 +136,8 @@ private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) { /** Get a list of header tabs from the parent UI. */ def headerTabs: Seq[WebUITab] = parent.getTabs + + def basePath: String = parent.getBasePath } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index b347eb1b83c1f..f0a1174a71d34 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -24,8 +24,6 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -45,7 +43,7 @@ private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("")

    Classpath Entries

    {classpathEntriesTable} - UIUtils.headerSparkPage(content, basePath, appName, "Environment", parent.headerTabs, parent) + UIUtils.headerSparkPage("Environment", content, parent) } private def propertyHeader = Seq("Name", "Value") diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index bbbe55ecf44a1..0d158fbe638d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -21,9 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.ui._ -private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "environment") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { val listener = new EnvironmentListener attachPage(new EnvironmentPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b814b0e6b8509..02df4e8fe61af 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -43,8 +43,6 @@ private case class ExecutorSummaryInfo( maxMemory: Long) private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -101,8 +99,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { ; - UIUtils.headerSparkPage(content, basePath, appName, "Executors (" + execInfo.size + ")", - parent.headerTabs, parent) + UIUtils.headerSparkPage("Executors (" + execInfo.size + ")", content, parent) } /** Render an HTML row representing an executor */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 5c2d1d1fe75d3..61eb111cd9100 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -23,11 +23,9 @@ import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} -private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "executors") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = new ExecutorsListener(parent.storageStatusListener) attachPage(new ExecutorsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 0da62892118d4..a82f71ed08475 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -94,7 +92,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

    Failed Stages ({failedStages.size})

    ++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", parent.headerTabs, parent) + UIUtils.headerSparkPage("Spark Stages", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 8a01ec80c9dd6..c16542c9db30f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -21,12 +21,10 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stages") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val live = parent.live val sc = parent.sc val conf = if (live) sc.conf else new SparkConf @@ -53,4 +51,5 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag Thread.sleep(100) } } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 0a2bf31833d2b..7a6c7d1a497ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -51,8 +49,7 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {

    Summary

    ++ poolTable.toNodeSeq ++

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Fair Scheduler Pool: " + poolName, - parent.headerTabs, parent) + UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index f4b68f241966d..64178e1e33d41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -25,7 +25,6 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { - private val basePath = parent.basePath private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -59,11 +58,11 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { case Some(stages) => stages.size case None => 0 } + val href = "%s/stages/pool?poolname=%s" + .format(UIUtils.prependBaseUri(parent.basePath), p.name) - - {p.name} - + {p.name} {p.minShare} {p.weight} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8bc1ba758cf77..d4eb02722ad12 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -29,8 +29,6 @@ import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -44,8 +42,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(content, basePath, appName, - "Details for Stage %s".format(stageId), parent.headerTabs, parent) + return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent) } val stageData = stageDataOption.get @@ -227,8 +224,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { maybeAccumulableTable ++

    Tasks

    ++ taskTable - UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), - parent.headerTabs, parent) + UIUtils.headerSparkPage("Details for Stage %d".format(stageId), content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 15998404ed612..16ad0df45aa0d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -32,7 +32,6 @@ private[ui] class StageTableBase( parent: JobProgressTab, killEnabled: Boolean = false) { - private val basePath = parent.basePath private val listener = parent.listener protected def isFairScheduler = parent.isFairScheduler @@ -88,17 +87,19 @@ private[ui] class StageTableBase( private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { + val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" + .format(s.stageId) - (kill) + (kill) } // scalastyle:on - val nameLink = - - {s.name} - + val nameLinkUri ="%s/stages/stage?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { @@ -111,7 +112,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -157,7 +158,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(parent.basePath), stageData.schedulingPool)}> {stageData.schedulingPool} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 84ac53da47552..8a0075ae8daf7 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -27,8 +27,6 @@ import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -36,8 +34,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage(Seq[Node](), basePath, appName, "RDD Not Found", - parent.headerTabs, parent) + return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) } // Worker table @@ -96,8 +93,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { ; - UIUtils.headerSparkPage(content, basePath, appName, "RDD Storage Info for " + rddInfo.name, - parent.headerTabs, parent) + UIUtils.headerSparkPage("RDD Storage Info for " + rddInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 9813d9330ac7f..716591c9ed449 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -27,14 +27,12 @@ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val rdds = listener.rddInfoList val content = UIUtils.listingTable(rddHeader, rddRow, rdds) - UIUtils.headerSparkPage(content, basePath, appName, "Storage ", parent.headerTabs, parent) + UIUtils.headerSparkPage("Storage", content, parent) } /** Header fields for the RDD table */ @@ -52,7 +50,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:off - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 5f6740d495521..67f72a94f0269 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -25,9 +25,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { val listener = new StorageListener(parent.storageStatusListener) attachPage(new StoragePage(this)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 451b23e01c995..1353e487c72cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -42,8 +42,7 @@ private[ui] class StreamingPage(parent: StreamingTab)

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ generateBatchStatsTable() - UIUtils.headerSparkPage( - content, parent.basePath, parent.appName, "Streaming", parent.headerTabs, parent, Some(5000)) + UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 51448d15c6516..34ac254f337eb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -19,15 +19,13 @@ package org.apache.spark.streaming.ui import org.apache.spark.Logging import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.WebUITab +import org.apache.spark.ui.SparkUITab /** Spark Web UI tab that shows statistics of a streaming job */ private[spark] class StreamingTab(ssc: StreamingContext) - extends WebUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(ssc.sc.ui, "streaming") with Logging { val parent = ssc.sc.ui - val appName = parent.appName - val basePath = parent.basePath val listener = new StreamingJobProgressListener(ssc) ssc.addStreamingListener(listener) From 3a5962f0f5acea5cbfd3cf1e3ed16e03b3bec37a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Aug 2014 13:38:03 -0700 Subject: [PATCH 39/54] Removed .travis.yml file since we are not using Travis. --- .travis.yml | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8ebd0d68429fc..0000000000000 --- a/.travis.yml +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - language: scala - scala: - - "2.10.3" - jdk: - - oraclejdk7 - env: - matrix: - - TEST="scalastyle assembly/assembly" - - TEST="catalyst/test sql/test streaming/test mllib/test graphx/test bagel/test" - - TEST=hive/test - cache: - directories: - - $HOME/.m2 - - $HOME/.ivy2 - - $HOME/.sbt - script: - - "sbt ++$TRAVIS_SCALA_VERSION $TEST" From d1d0ee41c27f1d07fed0c5d56ba26c723cc3dc26 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 18 Aug 2014 13:58:35 -0700 Subject: [PATCH 40/54] [SPARK-3103] [PySpark] fix saveAsTextFile() with utf-8 bugfix: It will raise an exception when it try to encode non-ASCII strings into unicode. It should only encode unicode as "utf-8". Author: Davies Liu Closes #2018 from davies/fix_utf8 and squashes the following commits: 4db7967 [Davies Liu] fix saveAsTextFile() with utf-8 --- python/pyspark/rdd.py | 4 +++- python/pyspark/tests.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 240381e5bae12..c708b69cc1e31 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1191,7 +1191,9 @@ def func(split, iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1fece998cd54..69d543d9d045d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -256,6 +256,15 @@ def test_save_as_textfile_with_unicode(self): raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) + self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) From 6201b27643023569e19b68aa9d5c4e4e59ce0d79 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 18 Aug 2014 14:10:10 -0700 Subject: [PATCH 41/54] [SPARK-2718] [yarn] Handle quotes and other characters in user args. Due to the way Yarn runs things through bash, normal quoting doesn't work as expected. This change applies the necessary voodoo to the user args to avoid issues with bash and special characters. The change also uncovered an issue with the event logger app name sanitizing code; it wasn't cleaning up all "bad" characters, so sometimes it would fail to create the log dirs. I just added some more bad character replacements. Author: Marcelo Vanzin Closes #1724 from vanzin/SPARK-2718 and squashes the following commits: cc84b89 [Marcelo Vanzin] Review feedback. c1a257a [Marcelo Vanzin] Add test for backslashes. 55571d4 [Marcelo Vanzin] Unbreak yarn-client. 515613d [Marcelo Vanzin] [SPARK-2718] [yarn] Handle quotes and other characters in user args. --- .../scheduler/EventLoggingListener.scala | 3 +- .../yarn/ApplicationMasterArguments.scala | 6 +- .../apache/spark/deploy/yarn/ClientBase.scala | 9 +-- .../deploy/yarn/ExecutorRunnableUtil.scala | 4 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 25 ++++++++ .../yarn/YarnSparkHadoopUtilSuite.scala | 64 +++++++++++++++++++ 6 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 7378ce923f0ae..370fcd85aa680 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -54,7 +54,8 @@ private[spark] class EventLoggingListener( private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") - private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis + private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_") + .toLowerCase + "-" + System.currentTimeMillis val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 4c383ab574abe..424b0fb0936f2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) - + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer = new ArrayBuffer[String]() @@ -47,7 +47,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -75,7 +75,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgs = userArgsBuffer.readOnly } - + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 1da0a1b675554..3897b3a373a8c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -300,11 +300,11 @@ trait ClientBase extends Logging { } def userArgsToString(clientArgs: ClientArguments): String = { - val prefix = " --args " + val prefix = " --arg " val args = clientArgs.userArgs val retval = new StringBuilder() for (arg <- args) { - retval.append(prefix).append(" '").append(arg).append("' ") + retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg)) } retval.toString } @@ -386,7 +386,7 @@ trait ClientBase extends Logging { // TODO: it might be nicer to pass these as an internal environment variable rather than // as Java options, due to complications with string parsing of nested quotes. for ((k, v) <- sparkConf.getAll) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } if (args.amClass == classOf[ApplicationMaster].getName) { @@ -400,7 +400,8 @@ trait ClientBase extends Logging { // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ javaOpts ++ - Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar, + Seq(args.amClass, "--class", YarnSparkHadoopUtil.escapeForShell(args.userClass), + "--jar ", YarnSparkHadoopUtil.escapeForShell(args.userJar), userArgsToString(args), "--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 71a9e42846b2b..312d82a649792 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -68,10 +68,10 @@ trait ExecutorRunnableUtil extends Logging { // authentication settings. sparkConf.getAll. filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e98308cdbd74e..10aef5eb2486f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -148,4 +148,29 @@ object YarnSparkHadoopUtil { } } + /** + * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands + * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The + * argument is enclosed in single quotes and some key characters are escaped. + * + * @param arg A single argument. + * @return Argument quoted for execution via Yarn's generated shell script. + */ + def escapeForShell(arg: String): String = { + if (arg != null) { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } + } + escaped.append("'").toString() + } else { + arg + } + } + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..7650bd4396c12 --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, IOException} + +import com.google.common.io.{ByteStreams, Files} +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.Logging + +class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { + + val hasBash = + try { + val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor() + exitCode == 0 + } catch { + case e: IOException => + false + } + + if (!hasBash) { + logWarning("Cannot execute bash, skipping bash tests.") + } + + def bashTest(name: String)(fn: => Unit) = + if (hasBash) test(name)(fn) else ignore(name)(fn) + + bashTest("shell script escaping") { + val scriptFile = File.createTempFile("script.", ".sh") + val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") + try { + val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") + Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile) + scriptFile.setExecutable(true) + + val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath())) + val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim() + val err = new String(ByteStreams.toByteArray(proc.getErrorStream())) + val exitCode = proc.waitFor() + exitCode should be (0) + out should be (args.mkString(" ")) + } finally { + scriptFile.delete() + } + } + +} From 115eeb30dd9c9dd10685a71f2c23ca23794d3142 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 14:40:05 -0700 Subject: [PATCH 42/54] [mllib] DecisionTree: treeAggregate + Python example bug fix Small DecisionTree updates: * Changed main DecisionTree aggregate to treeAggregate. * Fixed bug in python example decision_tree_runner.py with missing argument (since categoricalFeaturesInfo is no longer an optional argument for trainClassifier). * Fixed same bug in python doc tests, and added tree.py to doc tests. CC: mengxr Author: Joseph K. Bradley Closes #2015 from jkbradley/dt-opt2 and squashes the following commits: b5114fa [Joseph K. Bradley] Fixed python tree.py doc test (extra newline) 8e4665d [Joseph K. Bradley] Added tree.py to python doc tests. Fixed bug from missing categoricalFeaturesInfo argument. b7b2922 [Joseph K. Bradley] Fixed bug in python example decision_tree_runner.py with missing argument. Changed main DecisionTree aggregate to treeAggregate. 85bbc1f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 66d076f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 a0ed0da [Joseph K. Bradley] Renamed DTMetadata to DecisionTreeMetadata. Small doc updates. 3726d20 [Joseph K. Bradley] Small code improvements based on code review. ac0b9f8 [Joseph K. Bradley] Small updates based on code review. Main change: Now using << instead of math.pow. db0d773 [Joseph K. Bradley] scala style fix 6a38f48 [Joseph K. Bradley] Added DTMetadata class for cleaner code 931a3a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt2 797f68a [Joseph K. Bradley] Fixed DecisionTreeSuite bug for training second level. Needed to update treePointToNodeIndex with groupShift. f40381c [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 26d10dd [Joseph K. Bradley] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. 356daba [Joseph K. Bradley] Merge branch 'dt-opt1' into dt-opt2 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 c1565a5 [Joseph K. Bradley] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification b914f3b [Joseph K. Bradley] DecisionTree optimization: eliminated filters + small changes b2ed1f3 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals --- .../src/main/python/mllib/decision_tree_runner.py | 4 +++- .../org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- python/pyspark/mllib/tree.py | 14 ++++++++------ python/run-tests | 1 + 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index 8efadb5223f56..db96a7cb3730f 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -124,7 +124,9 @@ def usage(): (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + categoricalFeaturesInfo={} # no categorical features + model = DecisionTree.trainClassifier(reindexedData, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. print "Trained DecisionTree for classification:" print " Model numNodes: %d\n" % model.numNodes() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6b9a8f72c244e..5cdd258f6c20b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -826,7 +827,7 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. timer.start("aggregation") val binAggregates = { - input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) + input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e1a4671709b7d..e9d778df5a24b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -88,7 +88,8 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: - >>> from numpy import array, ndarray + >>> from numpy import array + >>> import sys >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -99,15 +100,15 @@ class DecisionTree(object): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> - >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) - >>> print(model) + >>> categoricalFeaturesInfo = {} # no categorical features + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, + ... categoricalFeaturesInfo=categoricalFeaturesInfo) + >>> sys.stdout.write(model) DecisionTreeModel classifier If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 True >>> model.predict(array([0.0])) == 0 @@ -119,7 +120,8 @@ class DecisionTree(object): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), + ... categoricalFeaturesInfo=categoricalFeaturesInfo) >>> model.predict(array([0.0, 1.0])) == 1 True >>> model.predict(array([0.0, 0.0])) == 0 diff --git a/python/run-tests b/python/run-tests index 1218edcbd7e08..a6271e0cf5fa9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then From c8b16ca0d86cc60fb960eebf0cb383f159a88b03 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 18:01:39 -0700 Subject: [PATCH 43/54] [SPARK-2850] [SPARK-2626] [mllib] MLlib stats examples + small fixes Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey Added sc.stop() to all examples. CorrelationSuite.scala * Added 1 test for RDDs with only 1 value RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. Python SparseVector (pyspark/mllib/linalg.py) * Added toDense() function python/run-tests script * Added stat.py (doc test) CC: mengxr dorx Main changes were examples to show usage across APIs. Author: Joseph K. Bradley Closes #1878 from jkbradley/mllib-stats-api-check and squashes the following commits: ea5c047 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check dafebe2 [Joseph K. Bradley] Bug fixes for examples SampledRDDs.scala and sampled_rdds.py: Check for division by 0 and for missing key in maps. 8d1e555 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 60c72d9 [Joseph K. Bradley] Fixed stat.py doc test to work for Python versions printing nan or NaN. b20d90a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 4e5d15e [Joseph K. Bradley] Changed pyspark/mllib/stat.py doc tests to use NaN instead of nan. 32173b7 [Joseph K. Bradley] Stats examples update. c8c20dc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check cf70b07 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 0b7cec3 [Joseph K. Bradley] Small updates based on code review. Renamed statistical_summary.py to correlations.py ab48f6e [Joseph K. Bradley] RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. 65e4ebc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 8195c78 [Joseph K. Bradley] Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey 064985b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check ee918e9 [Joseph K. Bradley] Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) --- examples/src/main/python/als.py | 2 + .../src/main/python/cassandra_inputformat.py | 2 + .../src/main/python/cassandra_outputformat.py | 2 + examples/src/main/python/hbase_inputformat.py | 2 + .../src/main/python/hbase_outputformat.py | 2 + examples/src/main/python/kmeans.py | 2 + .../src/main/python/logistic_regression.py | 2 + .../src/main/python/mllib/correlations.py | 60 +++++++++ .../main/python/mllib/decision_tree_runner.py | 5 + examples/src/main/python/mllib/kmeans.py | 1 + .../main/python/mllib/logistic_regression.py | 1 + .../python/mllib/random_rdd_generation.py | 55 ++++++++ .../src/main/python/mllib/sampled_rdds.py | 86 ++++++++++++ examples/src/main/python/pagerank.py | 2 + examples/src/main/python/pi.py | 2 + examples/src/main/python/sort.py | 2 + .../src/main/python/transitive_closure.py | 2 + examples/src/main/python/wordcount.py | 2 + .../spark/examples/mllib/Correlations.scala | 92 +++++++++++++ .../mllib/MultivariateSummarizer.scala | 98 ++++++++++++++ .../examples/mllib/RandomRDDGeneration.scala | 60 +++++++++ .../spark/examples/mllib/SampledRDDs.scala | 126 ++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 14 +- .../stat/MultivariateOnlineSummarizer.scala | 8 +- .../spark/mllib/stat/CorrelationSuite.scala | 15 ++- .../MultivariateOnlineSummarizerSuite.scala | 6 +- python/pyspark/mllib/linalg.py | 10 ++ python/pyspark/mllib/stat.py | 22 +-- python/run-tests | 1 + 29 files changed, 664 insertions(+), 20 deletions(-) create mode 100755 examples/src/main/python/mllib/correlations.py create mode 100755 examples/src/main/python/mllib/random_rdd_generation.py create mode 100755 examples/src/main/python/mllib/sampled_rdds.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index c862650b0aa1d..5b1fa4d997eeb 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -97,3 +97,5 @@ def update(i, vec, mat, ratings): error = rmse(R, ms, us) print "Iteration %d:" % i print "\nRMSE: %5.4f\n" % error + + sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 39fa6b0d22ef5..e4a897f61e39d 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -77,3 +77,5 @@ output = cass_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index 1dfbf98604425..836c35b5c6794 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -81,3 +81,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") + + sc.stop() diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index c9fa8e171c2a1..befacee0dea56 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -71,3 +71,5 @@ output = hbase_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index 5e11548fd13f7..49bbc5aebdb0b 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -63,3 +63,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") + + sc.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 036bdf4c4f999..86ef6f32c84e8 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -77,3 +77,5 @@ def closestPoint(p, centers): kPoints[x] = y print "Final centers: " + str(kPoints) + + sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 8456b272f9c05..3aa56b0528168 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -80,3 +80,5 @@ def add(x, y): w -= points.map(lambda m: gradient(m, w)).reduce(add) print "Final w: " + str(w) + + sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py new file mode 100755 index 0000000000000..6b16a56e44af7 --- /dev/null +++ b/examples/src/main/python/mllib/correlations.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Correlations using MLlib. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1,2]: + print >> sys.stderr, "Usage: correlations ()" + exit(-1) + sc = SparkContext(appName="PythonCorrelations") + if len(sys.argv) == 2: + filepath = sys.argv[1] + else: + filepath = 'data/mllib/sample_linear_regression_data.txt' + corrType = 'pearson' + + points = MLUtils.loadLibSVMFile(sc, filepath)\ + .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) + + print + print 'Summary of data file: ' + filepath + print '%d data points' % points.count() + + # Statistics (correlations) + print + print 'Correlation (%s) between label and each feature' % corrType + print 'Feature\tCorrelation' + numFeatures = points.take(1)[0].features.size + labelRDD = points.map(lambda lp: lp.label) + for i in range(numFeatures): + featureRDD = points.map(lambda lp: lp.features[i]) + corr = Statistics.corr(labelRDD, featureRDD, corrType) + print '%d\t%g' % (i, corr) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index db96a7cb3730f..6e4a4a0cb6be0 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -17,6 +17,8 @@ """ Decision tree classification and regression using MLlib. + +This example requires NumPy (http://www.numpy.org/). """ import numpy, os, sys @@ -117,6 +119,7 @@ def usage(): if len(sys.argv) == 2: dataPath = sys.argv[1] if not os.path.isfile(dataPath): + sc.stop() usage() points = MLUtils.loadLibSVMFile(sc, dataPath) @@ -133,3 +136,5 @@ def usage(): print " Model depth: %d\n" % model.depth() print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) print model + + sc.stop() diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index b308132c9aeeb..2eeb1abeeb12b 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -42,3 +42,4 @@ def parseVector(line): k = int(sys.argv[2]) model = KMeans.train(data, k) print "Final centers: " + str(model.clusterCenters) + sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 9d547ff77c984..8cae27fc4a52d 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -50,3 +50,4 @@ def parsePoint(line): model = LogisticRegressionWithSGD.train(points, iterations) print "Final weights: " + str(model.weights) print "Final intercept: " + str(model.intercept) + sc.stop() diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py new file mode 100755 index 0000000000000..b388d8d83fb86 --- /dev/null +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly generated RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.random import RandomRDDs + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: random_rdd_generation" + exit(-1) + + sc = SparkContext(appName="PythonRandomRDDGeneration") + + numExamples = 10000 # number of examples to generate + fraction = 0.1 # fraction of data to sample + + # Example: RandomRDDs.normalRDD + normalRDD = RandomRDDs.normalRDD(sc, numExamples) + print 'Generated RDD of %d examples sampled from the standard normal distribution'\ + % normalRDD.count() + print ' First 5 samples:' + for sample in normalRDD.take(5): + print ' ' + str(sample) + print + + # Example: RandomRDDs.normalVectorRDD + normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() + print ' First 5 samples:' + for sample in normalVectorRDD.take(5): + print ' ' + str(sample) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py new file mode 100755 index 0000000000000..ec64a5978c672 --- /dev/null +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly sampled RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: sampled_rdds " + exit(-1) + if len(sys.argv) == 2: + datapath = sys.argv[1] + else: + datapath = 'data/mllib/sample_binary_classification_data.txt' + + sc = SparkContext(appName="PythonSampledRDDs") + + fraction = 0.1 # fraction of data to sample + + examples = MLUtils.loadLibSVMFile(sc, datapath) + numExamples = examples.count() + if numExamples == 0: + print >> sys.stderr, "Error: Data file had no samples to load." + exit(1) + print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + + # Example: RDD.sample() and RDD.takeSample() + expectedSampleSize = int(numExamples * fraction) + print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ + % (fraction, expectedSampleSize) + sampledRDD = examples.sample(withReplacement = True, fraction = fraction) + print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize) + print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + + print + + # Example: RDD.sampleByKey() + keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) + print ' Keyed data using label (Int) as key ==> Orig' + # Count examples per label in original data. + keyCountsA = keyedRDD.countByKey() + + # Subsample, and count examples per label in sampled data. + fractions = {} + for k in keyCountsA.keys(): + fractions[k] = fraction + sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions) + keyCountsB = sampledByKeyRDD.countByKey() + sizeB = sum(keyCountsB.values()) + print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ + % sizeB + + # Compare samples + print ' \tFractions of examples with key' + print 'Key\tOrig\tSample' + for k in sorted(keyCountsA.keys()): + fracA = keyCountsA[k] / float(numExamples) + if sizeB != 0: + fracB = keyCountsB.get(k, 0) / float(sizeB) + else: + fracB = 0 + print '%d\t%g\t%g' % (k, fracA, fracB) + + sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0b96343158d44..b539c4128cdcc 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -68,3 +68,5 @@ def parseNeighbors(urls): # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): print "%s has rank: %s." % (link, rank) + + sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 21d94a2cd4b64..fc37459dc74aa 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -37,3 +37,5 @@ def f(_): count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) + + sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 41d00c1b79133..bb686f17518a0 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -34,3 +34,5 @@ output = sortedCount.collect() for (num, unitcount) in output: print num + + sc.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 8698369b13d84..bf331b542c438 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -64,3 +64,5 @@ def generateGraph(): break print "TC has %i edges" % tc.count() + + sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index dcc095fdd0ed9..ae6cd13b83d92 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -33,3 +33,5 @@ output = counts.collect() for (word, count) in output: print "%s: %i" % (word, count) + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala new file mode 100644 index 0000000000000..d6b2fe430e5a4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.Correlations + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object Correlations { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("Correlations") { + head("Correlations: an example app for computing correlations") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.Correlations \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"Correlations with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Calculate label -- feature correlations + val labelRDD = examples.map(_.label) + val numFeatures = examples.take(1)(0).features.size + val corrType = "pearson" + println() + println(s"Correlation ($corrType) between label and each feature") + println(s"Feature\tCorrelation") + var feature = 0 + while (feature < numFeatures) { + val featureRDD = examples.map(_.features(feature)) + val corr = Statistics.corr(labelRDD, featureRDD) + println(s"$feature\t$corr") + feature += 1 + } + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala new file mode 100644 index 0000000000000..4532512c01f84 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.MultivariateSummarizer + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object MultivariateSummarizer { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("MultivariateSummarizer") { + head("MultivariateSummarizer: an example app for MultivariateOnlineSummarizer") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.MultivariateSummarizer \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MultivariateSummarizer with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Summarize labels + val labelSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(Vectors.dense(lp.label)), + (sum1, sum2) => sum1.merge(sum2)) + + // Summarize features + val featureSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(lp.features), + (sum1, sum2) => sum1.merge(sum2)) + + println() + println(s"Summary statistics") + println(s"\tLabel\tFeatures") + println(s"mean\t${labelSummary.mean(0)}\t${featureSummary.mean.toArray.mkString("\t")}") + println(s"var\t${labelSummary.variance(0)}\t${featureSummary.variance.toArray.mkString("\t")}") + println( + s"nnz\t${labelSummary.numNonzeros(0)}\t${featureSummary.numNonzeros.toArray.mkString("\t")}") + println(s"max\t${labelSummary.max(0)}\t${featureSummary.max.toArray.mkString("\t")}") + println(s"min\t${labelSummary.min(0)}\t${featureSummary.min.toArray.mkString("\t")}") + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala new file mode 100644 index 0000000000000..924b586e3af99 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.random.RandomRDDs +import org.apache.spark.rdd.RDD + +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example app for randomly generated RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.RandomRDDGeneration + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomRDDGeneration { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName(s"RandomRDDGeneration") + val sc = new SparkContext(conf) + + val numExamples = 10000 // number of examples to generate + val fraction = 0.1 // fraction of data to sample + + // Example: RandomRDDs.normalRDD + val normalRDD: RDD[Double] = RandomRDDs.normalRDD(sc, numExamples) + println(s"Generated RDD of ${normalRDD.count()}" + + " examples sampled from the standard normal distribution") + println(" First 5 samples:") + normalRDD.take(5).foreach( x => println(s" $x") ) + + // Example: RandomRDDs.normalVectorRDD + val normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + println(s"Generated RDD of ${normalVectorRDD.count()} examples of length-2 vectors.") + println(" First 5 samples:") + normalVectorRDD.take(5).foreach( x => println(s" $x") ) + + println() + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala new file mode 100644 index 0000000000000..f01b8266e3fe3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.util.MLUtils +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ + +/** + * An example app for randomly generated and sampled RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.SampledRDDs + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object SampledRDDs { + + case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("SampledRDDs") { + head("SampledRDDs: an example app for randomly generated and sampled RDDs.") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.SampledRDDs \ + | examples/target/scala-*/spark-examples-*.jar + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"SampledRDDs with $params") + val sc = new SparkContext(conf) + + val fraction = 0.1 // fraction of data to sample + + val examples = MLUtils.loadLibSVMFile(sc, params.input) + val numExamples = examples.count() + if (numExamples == 0) { + throw new RuntimeException("Error: Data file had no samples to load.") + } + println(s"Loaded data with $numExamples examples from file: ${params.input}") + + // Example: RDD.sample() and RDD.takeSample() + val expectedSampleSize = (numExamples * fraction).toInt + println(s"Sampling RDD using fraction $fraction. Expected sample size = $expectedSampleSize.") + val sampledRDD = examples.sample(withReplacement = true, fraction = fraction) + println(s" RDD.sample(): sample has ${sampledRDD.count()} examples") + val sampledArray = examples.takeSample(withReplacement = true, num = expectedSampleSize) + println(s" RDD.takeSample(): sample has ${sampledArray.size} examples") + + println() + + // Example: RDD.sampleByKey() and RDD.sampleByKeyExact() + val keyedRDD = examples.map { lp => (lp.label.toInt, lp.features) } + println(s" Keyed data using label (Int) as key ==> Orig") + // Count examples per label in original data. + val keyCounts = keyedRDD.countByKey() + + // Subsample, and count examples per label in sampled data. (approximate) + val fractions = keyCounts.keys.map((_, fraction)).toMap + val sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = true, fractions = fractions) + val keyCountsB = sampledByKeyRDD.countByKey() + val sizeB = keyCountsB.values.sum + println(s" Sampled $sizeB examples using approximate stratified sampling (by label)." + + " ==> Approx Sample") + + // Subsample, and count examples per label in sampled data. (approximate) + val sampledByKeyRDDExact = + keyedRDD.sampleByKeyExact(withReplacement = true, fractions = fractions) + val keyCountsBExact = sampledByKeyRDDExact.countByKey() + val sizeBExact = keyCountsBExact.values.sum + println(s" Sampled $sizeBExact examples using exact stratified sampling (by label)." + + " ==> Exact Sample") + + // Compare samples + println(s" \tFractions of examples with key") + println(s"Key\tOrig\tApprox Sample\tExact Sample") + keyCounts.keys.toSeq.sorted.foreach { key => + val origFrac = keyCounts(key) / numExamples.toDouble + val approxFrac = if (sizeB != 0) { + keyCountsB.getOrElse(key, 0L) / sizeB.toDouble + } else { + 0 + } + val exactFrac = if (sizeBExact != 0) { + keyCountsBExact.getOrElse(key, 0L) / sizeBExact.toDouble + } else { + 0 + } + println(s"$key\t$origFrac\t$approxFrac\t$exactFrac") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e76bc9fefff01..2e414a73be8e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -53,8 +53,14 @@ class RowMatrix( /** Gets or computes the number of columns. */ override def numCols(): Long = { if (nCols <= 0) { - // Calling `first` will throw an exception if `rows` is empty. - nCols = rows.first().size + try { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().size + } catch { + case err: UnsupportedOperationException => + sys.error("Cannot determine the number of cols because it is not specified in the " + + "constructor and the rows RDD is empty.") + } } nCols } @@ -293,6 +299,10 @@ class RowMatrix( (s1._1 + s2._1, s1._2 += s2._2) ) + if (m <= 1) { + sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." + + " Cannot compute the covariance of a RowMatrix with <= 1 row.") + } updateNumRows(m) mean :/= m.toDouble diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 5105b5c37aaaa..7d845c44365dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ def add(sample: Vector): this.type = { if (n == 0) { - require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.") - n = sample.toBreeze.length + require(sample.size > 0, s"Vector should have dimension larger than zero.") + n = sample.size currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) @@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = BDV.fill(n)(Double.MaxValue) } - require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.toBreeze.length}.") + require(n == sample.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${sample.size}.") sample.toBreeze.activeIterator.foreach { case (_, 0.0) => // Skip explicit zero elements. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index a3f76f77a5dcc..34548c86ebc14 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -39,6 +39,17 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { Vectors.dense(9.0, 0.0, 0.0, 1.0) ) + test("corr(x, y) pearson, 1 value in data") { + val x = sc.parallelize(Array(1.0)) + val y = sc.parallelize(Array(4.0)) + intercept[RuntimeException] { + Statistics.corr(x, y, "pearson") + } + intercept[RuntimeException] { + Statistics.corr(x, y, "spearman") + } + } + test("corr(x, y) default, pearson") { val x = sc.parallelize(xData) val y = sc.parallelize(yData) @@ -58,7 +69,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z).isNaN()) + assert(Statistics.corr(x, z).isNaN) } test("corr(x, y) spearman") { @@ -78,7 +89,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance => zero variance in ranks val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z, "spearman").isNaN()) + assert(Statistics.corr(x, z, "spearman").isNaN) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index db13f142df517..1e9415249104b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -139,7 +139,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } @@ -167,7 +168,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 9a239abfbbeb1..f485a69db1fa2 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -23,6 +23,7 @@ SciPy is available in their environment. """ +import numpy from numpy import array, array_equal, ndarray, float64, int32 @@ -160,6 +161,15 @@ def squared_distance(self, other): j += 1 return result + def toArray(self): + """ + Returns a copy of this SparseVector as a 1-dimensional NumPy array. + """ + arr = numpy.zeros(self.size) + for i in xrange(self.indices.size): + arr[self.indices[i]] = self.values[i] + return arr + def __str__(self): inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index a73abc5ff90df..feef0d16cd644 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -118,16 +118,18 @@ def corr(x, y=None, method=None): >>> from linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) - >>> Statistics.corr(rdd) - array([[ 1. , 0.05564149, nan, 0.40047142], - [ 0.05564149, 1. , nan, 0.91359586], - [ nan, nan, 1. , nan], - [ 0.40047142, 0.91359586, nan, 1. ]]) - >>> Statistics.corr(rdd, method="spearman") - array([[ 1. , 0.10540926, nan, 0.4 ], - [ 0.10540926, 1. , nan, 0.9486833 ], - [ nan, nan, 1. , nan], - [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> pearsonCorr = Statistics.corr(rdd) + >>> print str(pearsonCorr).replace('nan', 'NaN') + [[ 1. 0.05564149 NaN 0.40047142] + [ 0.05564149 1. NaN 0.91359586] + [ NaN NaN 1. NaN] + [ 0.40047142 0.91359586 NaN 1. ]] + >>> spearmanCorr = Statistics.corr(rdd, method="spearman") + >>> print str(spearmanCorr).replace('nan', 'NaN') + [[ 1. 0.10540926 NaN 0.4 ] + [ 0.10540926 1. NaN 0.9486833 ] + [ NaN NaN 1. NaN] + [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") ... print "Method name as second argument without 'method=' shouldn't be allowed." diff --git a/python/run-tests b/python/run-tests index a6271e0cf5fa9..b506559a5e810 100755 --- a/python/run-tests +++ b/python/run-tests @@ -78,6 +78,7 @@ run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" +run_test "pyspark/mllib/stat.py" run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" From 217b5e915e2f21f047dfc4be680cd20d58baf9f8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 18 Aug 2014 18:20:54 -0700 Subject: [PATCH 44/54] [SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab Author: Xiangrui Meng Closes #2023 from mengxr/predict-on-values and squashes the following commits: cac47b8 [Xiangrui Meng] add classtag 2821b3b [Xiangrui Meng] use mapValues 0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn --- .../mllib/StreamingLinearRegression.scala | 4 +-- .../regression/StreamingLinearAlgorithm.scala | 31 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 0e992fa9967bb..c5bd5b0b178d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -59,10 +59,10 @@ object StreamingLinearRegression { val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) + .setInitialWeights(Vectors.zeros(args(3).toInt)) model.trainOn(trainingData) - model.predictOn(testData).print() + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b8b0b42611775..8db0442a7a569 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[ /** * Use the model to make predictions on batches of data from a DStream * - * @param data DStream containing labeled data + * @param data DStream containing feature vectors * @return DStream containing predictions */ - def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + def predictOn(data: DStream[Vector]): DStream[Double] = { if (Option(model.weights) == None) { - logError("Initial weights must be set before starting prediction") - throw new IllegalArgumentException + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) } - data.map(x => model.predict(x.features)) + data.map(model.predict) } + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * @param data DStream containing feature vectors + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { + if (Option(model.weights) == None) { + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) + } + data.mapValues(model.predict) + } } From 1f1819b20f887b487557c31e54b8bcd95b582dc6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 18 Aug 2014 20:42:19 -0700 Subject: [PATCH 45/54] [SPARK-3114] [PySpark] Fix Python UDFs in Spark SQL. This fixes SPARK-3114, an issue where we inadvertently broke Python UDFs in Spark SQL. This PR modifiers the test runner script to always run the PySpark SQL tests, irrespective of whether SparkSQL itself has been modified. It also includes Davies' fix for the bug. Closes #2026. Author: Josh Rosen Author: Davies Liu Closes #2027 from JoshRosen/pyspark-sql-fix and squashes the following commits: 9af2708 [Davies Liu] bugfix: disable compression of command 0d8d3a4 [Josh Rosen] Always run Python Spark SQL tests. --- dev/run-tests | 17 +++++++++++++---- python/pyspark/rdd.py | 2 +- python/pyspark/worker.py | 2 +- python/run-tests | 4 +--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index 0e24515d1376c..132f696d6447a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -58,7 +58,7 @@ if [ -n "$AMPLAB_JENKINS" ]; then diffs=`git diff --name-only master | grep "^sql/"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." - export _RUN_SQL_TESTS=true # exported for PySpark tests + _RUN_SQL_TESTS=true fi fi @@ -89,13 +89,22 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +# Build Spark; we always build with Hive because the PySpark SparkSQL tests need it. +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, +# etc to quit or retry. This echo is there to make it not block. +BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver " +echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \ + grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + +# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: if [ -n "$_RUN_SQL_TESTS" ]; then SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi -# echo "q" is needed because sbt on encountering a build file with failure -# (either resolution or compilation) prompts the user for input either q, r, +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ +echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "" diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c708b69cc1e31..86cd89b245aea 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1812,7 +1812,7 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - ser = CompressedSerializer(CloudPickleSerializer()) + ser = CloudPickleSerializer() pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 77a9c4a0e0677..6805063e06798 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -72,7 +72,7 @@ def main(infile, outfile): value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = ser._read_with_length(infile) + command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) diff --git a/python/run-tests b/python/run-tests index b506559a5e810..7b1ee3e1cddba 100755 --- a/python/run-tests +++ b/python/run-tests @@ -59,9 +59,7 @@ $PYSPARK_PYTHON --version run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" -if [ -n "$_RUN_SQL_TESTS" ]; then - run_test "pyspark/sql.py" -fi +run_test "pyspark/sql.py" # These tests are included in the module-level docs, and so must # be handled on a higher level rather than within the python file. export PYSPARK_DOC_TEST=1 From 82577339dd58b5811eab5d10667775e61e37ff51 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 18 Aug 2014 20:51:41 -0700 Subject: [PATCH 46/54] [SPARK-3116] Remove the excessive lockings in TorrentBroadcast Author: Reynold Xin Closes #2028 from rxin/torrentBroadcast and squashes the following commits: 92c62a5 [Reynold Xin] Revert the MEMORY_AND_DISK_SER changes. 03a5221 [Reynold Xin] [SPARK-3116] Remove the excessive lockings in TorrentBroadcast --- .../spark/broadcast/TorrentBroadcast.scala | 66 ++++++++----------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index fe73456ef8fad..d8be649f96e5f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,8 +17,7 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, InputStream, - ObjectInputStream, ObjectOutputStream, OutputStream} +import java.io._ import scala.reflect.ClassTag import scala.util.Random @@ -53,10 +52,8 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private val broadcastId = BroadcastBlockId(id) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - } + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) @transient private var arrayOfBlocks: Array[TorrentBlock] = null @transient private var totalBlocks = -1 @@ -91,18 +88,14 @@ private[spark] class TorrentBroadcast[T: ClassTag]( // Store meta-info val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) // Store individual pieces for (i <- 0 until totalBlocks) { val pieceId = BroadcastBlockId(id, "piece" + i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } @@ -165,21 +158,20 @@ private[spark] class TorrentBroadcast[T: ClassTag]( val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => - val tInfo = x.asInstanceOf[TorrentInfo] - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - hasBlocks = 0 - - case None => - Thread.sleep(500) - } + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) } attemptId -= 1 } + if (totalBlocks == -1) { return false } @@ -192,17 +184,15 @@ private[spark] class TorrentBroadcast[T: ClassTag]( val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { val pieceId = BroadcastBlockId(id, "piece" + pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - hasBlocks += 1 - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) } } @@ -291,9 +281,7 @@ private[broadcast] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { - synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - } + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } From cd0720ca77894d481fb73a8b5bb517013843cb1e Mon Sep 17 00:00:00 2001 From: Matt Forbes Date: Mon, 18 Aug 2014 21:43:32 -0700 Subject: [PATCH 47/54] Fix typo in decision tree docs Candidate splits were inconsistent with the example. Author: Matt Forbes Closes #1837 from emef/tree-doc and squashes the following commits: 3be14a1 [Matt Forbes] Fix typo in decision tree docs --- docs/mllib-decision-tree.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cbd880897578..c01a92a9a1b26 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -84,8 +84,8 @@ Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical -features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +features are ordered as A followed by C followed B or A, C, B. The two split candidates are A \| C, B +and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value is used for ordering. From 7eb9cbc273d758522e787fcb2ef68ef65911475f Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 19 Aug 2014 09:40:31 -0500 Subject: [PATCH 48/54] [SPARK-3072] YARN - Exit when reach max number failed executors In some cases on hadoop 2.x the spark application master doesn't properly exit and hangs around for 10 minutes after its really done. We should make sure it exits properly and stops the driver. Author: Thomas Graves Closes #2022 from tgravescs/SPARK-3072 and squashes the following commits: 665701d [Thomas Graves] Exit when reach max number failed executors --- .../spark/deploy/yarn/ApplicationMaster.scala | 33 ++++++++++++------- .../spark/deploy/yarn/ExecutorLauncher.scala | 5 +-- .../spark/deploy/yarn/ApplicationMaster.scala | 16 ++++++--- .../spark/deploy/yarn/ExecutorLauncher.scala | 5 +-- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 62b5c3bc5f0f3..46a01f5a9a2cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -267,12 +267,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - // Exists the loop if the user thread exits. - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + // Exits the loop if the user thread exits. + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { + checkNumExecutorsFailed() yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) @@ -303,11 +301,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + while (userThread.isAlive && !isFinished) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". @@ -327,6 +322,22 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, t } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } + } + } + private def sendProgress() { logDebug("Sending progress") // Simulated with an allocate request with no nodes requested ... diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 184e2ad6c82cd..72c7143edcd71 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -249,7 +249,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) checkNumExecutorsFailed() @@ -271,7 +272,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 035356d390c80..9c2bcf17a8508 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -247,13 +247,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, yarnAllocator.allocateResources() // Exits the loop if the user thread exits. - var iters = 0 - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) - iters += 1 } } logInfo("All executors have launched.") @@ -271,8 +270,17 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private def checkNumExecutorsFailed() { if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") finishApplicationMaster(FinalApplicationStatus.FAILED, "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } } } @@ -289,7 +297,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { + while (userThread.isAlive && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index fc7b8320d734d..a7585748b7f88 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -217,7 +217,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() @@ -249,7 +250,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") From cbfc26ba45f49559e64276c72e3054c6fe30ddd5 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 19 Aug 2014 10:15:11 -0700 Subject: [PATCH 49/54] [SPARK-3089] Fix meaningless error message in ConnectionManager Author: Kousuke Saruta Closes #2000 from sarutak/SPARK-3089 and squashes the following commits: 02dfdea [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3089 e759ce7 [Kousuke Saruta] Improved error message when closing SendingConnection --- .../main/scala/org/apache/spark/network/ConnectionManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e77d762bdf221..b3e951ded6e77 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -467,7 +467,7 @@ private[spark] class ConnectionManager( val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) if (!sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") + logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") return } From 31f0b071efd0b63eb9d6a6a131e5c4fa28237583 Mon Sep 17 00:00:00 2001 From: freeman Date: Tue, 19 Aug 2014 13:28:57 -0700 Subject: [PATCH 50/54] [SPARK-3128][MLLIB] Use streaming test suite for StreamingLR Refactored tests for streaming linear regression to use existing streaming test utilities. Summary of changes: - Made ``mllib`` depend on tests from ``streaming`` - Rewrote accuracy and convergence tests to use ``setupStreams`` and ``runStreams`` - Added new test for the accuracy of predictions generated by ``predictOnValue`` These tests should run faster, be easier to extend/maintain, and provide a reference for new tests. mengxr tdas Author: freeman Closes #2037 from freeman-lab/streamingLR-predict-tests and squashes the following commits: e851ca7 [freeman] Fixed long lines 50eb0bf [freeman] Refactored tests to use streaming test tools 32c43c2 [freeman] Added test for prediction --- mllib/pom.xml | 7 + .../StreamingLinearRegressionSuite.scala | 121 ++++++++++-------- .../spark/streaming/TestSuiteBase.scala | 4 +- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index fc1ecfbea708f..c7a1e2ae75c84 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -91,6 +91,13 @@ junit-interface test
    + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test +
    diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 45e25eecf508e..28489410f8225 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.mllib.regression -import java.io.File -import java.nio.charset.Charset - import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { -class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { + // use longer wait time to ensure job completion + override def maxWaitTimeMillis = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data - test("streaming linear regression parameter accuracy") { + test("parameter accuracy") { - val testDir = Files.createTempDir() - val numBatches = 10 - val batchDuration = Milliseconds(1000) - val ssc = new StreamingContext(sc, batchDuration) - val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) - .setNumIterations(50) + .setNumIterations(25) - model.trainOn(data) - - ssc.start() - - // write data to a file stream - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput( - 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) - - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) // check accuracy of final parameter estimates assertEqual(model.latestModel().intercept, 0.0, 0.1) @@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test that parameter estimates improve when learning Y = 10*X1 on streaming data - test("streaming linear regression parameter convergence") { + test("parameter convergence") { - val testDir = Files.createTempDir() - val batchDuration = Milliseconds(2000) - val ssc = new StreamingContext(sc, batchDuration) - val numBatches = 5 - val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) - .setNumIterations(50) - - model.trainOn(data) - - ssc.start() + .setNumIterations(25) - // write data to a file stream - val history = new ArrayBuffer[Double](numBatches) - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) - // wait an extra few seconds to make sure the update finishes before new data arrive - Thread.sleep(4000) - history.append(math.abs(model.latestModel().weights(0) - 10.0)) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + // compute change in error val deltas = history.drop(1).zip(history.dropRight(1)) // check error stability (it always either shrinks, or increases with small tol) assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) @@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } + // Test predictions on a stream + test("predictions") { + + // create model initialized with true weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(10.0, 10.0)) + .setStepSize(0.1) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // compute the mean absolute error and check that it's always less than 0.1 + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + assert(errors.forall(x => x <= 0.1)) + + } + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index cc178fba12c9d..f095da9cb55d3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -242,7 +242,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = ssc.graph.getOutputStreams. + filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]). + head.asInstanceOf[TestOutputStreamWithPartitions[V]] val output = outputStream.output try { From 94053a7b766788bb62e2dbbf352ccbcc75f71fc0 Mon Sep 17 00:00:00 2001 From: Vida Ha Date: Tue, 19 Aug 2014 13:35:05 -0700 Subject: [PATCH 51/54] SPARK-2333 - spark_ec2 script should allow option for existing security group - Uses the name tag to identify machines in a cluster. - Allows overriding the security group name so it doesn't need to coincide with the cluster name. - Outputs the request id's of up to 10 pending spot instance requests. Author: Vida Ha Closes #1899 from vidaha/vida/ec2-reuse-security-group and squashes the following commits: c80d5c3 [Vida Ha] wrap retries in a try catch block b2989d5 [Vida Ha] SPARK-2333: spark_ec2 script should allow option for existing security group --- docs/ec2-scripts.md | 14 +++++---- ec2/spark_ec2.py | 71 +++++++++++++++++++++++++++++++-------------- 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 156a727026790..f5ac6d894e1eb 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,14 +12,16 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster +launches a set of instances, which are tagged with the cluster name, +and placed into EC2 security groups. If you don't specify a security +group, the `spark-ec2` script will create security groups based on the +cluster name you request. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. +`test-slaves`. You can also specify a security group prefix to be used +in place of the cluster name. Machines in a cluster can be identified +by looking for the "Name" tag of the instance in the Amazon EC2 Console. # Before You Start diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0c2f85a3868f4..3a8c816cfffa1 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -124,7 +124,7 @@ def parse_args(): help="The SSH user you want to connect as (default: root)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") + help="When destroying a cluster, delete the security groups that were created.") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -138,7 +138,9 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - + parser.add_option( + "--security-group-prefix", type="string", default=None, + help="Use this prefix for the security group rather than the cluster name.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -285,8 +287,12 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") + if opts.security_group_prefix is None: + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") + else: + master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") + slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) master_group.authorize(src_group=slave_group) @@ -310,12 +316,11 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') - # Check if instances are already running in our groups + # Check if instances are already running with the cluster name existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) sys.exit(1) # Figure out Spark AMI @@ -371,9 +376,13 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] + outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) + if i in id_to_req: + if id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) + else: + outstanding_request_ids.append(i) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -382,8 +391,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print "%d of %d slaves granted, waiting longer for request ids including %s" % ( + len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -440,14 +449,29 @@ def launch_cluster(conn, opts, cluster_name): print "Launched master in %s, regid = %s" % (zone, master_res.id) # Give the instances descriptive names + # TODO: Add retry logic for tagging with name since it's used to identify a cluster. for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) + for i in range(0, 5): + try: + master.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) + + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) + for i in range(0, 5): + try: + slave.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) # Return all the instances return (master_nodes, slave_nodes) @@ -463,10 +487,10 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - group_names = [g.name for g in inst.groups] - if group_names == [cluster_name + "-master"]: + name = inst.tags.get(u'Name', "") + if name.startswith(cluster_name + "-master"): master_nodes.append(inst) - elif group_names == [cluster_name + "-slaves"]: + elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) @@ -474,7 +498,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) @@ -816,7 +840,10 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - group_names = [cluster_name + "-master", cluster_name + "-slaves"] + if opts.security_group_prefix is None: + group_names = [cluster_name + "-master", cluster_name + "-slaves"] + else: + group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] attempt = 1 while attempt <= 3: From 76eaeb4523ee01cabbea2d867daac48a277885a1 Mon Sep 17 00:00:00 2001 From: hzw19900416 Date: Tue, 19 Aug 2014 14:04:49 -0700 Subject: [PATCH 52/54] Move a bracket in validateSettings of SparkConf Move a bracket in validateSettings of SparkConf Author: hzw19900416 Closes #2012 from hzw19900416/codereading and squashes the following commits: e717fb6 [hzw19900416] Move a bracket in validateSettings of SparkConf --- core/src/main/scala/org/apache/spark/SparkConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13f0bff7ee507..b4f321ec99e78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate spark.executor.extraJavaOptions settings.get(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { - val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts)'. " + + val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." throw new Exception(msg) } From d7e80c2597d4a9cae2e0cb35a86f7889323f4cbb Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 19 Aug 2014 14:46:32 -0700 Subject: [PATCH 53/54] [SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes. If two RDDs have different batch size in serializers, then it will try to re-serialize the one with smaller batch size, then call RDD.zip() in Spark. Author: Davies Liu Closes #1894 from davies/zip and squashes the following commits: c4652ea [Davies Liu] add more test cases 6d05fc8 [Davies Liu] Merge branch 'master' into zip 813b1e4 [Davies Liu] add more tests for failed cases a4aafda [Davies Liu] fix zip with serializers which have different batch sizes. --- python/pyspark/rdd.py | 25 +++++++++++++++++++++++++ python/pyspark/serializers.py | 3 +++ python/pyspark/tests.py | 27 ++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 86cd89b245aea..140cbe05a43b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1687,6 +1687,31 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") + + def get_batch_size(ser): + if isinstance(ser, BatchedSerializer): + return ser.batchSize + return 0 + + def batch_as(rdd, batchSize): + ser = rdd._jrdd_deserializer + if isinstance(ser, BatchedSerializer): + ser = ser.serializer + return rdd._reserialize(BatchedSerializer(ser, batchSize)) + + my_batch = get_batch_size(self._jrdd_deserializer) + other_batch = get_batch_size(other._jrdd_deserializer) + if my_batch != other_batch: + # use the greatest batchSize to batch the other one. + if my_batch > other_batch: + other = batch_as(other, my_batch) + else: + self = batch_as(self, other_batch) + + # There will be an Exception in JVM if there are different number + # of items in each partitions. pairRDD = self._jrdd.zip(other._jrdd) deserializer = PairDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 74870c0edcf99..fc49aa42dbaf9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -255,6 +255,9 @@ def __init__(self, key_ser, val_ser): def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): + if len(keys) != len(vals): + raise ValueError("Can not deserialize RDD with different number of items" + " in pair: (%d, %d)" % (len(keys), len(vals))) for pair in izip(keys, vals): yield pair diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 69d543d9d045d..51bfbb47e53c2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -39,7 +39,7 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False @@ -339,6 +339,31 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEquals(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + class TestIO(PySparkTestCase): From 825d4fe47b9c4d48de88622dd48dcf83beb8b80a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 19 Aug 2014 16:06:48 -0700 Subject: [PATCH 54/54] [SPARK-3136][MLLIB] Create Java-friendly methods in RandomRDDs Though we don't use default argument for methods in RandomRDDs, it is still not easy for Java users to use because the output type is either `RDD[Double]` or `RDD[Vector]`. Java users should expect `JavaDoubleRDD` and `JavaRDD[Vector]`, respectively. We should create dedicated methods for Java users, and allow default arguments in Scala methods in RandomRDDs, to make life easier for both Java and Scala users. This PR also contains documentation for random data generation. brkyvz Author: Xiangrui Meng Closes #2041 from mengxr/stat-doc and squashes the following commits: fc5eedf [Xiangrui Meng] add missing comma ffde810 [Xiangrui Meng] address comments aef6d07 [Xiangrui Meng] add doc for random data generation b99d94b [Xiangrui Meng] add java-friendly methods to RandomRDDs --- docs/mllib-guide.md | 2 +- docs/mllib-stats.md | 74 ++- .../mllib/random/RandomDataGenerator.scala | 18 +- .../spark/mllib/random/RandomRDDs.scala | 476 +++++++----------- .../mllib/random/JavaRandomRDDsSuite.java | 134 +++++ python/pyspark/mllib/random.py | 20 +- 6 files changed, 418 insertions(+), 306 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 23d5a0c4607af..ca0a84a8c53fd 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -9,7 +9,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Data types](mllib-basics.html) * [Basic statistics](mllib-stats.html) - * data generators + * random data generation * stratified sampling * summary statistics * hypothesis testing diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md index ca9ef46c15186..f25dca746ba3a 100644 --- a/docs/mllib-stats.md +++ b/docs/mllib-stats.md @@ -25,7 +25,79 @@ displayTitle: MLlib - Statistics Functionality \newcommand{\zero}{\mathbf{0}} \]` -## Data Generators +## Random data generation + +Random data generation is useful for randomized algorithms, prototyping, and performance testing. +MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +uniform, standard normal, or Poisson. + +
    +
    +[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.random.RandomRDDs._ + +val sc: SparkContext = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +val u = normalRDD(sc, 1000000L, 10) +// Apply a transform to get a random double RDD following `N(1, 4)`. +val v = u.map(x => 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/java/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight java %} +import org.apache.spark.SparkContext; +import org.apache.spark.api.JavaDoubleRDD; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +JavaSparkContext jsc = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +JavaDoubleRDD u = normalJavaRDD(jsc, 1000000L, 10); +// Apply a transform to get a random double RDD following `N(1, 4)`. +JavaDoubleRDD v = u.map( + new Function() { + public Double call(Double x) { + return 1.0 + 2.0 * x; + } + }); +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight python %} +from pyspark.mllib.random import RandomRDDs + +sc = ... # SparkContext + +# Generate a random double RDD that contains 1 million i.i.d. values drawn from the +# standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +# Apply a transform to get a random double RDD following `N(1, 4)`. +v = u.map(lambda x: 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    ## Stratified Sampling diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9cab49f6ed1f0..28179fbc450c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,14 +20,14 @@ package org.apache.spark.mllib.random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** - * :: Experimental :: + * :: DeveloperApi :: * Trait for random data generators that generate i.i.d. data. */ -@Experimental +@DeveloperApi trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** @@ -43,10 +43,10 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from U[0.0, 1.0] */ -@Experimental +@DeveloperApi class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -62,10 +62,10 @@ class UniformGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the standard normal distribution. */ -@Experimental +@DeveloperApi class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -81,12 +81,12 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the Poisson distribution with the given mean. * * @param mean mean for the Poisson distribution. */ -@Experimental +@DeveloperApi class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 36270369526cd..c5f4b084321f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,9 +20,10 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} +import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -34,335 +35,279 @@ import org.apache.spark.util.Utils object RandomRDDs { /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. + * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. + * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use + * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { + def uniformRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitions, seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - uniformRDD(sc, size, numPartitions, Utils.random.nextLong) + def uniformJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. + * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = { - uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { - val normal = new StandardNormalGenerator() - randomRDD(sc, normal, size, numPartitions, seed) + def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } /** - * :: Experimental :: * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. * * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. + * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. * * @param sc SparkContext used to create the RDD. * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - normalRDD(sc, size, numPartitions, Utils.random.nextLong) + def normalRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val normal = new StandardNormalGenerator() + randomRDD(sc, normal, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + * Java-friendly version of [[RandomRDDs#normalRDD]]. */ - @Experimental - def normalRDD(sc: SparkContext, size: Long): RDD[Double] = { - normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) + def normalJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#normalJavaRDD]] with the default seed. */ - @Experimental - def poissonRDD(sc: SparkContext, - mean: Double, - size: Long, - numPartitions: Int, - seed: Long): RDD[Double] = { - val poisson = new PoissonGenerator(mean) - randomRDD(sc, poisson, size, numPartitions, seed) + def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = { - poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong) + def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } /** - * :: Experimental :: * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. * * @param sc SparkContext used to create the RDD. * @param mean Mean, or lambda, for the Poisson distribution. * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = { - poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong) + def poissonRDD( + sc: SparkContext, + mean: Double, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val poisson = new PoissonGenerator(mean) + randomRDD(sc, poisson, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, size: Long, numPartitions: Int, - seed: Long): RDD[T] = { - new RandomRDD[T](sc, size, numPartitions, generator, seed) + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions, seed)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. + * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, size: Long, - numPartitions: Int): RDD[T] = { - randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions)) } /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. + * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. + */ + def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) + } + + /** + * :: DeveloperApi :: + * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. + * @param generator RandomDataGenerator used to populate the RDD. * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, + @DeveloperApi + def randomRDD[T: ClassTag]( + sc: SparkContext, generator: RandomDataGenerator[T], - size: Long): RDD[T] = { - randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[T] = { + new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } // TODO Generate RDD[Vector] from multivariate distributions. /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. + * uniform distribution on `U(0.0, 1.0)`. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD. * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. + * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, + def uniformVectorRDD( + sc: SparkContext, numRows: Long, numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { val uniform = new UniformGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + randomVectorRDD(sc, uniform, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. + * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, + def uniformJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. + * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ - @Experimental - def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. + */ + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * standard normal distribution. * * @param sc SparkContext used to create the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + */ + def normalVectorRDD( + sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + val normal = new StandardNormalGenerator() + randomVectorRDD(sc, normal, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ - @Experimental - def normalVectorRDD(sc: SparkContext, + def normalJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, numPartitions: Int, - seed: Long): RDD[Vector] = { - val uniform = new StandardNormalGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) + seed: Long): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ - @Experimental - def normalVectorRDD(sc: SparkContext, + def normalJavaVectorRDD( + jsc: JavaSparkContext, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). + * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ - @Experimental - def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + def normalJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the * Poisson distribution with the input mean. * @@ -370,124 +315,85 @@ object RandomRDDs { * @param mean Mean, or lambda, for the Poisson distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) + * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonVectorRDD( + sc: SparkContext, mean: Double, numRows: Long, numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { val poisson = new PoissonGenerator(mean) - randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed) + randomVectorRDD(sc, poisson, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonJavaVectorRDD( + jsc: JavaSparkContext, mean: Double, numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions, seed).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ - @Experimental - def poissonVectorRDD(sc: SparkContext, + def poissonJavaVectorRDD( + jsc: JavaSparkContext, mean: Double, numRows: Long, - numCols: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions).toJavaRDD() } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], + def poissonJavaVectorRDD( + jsc: JavaSparkContext, + mean: Double, numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed) + numCols: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols).toJavaRDD() } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. + * input RandomDataGenerator. * * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. + * @param generator RandomDataGenerator used to populate the RDD. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. */ - @Experimental + @DeveloperApi def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, - numPartitions: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong) + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + new RandomVectorRDD( + sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], - numRows: Long, - numCols: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, - sc.defaultParallelism, Utils.random.nextLong) + private def numPartitionsOrDefault(sc: SparkContext, numPartitions: Int): Int = { + if (numPartitions > 0) numPartitions else sc.defaultMinPartitions } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java new file mode 100644 index 0000000000000..a725736ca1a58 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random; + +import com.google.common.collect.Lists; +import org.apache.spark.api.java.JavaRDD; +import org.junit.Assert; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +public class JavaRandomRDDsSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testUniformRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testNormalRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testPoissonRDD() { + double mean = 2.0; + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testUniformVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testNormalVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testPoissonVectorRDD() { + double mean = 2.0; + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); + JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); + JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } +} diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 3f3b19053d32e..4dc1a4a912421 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -35,10 +35,10 @@ class RandomRDDs: def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the - uniform distribution on [0.0, 1.0]. + uniform distribution U(0.0, 1.0). - To transform the distribution in the generated RDD from U[0.0, 1.0] - to U[a, b], use + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} @@ -60,11 +60,11 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): @staticmethod def normalRDD(sc, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the standard normal + Generates an RDD comprised of i.i.d. samples from the standard normal distribution. To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma), use + to some other normal N(mean, sigma^2), use C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} @@ -84,7 +84,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): @staticmethod def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the Poisson + Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. >>> mean = 100.0 @@ -105,8 +105,8 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): @staticmethod def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn - from the uniform distribution on [0.0 1.0]. + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). >>> import numpy as np >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) @@ -125,7 +125,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. >>> import numpy as np @@ -145,7 +145,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. >>> import numpy as np