Skip to content

Commit

Permalink
fixed off-by-one error in bin to split conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
manishamde committed Jul 9, 2014
1 parent 9cc3e31 commit 06b1690
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,7 @@ object DecisionTree extends Serializable with Logging {
= binData(shift + numClasses * splitIndex + innerClassIndex) +
leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) +
binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
innerClassIndex += 1
}
Expand Down Expand Up @@ -1077,13 +1077,13 @@ object DecisionTree extends Serializable with Logging {
// calculating right node aggregate for a split as a sum of right node aggregate of a
// higher split and the right bin aggregate of a bin where the split is a low split
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
binData(shift + (3 * (numBins - 2 - splitIndex))) +
binData(shift + (3 * (numBins - 1 - splitIndex))) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) =
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2)

splitIndex += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict === 0)
assert(stats.prob > 0.5)
assert(stats.prob < 0.6)
assert(stats.predict === 1)
assert(stats.prob == 0.6)
assert(stats.impurity > 0.2)
}

Expand All @@ -440,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {

val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.predict == 0.6)
assert(stats.impurity > 0.2)
}

Expand Down Expand Up @@ -657,7 +655,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val bestSplit = bestSplits(0)._1
assert(bestSplit.feature === 0)
assert(bestSplit.categories.length === 1)
println(bestSplit)
assert(bestSplit.categories.contains(1.0))
assert(bestSplit.featureType === Categorical)
}
Expand Down

0 comments on commit 06b1690

Please sign in to comment.