diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 975dd4f0cd7e7..e8adef377481c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -87,7 +87,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { topNode.build(nodes) val decisionTreeModel = { - return new DecisionTreeModel(topNode) + return new DecisionTreeModel(topNode, strategy.algo) } return decisionTreeModel @@ -98,14 +98,8 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = scala.math.pow(2, level).toInt - 1 + index - val predict = { - val leftSamples = nodeSplitStats._2.leftSamples.toDouble - val rightSamples = nodeSplitStats._2.rightSamples.toDouble - val totalSamples = leftSamples + rightSamples - leftSamples / totalSamples - } val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) - val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats)) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) nodes(nodeIndex) = node } @@ -370,8 +364,8 @@ object DecisionTree extends Serializable with Logging { val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) @@ -387,7 +381,9 @@ object DecisionTree extends Serializable with Logging { } } - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + val predict = leftCount / (leftCount + rightCount) + + new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } case Regression => { val leftCount = leftNodeAgg(featureIndex)(3 * index) @@ -400,8 +396,8 @@ object DecisionTree extends Serializable with Logging { val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares) - if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong) - if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0) + if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,rightSum/rightCount) + if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,leftSum/leftCount) val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) @@ -417,7 +413,7 @@ object DecisionTree extends Serializable with Logging { } } - new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong) + new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum)/(leftCount+rightCount)) } } @@ -515,7 +511,7 @@ object DecisionTree extends Serializable with Logging { var bestFeatureIndex = 0 var bestSplitIndex = 0 //Initialization with infeasible values - var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0) + var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gainStats = gains(featureIndex)(splitIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 1d7c03289c407..587e549c34ca8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,9 +17,19 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ -class DecisionTreeModel(val topNode : Node) extends Serializable { +class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable { - def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + def predict(features : Array[Double]) = { + algo match { + case Classification => { + if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + } + case Regression => { + topNode.predictIfLeaf(features) + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 60a4f99b7f806..b992684b2b05b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.tree.model class InformationGainStats(val gain : Double, val impurity: Double, val leftImpurity : Double, - val leftSamples : Long, + //val leftSamples : Long, val rightImpurity : Double, - val rightSamples : Long) extends Serializable { + //val rightSamples : Long + val predict : Double) extends Serializable { override def toString = "gain = " + gain + ", impurity = " + impurity + ", left impurity = " - + leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = " - + rightImpurity + ", rightSamples = " + rightSamples + + leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict }