Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor inner functions #303

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ caretList <- function(
if (is.null(metric)) {
metric <- "RMSE"
if (is_class) {
metric <- "Accuracy"
if (is_binary) {
metric <- "ROC"
}
metric <- if (is_binary) "ROC" else "Accuracy"
}
}

Expand All @@ -94,6 +91,10 @@ caretList <- function(
)
}

# ALWAYS save class probs
trControl[["classProbs"]] <- is_class
trControl["savePredictions"] <- "final"

# Capture global arguments for train as a list
# Squish trControl back onto the global arguments list
global_args <- list(...)
Expand Down
202 changes: 106 additions & 96 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,22 @@
#' @param excluded_class_id an integer indicating the class to exclude. If 0L, no class is excluded
#' @param ... additional arguments to pass to \code{\link[caret]{predict.train}}, if newdata is not NULL
#' @return a data.table
#' @keywords internal
caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
stopifnot(methods::is(object, "train"))

# Extract the model type
model_type <- extractModelType(object, validate_for_stacking = is.null(newdata))
is_class <- isClassifierAndValidate(object, validate_for_stacking = is.null(newdata))

# If newdata is NULL, return the stacked predictions
if (is.null(newdata)) {
# Extract the best tune
a <- data.table::data.table(object$bestTune, key = names(object$bestTune))

# Extract the best predictions
b <- data.table::data.table(object$pred, key = names(object$bestTune))

# Subset pred data to the best tune only
pred <- b[a, ]

# Keep only the predictions
keep_cols <- "pred"
if (model_type == "Classification") {
keep_cols <- levels(object)
}
pred <- pred[, c("rowIndex", keep_cols), drop = FALSE, with = FALSE]

# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
data.table::setkeyv(pred, "rowIndex")
pred <- pred[, lapply(.SD, mean), by = "rowIndex"]
data.table::setorderv(pred, "rowIndex")

# Remove the rowIndex
data.table::set(pred, j = "rowIndex", value = NULL)
pred <- extractBestPreds(object)
keep_cols <- if (is_class) levels(object) else "pred"
pred <- pred[, keep_cols, with = FALSE]

# Otherwise, predict on newdata
} else {
if (model_type == "Classification") {
if (is_class) {
pred <- caret::predict.train(object, type = "prob", newdata = newdata, ...)
} else {
pred <- caret::predict.train(object, type = "raw", newdata = newdata, ...)
Expand All @@ -62,7 +42,7 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {
# Make sure in both cases we have consitent column names and column order
# Drop the excluded class for classificaiton
stopifnot(nrow(pred) == nrow(newdata))
if (model_type == "Classification") {
if (is_class) {
stopifnot(
ncol(pred) == nlevels(object),
names(pred) == levels(object)
Expand Down Expand Up @@ -110,9 +90,9 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =
model <- do.call(caret::train, model_args)
}

# Use data.table for stacked predictions
# Only save stacked predictions for the best model
if ("pred" %in% names(model)) {
model[["pred"]] <- data.table::data.table(model[["pred"]])
model[["pred"]] <- extractBestPreds(model)
}

if (trim) {
Expand Down Expand Up @@ -143,6 +123,55 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =
model
}

#' @title Aggregate mean or first
#' @description For numeric data take the mean. For character data take the first value.
#' @param x a train object
#' @return a data.table::data.table with predictions
#' @keywords internal
aggregate_mean_or_first <- function(x) {
if (is.numeric(x)) {
mean(x)
} else {
x[1L]
}
}

#' @title Extract the best predictions from a train object
#' @description Extract the best predictions from a train object.
#' @param x a train object
#' @return a data.table::data.table with predictions
#' @keywords internal
extractBestPreds <- function(x) {
stopifnot(methods::is(x, "train"))
if (is.null(x$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
}
stopifnot(methods::is(x$pred, "data.frame"))

# Extract the best tune
keys <- names(x$bestTune)
best_tune <- data.table::data.table(x$bestTune, key = keys)

# Extract the best predictions
pred <- data.table::data.table(x$pred, key = keys)

# Subset pred data to the best tune only
# Drop rows for other tunes
pred <- pred[best_tune, ]

# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]

# Order results consistently
data.table::setorderv(pred, keys)

# Return
pred
}

#' @title Validate the excluded class
#' @description Helper function to ensure that the excluded level for classification is an integer.
#' Set to 0L to exclude no class.
Expand Down Expand Up @@ -202,45 +231,6 @@ dropExcludedClass <- function(x, all_classes, excluded_class_id) {
x
}

#' @title Extract the model type from a \code{\link[caret]{train}} object
#' @description Extract the model type from a \code{\link[caret]{train}} object.
#' For classification, validates that the model can predict probabilities, and,
#' if stacked predictions are requested, that classProbs = TRUE.
#' @param object a \code{\link[caret]{train}} object
#' @param validate_for_stacking a logical indicating whether to validate the object for stacked predictions
#' @return a character string
#' @keywords internal
extractModelType <- function(object, validate_for_stacking = TRUE) {
stopifnot(methods::is(object, "train"))

# Extract type
model_type <- object$modelType

# Class or reg?
is_class <- model_type == "Classification"

# Validate for predictions
if (is_class && !is.function(object$modelInfo$prob)) {
stop("No probability function found. Re-fit with a method that supports prob.", call. = FALSE)
}
# Validate for stacked predictions
if (validate_for_stacking) {
err <- "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions."
if (is.null(object$control$savePredictions)) {
stop(err, call. = FALSE)
}
if (!object$control$savePredictions %in% c("all", "final", TRUE)) {
stop(err, call. = FALSE)
}
if (is_class && !object$control$classProbs) {
stop("classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl.", call. = FALSE)
}
}

# Return
model_type
}

#' @title S3 definition for concatenating train objects
#'
#' @description take N objects of class train and concatenate into an object of class caretList for future ensembling
Expand Down Expand Up @@ -308,8 +298,8 @@ extractMetric <- function(x, ...) {
#' @param ... ignored
#' If NULL, uses the metric that was used to train the model.
#' @return A numeric representing the metric desired metric.
#' @export
#' @method extractMetric train
#' @export
extractMetric.train <- function(x, metric = NULL, ...) {
if (is.null(metric) || !metric %in% names(x$results)) {
metric <- x$metric
Expand Down Expand Up @@ -340,6 +330,7 @@ extractMetric.train <- function(x, metric = NULL, ...) {
#' used instead.
#' @param x a single caret train object
#' @return Name associated with model
#' @keywords internal
extractModelName <- function(x) {
if (is.list(x$method)) {
checkCustomModel(x$method)$method
Expand All @@ -350,33 +341,52 @@ extractModelName <- function(x) {
}
}

#' @title Extract the best predictions and observations from a train object
#' @description This function extracts the best predictions and observations from a train object
#' and then calculates residuals. It only uses one class for classification models, by default class 2.
#' @param object a \code{train} object
#' @param show_class_id For classification only: which class level to use for residuals
#' @return a data.table::data.table with predictions, observeds, and residuals
extractPredObsResid <- function(object, show_class_id = 2L) {
if (is.null(object$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
#' @title Is Classifier
#' @description Check if a model is a classifier.
#' @param model A train object from the caret package.
#' @return A logical indicating whether the model is a classifier.
#' @keywords internal
isClassifier <- function(model) {
stopifnot(methods::is(model, "train") || methods::is(model, "caretStack"))
if (methods::is(model, "train")) {
out <- model$modelType == "Classification"
} else {
out <- model$ens_model$modelType == "Classification"
}
stopifnot(
methods::is(object, "train"),
is.data.frame(object$pred)
)
keep_cols <- c("pred", "obs", "rowIndex")
type <- object$modelType
predobs <- data.table::data.table(object$pred)
if (type == "Classification") {
show_class <- levels(object)[show_class_id]
data.table::set(predobs, j = "pred", value = predobs[[show_class]])
data.table::set(predobs, j = "obs", value = as.integer(predobs[["obs"]] == show_class))
out
}

#' @title Validate a model type
#' @description Validate the model type from a \code{\link[caret]{train}} object.
#' For classification, validates that the model can predict probabilities, and,
#' if stacked predictions are requested, that classProbs = TRUE.
#' @param object a \code{\link[caret]{train}} object
#' @param validate_for_stacking a logical indicating whether to validate the object for stacked predictions
#' @return a logical. TRUE if classifier, otherwise FALSE.
#' @keywords internal
isClassifierAndValidate <- function(object, validate_for_stacking = TRUE) {
stopifnot(methods::is(object, "train"))

is_class <- isClassifier(object)

# Validate for predictions
if (is_class && !is.function(object$modelInfo$prob)) {
stop("No probability function found. Re-fit with a method that supports prob.", call. = FALSE)
}
predobs <- predobs[, keep_cols, with = FALSE]
data.table::setkeyv(predobs, "rowIndex")
predobs <- predobs[, lapply(.SD, mean), by = "rowIndex"]
r <- predobs[["obs"]] - predobs[["pred"]]
data.table::set(predobs, j = "resid", value = r)
data.table::setorderv(predobs, "rowIndex")
predobs
# Validate for stacked predictions
if (validate_for_stacking) {
err <- "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions."
if (is.null(object$control$savePredictions)) {
stop(err, call. = FALSE)
}
if (!object$control$savePredictions %in% c("all", "final", TRUE)) {
stop(err, call. = FALSE)
}
if (is_class && !object$control$classProbs) {
stop("classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl.", call. = FALSE)
}
}

# Return
is_class
}
43 changes: 35 additions & 8 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ predict.caretStack <- function(
check_caretStack(object)

# Extract model types
model_type <- object$ens_model$modelType
is_class <- model_type == "Classification"
is_class <- isClassifier(object)

# If the excluded class wasn't set at train time, set it
object <- set_excluded_class_id(object, model_type)
object <- set_excluded_class_id(object, is_class)

# Check return_class_only
if (return_class_only) {
Expand Down Expand Up @@ -224,10 +223,10 @@ check_caretStack <- function(object) {
#' @description Set the excluded class id for a caretStack object
#'
#' @param object a caretStack object
#' @param model_type the model type as a character vector with length 1
#' @param is_class the model type as a logical vector with length 1
#' @keywords internal
set_excluded_class_id <- function(object, model_type) {
if (model_type == "Classification" && is.null(object[["excluded_class_id"]])) {
set_excluded_class_id <- function(object, is_class) {
if (is_class && is.null(object[["excluded_class_id"]])) {
object[["excluded_class_id"]] <- 1L
warning("No excluded_class_id set. Setting to 1L.", call. = FALSE)
}
Expand Down Expand Up @@ -412,6 +411,34 @@ plot.caretStack <- function(x, metric = NULL, ...) {
plt
}

#' @title Extracted stacked residuals for the autoplot
#' @description This function extracts the predictions, observeds, and residuals from a \code{train} object.
#' It uses the object's stacked predictions from cross-validation.
#' @param object a \code{train} object
#' @param show_class_id For classification only: which class level to use for residuals
#' @return a data.table::data.table with predictions, observeds, and residuals
#' @keywords internal
stackedTrainResiduals <- function(object, show_class_id = 2L) {
stopifnot(methods::is(object, "train"))
is_class <- isClassifier(object)
predobs <- extractBestPreds(object)
rowIndex <- predobs[["rowIndex"]]
pred <- predobs[["pred"]]
obs <- predobs[["obs"]]
if (is_class) {
show_class <- levels(object)[show_class_id]
pred <- predobs[[show_class]]
obs <- as.integer(obs == show_class)
}
predobs <- data.table::data.table(
rowIndex = rowIndex,
pred = pred,
obs = obs,
resid = obs - pred
)
predobs
}
zachmayer marked this conversation as resolved.
Show resolved Hide resolved

#' @title Convenience function for more in-depth diagnostic plots of caretStack objects
#' @description This function provides a more robust series of diagnostic plots
#' for a caretEnsemble object.
Expand Down Expand Up @@ -445,7 +472,7 @@ plot.caretStack <- function(x, metric = NULL, ...) {
# https://github.com/thomasp85/patchwork/issues/226 — why we need importFrom patchwork plot_layout
autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
stopifnot(methods::is(object, "caretStack"))
ensemble_data <- extractPredObsResid(object$ens_model, show_class_id = show_class_id)
ensemble_data <- stackedTrainResiduals(object$ens_model, show_class_id = show_class_id)

# Performance metrics by model
g1 <- plot(object) + ggplot2::labs(title = "Metric and SD For Component Models")
Expand All @@ -470,7 +497,7 @@ autoplot.caretStack <- function(object, xvars = NULL, show_class_id = 2L, ...) {
ggplot2::theme_bw()

# Disagreement in sub-model residuals
sub_model_data <- lapply(object$models, extractPredObsResid, show_class_id = show_class_id)
sub_model_data <- lapply(object$models, stackedTrainResiduals, show_class_id = show_class_id)
for (model_name in names(sub_model_data)) {
data.table::set(sub_model_data[[model_name]], j = "model", value = model_name)
}
Expand Down
16 changes: 0 additions & 16 deletions R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,6 @@ mae <- function(a, b) {
mean(abs(a - b))
}

#' @title Is Classifier
#' @description Check if a model is a classifier.
#' @param model A train object from the caret package.
#' @return A logical indicating whether the model is a classifier.
#' @keywords internal
isClassifier <- function(model) {
stopifnot(methods::is(model, "train") || methods::is(model, "caretStack"))
if (methods::is(model, "train")) {
out <- model$modelType == "Classification"
} else {
out <- model$ens_model$modelType == "Classification"
}
out
}


#' @title Shuffled MAE
#' @description Compute the mean absolute error of a model's predictions when a variable is shuffled.
#' @param original_data A data.table of the original data.
Expand Down
Binary file modified coverage.rds
Binary file not shown.
Loading