diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ddb78d3903049..43ede29ef6fd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -121,7 +121,7 @@ object DecisionTree extends Serializable { /*Finds the right bin for the given feature*/ def findBin(featureIndex: Int, labeledPoint: LabeledPoint) : Int = { - println("finding bin for labeled point " + labeledPoint.features(featureIndex)) + //println("finding bin for labeled point " + labeledPoint.features(featureIndex)) //TODO: Do binary search for (binIndex <- 0 until strategy.numSplits) { val bin = bins(featureIndex)(binIndex) @@ -227,7 +227,7 @@ object DecisionTree extends Serializable { val binAggregates = binMappedRDD.aggregate(Array.fill[Double](2*numSplits*numFeatures*numNodes)(0))(binSeqOp,binCombOp) println("binAggregates.length = " + binAggregates.length) - binAggregates.foreach(x => println(x)) + //binAggregates.foreach(x => println(x)) def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, index: Int, rightNodeAgg: Array[Array[Double]], topImpurity: Double): Double = { @@ -235,13 +235,19 @@ object DecisionTree extends Serializable { val left0Count = leftNodeAgg(featureIndex)(2 * index) val left1Count = leftNodeAgg(featureIndex)(2 * index + 1) val leftCount = left0Count + left1Count - println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) + + if (leftCount == 0) return 0 + + //println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount) val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) val right0Count = rightNodeAgg(featureIndex)(2 * index) val right1Count = rightNodeAgg(featureIndex)(2 * index + 1) val rightCount = right0Count + right1Count - println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) + + if (rightCount == 0) return 0 + + //println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount) val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) val leftWeight = leftCount.toDouble / (leftCount + rightCount) @@ -261,21 +267,21 @@ object DecisionTree extends Serializable { def extractLeftRightNodeAggregates(binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numSplits - 1)) - println("binData.length = " + binData.length) - println("binData.sum = " + binData.sum) + //println("binData.length = " + binData.length) + //println("binData.sum = " + binData.sum) for (featureIndex <- 0 until numFeatures) { - println("featureIndex = " + featureIndex) + //println("featureIndex = " + featureIndex) val shift = 2*featureIndex*numSplits leftNodeAgg(featureIndex)(0) = binData(shift + 0) - println("binData(shift + 0) = " + binData(shift + 0)) + //println("binData(shift + 0) = " + binData(shift + 0)) leftNodeAgg(featureIndex)(1) = binData(shift + 1) - println("binData(shift + 1) = " + binData(shift + 1)) + //println("binData(shift + 1) = " + binData(shift + 1)) rightNodeAgg(featureIndex)(2 * (numSplits - 2)) = binData(shift + (2 * (numSplits - 1))) - println(binData(shift + (2 * (numSplits - 1)))) + //println(binData(shift + (2 * (numSplits - 1)))) rightNodeAgg(featureIndex)(2 * (numSplits - 2) + 1) = binData(shift + (2 * (numSplits - 1)) + 1) - println(binData(shift + (2 * (numSplits - 1)) + 1)) + //println(binData(shift + (2 * (numSplits - 1)) + 1)) for (splitIndex <- 1 until numSplits - 1) { - println("splitIndex = " + splitIndex) + //println("splitIndex = " + splitIndex) leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) leftNodeAgg(featureIndex)(2 * splitIndex + 1) @@ -295,7 +301,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (index <- 0 until numSplits -1) { - println("splitIndex = " + index) + //println("splitIndex = " + index) gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity) } } @@ -312,8 +318,8 @@ object DecisionTree extends Serializable { val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - println("gains.size = " + gains.size) - println("gains(0).size = " + gains(0).size) + //println("gains.size = " + gains.size) + //println("gains(0).size = " + gains(0).size) val (bestFeatureIndex,bestSplitIndex) = { var bestFeatureIndex = 0 @@ -322,7 +328,7 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures) { for (splitIndex <- 0 until numSplits - 1){ val gain = gains(featureIndex)(splitIndex) - println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) + //println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain) if(gain > maxGain) { maxGain = gain bestFeatureIndex = featureIndex @@ -335,6 +341,8 @@ object DecisionTree extends Serializable { } splits(bestFeatureIndex)(bestSplitIndex) + + //TODo: Return array of node stats with split and impurity information } //Calculate best splits for all nodes at a given level @@ -388,6 +396,9 @@ object DecisionTree extends Serializable { for (featureIndex <- 0 until numFeatures){ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride : Double = numSamples.toDouble/numSplits + + println("stride = " + stride) + for (index <- 0 until numSplits-1) { val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),"continuous") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e886c40901b45..2c9794371eb29 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext._ import org.jblas._ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impurity.Gini +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini} import org.apache.spark.mllib.tree.model.Filter class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { @@ -44,7 +44,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { } test("split and bin calculation"){ - val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy("regression",Gini,3,100,"sort") @@ -56,8 +56,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { println(splits(1)(98)) } - test("stump"){ - val arr = DecisionTreeSuite.generateReverseOrderedLabeledPoints() + test("stump with fixed label 0 for Gini"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length == 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy("regression",Gini,3,100,"sort") @@ -69,17 +69,85 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length==99) assert(bins(0).length==100) println(splits(1)(98)) - DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) } + test("stump with fixed label 1 for Gini"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Gini,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + + test("stump with fixed label 0 for Entropy"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Entropy,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + test("stump with fixed label 1 for Entropy"){ + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length == 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy("regression",Entropy,3,100,"sort") + val (splits, bins) = DecisionTree.find_splits_bins(rdd,strategy) + assert(splits.length==2) + assert(splits(0).length==99) + assert(bins.length==2) + assert(bins(0).length==100) + assert(splits(0).length==99) + assert(bins(0).length==100) + println(splits(1)(98)) + val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + assert(bestSplits.length == 1) + println(bestSplits(0)) + } + + } object DecisionTreeSuite { - def generateReverseOrderedLabeledPoints() : Array[LabeledPoint] = { + def generateOrderedLabeledPointsWithLabel0() : Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) + arr(i) = lp + } + arr + } + + + def generateOrderedLabeledPointsWithLabel1() : Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ - val lp = new LabeledPoint(1.0,Array(i.toDouble,1000.0-i)) + val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) arr(i) = lp } arr