Skip to content

Commit

Permalink
Let the user choose which class to exclude in careStack train and pre…
Browse files Browse the repository at this point in the history
…dict
  • Loading branch information
antongomez committed Jun 10, 2024
1 parent 284e9ae commit cd5250d
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 12 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ local.properties
*.vspscc
.builds
*.dotCover
# Auxiliar files
checkList.md
caretEnsembleTests.ipynb
## TODO: If you have NuGet Package Restore enabled, uncomment this
#packages/
# Visual C++ cache files
Expand Down
8 changes: 5 additions & 3 deletions R/caretStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ predict.caretStack <- function(
preds <- predict(object$models, newdata = newdata, na.action = na.action)

if (type == "Classification") {
# Do not include last class asociated probabilities
column_names <- colnames(preds)
# TODO: let the user specify which classes to include or which to exclude
num_classes <- length(levels(object$models[[1]]$pred$obs))
classes_included <- levels(object$models[[1]]$pred$obs)[-num_classes] # exclude last class
if (getMulticlassExcludedLevel() >= 1 && getMulticlassExcludedLevel() <= num_classes) {
classes_included <- levels(object$models[[1]]$pred$obs)[-getMulticlassExcludedLevel()]
} else {
classes_included <- levels(object$models[[1]]$pred$obs)
}
pattern <- paste(classes_included, collapse = "|")
# Remove columns that are associated with the class that was excluded
filtered_column_names <- grep(pattern, column_names, value = TRUE)
Expand Down
88 changes: 83 additions & 5 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ setBinaryTargetLevel <- function(level) {
#' @param arg argument to potentially be used as new target level
#' @return Binary target level (as integer equal to 1 or 2)
validateBinaryTargetLevel <- function(arg) {
val <- suppressWarnings(try(as.integer(arg), silent = T))
val <- suppressWarnings(try(as.integer(arg), silent = TRUE))
if (!is.integer(val) || !val %in% c(1L, 2L)) {
stop(paste0(
"Specified target binary class level is not valid. ",
Expand All @@ -48,6 +48,59 @@ validateBinaryTargetLevel <- function(arg) {
val
}

#' @title Return the configured multiclass excluded level
#' @description To train a model using probability outputs
#' provided by other models in a classification problem, it is
#' necessary to exclude one of the classes. By default, this class
#' is assumed to be the first level in an outcome factor,
#' but this setting can be overridden using
#' \code{setMulticlassTargetLevel(3L)} if the classification
#' problem has at least 3 classes.
#' @seealso setMulticlassTargetLevel
#' @return Currently configured multiclass excluded level (as integer)
#' @export
getMulticlassExcludedLevel <- function() {
arg <- getOption("caret.ensemble.multiclass.excluded.level", default = 1L)
validateMulticlassExcludedLevel(arg)
}

#' @title Set the multiclass excluded level
#' @description To train a model using probability outputs
#' provided by other models in a classification problem, it is
#' necessary to exclude one of the classes. By default, this class
#' is assumed to be the first level in an outcome factor,
#' but this setting can be overridden using
#' \code{setMulticlassTargetLevel(3L)} if the classification
#' problem has at least 3 classes.
#' @note Setting this value outside the range between 1 and
#' the number of classes will cause caretStack to train the model
#' with the probabilities associated with ALL classes, leading to
#' potential collinearity issues.
#' @param level an integer to be used as excluded
#' @seealso getMulticlassExcludedLevel
#' @export
setMulticlassExcludedLevel <- function(level) {
level <- validateMulticlassExcludedLevel(level)
options(caret.ensemble.multiclass.excluded.level = level)
}

#' @title Validate arguments given as multiclass excluded level
#' @description Helper function used to ensure that excluded
#' multiclass levels given by clients can be coerced to an integer.
#' @param arg argument to potentially be used as new excluded level
#' @return Multiclass excluded level (as integer)
validateMulticlassExcludedLevel <- function(arg) {
val <- suppressWarnings(try(as.integer(arg), silent = TRUE))
if (!is.integer(val)) {
stop(paste0(
"Specified multiclass excluded level is not valid. ",
"Value should be a integer but '", arg, "' was given ",
"(see caretEnsemble::setMulticlassExcludedLevel for more details)"
))
}
val
}


#####################################################
# Misc. Functions
Expand Down Expand Up @@ -193,6 +246,27 @@ check_bestpreds_preds <- function(modelLibrary) {
return(invisible(NULL))
}

#' @title Check multiclass excluded level
#' @description Verifies that the multiclass excluded level is
#' within the range of the number of classes.
#'
#' @param excluded_level the level to exclude
#' @param num_classes the number of classes
check_multiclass_excluded_level <- function(excluded_level, num_classes) {
if (excluded_level < 1 || excluded_level > num_classes) {
warning(paste0(
"The excluded level must be between 1 and the number of classes (",
num_classes,
"). ",
"Provided value was ",
excluded_level,
". ",
"\nThis value can be changed using setMulticlassExcludedLevel(). ",
"\nAttempting to train a model with all classes included."
))
}
}

#####################################################
# caretEnsemble check functions
#####################################################
Expand All @@ -215,7 +289,7 @@ check_binary_classification <- function(list_of_models) {
# Extraction functions
#####################################################
#' @title Extract the method name associated with a single train object
#' @description Extracts the method name associated with a single train object. Note
#' @description Extracts the method name associated with a single train object. Note
#' that for standard models (i.e. those already prespecified by caret), the
#' "method" attribute on the train object is used directly while for custom
#' models the "method" attribute within the model$modelInfo attribute is
Expand Down Expand Up @@ -321,10 +395,14 @@ makePredObsMatrix <- function(list_of_models) {
# The names of the columns of the final matrix will consist of a
# concatenation of the model name and the class name for
# which the probability is provided.
# Remove at least one class to avoid colineality problems
num_classes <- length(levels(list_of_models[[1]]$pred$obs))
# remove at least one class to avoid colineality problems
# TODO: let the user choose which class to remove
classes_included <- levels(list_of_models[[1]]$pred$obs)[-num_classes] # remove last class
check_multiclass_excluded_level(getMulticlassExcludedLevel(), num_classes)
if (getMulticlassExcludedLevel() >= 1 && getMulticlassExcludedLevel() <= num_classes) {
classes_included <- levels(list_of_models[[1]]$pred$obs)[-getMulticlassExcludedLevel()]
} else {
classes_included <- levels(list_of_models[[1]]$pred$obs)
}
class_model_combinations <- expand.grid(classes_included, names(modelLibrary))
old_column_names <- apply(class_model_combinations, 1, function(x) paste(x[1], x[2], sep = "_"))
column_names <- apply(class_model_combinations, 1, function(x) paste(x[2], x[1], sep = "_"))
Expand Down
39 changes: 38 additions & 1 deletion tests/testthat/test-helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ test_that("Checks generate errors", {
expect_error(check_caretList_model_types(x))
})

test_that("Check Binary Classification for caretEnsemble work", {
context("Test helper functions for multiclass classification")

test_that("Check errors in caretEnsemble for multiclass classification work", {
skip_on_cran()
skip_if_not_installed("rpart")
data(iris)
Expand All @@ -216,3 +218,38 @@ test_that("Check Binary Classification for caretEnsemble work", {
expect_true(is.null(check_binary_classification(list("string"))))
expect_true(is.null(check_binary_classification(iris)))
})

test_that("Configuration function for excluded level work", {
expect_warning(check_multiclass_excluded_level(4, 3))
expect_warning(check_multiclass_excluded_level(0, 3))
expect_true(is.null(check_multiclass_excluded_level(3, 3)))
expect_true(is.null(check_multiclass_excluded_level(1, 3)))

data(iris)
myControl <- trainControl(
method = "cv", number = 5,
savePredictions = "final", index = createResample(iris[, 5], 5),
classProbs = TRUE
)
model_list <- caretList(
x = iris[, -5],
y = iris[, 5],
methodList = c("rpart", "glmnet"),
trControl = myControl
)

setMulticlassExcludedLevel(0)
expect_warning(caretStack(model_list, method = "knn"))
setMulticlassExcludedLevel(4)
expect_warning(caretStack(model_list, method = "knn"))

# Check if we are actually excluding level 1 (setosa)
setMulticlassExcludedLevel(1)
classes <- levels(iris[, 5])[-1]
models <- c("rpart", "glmnet")
class_model_combinations <- expand.grid(classes, models)
varImp_rownames <- apply(class_model_combinations, 1, function(x) paste(x[2], x[1], sep = "_"))

model_stack <- caretStack(model_list, method = "knn")
expect_identical(rownames(varImp(model_stack$ens_model)$importance), varImp_rownames)
})
4 changes: 4 additions & 0 deletions tests/testthat/test-multiclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ test_that("We can make a confusion matrix", {
expect_true(cm$overall["Accuracy"] > 0.9)
})

#############################################################################
context("caretEnsemble not avaible for multiclass problems")
#############################################################################

test_that("Multiclass is not supported for caretEnsemble", {
data(iris)
data(models.class)
Expand Down

0 comments on commit cd5250d

Please sign in to comment.