diff --git a/.gitignore b/.gitignore index b497da22..4cb362ad 100644 --- a/.gitignore +++ b/.gitignore @@ -49,9 +49,6 @@ local.properties *.vspscc .builds *.dotCover -# Auxiliar files -checkList.md -caretEnsembleTests.ipynb ## TODO: If you have NuGet Package Restore enabled, uncomment this #packages/ # Visual C++ cache files diff --git a/R/caretStack.R b/R/caretStack.R index d5d98f6e..45088f9c 100644 --- a/R/caretStack.R +++ b/R/caretStack.R @@ -75,11 +75,13 @@ predict.caretStack <- function( preds <- predict(object$models, newdata = newdata, na.action = na.action) if (type == "Classification") { - # Do not include last class asociated probabilities column_names <- colnames(preds) - # TODO: let the user specify which classes to include or which to exclude num_classes <- length(levels(object$models[[1]]$pred$obs)) - classes_included <- levels(object$models[[1]]$pred$obs)[-num_classes] # exclude last class + if (getMulticlassExcludedLevel() >= 1 && getMulticlassExcludedLevel() <= num_classes) { + classes_included <- levels(object$models[[1]]$pred$obs)[-getMulticlassExcludedLevel()] + } else { + classes_included <- levels(object$models[[1]]$pred$obs) + } pattern <- paste(classes_included, collapse = "|") # Remove columns that are associated with the class that was excluded filtered_column_names <- grep(pattern, column_names, value = TRUE) diff --git a/R/helper_functions.R b/R/helper_functions.R index 3f4d9e7a..ad0eed95 100644 --- a/R/helper_functions.R +++ b/R/helper_functions.R @@ -37,7 +37,7 @@ setBinaryTargetLevel <- function(level) { #' @param arg argument to potentially be used as new target level #' @return Binary target level (as integer equal to 1 or 2) validateBinaryTargetLevel <- function(arg) { - val <- suppressWarnings(try(as.integer(arg), silent = T)) + val <- suppressWarnings(try(as.integer(arg), silent = TRUE)) if (!is.integer(val) || !val %in% c(1L, 2L)) { stop(paste0( "Specified target binary class level is not valid. ", @@ -48,6 +48,59 @@ validateBinaryTargetLevel <- function(arg) { val } +#' @title Return the configured multiclass excluded level +#' @description To train a model using probability outputs +#' provided by other models in a classification problem, it is +#' necessary to exclude one of the classes. By default, this class +#' is assumed to be the first level in an outcome factor, +#' but this setting can be overridden using +#' \code{setMulticlassTargetLevel(3L)} if the classification +#' problem has at least 3 classes. +#' @seealso setMulticlassTargetLevel +#' @return Currently configured multiclass excluded level (as integer) +#' @export +getMulticlassExcludedLevel <- function() { + arg <- getOption("caret.ensemble.multiclass.excluded.level", default = 1L) + validateMulticlassExcludedLevel(arg) +} + +#' @title Set the multiclass excluded level +#' @description To train a model using probability outputs +#' provided by other models in a classification problem, it is +#' necessary to exclude one of the classes. By default, this class +#' is assumed to be the first level in an outcome factor, +#' but this setting can be overridden using +#' \code{setMulticlassTargetLevel(3L)} if the classification +#' problem has at least 3 classes. +#' @note Setting this value outside the range between 1 and +#' the number of classes will cause caretStack to train the model +#' with the probabilities associated with ALL classes, leading to +#' potential collinearity issues. +#' @param level an integer to be used as excluded +#' @seealso getMulticlassExcludedLevel +#' @export +setMulticlassExcludedLevel <- function(level) { + level <- validateMulticlassExcludedLevel(level) + options(caret.ensemble.multiclass.excluded.level = level) +} + +#' @title Validate arguments given as multiclass excluded level +#' @description Helper function used to ensure that excluded +#' multiclass levels given by clients can be coerced to an integer. +#' @param arg argument to potentially be used as new excluded level +#' @return Multiclass excluded level (as integer) +validateMulticlassExcludedLevel <- function(arg) { + val <- suppressWarnings(try(as.integer(arg), silent = TRUE)) + if (!is.integer(val)) { + stop(paste0( + "Specified multiclass excluded level is not valid. ", + "Value should be a integer but '", arg, "' was given ", + "(see caretEnsemble::setMulticlassExcludedLevel for more details)" + )) + } + val +} + ##################################################### # Misc. Functions @@ -193,6 +246,27 @@ check_bestpreds_preds <- function(modelLibrary) { return(invisible(NULL)) } +#' @title Check multiclass excluded level +#' @description Verifies that the multiclass excluded level is +#' within the range of the number of classes. +#' +#' @param excluded_level the level to exclude +#' @param num_classes the number of classes +check_multiclass_excluded_level <- function(excluded_level, num_classes) { + if (excluded_level < 1 || excluded_level > num_classes) { + warning(paste0( + "The excluded level must be between 1 and the number of classes (", + num_classes, + "). ", + "Provided value was ", + excluded_level, + ". ", + "\nThis value can be changed using setMulticlassExcludedLevel(). ", + "\nAttempting to train a model with all classes included." + )) + } +} + ##################################################### # caretEnsemble check functions ##################################################### @@ -215,7 +289,7 @@ check_binary_classification <- function(list_of_models) { # Extraction functions ##################################################### #' @title Extract the method name associated with a single train object -#' @description Extracts the method name associated with a single train object. Note +#' @description Extracts the method name associated with a single train object. Note #' that for standard models (i.e. those already prespecified by caret), the #' "method" attribute on the train object is used directly while for custom #' models the "method" attribute within the model$modelInfo attribute is @@ -321,10 +395,14 @@ makePredObsMatrix <- function(list_of_models) { # The names of the columns of the final matrix will consist of a # concatenation of the model name and the class name for # which the probability is provided. + # Remove at least one class to avoid colineality problems num_classes <- length(levels(list_of_models[[1]]$pred$obs)) - # remove at least one class to avoid colineality problems - # TODO: let the user choose which class to remove - classes_included <- levels(list_of_models[[1]]$pred$obs)[-num_classes] # remove last class + check_multiclass_excluded_level(getMulticlassExcludedLevel(), num_classes) + if (getMulticlassExcludedLevel() >= 1 && getMulticlassExcludedLevel() <= num_classes) { + classes_included <- levels(list_of_models[[1]]$pred$obs)[-getMulticlassExcludedLevel()] + } else { + classes_included <- levels(list_of_models[[1]]$pred$obs) + } class_model_combinations <- expand.grid(classes_included, names(modelLibrary)) old_column_names <- apply(class_model_combinations, 1, function(x) paste(x[1], x[2], sep = "_")) column_names <- apply(class_model_combinations, 1, function(x) paste(x[2], x[1], sep = "_")) diff --git a/tests/testthat/test-helper_functions.R b/tests/testthat/test-helper_functions.R index f855e53b..b34a0cd8 100644 --- a/tests/testthat/test-helper_functions.R +++ b/tests/testthat/test-helper_functions.R @@ -194,7 +194,9 @@ test_that("Checks generate errors", { expect_error(check_caretList_model_types(x)) }) -test_that("Check Binary Classification for caretEnsemble work", { +context("Test helper functions for multiclass classification") + +test_that("Check errors in caretEnsemble for multiclass classification work", { skip_on_cran() skip_if_not_installed("rpart") data(iris) @@ -216,3 +218,38 @@ test_that("Check Binary Classification for caretEnsemble work", { expect_true(is.null(check_binary_classification(list("string")))) expect_true(is.null(check_binary_classification(iris))) }) + +test_that("Configuration function for excluded level work", { + expect_warning(check_multiclass_excluded_level(4, 3)) + expect_warning(check_multiclass_excluded_level(0, 3)) + expect_true(is.null(check_multiclass_excluded_level(3, 3))) + expect_true(is.null(check_multiclass_excluded_level(1, 3))) + + data(iris) + myControl <- trainControl( + method = "cv", number = 5, + savePredictions = "final", index = createResample(iris[, 5], 5), + classProbs = TRUE + ) + model_list <- caretList( + x = iris[, -5], + y = iris[, 5], + methodList = c("rpart", "glmnet"), + trControl = myControl + ) + + setMulticlassExcludedLevel(0) + expect_warning(caretStack(model_list, method = "knn")) + setMulticlassExcludedLevel(4) + expect_warning(caretStack(model_list, method = "knn")) + + # Check if we are actually excluding level 1 (setosa) + setMulticlassExcludedLevel(1) + classes <- levels(iris[, 5])[-1] + models <- c("rpart", "glmnet") + class_model_combinations <- expand.grid(classes, models) + varImp_rownames <- apply(class_model_combinations, 1, function(x) paste(x[2], x[1], sep = "_")) + + model_stack <- caretStack(model_list, method = "knn") + expect_identical(rownames(varImp(model_stack$ens_model)$importance), varImp_rownames) +}) diff --git a/tests/testthat/test-multiclass.R b/tests/testthat/test-multiclass.R index 79b7a604..d38b70c8 100644 --- a/tests/testthat/test-multiclass.R +++ b/tests/testthat/test-multiclass.R @@ -184,6 +184,10 @@ test_that("We can make a confusion matrix", { expect_true(cm$overall["Accuracy"] > 0.9) }) +############################################################################# +context("caretEnsemble not avaible for multiclass problems") +############################################################################# + test_that("Multiclass is not supported for caretEnsemble", { data(iris) data(models.class)