Skip to content

Commit

Permalink
Merge pull request #4 from mengxr/dtree
Browse files Browse the repository at this point in the history
another pass on code style
  • Loading branch information
manishamde committed Mar 31, 2014
2 parents e1dd86f + f536ae9 commit 7d54b4f
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable {
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable {

var numBins: Int = Int.MinValue

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,30 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating the [[http://en.wikipedia
* .org/wiki/Decision_tree_learning#Gini_impurity]] during binary classification
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
*/
object Gini extends Impurity {

/**
* gini coefficient calculation
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return gini coefficient value
* @return Gini coefficient value
*/
def calculate(c0 : Double, c1 : Double): Double = {
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0*f0 - f1*f1
1 - f0 * f0 - f1 * f1
}
}

def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.mllib.tree.impurity

/**
* Trail for calculating information gain
* Trait for calculating information gain.
*/
trait Impurity extends Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,21 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating variance during regression
*/
object Variance extends Impurity {
def calculate(c0: Double, c1: Double): Double
= throw new UnsupportedOperationException("Variance.calculate")
override def calculate(c0: Double, c1: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum*sum)/count
squaredLoss/count
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,4 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) {

}
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,4 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
def predict(features: RDD[Array[Double]]): RDD[Double] = {
features.map(x => predict(x))
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,4 @@ class InformationGainStats(
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Node (
val split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
val stats: Option[InformationGainStats]) extends Serializable with Logging{
val stats: Option[InformationGainStats]) extends Serializable with Logging {

override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
"split = " + split + ", stats = " + stats
Expand All @@ -46,7 +46,7 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
def build(nodes : Array[Node]): Unit = {
def build(nodes: Array[Node]): Unit = {

logDebug("building node " + id + " at level " +
(scala.math.log(id + 1)/scala.math.log(2)).toInt )
Expand All @@ -68,7 +68,7 @@ class Node (
* @param feature feature value
* @return predicted value
*/
def predictIfLeaf(feature : Array[Double]) : Double = {
def predictIfLeaf(feature: Array[Double]) : Double = {
if (isLeaf) {
predict
} else{
Expand All @@ -87,5 +87,4 @@ class Node (
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ case class Split(
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType : FeatureType)
class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())

/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType : FeatureType)
class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

/**
Expand All @@ -59,6 +59,6 @@ class DummyHighSplit(feature: Int, featureType : FeatureType)
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType : FeatureType)
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

Loading

0 comments on commit 7d54b4f

Please sign in to comment.