diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5f04b369a6a60..5e88109b5ffb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -340,8 +340,8 @@ object DecisionTree extends Serializable with Logging { } throw new UnknownError("no bin was found for continuous variable.") } else { - - for (binIndex <- 0 until strategy.numBins) { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + for (binIndex <- 0 until numCategoricalBins) { val bin = bins(featureIndex)(binIndex) val category = bin.category val features = labeledPoint.features @@ -917,13 +917,6 @@ object DecisionTree extends Serializable with Logging { bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } else { - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - for (i <- maxFeatureValue until numBins){ - bins(featureIndex)(i) - = new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - new DummyCategoricalSplit(featureIndex, Categorical), Categorical, Double.MaxValue) - } } } (splits,bins) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a299b087dfda8..f8914e03bd12f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -64,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 2, + 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) assert(splits.length==2) assert(bins.length==2) @@ -120,7 +121,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(2).category == Double.MaxValue) + assert(bins(0)(2) == null) assert(bins(1)(0).category == 0.0) assert(bins(1)(0).lowSplit.categories.length == 0) @@ -134,7 +135,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - assert(bins(1)(2).category == Double.MaxValue) + assert(bins(1)(2) == null) } @@ -142,7 +143,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) //Checking splits @@ -217,7 +219,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0)(2).highSplit.categories.contains(0.0)) assert(bins(0)(2).highSplit.categories.contains(2.0)) - assert(bins(0)(3).category == Double.MaxValue) + assert(bins(0)(3) == null) assert(bins(1)(0).category == 0.0) assert(bins(1)(0).lowSplit.categories.length == 0) @@ -240,7 +242,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(2.0)) - assert(bins(1)(3).category == Double.MaxValue) + assert(bins(1)(3) == null) } @@ -249,10 +251,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 assert(split.categories.length == 1) @@ -272,10 +276,12 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length == 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, + 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) val split = bestSplits(0)._1 assert(split.categories.length == 1) @@ -305,7 +311,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -329,7 +336,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -355,7 +363,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold) @@ -379,7 +388,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length==100) strategy.numBins = 100 - val bestSplits = DecisionTree.findBestSplits(rdd,Array(0.0),strategy,0,Array[List[Filter]](),splits,bins) + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) assert(bestSplits.length == 1) assert(0==bestSplits(0)._1.feature) assert(10==bestSplits(0)._1.threshold)