Skip to content

Commit

Permalink
[SPARK-12379][ML][MLLIB] Copy GBT implementation to spark.ml
Browse files Browse the repository at this point in the history
Currently, GBTs in spark.ml wrap the implementation in spark.mllib. This is preventing several improvements to GBTs in spark.ml, so we need to move the implementation to ml and use spark.ml decision trees in the implementation. At first, we should make minimal changes to the implementation.
Performance testing should be done to ensure there were no regressions.

Performance testing results are [here](https://docs.google.com/document/d/1dYd2mnfGdUKkQ3vZe2BpzsTnI5IrpSLQ-NNKDZhUkgw/edit?usp=sharing)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #10607 from sethah/SPARK-12379.
  • Loading branch information
sethah authored and MLnick committed Mar 15, 2016
1 parent 10251a7 commit dafd70f
Show file tree
Hide file tree
Showing 12 changed files with 306 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ final class DecisionTreeClassifier @Since("1.4.0") (
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
Expand Down Expand Up @@ -158,9 +158,8 @@ final class GBTClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}

@Since("1.4.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}

/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
TreeRegressorParams}
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
SquaredError => OldSquaredError}
Expand Down Expand Up @@ -145,9 +145,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(boostingStrategy)
val oldModel = oldGBT.run(oldDataset)
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy)
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}

@Since("1.4.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.tree.impl

import org.apache.spark.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

private[ml] object GradientBoostedTrees extends Logging {

/**
* Method to train a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def run(input: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.")
}
}

/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @param validationInput Validation dataset.
* This dataset should be different from the training dataset,
* but it should follow the same distribution.
* E.g., these two datasets could be created from an original dataset
* by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def runWithValidation(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy
): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
case OldAlgo.Regression =>
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true)
case OldAlgo.Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
val remappedValidationInput = validationInput.map(
x => new LabeledPoint((x.label * 2) - 1, x.features))
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
validate = true)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
}

/**
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
* @param initTreeWeight: learning rate assigned to the first tree.
* @param initTree: first DecisionTreeModel.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to every sample.
*/
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
val error = loss.computeError(pred, lp.label)
(pred, error)
}
}

/**
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
* @param predictionAndError: predictionError RDD
* @param treeWeight: Learning rate.
* @param tree: Tree using which the prediction and error should be updated.
* @param loss: evaluation metric.
* @return a RDD with each element being a zip of the prediction and error
* corresponding to each sample.
*/
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
treeWeight: Double,
tree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {

val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
}
newPredError
}

/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
* @param validate whether or not to use the validation dataset.
* @return tuple of ensemble models and weights:
* (array of decision tree models, array of model weights)
*/
def boost(
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint],
boostingStrategy: OldBoostingStrategy,
validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) = {
val timer = new TimeTracker()
timer.start("total")
timer.start("init")

boostingStrategy.assertValid()

// Initialize gradient boosting parameters
val numIterations = boostingStrategy.numIterations
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression
treeStrategy.impurity = OldVariance
treeStrategy.assertValid()

// Cache input
val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
input.persist(StorageLevel.MEMORY_AND_DISK)
true
} else {
false
}

// Prepare periodic checkpointers
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)

timer.stop("init")

logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")

// Initialize tree
timer.start("building tree 0")
val firstTree = new DecisionTreeRegressor()
val firstTreeModel = firstTree.train(input, treeStrategy)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight

var predError: RDD[(Double, Double)] =
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())

// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")

var validatePredError: RDD[(Double, Double)] =
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1

var m = 1
var doneLearning = false
while (m < numIterations && !doneLearning) {
// Update data with pseudo-residuals
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}

timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val dt = new DecisionTreeRegressor()
val model = dt.train(data, treeStrategy)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate

predError = updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())

if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.

validatePredError = updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
bestM = m + 1
}
}
m += 1
}

timer.stop("total")

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

predErrorCheckpointer.deleteAllCheckpoints()
validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist()

if (validate) {
(baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
} else {
(baseLearners, baseLearnerWeights)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ import org.apache.spark.storage.StorageLevel
*
* TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicRDDCheckpointer[T](
private[spark] class PeriodicRDDCheckpointer[T](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ case class BoostingStrategy @Since("1.4.0") (
* Check validity of parameters.
* Throws exception if invalid.
*/
private[tree] def assertValid(): Unit = {
private[spark] def assertValid(): Unit = {
treeStrategy.algo match {
case Classification =>
require(treeStrategy.numClasses == 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Strategy @Since("1.3.0") (
* Check validity of parameters.
* Throws exception if invalid.
*/
private[tree] def assertValid(): Unit = {
private[spark] def assertValid(): Unit = {
algo match {
case Classification =>
require(numClasses >= 2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}

override private[mllib] def computeError(prediction: Double, label: Double): Double = {
override private[spark] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
Expand Down
Loading

1 comment on commit dafd70f

@vlad17
Copy link

@vlad17 vlad17 commented on dafd70f Aug 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change, and especially for the performance test.

Would you mind posting the benchmark test file and invocation that you used as well?

Please sign in to comment.