Skip to content

Commit

Permalink
Merge pull request #571 from bgreenwell/master
Browse files Browse the repository at this point in the history
New deforest() function for removing trees from a fitted random forest
  • Loading branch information
mnwright authored Nov 12, 2021
2 parents b149f50 + 2cedf12 commit 9aac567
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 0 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Generated by roxygen2: do not edit by hand

S3method(deforest,ranger)
S3method(importance,ranger)
S3method(predict,ranger)
S3method(predict,ranger.forest)
S3method(predictions,ranger)
S3method(predictions,ranger.prediction)
S3method(print,deforest.ranger)
S3method(print,ranger)
S3method(print,ranger.forest)
S3method(print,ranger.prediction)
S3method(timepoints,ranger)
S3method(timepoints,ranger.prediction)
export(csrf)
export(deforest)
export(getTerminalNodeIDs)
export(holdoutRF)
export(importance)
Expand Down
175 changes: 175 additions & 0 deletions R/deforest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#' Deforesting a random forest
#'
#' The main purpose of this function is to allow for post-processing of
#' ensembles via L2 regularized regression (i.e., the LASSO), as described in
#' Friedman and Popescu (2003). The basic idea is to use the LASSO to
#' post-process the predictions from the individual base learners in an ensemble
#' (i.e., decision trees) in the hopes of producing a much smaller model without
#' sacrificing much in the way of accuracy, and in some cases, improving it.
#' Friedman and Popescu (2003) describe conditions under which tree-based
#' ensembles, like random forest, can potentially benefit from such
#' post-processing (e.g., using shallower trees trained on much smaller samples
#' of the training data without replacement). However, the computational
#' benefits of such post-processing can only be realized if the base learners
#' "zeroed out" by the LASSO can actually be removed from the original ensemble,
#' hence the purpose of this function. A complete example using
#' \code{\link{ranger}} can be found at
#' \url{https://github.com/imbs-hl/ranger/issues/568}.
#'
#' @param object A fitted random forest (e.g., a \code{\link{ranger}}
#' object).
#'
#' @param which.trees Vector giving the indices of the trees to remove.
#'
#' @param warn Logical indicating whether or not to warn users that some of the
#' standard output of a typical \code{\link{ranger}} object or no longer
#' available after deforestation. Default is \code{TRUE}.
#'
#' @param ... Additional (optional) arguments. (Currently ignored.)
#'
#' @return An object of class \code{"deforest.ranger"}; essentially, a
#' \code{\link{ranger}} object with certain components replaced with
#' \code{NA}s (e.g., out-of-bag (OOB) predictions, variable importance scores
#' (if requested), and OOB-based error metrics).
#'
#' @note This function is a generic and can be extended by other packages.
#'
#' @references
#' Friedman, J. and Popescu, B. (2003). Importance sampled learning ensembles,
#' Technical report, Stanford University, Department of Statistics.
#' \url{https://statweb.stanford.edu/~jhf/ftp/isle.pdf}.
#'
#' @rdname deforest
#'
#' @export
#'
#' @author Brandon M. Greenwell
#'
#' @examples
#' ## Example of deforesting a random forest
#' rfo <- ranger(Species ~ ., data = iris, probability = TRUE, num.trees = 100)
#' dfo <- deforest(rfo, which.trees = c(1, 3, 5))
#' dfo # same as `rfo` but with trees 1, 3, and 5 removed
#'
#' ## Sanity check
#' preds.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
#' preds.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
#' identical(preds.rfo[, , -c(1, 3, 5)], y = preds.dfo)
deforest <- function(object, which.trees = NULL, ...) {
UseMethod("deforest")
}


#' @rdname deforest
#'
#' @export
deforest.ranger <- function(object, which.trees = NULL, warn = TRUE, ...) {

# Warn users about `predictions` and `prediction.error` components
if (isTRUE(warn)) {
warning("Many of the components of a typical \"ranger\" object are ",
"not available after deforestation and are instead replaced with ",
"`NA` (e.g., out-of-bag (OOB) predictions, variable importance ",
"scores (if requested), and OOB-based error metrics).",
call. = FALSE)
}

# "Remove trees" by removing necessary components from `forest` object
object$forest$child.nodeIDs[which.trees] <- NULL
object$forest$split.values[which.trees] <- NULL
object$forest$split.varIDs[which.trees] <- NULL
object$forest$terminal.class.counts[which.trees] <- NULL # for prob forests
object$forest$chf[which.trees] <- NULL # for survival forests

# Update `num.trees` components so `predict.ranger()` works
object$forest$num.trees <- object$num.trees <-
length(object$forest$child.nodeIDs)

# Coerce other components to `NA` as needed
if (!is.null(object$prediction.error)) {
object$prediction.error <- NA
}
if (!is.null(object$predictions)) { # classification and regression
object$predictions[] <- NA
}
if (!is.null(object$r.squared)) { # regression
object$r.squared <- NA
}
if (!is.null(object$chf)) { # survival forests
object$chf[] <- NA
}
if (!is.null(object$survival)) { # survival forests
object$survival[] <- NA
}
if (object$importance.mode != "none") { # variable importance
object$importance.mode <- NA
object$variable.importance[] <- NA
}

# Return "deforested" forest
class(object) <- c("deforest.ranger", class(object))
object

}


