Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Varimp #289

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ Imports:
methods,
pbapply,
ggplot2,
digest,
lattice,
gridExtra,
data.table,
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ importFrom(caret,varImp)
importFrom(data.table,.SD)
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,dcast.data.table)
importFrom(data.table,merge.data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(data.table,setcolorder)
importFrom(data.table,setkeyv)
importFrom(data.table,setnames)
importFrom(data.table,setorderv)
importFrom(digest,digest)
importFrom(ggplot2,aes)
importFrom(ggplot2,autoplot)
importFrom(ggplot2,geom_bar)
Expand Down
97 changes: 43 additions & 54 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ getMetric <- function(x, metric = NULL, return_sd = FALSE) {
#' @param metric a character string representing the metric to extract.
#' If NULL, each model will return the metric it was trained on.
#' If not NULL, the specified metric must be present for EVERY trained model.
#' @impotFrom data.table data.table setorderv
#' @importFrom data.table data.table setorderv
extractModelMetrics <- function(ensemble, metric = NULL) {
stopifnot(is.caretEnsemble(ensemble))
model_metrics <- data.table(
Expand Down Expand Up @@ -117,28 +117,25 @@ summary.caretEnsemble <- function(object, ...) {
print(extractModelMetrics(object), row.names = FALSE)
}

#' @title Calculate the variable importance of variables in a caret model.
#' @description This function wraps the \code{\link[caret]{varImp}} function
#' from the caret package. It returns a \code{\link[data.table]{data.table}} with importances normalized to sum to 1.
#' @param x a \code{\link[caret]{train}} object
#' @param model_name a character string representing the name of the model
#' @param ... additional arguments passed to \code{\link[caret]{varImp}}
#' @return a \code{\link[data.table]{data.table}} with 2 columns: the variables and their importances.
#' @importFrom caret varImp
#' @importFrom data.table data.table
#' @keywords internal
# This function only gets called once, in varImp.caretEnsemble
varImpFrame <- function(x) {
dat <- do.call(rbind.data.frame, x)
dat <- dat[!duplicated(lapply(dat, summary))]

# Parse frame
dat$id <- row.names(dat)
dat$model <- sub("\\.[^\n]*", "", dat$id)
dat$var <- sub("^[^.]*", "", dat$id)
dat$var <- substr(dat$var, 2L, nchar(dat$var))

# Parse intercept variables
dat$var[grep("Inter", dat$var, fixed = TRUE)] <- "Intercept"
dat$id <- NULL
row.names(dat) <- NULL
dat <- reshape(dat,
direction = "wide", v.names = "Overall",
idvar = "var", timevar = "model"
varImpDataTable <- function(x, model_name, ...) {
imp <- caret::varImp(x, ...)
imp <- imp[["importance"]]
imp <- data.table::data.table(
model_name = model_name,
var = trimws(gsub("[`()]", "", row.names(imp)), which = "both"),
imp = imp[["Overall"]] / sum(imp[["Overall"]])
)
row.names(dat) <- dat[, 1L]
dat[, -1L]
imp
}

#' @title Calculate the variable importance of variables in a caretEnsemble.
Expand All @@ -148,47 +145,39 @@ varImpFrame <- function(x) {
#' importance for each model is calculated and then averaged by the weight of the overall model
#' in the ensembled object.
#' @param object a \code{caretEnsemble} to make predictions from.
#' @param ... other arguments to be passed to varImp
#' @return A \code{\link{data.frame}} with one row per variable and one column
#' @param ... additional arguments passed to \code{\link[caret]{varImp}}
#' @importFrom data.table rbindlist dcast.data.table merge.data.table setorderv setcolorder
#' @return A \code{\link[data.table]{data.table}} with one row per variable and one column
#' per model in object
#' @importFrom digest digest
#' @importFrom caret varImp
#' @export
varImp.caretEnsemble <- function(object, ...) {
# Extract and formal individual model importances
# Todo, clean up this code! Make varImp.caretList
coef_importance <- lapply(object$models, caret::varImp)
coef_importance <- lapply(coef_importance, function(x) {
names(x$importance)[1L] <- "Overall"
x$importance <- x$importance[, "Overall", drop = FALSE]
x$importance
})
model_names <- make.names(names(object$models), unique = TRUE, allow_ = TRUE)

# Convert to data.frame
dat <- varImpFrame(coef_importance)
dat[is.na(dat)] <- 0L
names(dat) <- make.names(names(coef_importance))
# Individual model importances
# TODO: varImp.caretList should be a separate function
model_imp <- mapply(varImpDataTable, object$models, model_names, MoreArgs = list(...), SIMPLIFY = FALSE)
model_imp <- data.table::rbindlist(model_imp, fill = TRUE, use.names = TRUE)
model_imp <- data.table::dcast.data.table(model_imp, var ~ model_name, value.var = "imp", fill = 0.0)

# Scale the importances
norm_to_100 <- function(d) d / sum(d) * 100.0
dat <- apply(dat, 2L, norm_to_100)
# Overall importance
ens_imp <- varImpDataTable(object$ens_model, "ensemble")
ens_imp <- data.table::dcast.data.table(ens_imp, model_name ~ var, value.var = "imp", fill = 0.0)

# Calculate overall importance
model_weights <- coef(object$ens_model$finalModel)
# In the case of 2 classes each method will
# have only one coef associated.
# The names of the weights keep the order of the
# models in the ensemble
names(model_weights) <- names(object$models)
model_weights <- model_weights[names(model_weights) %in% names(coef_importance)]
model_weights <- abs(model_weights)
overall <- norm_to_100(apply(dat, 1L, weighted.mean, w = model_weights))
dat <- data.frame(overall = overall, dat)
# Use overall importance to weight individual model importances
model_imp_mat <- as.matrix(model_imp[, model_names, with = FALSE])
ens_imp_mat <- as.matrix(ens_imp[, model_names, with = FALSE])
overall_imp <- data.table::data.table(
var = model_imp[["var"]],
overall = (model_imp_mat %*% t(ens_imp_mat))[, 1L]
)

# Order by overall importance
dat <- dat[order(dat[["overall"]]), ]
# Merge overall importance with individual model importances
imp <- data.table::merge.data.table(overall_imp, model_imp, by = "var", all = TRUE)

dat
# Order and return
data.table::setorderv(imp, "overall", order = -1L)
data.table::setcolorder(imp, c("var", "overall", model_names))
imp
}

#' @title Plot Diagnostics for an caretEnsemble Object
Expand Down
Loading
Loading