Skip to content

Commit

Permalink
Refactor and standardize predcitions (#290)
Browse files Browse the repository at this point in the history
* reorg code files

* refactor preds/stack to share common logic

* continue refactor

* fix tests

* fix most tests except ensemble accuracy

* fix-ensemble-tests

* rebuild

* add failing test

* try to fix stacked stack

* try to make importance and SE worl

* fix some tests, add a new mode

* fix class selection

* tests passgit add -Agit add -Agit add -A

* work on vignettes

* vignettes

* 100 coverage

* rebuild

* Fix PR feedback
  • Loading branch information
zachmayer authored Jul 29, 2024
1 parent 037d994 commit 6ff11e0
Show file tree
Hide file tree
Showing 45 changed files with 6,436 additions and 6,621 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ pip-log.txt
caretEnsemble_test_plots.png
doc
Meta
/doc/
/Meta/
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ linters: linters_with_tags(
"robustness",
"style",
"tidy_design"
),
), # TODO: add todo linter
return_linter(),
object_overwrite_linter(),
object_length_linter(),
line_length_linter(120),
object_usage_linter(),
object_name_linter(),
cyclocomp_linter(17L), # predict.caretStack() has a cyclomatic complexity of 17
todo_comment_linter = NULL, # TODO refactor
implicit_assignment_linter = NULL, # TODO refactor
expect_identical_linter = NULL, # TODO big refactor
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: caretEnsemble
Type: Package
Title: Ensembles of Caret Models
Version: 2.0.4
Date: 2024-06-25
Date: 2024-07-27
Authors@R: c(person(c("Zachary", "A."), "Deane-Mayer", role = c("aut", "cre"),
email = "zach.mayer@gmail.com"),
person(c("Jared", "E."), "Knowles", role=c("aut"),
Expand Down
10 changes: 8 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Makefile for R project

.PHONY: all install-deps install document update-test-fixtures test coverage-test coverage check-cran fix-style lint spell clean
.PHONY: all install-deps install document update-test-fixtures test coverage-test coverage check-cran fix-style lint spell build-vignettes clean

# Default target
all: clean fix-style document install lint spell test check-cran coverage
all: clean fix-style document install lint spell test build-vignettes check-cran coverage

# Install dependencies
install-deps:
Expand Down Expand Up @@ -88,6 +88,10 @@ spell:
}; \
"

# Build vignettes
build-vignettes:
Rscript -e "devtools::build_vignettes()"

# Clean up generated files
clean:
rm -rf *.Rcheck
Expand All @@ -99,3 +103,5 @@ clean:
rm -f .Rhistory
rm -rf lib/
rm -f caretEnsemble_test_plots.png
Rscript -e "devtools::clean_vignettes()"
Rscript -e "devtools::clean_dll()"
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ export(caretEnsemble)
export(caretList)
export(caretModelSpec)
export(caretStack)
export(check_binary_classification)
export(getMetric)
export(is.caretEnsemble)
export(is.caretList)
Expand Down
2 changes: 2 additions & 0 deletions R/caretEnsemble-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#' @importFrom graphics plot
#' @importFrom methods is
#' @importFrom stats coef median model.frame model.response predict qnorm reshape resid residuals weighted.mean weights
#' @importFrom data.table .SD
#' @importFrom rlang .data
"_PACKAGE"

#' @title caretList of classification models
Expand Down
62 changes: 42 additions & 20 deletions R/caretEnsemble.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#' @title Check binary classification
#' @description Check that the problem is a binary classification problem
#'
#' @param list_of_models a list of caret models to check
#' @keywords internal
check_binary_classification <- function(list_of_models) {
if (is.list(list_of_models) && length(list_of_models) > 1L) {
lapply(list_of_models, function(x) {
# avoid regression models
if (is(x, "train") && !is.null(x$pred$obs) && is.factor(x$pred$obs) && nlevels(x$pred$obs) > 2L) {
stop("caretEnsemble only supports binary classification problems")
}
})
}
invisible(NULL)
}

#' @title Combine several predictive models via weights
#'
#' @description Find a good linear combination of several classification or regression models,
Expand Down Expand Up @@ -130,10 +147,12 @@ summary.caretEnsemble <- function(object, ...) {
varImpDataTable <- function(x, model_name, ...) {
imp <- caret::varImp(x, ...)
imp <- imp[["importance"]]
# Normalize to sum to 1, by class or overall
imp_by_class <- data.table::as.data.table(lapply(imp, function(x) x / sum(x)))
imp <- data.table::data.table(
model_name = model_name,
var = trimws(gsub("[`()]", "", row.names(imp)), which = "both"),
imp = imp[["Overall"]] / sum(imp[["Overall"]])
imp_by_class
)
imp
}
Expand All @@ -157,11 +176,11 @@ varImp.caretEnsemble <- function(object, ...) {
# TODO: varImp.caretList should be a separate function
model_imp <- mapply(varImpDataTable, object$models, model_names, MoreArgs = list(...), SIMPLIFY = FALSE)
model_imp <- data.table::rbindlist(model_imp, fill = TRUE, use.names = TRUE)
model_imp <- data.table::dcast.data.table(model_imp, var ~ model_name, value.var = "imp", fill = 0.0)
model_imp <- data.table::dcast.data.table(model_imp, var ~ model_name, value.var = "Overall", fill = 0.0)

# Overall importance
ens_imp <- varImpDataTable(object$ens_model, "ensemble")
ens_imp <- data.table::dcast.data.table(ens_imp, model_name ~ var, value.var = "imp", fill = 0.0)
ens_imp <- data.table::dcast.data.table(ens_imp, model_name ~ var, value.var = "Overall", fill = 0.0)

# Use overall importance to weight individual model importances
model_imp_mat <- as.matrix(model_imp[, model_names, with = FALSE])
Expand Down Expand Up @@ -227,21 +246,25 @@ plot.caretEnsemble <- function(x, ...) {
#' @return a data.table with predictions, observeds, and residuals
#' @importFrom data.table data.table
extractPredObsResid <- function(object, show_class_id = 2L) {
stopifnot(is(object, "train"))
stopifnot(
is(object, "train"),
is.data.frame(object$pred)
)
keep_cols <- c("pred", "obs", "rowIndex")
type <- object$modelType
predobs <- extractBestPredsAndObs(list(object))
pred <- predobs$pred
obs <- predobs$obs
id <- predobs$rowIndex
if (type == "Regression") {
pred <- pred[[1L]]
} else {
predobs <- data.table(object$pred)
if (type == "Classification") {
show_class <- levels(object)[show_class_id]
pred <- pred[[show_class]]
obs <- as.integer(obs == show_class)
set(predobs, j = "pred", value = predobs[[show_class]])
set(predobs, j = "obs", value = as.integer(predobs[["obs"]] == show_class))
}
out <- data.table::data.table(pred, obs, resid = obs - pred, id)
out
predobs <- predobs[, keep_cols, with = FALSE]
data.table::setkeyv(predobs, "rowIndex")
predobs <- predobs[, lapply(.SD, mean), by = "rowIndex"]
r <- predobs[["obs"]] - predobs[["pred"]]
data.table::set(predobs, j = "resid", value = r)
data.table::setorderv(predobs, "rowIndex")
predobs
}

#' @title Convenience function for more in-depth diagnostic plots of caretEnsemble objects
Expand Down Expand Up @@ -279,7 +302,6 @@ extractPredObsResid <- function(object, show_class_id = 2L) {
autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ...) {
stopifnot(is(object, "caretEnsemble"))
ensemble_data <- extractPredObsResid(object$ens_model, show_class_id = show_class_id)
data.table::setkeyv(ensemble_data, "id")

# Performance metrics by model
g1 <- plot(object) + labs(title = "Metric and SD For Component Models")
Expand Down Expand Up @@ -314,7 +336,7 @@ autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ...
ymax = max(.SD[["resid"]]),
yavg = median(.SD[["resid"]]),
yhat = .SD[["pred"]][1L]
), by = "id"]
), by = "rowIndex"]
g4 <- ggplot2::ggplot(sub_model_summary, ggplot2::aes(
x = .data[["yhat"]],
y = .data[["yavg"]]
Expand All @@ -341,8 +363,8 @@ autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ...
xvars <- setdiff(xvars, c(".outcome", ".weights", "(Intercept)"))
xvars <- sample(xvars, 2L)
}
data.table::set(x_data, j = "id", value = seq_len(nrow(x_data)))
plotdf <- merge(ensemble_data, x_data, by = "id")
data.table::set(x_data, j = "rowIndex", value = seq_len(nrow(x_data)))
plotdf <- merge(ensemble_data, x_data, by = "rowIndex")
g5 <- ggplot2::ggplot(plotdf, ggplot2::aes(.data[[xvars[1L]]], .data[["resid"]])) +
ggplot2::geom_point() +
ggplot2::geom_smooth(se = FALSE) +
Expand All @@ -357,5 +379,5 @@ autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ...
ggplot2::scale_y_continuous("Residuals") +
ggplot2::labs(title = paste0("Residuals Against ", xvars[2L])) +
ggplot2::theme_bw()
suppressMessages(gridExtra::grid.arrange(g1, g2, g3, g4, g5, g6, ncol = 2L))
suppressWarnings(suppressMessages(gridExtra::grid.arrange(g1, g2, g3, g4, g5, g6, ncol = 2L)))
}
94 changes: 60 additions & 34 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ tuneCheck <- function(x) {
x
}

#' @title Validate a custom caret model info list
#' @description Currently, this only ensures that all model info lists
#' were also assigned a "method" attribute for consistency with usage
#' of non-custom models
#' @param x a model info list (e.g. \code{getModelInfo("rf", regex=F)\[[1]]})
#' @return validated model info list (i.e. x)
checkCustomModel <- function(x) {
if (is.null(x$method)) {
stop(paste(
"Custom models must be defined with a \"method\" attribute containing the name",
"by which that model should be referenced. Example: my.glm.model$method <- \"custom_glm\""
))
}
x
}

#' @title Check that the methods supplied by the user are valid caret methods
#' @description This function uses modelLookup from caret to ensure the list of
#' methods supplied by the user are all models caret can fit.
Expand Down Expand Up @@ -276,7 +292,7 @@ as.caretList <- function(object) {
if (is.null(object)) {
stop("object is null")
}
UseMethod("as.caretList")
UseMethod("as.caretList", object)
}

#' @title Convert object to caretList object - For Future Use
Expand Down Expand Up @@ -345,51 +361,61 @@ as.caretList.list <- function(object) {
#' @importFrom data.table as.data.table setnames
#' @export
#' @method predict caretList
predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_class_id = 0L, ...) {
predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_class_id = 1L, ...) {
stopifnot(is.caretList(object))

# Decided whether to be verbose or quiet
apply_fun <- lapply
if (verbose) {
apply_fun <- pbapply::pblapply
}

# Check data
if (is.null(newdata)) {
train_data_nulls <- sapply((object), function(x) is.null(x[["trainingData"]]))
if (any(train_data_nulls)) {
stop("newdata is NULL and trainingData is NULL for some models. Use newdata or retrain with returnData=TRUE.")
}
}

# Loop over the models and make predictions
preds <- apply_fun(object, function(x) {
type <- x$modelType

# predict for class
if (type == "Classification") {
# use caret::levels.train to extract the levels of the target from each model
# and then drop the excluded class if needed
pred <- caret::predict.train(x, type = "prob", newdata = newdata, ...)
pred <- data.table::as.data.table(pred)
pred <- dropExcludedClass(pred, all_classes = levels(x), excluded_class_id = excluded_class_id)

# predict for reg
} else if (type == "Regression") {
pred <- caret::predict.train(x, type = "raw", newdata = newdata)
pred <- data.table::as.data.table(pred)

# Error
preds <- apply_fun(object, caretPredict, newdata = newdata, excluded_class_id = excluded_class_id, ...)
stopifnot(
is.list(preds),
length(preds) >= 1L,
length(preds) == length(object),
sapply(preds, data.table::is.data.table)
)

# All preds must have the same number of rows.
# We allow different columns, and even different column names!
# E.g. you could mix classification and regression models
# caretPredict will aggregate multiple predictions for the same row (e.g. repeated CV)
# caretPredict will make sure the rows are sorted by the original row order
pred_rows <- sapply(preds, nrow)
stopifnot(pred_rows == pred_rows[1L])

# Name the predictions
for (i in seq_along(preds)) {
p <- preds[[i]]
model_name <- names(object)[i]
if (ncol(p) == 1L) {
# For a single column, name it after the model (e.g. regression or binary with an excluded class)
setnames(p, names(p), model_name)
} else {
stop(paste("Unknown model type:", type))
# For multiple columns, name them including the model (e.g. multiclass)
setnames(p, names(p), paste(model_name, names(p), sep = "_"))
}
}
preds <- unname(preds)

# Return
pred
})

# Turn a list of data tables into one data.table
# Note that data.table will name the columns based off the names of the list and the names of each data.table
# Combine the predictions into a single data.table
preds <- data.table::as.data.table(preds)

stopifnot(
!is.null(names(preds)),
length(dim(preds)) == 2L
)
all_regression <- all(sapply(object, function(x) x$modelType == "Regression"))
if (all_regression) {
stopifnot(
length(names(preds)) == length(object),
names(preds) == names(object)
)
}

# Return
preds
}
Loading

0 comments on commit 6ff11e0

Please sign in to comment.