Skip to content

Commit

Permalink
added multiclass support
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed May 12, 2014
1 parent 6c7af22 commit 5c78e1a
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 56 deletions.
64 changes: 37 additions & 27 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -681,36 +681,47 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double): InformationGainStats = {
strategy.algo match {
case Classification =>
// TODO: Modify here
val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0)
val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1)
val leftCount = left0Count + left1Count

val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0)
val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1)
val rightCount = right0Count + right1Count
var classIndex = 0
val leftCounts: Array[Double] = new Array[Double](numClasses)
val rightCounts: Array[Double] = new Array[Double](numClasses)
var leftTotalCount = 0.0
var rightTotalCount = 0.0
while (classIndex < numClasses) {
val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
leftCounts(classIndex) = leftClassCount
leftTotalCount += leftClassCount
rightCounts(classIndex) = rightClassCount
rightTotalCount += rightClassCount
classIndex += 1
}

val impurity = {
if (level > 0) {
topImpurity
} else {
// Calculate impurity for root node.
strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
val rootNodeCounts = new Array[Double](numClasses)
var classIndex = 0
while (classIndex < numClasses) {
rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
}
strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
}
}

if (leftCount == 0) {
if (leftTotalCount == 0) {
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
}
if (rightCount == 0) {
if (rightTotalCount == 0) {
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
}

val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)

val leftWeight = leftCount.toDouble / (leftCount + rightCount)
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
val leftWeight = leftTotalCount.toDouble / (leftTotalCount + rightTotalCount)
val rightWeight = rightTotalCount.toDouble / (leftTotalCount + rightTotalCount)

val gain = {
if (level > 0) {
Expand All @@ -720,7 +731,8 @@ object DecisionTree extends Serializable with Logging {
}
}

val predict = (left1Count + right1Count) / (leftCount + rightCount)
//TODO: Make modification here
val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount)

new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
case Regression =>
Expand Down Expand Up @@ -782,7 +794,6 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
strategy.algo match {
case Classification =>
// TODO: Multiclass modification here

// Initialize left and right split aggregates.
val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
Expand All @@ -793,17 +804,19 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures){
val numCategories = strategy.categoricalFeaturesInfo(featureIndex)
val maxSplits = math.pow(2, numCategories) - 1
var i = 0
// TODO: Add multiclass case here
while (i < maxSplits) {
var splitIndex = 0
while (splitIndex < maxSplits) {
var classIndex = 0
while (classIndex < numClasses) {
// shift for this featureIndex
val shift = numClasses * featureIndex * numBins

leftNodeAgg(featureIndex)(splitIndex)(classIndex)
= binData(shift + classIndex)
rightNodeAgg(featureIndex)(splitIndex)(classIndex)
= binData(shift + numClasses + classIndex)
classIndex += 1
}
i += 1
splitIndex += 1
}
featureIndex += 1
}
Expand Down Expand Up @@ -931,8 +944,6 @@ object DecisionTree extends Serializable with Logging {
binData: Array[Double],
nodeImpurity: Double): (Split, InformationGainStats) = {

// TODO: Multiclass modification here

logDebug("node impurity = " + nodeImpurity)

// Extract left right node aggregates.
Expand Down Expand Up @@ -977,9 +988,8 @@ object DecisionTree extends Serializable with Logging {
def getBinDataForNode(node: Int): Array[Double] = {
strategy.algo match {
case Classification =>
// TODO: Multiclass modification here
val shift = 2 * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
val shift = numClasses * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
binsForNode
case Regression =>
val shift = 3 * node * numBins * numFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,22 @@ object Entropy extends Impurity {

/**
* :: DeveloperApi ::
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
-(f0 * log2(f0)) - (f1 * log2(f1))
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 0.0
var classIndex = 0
while (classIndex < numClasses) {
val freq = counts(classIndex) / totalCount
impurity -= freq * log2(freq)
classIndex += 1
}
impurity
}

override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,22 @@ object Gini extends Impurity {

/**
* :: DeveloperApi ::
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return Gini coefficient value
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0 * f0 - f1 * f1
override def calculate(counts: Array[Double], totalCount: Double): Double = {
val numClasses = counts.length
var impurity = 1.0
var classIndex = 0
while (classIndex < numClasses) {
val freq = counts(classIndex) / totalCount
impurity -= freq * freq
classIndex += 1
}
impurity
}

override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ trait Impurity extends Serializable {

/**
* :: DeveloperApi ::
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
def calculate(c0 : Double, c1 : Double): Double
def calculate(counts: Array[Double], totalCount: Double): Double

/**
* :: DeveloperApi ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
*/
@Experimental
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =
override def calculate(counts: Array[Double], totalCounts: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
Expand Down

0 comments on commit 5c78e1a

Please sign in to comment.