diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a9a578c4ac262..89a3f6de4fcb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -29,9 +29,21 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ +/* +A class that implements a decision tree algorithm for classification and regression. +It supports both continuous and categorical features. +@param strategy The configuration parameters for the tree algorithm which specify the type of algorithm (classification, +regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. + */ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { + /* + Method to train a decision tree model over an RDD + + @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree + @return a DecisionTreeModel that can be used for prediction + */ def train(input : RDD[LabeledPoint]) : DecisionTreeModel = { //Cache input RDD for speedup during multiple passes