Skip to content

Commit

Permalink
Refactor Varimp (#289)
Browse files Browse the repository at this point in the history
* refactor varimp

* varimp

* revert test updates

* rebuild
  • Loading branch information
zachmayer authored Jul 26, 2024
1 parent 1c2c15d commit 037d994
Show file tree
Hide file tree
Showing 9 changed files with 3,795 additions and 3,901 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ Imports:
methods,
pbapply,
ggplot2,
digest,
lattice,
gridExtra,
data.table,
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ importFrom(caret,varImp)
importFrom(data.table,.SD)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,dcast.data.table)
importFrom(data.table,merge.data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(data.table,setcolorder)
importFrom(data.table,setkeyv)
importFrom(data.table,setnames)
importFrom(data.table,setorderv)
importFrom(digest,digest)
importFrom(ggplot2,aes)
importFrom(ggplot2,autoplot)
importFrom(ggplot2,geom_bar)
Expand Down
97 changes: 43 additions & 54 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ getMetric <- function(x, metric = NULL, return_sd = FALSE) {
#' @param metric a character string representing the metric to extract.
#' If NULL, each model will return the metric it was trained on.
#' If not NULL, the specified metric must be present for EVERY trained model.
#' @impotFrom data.table data.table setorderv
#' @importFrom data.table data.table setorderv
extractModelMetrics <- function(ensemble, metric = NULL) {
stopifnot(is.caretEnsemble(ensemble))
model_metrics <- data.table(
Expand Down Expand Up @@ -117,28 +117,25 @@ summary.caretEnsemble <- function(object, ...) {
print(extractModelMetrics(object), row.names = FALSE)
}

#' @title Calculate the variable importance of variables in a caret model.
#' @description This function wraps the \code{\link[caret]{varImp}} function
#' from the caret package. It returns a \code{\link[data.table]{data.table}} with importances normalized to sum to 1.
#' @param x a \code{\link[caret]{train}} object
#' @param model_name a character string representing the name of the model
#' @param ... additional arguments passed to \code{\link[caret]{varImp}}
#' @return a \code{\link[data.table]{data.table}} with 2 columns: the variables and their importances.
#' @importFrom caret varImp
#' @importFrom data.table data.table
#' @keywords internal
# This function only gets called once, in varImp.caretEnsemble
varImpFrame <- function(x) {
dat <- do.call(rbind.data.frame, x)
dat <- dat[!duplicated(lapply(dat, summary))]

# Parse frame
dat$id <- row.names(dat)
dat$model <- sub("\\.[^\n]*", "", dat$id)
dat$var <- sub("^[^.]*", "", dat$id)
dat$var <- substr(dat$var, 2L, nchar(dat$var))

# Parse intercept variables
dat$var[grep("Inter", dat$var, fixed = TRUE)] <- "Intercept"
dat$id <- NULL
row.names(dat) <- NULL
dat <- reshape(dat,
direction = "wide", v.names = "Overall",
idvar = "var", timevar = "model"
varImpDataTable <- function(x, model_name, ...) {
imp <- caret::varImp(x, ...)
imp <- imp[["importance"]]
imp <- data.table::data.table(
model_name = model_name,
var = trimws(gsub("[`()]", "", row.names(imp)), which = "both"),
imp = imp[["Overall"]] / sum(imp[["Overall"]])
)
row.names(dat) <- dat[, 1L]
dat[, -1L]
imp
}

#' @title Calculate the variable importance of variables in a caretEnsemble.
Expand All @@ -148,47 +145,39 @@ varImpFrame <- function(x) {
#' importance for each model is calculated and then averaged by the weight of the overall model
#' in the ensembled object.
#' @param object a \code{caretEnsemble} to make predictions from.
#' @param ... other arguments to be passed to varImp
#' @return A \code{\link{data.frame}} with one row per variable and one column
#' @param ... additional arguments passed to \code{\link[caret]{varImp}}
#' @importFrom data.table rbindlist dcast.data.table merge.data.table setorderv setcolorder
#' @return A \code{\link[data.table]{data.table}} with one row per variable and one column
#' per model in object
#' @importFrom digest digest
#' @importFrom caret varImp
#' @export
varImp.caretEnsemble <- function(object, ...) {
# Extract and formal individual model importances
# Todo, clean up this code! Make varImp.caretList
coef_importance <- lapply(object$models, caret::varImp)
coef_importance <- lapply(coef_importance, function(x) {
names(x$importance)[1L] <- "Overall"
x$importance <- x$importance[, "Overall", drop = FALSE]
x$importance
})
model_names <- make.names(names(object$models), unique = TRUE, allow_ = TRUE)

# Convert to data.frame
dat <- varImpFrame(coef_importance)
dat[is.na(dat)] <- 0L
names(dat) <- make.names(names(coef_importance))
# Individual model importances
# TODO: varImp.caretList should be a separate function
model_imp <- mapply(varImpDataTable, object$models, model_names, MoreArgs = list(...), SIMPLIFY = FALSE)
model_imp <- data.table::rbindlist(model_imp, fill = TRUE, use.names = TRUE)
model_imp <- data.table::dcast.data.table(model_imp, var ~ model_name, value.var = "imp", fill = 0.0)

# Scale the importances
norm_to_100 <- function(d) d / sum(d) * 100.0
dat <- apply(dat, 2L, norm_to_100)
# Overall importance
ens_imp <- varImpDataTable(object$ens_model, "ensemble")
ens_imp <- data.table::dcast.data.table(ens_imp, model_name ~ var, value.var = "imp", fill = 0.0)

# Calculate overall importance
model_weights <- coef(object$ens_model$finalModel)
# In the case of 2 classes each method will
# have only one coef associated.
# The names of the weights keep the order of the
# models in the ensemble
names(model_weights) <- names(object$models)
model_weights <- model_weights[names(model_weights) %in% names(coef_importance)]
model_weights <- abs(model_weights)
overall <- norm_to_100(apply(dat, 1L, weighted.mean, w = model_weights))
dat <- data.frame(overall = overall, dat)
# Use overall importance to weight individual model importances
model_imp_mat <- as.matrix(model_imp[, model_names, with = FALSE])
ens_imp_mat <- as.matrix(ens_imp[, model_names, with = FALSE])
overall_imp <- data.table::data.table(
var = model_imp[["var"]],
overall = (model_imp_mat %*% t(ens_imp_mat))[, 1L]
)

# Order by overall importance
dat <- dat[order(dat[["overall"]]), ]
# Merge overall importance with individual model importances
imp <- data.table::merge.data.table(overall_imp, model_imp, by = "var", all = TRUE)

dat
# Order and return
data.table::setorderv(imp, "overall", order = -1L)
data.table::setcolorder(imp, c("var", "overall", model_names))
imp
}

#' @title Plot Diagnostics for an caretEnsemble Object
Expand Down
Loading

0 comments on commit 037d994

Please sign in to comment.