Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <manish9ue@gmail.com>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 6b7de78 commit b09dc98
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,18 +367,18 @@ object DecisionTree extends Serializable with Logging {

def calculateGainForSplit(leftNodeAgg: Array[Array[Double]],
featureIndex: Int,
index: Int,
splitIndex: Int,
rightNodeAgg: Array[Array[Double]],
topImpurity: Double) : InformationGainStats = {
strategy.algo match {
case Classification => {

val left0Count = leftNodeAgg(featureIndex)(2 * index)
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
val leftCount = left0Count + left1Count

val right0Count = rightNodeAgg(featureIndex)(2 * index)
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
val rightCount = right0Count + right1Count

val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
Expand All @@ -405,13 +405,13 @@ object DecisionTree extends Serializable with Logging {
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
}
case Regression => {
val leftCount = leftNodeAgg(featureIndex)(3 * index)
val leftSum = leftNodeAgg(featureIndex)(3 * index + 1)
val leftSumSquares = leftNodeAgg(featureIndex)(3 * index + 2)
val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)

val rightCount = rightNodeAgg(featureIndex)(3 * index)
val rightSum = rightNodeAgg(featureIndex)(3 * index + 1)
val rightSumSquares = rightNodeAgg(featureIndex)(3 * index + 2)
val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)

val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)

Expand Down Expand Up @@ -463,9 +463,9 @@ object DecisionTree extends Serializable with Logging {
leftNodeAgg(featureIndex)(2 * splitIndex + 1)
= binData(shift + 2*splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
= binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
= binData(shift + (2 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1)
= binData(shift + (2 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
= binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
}
}
(leftNodeAgg, rightNodeAgg)
Expand All @@ -490,11 +490,11 @@ object DecisionTree extends Serializable with Logging {
leftNodeAgg(featureIndex)(3 * splitIndex + 2)
= binData(shift + 3*splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
= binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
= binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1)
= binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
= binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2)
= binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
= binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
}
}
(leftNodeAgg, rightNodeAgg)
Expand All @@ -508,9 +508,9 @@ object DecisionTree extends Serializable with Logging {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)

for (featureIndex <- 0 until numFeatures) {
for (index <- 0 until numBins -1) {
for (splitIndex <- 0 until numBins -1) {
//logDebug("splitIndex = " + index)
gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity)
}
}
gains
Expand Down Expand Up @@ -544,6 +544,8 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureIndex,bestSplitIndex,bestGainStats)
}

logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
(splits(bestFeatureIndex)(bestSplitIndex),gainStats)
}

Expand Down Expand Up @@ -614,13 +616,14 @@ object DecisionTree extends Serializable with Logging {

//Find all splits
for (featureIndex <- 0 until numFeatures){
val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinous) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted

val stride : Double = numSamples.toDouble/numBins
logDebug("stride = " + stride)
for (index <- 0 until numBins-1) {
//TODO: Investigate this
val sampleIndex = (index+1)*stride.toInt
val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List())
splits(featureIndex)(index) = split
Expand Down

0 comments on commit b09dc98

Please sign in to comment.