Skip to content

Commit

Permalink
regression predict logic
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 53108ed commit 6df35b9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
26 changes: 11 additions & 15 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
topNode.build(nodes)

val decisionTreeModel = {
return new DecisionTreeModel(topNode)
return new DecisionTreeModel(topNode, strategy.algo)
}

return decisionTreeModel
Expand All @@ -98,14 +98,8 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
val predict = {
val leftSamples = nodeSplitStats._2.leftSamples.toDouble
val rightSamples = nodeSplitStats._2.rightSamples.toDouble
val totalSamples = leftSamples + rightSamples
leftSamples / totalSamples
}
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
}
Expand Down Expand Up @@ -370,8 +364,8 @@ object DecisionTree extends Serializable with Logging {

val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)

if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0)

val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
Expand All @@ -387,7 +381,9 @@ object DecisionTree extends Serializable with Logging {
}
}

new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
val predict = leftCount / (leftCount + rightCount)

new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
}
case Regression => {
val leftCount = leftNodeAgg(featureIndex)(3 * index)
Expand All @@ -400,8 +396,8 @@ object DecisionTree extends Serializable with Logging {

val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)

if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,rightSum/rightCount)
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,leftSum/leftCount)

val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
Expand All @@ -417,7 +413,7 @@ object DecisionTree extends Serializable with Logging {
}
}

new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum)/(leftCount+rightCount))

}
}
Expand Down Expand Up @@ -515,7 +511,7 @@ object DecisionTree extends Serializable with Logging {
var bestFeatureIndex = 0
var bestSplitIndex = 0
//Initialization with infeasible values
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0)
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1)
for (featureIndex <- 0 until numFeatures) {
for (splitIndex <- 0 until numSplits - 1){
val gainStats = gains(featureIndex)(splitIndex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@
package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._

class DecisionTreeModel(val topNode : Node) extends Serializable {
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {

def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
def predict(features : Array[Double]) = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
}
case Regression => {
topNode.predictIfLeaf(features)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.mllib.tree.model
class InformationGainStats(val gain : Double,
val impurity: Double,
val leftImpurity : Double,
val leftSamples : Long,
//val leftSamples : Long,
val rightImpurity : Double,
val rightSamples : Long) extends Serializable {
//val rightSamples : Long
val predict : Double) extends Serializable {

override def toString =
"gain = " + gain + ", impurity = " + impurity + ", left impurity = "
+ leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = "
+ rightImpurity + ", rightSamples = " + rightSamples
+ leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict


}

0 comments on commit 6df35b9

Please sign in to comment.