#' Print deforested ranger summary
#'
#' Print basic information about a deforested \code{\link{ranger}} object.
#'
#' @param x A \code{\link{deforest}} object (i.e., an object that inherits from
#' class \code{"deforest.ranger"}).
#'
#' @param ... Further arguments passed to or from other methods.
#'
#' @note Many of the components of a typical \code{\link{ranger}} object are not
#' available after deforestation and are instead replaced with \code{NA} (e.g.,
#' out-of-bag (OOB) predictions, variable importance scores (if requested), and
#' OOB-based error metrics).
#'
#' @seealso \code{\link{deforest}}.
#'
#' @author Brandon M. Greenwell
#'
#' @export
print.deforest.ranger <- function (x, ...) {
cat("Ranger (deforested) result\n\n")
cat("Note that many of the components of a typical \"ranger\" object are",
"not available after deforestation and are instead replaced with `NA`",
"(e.g., out-of-bag (OOB) predictions, variable importance scores (if",
"requested), and OOB-based error metrics)",
"\n\n")
cat("Type: ", x$treetype, "\n")
cat("Number of trees: ", x$num.trees, "\n")
cat("Sample size: ", x$num.samples, "\n")
cat("Number of independent variables: ", x$num.independent.variables, "\n")
cat("Mtry: ", x$mtry, "\n")
cat("Target node size: ", x$min.node.size, "\n")
cat("Variable importance mode: ", x$importance.mode, "\n")
cat("Splitrule: ", x$splitrule, "\n")
if (x$treetype == "Survival") {
cat("Number of unique death times: ", length(x$unique.death.times), "\n")
}
if (!is.null(x$splitrule) && x$splitrule == "extratrees" &&
!is.null(x$num.random.splits)) {
cat("Number of random splits: ", x$num.random.splits, "\n")
}
if (x$treetype == "Classification") {
cat("OOB prediction error: ", x$prediction.error, "\n")
}
else if (x$treetype == "Regression") {
cat("OOB prediction error (MSE): ", x$prediction.error, "\n")
}
else if (x$treetype == "Survival") {
cat("OOB prediction error (1-C): ", x$prediction.error, "\n")
}
else if (x$treetype == "Probability estimation") {
cat("OOB prediction error (Brier s.): ", x$prediction.error, "\n")
}
else {
cat("OOB prediction error: ", x$prediction.error, "\n")
}
if (x$treetype == "Regression") {
cat("R squared (OOB): ", x$r.squared, "\n")
}
}
68 changes: 68 additions & 0 deletions man/deforest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/print.deforest.ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test_deforest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
library(ranger)
library(survival)
context("ranger_deforest")


test_that("deforest works as expected for probability estimation", {
rfo <- ranger(Species ~ ., data = iris, num.trees = 10, probability = TRUE)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, , -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for classification", {
rfo <- ranger(Species ~ ., data = iris, num.trees = 10)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for regression", {
n <- 50
x <- runif(n, min = 0, max = 2*pi)
dat <- data.frame(x = x, y = sin(x) + rnorm(n, sd = 0.1))
rfo <- ranger(y ~ ., data = dat, num.trees = 10)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = dat, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = dat, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for censored outcomes", {
dat <- data.frame(time = runif(100, 1, 10), status = rbinom(100, 1, .5),
x = rbinom(100, 1, .5))
rfo <- ranger(Surv(time, status) ~ x, data = dat, num.trees = 10,
splitrule = "logrank")
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = dat, predict.all = TRUE)
pred.dfo <- predict(dfo, data = dat, predict.all = TRUE)
expect_identical(pred.rfo$chf[, , -c(1, 3, 5)], pred.dfo$chf)
expect_identical(pred.rfo$survival[, , -c(1, 3, 5)], pred.dfo$survival)
})

0 comments on commit 9aac567

Please sign in to comment.