diff --git a/Makefile b/Makefile index 6513f3ea..52d08fb2 100644 --- a/Makefile +++ b/Makefile @@ -3,21 +3,21 @@ .PHONY: help help: @echo "Available targets:" - @echo " all - Run clean, fix-style, document, install, build-vignettes, lint, spell, test, check, coverage" - @echo " install-deps - Install dependencies" - @echo " install - Install the whole package, including dependencies" - @echo " document - Generate documentation" - @echo " update-test-fixtures - Update test fixtures" - @echo " test - Run unit tests" - @echo " coverage - Generate coverage reports" - @echo " check - Run R CMD check as CRAN" - @echo " fix-style - Auto style the code" - @echo " lint - Check the code for lint" - @echo " spell - Check spelling" - @echo " build - Build the package" - @echo " build-vignettes - Build vignettes" - @echo " release - Release to CRAN" - @echo " clean - Clean up generated files" + @echo " all Run clean, fix-style, document, install, build-vignettes, lint, spell, test, check, coverage" + @echo " install-deps Install dependencies" + @echo " install Install the whole package, including dependencies" + @echo " document Generate documentation" + @echo " update-test-fixtures Update test fixtures" + @echo " test Run unit tests" + @echo " coverage Generate coverage reports" + @echo " check Run R CMD check as CRAN" + @echo " fix-style Auto style the code" + @echo " lint Check the code for lint" + @echo " spell Check spelling" + @echo " build Build the package" + @echo " build-vignettes Build vignettes" + @echo " release Release to CRAN" + @echo " clean Clean up generated files" .PHONY: all all: clean fix-style document install build-vignettes lint spell test check coverage diff --git a/coverage.rds b/coverage.rds index e4277563..247b11a1 100644 Binary files a/coverage.rds and b/coverage.rds differ diff --git a/tests/testthat/test-S3-generic-extensions.R b/tests/testthat/test-S3-generic-extensions.R deleted file mode 100644 index 1a6f3aee..00000000 --- a/tests/testthat/test-S3-generic-extensions.R +++ /dev/null @@ -1,203 +0,0 @@ -# test-S3-generic-extensions -set.seed(107L) - -data(models.class) -data(X.class) -data(Y.class) -data(models.reg) -data(X.reg) -data(Y.reg) - -data(Sonar, package = "mlbench") - -# A common control to use for both test fixtures -my_control <- caret::trainControl( - method = "cv", - number = 2L, - savePredictions = "final", - summaryFunction = caret::twoClassSummary, - classProbs = TRUE, - verboseIter = FALSE, - index = caret::createResample(Sonar$Class, 3L) -) - -# a model of class caretList -model_list1 <- caretList( - Class ~ ., - data = Sonar, - trControl = my_control, - tuneList = list( - glm = caretModelSpec(method = "rf"), - rpart = caretModelSpec(method = "rpart") - ) -) - -# a model of class train -rpart_model <- caret::train( - Class ~ ., - data = Sonar, - tuneLength = 2L, - metric = "ROC", - trControl = my_control, - method = "rpart" -) - -############################################### -testthat::context("Ancillary caretList S3 Generic Functions Extensions") -################################################ -testthat::test_that("c.caretEnsemble can bind two caretList objects", { - model_list2 <- caretList( - Class ~ ., - data = Sonar, - trControl = my_control, - tuneList = list( - glm = caretModelSpec(method = "rpart", tuneLength = 2L), - rpart = caretModelSpec(method = "rf", tuneLength = 2L) - ) - ) - - bigList <- c(model_list1, model_list2) - ens1 <- caretEnsemble(bigList) - - testthat::expect_is(bigList, "caretList") - testthat::expect_is(ens1, "caretEnsemble") - testthat::expect_identical(anyDuplicated(names(bigList)), 0L) - testthat::expect_length(unique(names(bigList)), 4L) -}) - -testthat::test_that("c.caretEnsemble can bind a caretList and train object", { - bigList <- c(model_list1, rpart_model) - ens1 <- caretEnsemble(bigList) - - testthat::expect_is(bigList, "caretList") - testthat::expect_is(ens1, "caretEnsemble") - testthat::expect_identical(anyDuplicated(names(bigList)), 0L) - testthat::expect_length(unique(names(bigList)), 3L) -}) - -testthat::test_that("c.caretList can bind two objects of class train", { - bigList <- c(rpart_model, rpart_model) - ens1 <- caretEnsemble(bigList) - - testthat::expect_is(bigList, "caretList") - testthat::expect_is(ens1, "caretEnsemble") - - testthat::expect_identical(anyDuplicated(names(bigList)), 0L) - testthat::expect_length(unique(names(bigList)), 2L) -}) - -testthat::test_that("c.caretList stops for invalid class", { - testthat::expect_error(c.caretList(list()), "class of modelList1 must be 'caretList' or 'train'") -}) - -testthat::test_that("c.train stops for invalid class", { - testthat::expect_error(c.train(list()), "class of modelList1 must be 'caretList' or 'train'") -}) - -############################################### -testthat::context("Edge cases for caretList S3 Generic Functions Extensions") -################################################ - -testthat::test_that("c.caretList combines caretList objects correctly", { - # Split models.class into two parts - models_class1 <- models.class[1L:2L] - models_class2 <- models.class[3L:4L] - class(models_class1) <- class(models_class2) <- "caretList" - - combined_models <- c(models_class1, models_class2) - - testthat::expect_s3_class(combined_models, "caretList") - testthat::expect_length(combined_models, length(models.class)) - testthat::expect_true(all(names(combined_models) %in% names(models.class))) -}) - -testthat::test_that("c.caretList combines caretList and train objects correctly", { - models_class1 <- models.class[1L:2L] - class(models_class1) <- "caretList" - single_model <- models.class[[3L]] - - combined_models <- c(models_class1, single_model) - - testthat::expect_s3_class(combined_models, "caretList") - testthat::expect_length(combined_models, 3L) - testthat::expect_true(all(names(combined_models) %in% names(models.class))) -}) - -testthat::test_that("c.train combines train objects correctly", { - model1 <- models.class[[1L]] - model2 <- models.class[[2L]] - - combined_models <- c(model1, model2) - - testthat::expect_s3_class(combined_models, "caretList") - testthat::expect_length(combined_models, 2L) - testthat::expect_true(all(names(combined_models) %in% names(models.class)[1L:2L])) -}) - -testthat::test_that("c.caretList handles duplicate names", { - models_class1 <- models.class[1L:2L] - class(models_class1) <- "caretList" - - combined_models <- c(models_class1, models_class1) - - testthat::expect_s3_class(combined_models, "caretList") - testthat::expect_length(combined_models, 4L) - testthat::expect_true(all(make.names(rep(names(models_class1), 2L), unique = TRUE) %in% names(combined_models))) -}) - -testthat::test_that("c.caretList and c.train fail for invalid inputs", { - testthat::expect_error(c.caretList(list(a = 1L, b = 2L)), "class of modelList1 must be 'caretList' or 'train'") - testthat::expect_error(c.train(list(a = 1L, b = 2L)), "class of modelList1 must be 'caretList' or 'train'") -}) - -testthat::test_that("[.caretList subsets caretList objects correctly", { - subset_models <- models.class[1L:2L] - - testthat::expect_s3_class(subset_models, "caretList") - testthat::expect_length(subset_models, 2L) - testthat::expect_true(all(names(subset_models) %in% names(models.class)[1L:2L])) -}) - -testthat::test_that("as.caretList.list converts list to caretList", { - model_list <- list(model1 = models.class[[1L]], model2 = models.class[[2L]]) - caretlist_object <- as.caretList(model_list) - - testthat::expect_s3_class(caretlist_object, "caretList") - testthat::expect_length(caretlist_object, 2L) - testthat::expect_true(all(names(caretlist_object) %in% names(model_list))) -}) - -testthat::test_that("as.caretList.list fails for invalid inputs", { - testthat::expect_error(as.caretList(list(a = 1L, b = 2L)), "object requires all elements of list to be caret models") -}) - -testthat::test_that("as.caretList.list names lists without names", { - models.no.name <- models.class - names(models.no.name) <- NULL - class(models.no.name) <- "list" - testthat::expect_null(names(models.no.name)) - cl <- as.caretList(models.no.name) - testthat::expect_named(cl, unname(vapply(models.class, "[[", character(1L), "method"))) -}) -testthat::test_that("as.caretList fails on non-list", { - testthat::expect_error(as.caretList(1L), "object must be a list") -}) - -testthat::test_that("predict.caretList works for classification and regression", { - class_preds <- predict(models.class, newdata = X.class, excluded_class_id = 0L) - reg_preds <- predict(models.reg, newdata = X.reg) - - testthat::expect_is(class_preds, "data.table") - testthat::expect_is(reg_preds, "data.table") - testthat::expect_identical(nrow(class_preds), nrow(X.class)) - testthat::expect_identical(nrow(reg_preds), nrow(X.reg)) - testthat::expect_identical(ncol(class_preds), length(models.class) * 2L) - testthat::expect_identical(ncol(reg_preds), length(models.reg)) -}) - -testthat::test_that("predict.caretList handles type='prob' for classification", { - class_probs <- predict(models.class, newdata = X.class, excluded_class_id = 0L) - testthat::expect_is(class_probs, "data.table") - testthat::expect_identical(nrow(class_probs), nrow(X.class)) - testthat::expect_identical(ncol(class_probs), length(models.class) * nlevels(Y.class)) -}) diff --git a/tests/testthat/test-caretEnsemble.R b/tests/testthat/test-caretEnsemble.R index 0390a260..4a661551 100644 --- a/tests/testthat/test-caretEnsemble.R +++ b/tests/testthat/test-caretEnsemble.R @@ -1,66 +1,105 @@ -# Are tests failing here? -# UPDATE THE FIXTURES! -# make update-test-fixtures - -data(models.reg) -data(X.reg) -data(Y.reg) - -data(models.class) -data(X.class) -data(Y.class) - +# Load required data +utils::data(models.reg) +utils::data(X.reg) +utils::data(Y.reg) +utils::data(models.class) +utils::data(X.class) +utils::data(Y.class) +utils::data(Sonar, package = "mlbench") + +# Set up test environment set.seed(1234L) +k <- 2L ens.reg <- caretEnsemble( models.reg, - trControl = caret::`trainControl`(method = "cv", number = 2L, savePredictions = "final") + trControl = caret::trainControl( + method = "cv", + number = k, + index = caret::createFolds(Y.reg, k = k), + savePredictions = "final" + ) +) + +ens.class <- caretEnsemble( + models.class, + metric = "ROC", + trControl = caret::trainControl( + method = "cv", + number = k, + index = caret::createFolds(Y.class, k = k), + summaryFunction = caret::twoClassSummary, + classProbs = TRUE, + savePredictions = TRUE + ) ) -############################################################################# -testthat::context("Test metric and residual extraction") -############################################################################# +# Helper function for prediction tests +test_predictions <- function(ens, newdata, one_row_preds) { + is_class <- isClassifier(ens) + N <- nrow(newdata) + + pred_stacked <- predict(ens) + pred <- predict(ens, newdata = newdata) + pred_se <- predict(ens.reg, newdata = X.reg, se = TRUE) + pred_one <- predict(ens, newdata = newdata[1L, , drop = FALSE]) + + testthat::expect_s3_class(pred_stacked, "data.table") + testthat::expect_s3_class(pred, "data.table") + testthat::expect_s3_class(pred_se, "data.table") + testthat::expect_s3_class(pred_one, "data.table") + + testthat::expect_identical(nrow(pred_stacked), N) + testthat::expect_identical(nrow(pred), N) + testthat::expect_identical(nrow(pred_se), N) + testthat::expect_identical(nrow(pred_one), 1L) + + testthat::expect_equal(pred_stacked, pred, tol = ifelse(is_class, 0.35, 0.05)) + + if (is_class) { + testthat::expect_identical(ncol(pred_stacked), 2L) + testthat::expect_identical(ncol(pred), 2L) + testthat::expect_identical(ncol(pred_one), 2L) + testthat::expect_equivalent(pred_one$Yes, one_row_preds[1L], tol = 0.05) + testthat::expect_equivalent(pred_one$No, one_row_preds[2L], tol = 0.05) + } else { + testthat::expect_equivalent(pred_one$pred, one_row_preds[1L], tol = 0.05) + } +} -testthat::test_that("We can extract resdiuals from train regression objects", { +###################################################################### +testthat::context("Metric and residual extraction") +###################################################################### + +testthat::test_that("We can extract residuals from train regression objects", { data(iris) - mod <- caret::train( - iris[, 1L:2L], iris[, 3L], - method = "lm" - ) + mod <- caret::train(iris[, 1L:2L], iris[, 3L], method = "lm") r <- stats::residuals(mod) testthat::expect_is(r, "numeric") testthat::expect_length(r, 150L) }) -############################################################################# -testthat::context("Does ensembling and prediction work?") -############################################################################# +###################################################################### +testthat::context("Ensembling and prediction") +###################################################################### testthat::test_that("We can ensemble regression models", { testthat::expect_s3_class(ens.reg, "caretEnsemble") - pred.reg <- predict(ens.reg, newdata = X.reg) - pred.reg2 <- predict(ens.reg, newdata = X.reg, se = TRUE) - - testthat::expect_true(all(pred.reg == pred.reg2$pred)) - + test_predictions(ens.reg, X.reg, 5.04) +}) - testthat::expect_s3_class(pred.reg, "data.table") - testthat::expect_identical(nrow(pred.reg), 150L) - ens.class <- caretEnsemble(models.class) +testthat::test_that("We can ensemble classification models", { testthat::expect_s3_class(ens.class, "caretEnsemble") - pred.class <- predict(ens.class, newdata = X.class) - testthat::expect_s3_class(pred.class, "data.table") - testthat::expect_identical(nrow(pred.class), 150L) + test_predictions(ens.class, X.class, c(0.02, 0.98)) }) -############################################################################# -testthat::context("Does ensembling work with models with differing predictors") -############################################################################# +###################################################################### +testthat::context("Ensembling with models of differing predictors") +###################################################################### testthat::test_that("We can ensemble models of different predictors", { data(iris) Y.reg <- iris[, 1L] X.reg <- model.matrix(~., iris[, -1L]) - mseeds <- vector(mode = "list", length = 12L) my_control <- caret::trainControl( method = "cv", number = 2L, p = 0.75, @@ -77,122 +116,45 @@ testthat::test_that("We can ensemble models of different predictors", { ) nestedList <- as.caretList(nestedList) - # Can we predict from the list pred_list <- predict(nestedList, newdata = X.reg) testthat::expect_s3_class(pred_list, "data.table") testthat::expect_identical(nrow(pred_list), 150L) testthat::expect_identical(ncol(pred_list), length(nestedList)) - # Can we predict from the ensemble ensNest <- caretEnsemble(nestedList) testthat::expect_s3_class(ensNest, "caretEnsemble") pred.nest <- predict(ensNest, newdata = X.reg) testthat::expect_s3_class(pred.nest, "data.table") testthat::expect_identical(nrow(pred.nest), 150L) - # Ensemble errors on NAs X_reg_new <- X.reg X_reg_new[2L, 3L] <- NA - expect_error( + testthat::expect_error( predict(ensNest, newdata = X_reg_new), "is.finite(newdata) are not all TRUE", fixed = TRUE ) }) -testthat::context("Does ensemble prediction work with new data") - -testthat::test_that("caretEnsemble works for regression models", { - set.seed(1234L) - testthat::expect_is(ens.reg, "caretEnsemble") - - # Predictions - pred_stacked <- predict(ens.reg) # stacked predictions - pred_in_sample <- predict(ens.reg, newdata = X.reg) # in sample predictions - pred_one <- predict(ens.reg, newdata = X.reg[2L, , drop = FALSE]) # one row predictions - - # Check class - testthat::expect_s3_class(pred_stacked, "data.table") - testthat::expect_s3_class(pred_in_sample, "data.table") - testthat::expect_s3_class(pred_one, "data.table") - - # Check len - testthat::expect_identical(nrow(pred_stacked), 150L) - testthat::expect_identical(nrow(pred_in_sample), 150L) - testthat::expect_identical(nrow(pred_one), 1L) - - # stacked predcitons should be similar to in sample predictions - testthat::expect_equal(pred_stacked, pred_in_sample, tol = 0.1) - - # One row predictions - testthat::expect_equivalent(pred_one$pred, 4.712639, tol = 0.05) -}) - -testthat::test_that("caretEnsemble works for classification models", { - set.seed(1234L) - ens.class <- caretEnsemble( - models.class, - trControl = caret::trainControl( - method = "cv", - number = 10L, - savePredictions = "final", - classProbs = TRUE - ) - ) - testthat::expect_s3_class(ens.class, "caretEnsemble") - ens.class$ens_model$finalModel - - # Predictions - pred_stacked <- predict(ens.class) # stacked predictions - pred_in_sample <- predict(ens.class, newdata = X.class) # in sample predictions - pred_one <- predict(ens.class, newdata = X.class[2L, , drop = FALSE]) # one row predictions - - # Check class - testthat::expect_s3_class(pred_stacked, "data.table") - testthat::expect_s3_class(pred_in_sample, "data.table") - testthat::expect_s3_class(pred_one, "data.table") - - # Check rows - testthat::expect_identical(nrow(pred_stacked), 150L) - testthat::expect_identical(nrow(pred_in_sample), 150L) - testthat::expect_identical(nrow(pred_one), 1L) - - # Check cols - testthat::expect_identical(ncol(pred_stacked), 2L) - testthat::expect_identical(ncol(pred_in_sample), 2L) - testthat::expect_identical(ncol(pred_one), 2L) - - # stacked predcitons should be similar to in sample predictions - testthat::expect_equal(pred_stacked, pred_in_sample, tol = 0.1) - - # One row predictions - testthat::expect_equivalent(pred_one$Yes, 0.02, tol = 0.05) - testthat::expect_equivalent(pred_one$No, 0.98, tol = 0.05) -}) - -testthat::context("Do ensembles of custom models work?") +###################################################################### +testthat::context("Ensembles with custom models") +###################################################################### testthat::test_that("Ensembles using custom models work correctly", { set.seed(1234L) - # Create custom caret models with a properly assigned method attribute custom.rf <- getModelInfo("rf", regex = FALSE)[[1L]] custom.rf$method <- "custom.rf" custom.rpart <- getModelInfo("rpart", regex = FALSE)[[1L]] custom.rpart$method <- "custom.rpart" - # Define models to be used in ensemble - # Add an unnamed model to ensure that method names are extracted from model info - # Add a named custom model, to contrast the above - # Add a non-custom model tune.list <- list( caretModelSpec(method = custom.rf, tuneLength = 1L), myrpart = caretModelSpec(method = custom.rpart, tuneLength = 1L), treebag = caretModelSpec(method = "treebag", tuneLength = 1L) ) - # Create an ensemble using the above models cl <- caretList(X.class, Y.class, tuneList = tune.list) cs <- caretEnsemble( cl, @@ -204,46 +166,18 @@ testthat::test_that("Ensembles using custom models work correctly", { ) ) testthat::expect_is(cs, "caretEnsemble") - - # Validate names assigned to ensembled models testthat::expect_named(cs$models, c("custom.rf", "myrpart", "treebag")) - # Validate ensemble predictions - pred_stacked <- predict(cs) # stacked predictions - pred_in_sample <- predict(cs, newdata = X.class) # in sample predictions - pred_one <- predict(cs, newdata = X.class[2L, , drop = FALSE]) # one row predictions - - # Check class - testthat::expect_s3_class(pred_stacked, "data.table") - testthat::expect_s3_class(pred_in_sample, "data.table") - testthat::expect_s3_class(pred_one, "data.table") - - # Check rows - testthat::expect_identical(nrow(pred_stacked), 150L) - testthat::expect_identical(nrow(pred_in_sample), 150L) - testthat::expect_identical(nrow(pred_one), 1L) - - # Check cols - testthat::expect_identical(ncol(pred_stacked), 2L) - testthat::expect_identical(ncol(pred_in_sample), 2L) - testthat::expect_identical(ncol(pred_one), 2L) - - # stacked predcitons should be similar to in sample predictions - # These differ a lot! - testthat::expect_equal(pred_stacked, pred_in_sample, tol = 0.4) + test_predictions(cs, X.class, c(0.0198, 0.9802)) - # One row predictions - testthat::expect_equivalent(pred_one$Yes, 0.07557944, tol = 0.1) - testthat::expect_equivalent(pred_one$No, 0.9244206, tol = 0.1) - - # Verify that not specifying a method attribute for custom models causes an error - # Add a custom caret model WITHOUT a properly assigned method attribute - tune.list <- list( + tune.list_bad <- list( caretModelSpec(method = getModelInfo("rf", regex = FALSE)[[1L]], tuneLength = 1L), treebag = caretModelSpec(method = "treebag", tuneLength = 1L) ) - msg <- "Custom models must be defined with a \"method\" attribute" - testthat::expect_error(caretList(X.class, Y.class, tuneList = tune.list, trControl = train.control), regexp = msg) + testthat::expect_error( + caretList(X.class, Y.class, tuneList = tune.list_bad, trControl = train.control), + "Custom models must be defined with a \"method\" attribute" + ) }) testthat::test_that("Ensembles fails if predictions are not saved", { @@ -254,3 +188,126 @@ testthat::test_that("Ensembles fails if predictions are not saved", { "No predictions saved during training. Please set savePredictions = 'final' in trainControl" ) }) + +###################################################################### +testthat::context("Variable importance and plotting") +###################################################################### + +testthat::test_that("caret::varImp.caretEnsemble works", { + set.seed(2239L) + + for (m in list(ens.class, ens.reg)) { + for (s in c(TRUE, FALSE)) { + i <- caret::varImp(m, normalize = s) + testthat::expect_is(i, "numeric") + if (isClassifier(m)) { + len <- length(m$models) * 2L + n <- c(outer(c("rf", "glm", "rpart", "treebag"), c("No", "Yes"), paste, sep = "_")) + n <- matrix(n, ncol = 2L) + n <- c(t(n)) + } else { + len <- length(m$models) + n <- names(m$models) + } + testthat::expect_length(i, len) + testthat::expect_named(i, n) + if (s) { + testthat::expect_true(all(i >= 0.0)) + testthat::expect_true(all(i <= 1.0)) + testthat::expect_equal(sum(i), 1.0, tolerance = 1e-6) + } + } + } +}) + +testthat::test_that("plot.caretEnsemble works", { + for (ens in list(ens.class, ens.reg)) { + plt <- plot(ens) + testthat::expect_is(plt, "ggplot") + testthat::expect_identical(nrow(plt$data), 5L) + testthat::expect_named(ens$models, plt$data$model_name[-1L]) + } +}) + +testthat::test_that("ggplot2::autoplot.caretEnsemble works", { + for (ens in list(ens.class, ens.reg)) { + plt1 <- ggplot2::autoplot(ens) + plt2 <- ggplot2::autoplot(ens, xvars = c("Petal.Length", "Petal.Width")) + + testthat::expect_is(plt1, "ggplot") + testthat::expect_is(plt2, "ggplot") + testthat::expect_is(plt1, "patchwork") + testthat::expect_is(plt2, "patchwork") + + train_model <- ens.reg$models[[1L]] + testthat::expect_error(ggplot2::autoplot(train_model), "Objects of class (.*?) are not supported by autoplot") + } +}) + +testthat::test_that("summary.caretEnsemble works", { + for (ens in list(ens.class, ens.reg)) { + smry <- testthat::expect_silent(summary(ens.class)) + testthat::expect_output(print(smry), ens.class$ens_model$metric) + for (name in names(ens.class$models)) { + testthat::expect_output(print(smry), name) + } + } +}) + +testthat::test_that("predict.caretEnsemble works with and without se and weights", { + for (ens in list(ens.class, ens.reg)) { + is_class <- isClassifier(ens) + for (se in c(FALSE, TRUE)) { + p <- predict( + ens, + newdata = X.reg, + se = se, + excluded_class_id = 1L + ) + testthat::expect_s3_class(p, "data.table") + if (se) { + testthat::expect_named(p, c("pred", "lwr", "upr")) + } else { + testthat::expect_named(p, ifelse(is_class, "Yes", "pred")) + } + } + } +}) + +testthat::test_that("We can train and ensemble models with custom tuning lists", { + target <- "Class" + + custom_list <- caretList( + x = Sonar[, setdiff(names(Sonar), target)], + y = Sonar[, target], + tuneList = list( + rpart = caretModelSpec( + method = "rpart", + tuneGrid = data.table::data.table(.cp = c(0.01, 0.001, 0.1, 1.0)) + ), + knn = caretModelSpec( + method = "knn", + tuneLength = 9L + ), + lda = caretModelSpec( + method = "lda2", + tuneLength = 1L + ), + nnet = caretModelSpec( + method = "nnet", + tuneLength = 2L, + trace = FALSE, + softmax = FALSE + ) + ) + ) + testthat::expect_is(custom_list, "caretList") + testthat::expect_identical(nrow(custom_list[["rpart"]]$results), 4L) + testthat::expect_identical(nrow(custom_list[["knn"]]$results), 9L) + testthat::expect_identical(nrow(custom_list[["lda"]]$results), 1L) + testthat::expect_identical(nrow(custom_list[["nnet"]]$results), 4L) + testthat::expect_false(custom_list[["nnet"]]$finalModel$softmax) + + custom_ensemble <- caretEnsemble(custom_list) + testthat::expect_is(custom_ensemble, "caretEnsemble") +}) diff --git a/tests/testthat/test-caretList.R b/tests/testthat/test-caretList.R index cac7bece..8d2ce95e 100644 --- a/tests/testthat/test-caretList.R +++ b/tests/testthat/test-caretList.R @@ -1,13 +1,13 @@ -# Test caretList +# Setup set.seed(442L) -data(models.reg) -data(X.reg) -data(Y.reg) +utils::data(models.reg) +utils::data(X.reg) +utils::data(Y.reg) -data(models.class) -data(X.class) -data(Y.class) +utils::data(models.class) +utils::data(X.class) +utils::data(Y.class) train <- caret::twoClassSim( n = 1000L, intercept = -8L, linearVars = 3L, @@ -18,627 +18,106 @@ test <- caret::twoClassSim( noiseVars = 10L, corrVars = 4L, corrValue = 0.6 ) -testthat::test_that("caretModelSpec returns valid specs", { - tuneList <- list( - rf1 = caretModelSpec(), - rf2 = caretModelSpec(method = "rf", tuneLength = 5L), - caretModelSpec(method = "rpart"), - caretModelSpec(method = "knn", tuneLength = 10L) - ) - tuneList <- caretEnsemble::tuneCheck(tuneList) - testthat::expect_type(tuneList, "list") - testthat::expect_length(tuneList, 4L) - testthat::expect_identical(sum(duplicated(names(tuneList))), 0L) -}) +n <- 100L +p <- 1000L +large_data <- list( + X = data.table::data.table(matrix(stats::rnorm(n * p), n, p)), + y = factor(sample(c("A", "B"), n, replace = TRUE)) +) +################################################################ +testthat::context("caretModelSpec, tuneCheck, methodCheck") +################################################################ testthat::test_that("caretModelSpec and checking functions work as expected", { all_models <- sort(unique(caret::modelLookup()$model)) - for (model in all_models) { - testthat::expect_identical(caretModelSpec(model, tuneLength = 5L, preProcess = "knnImpute")$method, model) - } - tuneList <- lapply(all_models, function(x) list(method = x, preProcess = "pca")) - all_models_check <- tuneCheck(tuneList) - testthat::expect_is(all_models_check, "list") - testthat::expect_length(all_models, length(all_models_check)) + testthat::expect_identical(caretModelSpec("rf", tuneLength = 5L, preProcess = "knnImpute")$method, "rf") tuneList <- lapply(all_models, function(x) list(method = x, preProcess = "pca")) - names(tuneList) <- all_models - names(tuneList)[c(1L, 5L, 10L)] <- "" all_models_check <- tuneCheck(tuneList) testthat::expect_is(all_models_check, "list") testthat::expect_length(all_models, length(all_models_check)) methodCheck(all_models) - err <- "The following models are not valid caret models: THIS_IS_NOT_A_REAL_MODEL" - testthat::expect_error(methodCheck(c(all_models, "THIS_IS_NOT_A_REAL_MODEL")), err) - testthat::expect_error(methodCheck(c(all_models, "THIS_IS_NOT_A_REAL_MODEL", "GBM"))) -}) - -testthat::test_that("Target extraction functions work", { - data(iris) - testthat::expect_identical(extractCaretTarget(iris[, 1L:4L], iris[, 5L]), iris[, 5L]) - testthat::expect_identical(extractCaretTarget(iris[, 2L:5L], iris[, 1L]), iris[, 1L]) - testthat::expect_identical(extractCaretTarget(Species ~ ., iris), iris[, "Species"]) - testthat::expect_identical(extractCaretTarget(Sepal.Width ~ ., iris), iris[, "Sepal.Width"]) -}) - -testthat::test_that("caretList errors for bad models", { - data(iris) - - # Basic checks - testthat::expect_error(caretList(Sepal.Width ~ ., iris), "Please either define a methodList or tuneList") - testthat::expect_warning( - caretList(Sepal.Width ~ ., iris, methodList = c("lm", "lm")), - "Duplicate entries in methodList. Using unique methodList values." - ) - testthat::expect_is(caretList(Sepal.Width ~ ., iris, methodList = "lm", continue_on_fail = TRUE), "caretList") - - # Check that by default a bad model kills the training job - bad <- list( - bad = caretModelSpec(method = "glm", tuneLength = 1L) - ) - testthat::expect_output( - testthat::expect_warning( - testthat::expect_error( - caretList(iris[, 1L:4L], iris[, 5L], tuneList = bad), - regexp = "Stopping" # Stop training on the first error. This is the mssage straight from train. - ), - regexp = "model fit failed for Fold1" - ), - regexp = "Something is wrong; all the Accuracy metric values are missing:" - ) - testthat::expect_output( - testthat::expect_warning( - testthat::expect_error( - caretList(iris[, 1L:4L], iris[, 5L], tuneList = bad, continue_on_fail = TRUE), - regexp = "caret:train failed for all models. Please inspect your data." - ), - regexp = "model fit failed for Fold1" - ), - regexp = "Something is wrong; all the Accuracy metric values are missing:" - ) - - # Check that at least one good model + continue_on_fail works - good_bad <- list( - good = caretModelSpec(method = "glmnet", tuneLength = 1L), - bad = caretModelSpec(method = "glm", tuneLength = 1L) - ) - testthat::expect_s3_class( - testthat::expect_output( - testthat::expect_warning( - caretList(iris[, 1L:4L], iris[, 5L], tuneList = good_bad, continue_on_fail = TRUE), - regexp = "model fit failed for Fold1" - ), - regexp = "Something is wrong; all the Accuracy metric values are missing:" - ), "caretList" - ) -}) - -testthat::test_that("caretList predictions", { - models <- testthat::expect_warning( - caretList( - iris[, 1L:2L], iris[, 3L], - tuneLength = 1L, verbose = FALSE, - methodList = "rf", tuneList = list(nnet = caretModelSpec(method = "nnet", trace = FALSE)) - ), "There were missing values in resampled performance measures." - ) - - p1 <- predict(models) - p2 <- predict(models, newdata = iris[100L, 1L:2L]) - p3 <- predict(models, newdata = iris[110L, 1L:2L]) - testthat::expect_is(p1, "data.table") - testthat::expect_is(p1[[1L]], "numeric") - testthat::expect_is(p1[[2L]], "numeric") - testthat::expect_named(models, colnames(p1)) - testthat::expect_is(p2, "data.table") - testthat::expect_is(p2[[1L]], "numeric") - testthat::expect_is(p2[[2L]], "numeric") - testthat::expect_named(models, colnames(p2)) - testthat::expect_is(p3, "data.table") - testthat::expect_is(p3[[1L]], "numeric") - testthat::expect_is(p3[[2L]], "numeric") - testthat::expect_named(models, colnames(p3)) - - models <- caretList( - iris[, 1L:2L], iris[, 5L], - tuneLength = 1L, verbose = FALSE, - methodList = "rf", - tuneList = list(nnet = caretModelSpec(method = "nnet", trace = FALSE)) - ) - - p2 <- predict(models, excluded_class_id = 0L) - p3 <- predict(models, newdata = iris[, 1L:2L], excluded_class_id = 0L) - testthat::expect_is(p2, "data.table") - testthat::expect_is(p2[[1L]], "numeric") - testthat::expect_is(p2[[2L]], "numeric") - testthat::expect_is(p2[[3L]], "numeric") - testthat::expect_is(p2[[4L]], "numeric") - testthat::expect_is(p3, "data.table") - testthat::expect_is(p3[[1L]], "numeric") - testthat::expect_is(p3[[2L]], "numeric") - testthat::expect_is(p3[[3L]], "numeric") - testthat::expect_is(p3[[4L]], "numeric") - testthat::expect_identical( - length(names(models)) * nlevels(as.factor(iris[, 5L])), - length(colnames(p3)) - ) # check that we have the right number of columns - testthat::expect_identical(dim(p2), dim(p3)) - testthat::expect_named(p2, names(p3)) - - modelnames <- names(models) - classes <- levels(iris[, 5L]) - combinations <- expand.grid(classes, modelnames) - correct_colnames <- apply(combinations, 1L, function(x) paste(x[2L], x[1L], sep = "_")) - testthat::expect_named( - p3, - correct_colnames - ) # check the column names are correct and ordered correctly (methodname_classname) -}) - -testthat::test_that("as.caretList.list returns a caretList object", { - modelList <- caretList(Sepal.Length ~ Sepal.Width, - head(iris, 50L), - methodList = c("glm", "lm", "knn") + testthat::expect_error( + methodCheck(c(all_models, "THIS_IS_NOT_A_REAL_MODEL")), + "The following models are not valid caret models: THIS_IS_NOT_A_REAL_MODEL" ) - class(modelList) <- "list" - testthat::expect_is(as.caretList(modelList), "caretList") -}) -############################################################# -testthat::context("Bad characters in target variable names and model names") -############################################################# -testthat::test_that("Target variable names with character | are not allowed", { - bad_iris <- iris[1L:100L, ] - bad_iris[, 5L] <- gsub("versicolor", "versicolor|1", bad_iris[, 5L], fixed = TRUE) - bad_iris[, 5L] <- gsub("setosa", "setosa|2", bad_iris[, 5L], fixed = TRUE) - bad_iris[, 5L] <- as.factor(as.character(bad_iris[, 5L])) - - # Expect an error from caret testthat::expect_error( - caretList( - x = bad_iris[, -5L], - y = bad_iris[, 5L], - methodList = c("rpart", "glmnet") - ), "At least one of the class levels is not a valid R variable name; This will cause errors when class prob" + methodCheck("InvalidMethod"), + "The following models are not valid caret models: InvalidMethod" ) -}) -testthat::test_that("Character | in model names is transformed into a point", { - reduced_iris <- iris[1L:100L, ] - reduced_iris[, 5L] <- as.factor(as.character(reduced_iris[, 5L])) - - # Chack that specified model names are transformed with function make.names - model_list <- caretList( - x = reduced_iris[, -5L], - y = reduced_iris[, 5L], - tuneList = list( - "nnet|1" = caretModelSpec( - method = "nnet", - tuneGrid = expand.grid(.size = c(1L, 3L), .decay = 0.3), - trace = FALSE - ), - "nnet|2" = caretModelSpec( - method = "nnet", - tuneGrid = expand.grid(.size = 3L, .decay = c(0.1, 0.3)), - trace = FALSE - ) - ) - ) - testthat::expect_named(model_list, c("nnet.1", "nnet.2")) -}) - -############################################### -testthat::context("We can fit models with a mix of methodList and tuneList") -################################################ -testthat::test_that("We can fit models with a mix of methodList and tuneList", { - myList <- list( - rpart = caretModelSpec(method = "rpart", tuneLength = 10L), - rf = caretModelSpec(method = "rf", tuneGrid = data.table::data.table(mtry = 2L)) - ) - test <- testthat::expect_warning( - caretList( - x = iris[, 1L:3L], - y = iris[, 4L], - methodList = c("knn", "glm"), - tuneList = myList - ), "There were missing values in resampled performance measures." + testthat::expect_error( + methodCheck(list(invalid_method = 42L)), + "Method \"42\" is invalid" ) - testthat::expect_is(test, "caretList") - testthat::expect_is(caretEnsemble(test), "caretEnsemble") - testthat::expect_length(test, 4L) - methods <- vapply(test, function(x) x$method, character(1L)) - names(methods) <- NULL - testthat::expect_identical(methods, c("rpart", "rf", "knn", "glm")) }) -################################################ -testthat::context("We can handle different CV methods") -################################################ -testthat::test_that("We can handle different CV methods", { - for (m in c( - "boot", - "adaptive_boot", - "cv", - "repeatedcv", - "adaptive_cv", - "LGOCV", - "adaptive_LGOCV" - ) - ) { - N <- 7L - x <- iris[, 1L:3L] - y <- iris[, 4L] - - if (m == "boot" || m == "adaptive_boot") { - idx <- caret::createResample(y, times = N, list = TRUE) - } else if (m == "cv" || m == "adaptive_cv") { - idx <- caret::createFolds(y, k = N, list = TRUE, returnTrain = TRUE) - } else if (m == "repeatedcv") { - idx <- caret::createMultiFolds(y, k = N, times = 2L) - } else if (m == "LGOCV" || m == "adaptive_LGOCV") { - idx <- caret::createDataPartition( - y, - times = N, - p = 0.5, - list = TRUE, - groups = min(5L, length(y)) - ) - } - - models <- testthat::expect_warning( - caretList( - x = x, - y = y, - tuneLength = 2L, - methodList = c("rpart", "rf") - ), "There were missing values in resampled performance measures." - ) - ens <- caretStack(models, method = "glm") - - for (x in models) { - testthat::expect_s3_class(x, "train") - } +################################################################ +testthat::context("S3 methods for caretlist") +################################################################ - ens <- caretEnsemble(models) - - testthat::expect_is(ens, "caretEnsemble") - - ens <- caretStack(models, method = "glm") - - testthat::expect_is(ens, "caretStack") - } -}) - -testthat::test_that("Non standard cv methods work", { +testthat::test_that("Target extraction functions work", { data(iris) - models <- lapply( - c("boot632", "LOOCV", "none"), - function(m) { - model <- caret::train( - x = iris[, 1L:2L], - y = iris[, 3L], - tuneLength = 1L, - data = iris, - method = "rf", - trControl = caret::trainControl( - method = m, - savePredictions = "final" - ) - ) - testthat::expect_is(model, "train") - model - } - ) - caret_list <- as.caretList(models) - p <- predict(caret_list, newdata = iris[, 1L:2L]) - testthat::expect_s3_class(p, "data.table") -}) - -############################################### -testthat::context("Classification models") -################################################ -testthat::test_that("Classification models", { - # Simple two method list - # Warning because we Are going to auto-set indexes - test1 <- caretList( - x = train[, -23L], - y = train[, "Class"], - methodList = c("knn", "glm") - ) - - testthat::expect_is(test1, "caretList") - testthat::expect_is(caretEnsemble(test1), "caretEnsemble") - testthat::expect_is(caretEnsemble(test1), "caretEnsemble") -}) - -testthat::test_that("Longer tests for Classification models", { - # Simple two method list - # Warning because we Are going to auto-set indexes - test1 <- caretList( - x = train[, -23L], - y = train[, "Class"], - methodList = c("knn", "glm") - ) - - testthat::expect_is(test1, "caretList") - testthat::expect_is(caretEnsemble(test1), "caretEnsemble") - testthat::expect_is(caretEnsemble(test1), "caretEnsemble") - - test2 <- caretList( - x = train[, -23L], - y = train[, "Class"], - metric = "ROC", - methodList = c("knn", "glm", "rpart") - ) - - test3 <- caretList( - x = train[, -23L], - y = train[, "Class"], - metric = "ROC", - methodList = c("rpart", "knn", "glm") - ) - - testthat::expect_is(test2, "caretList") - testthat::expect_is(test3, "caretList") - testthat::expect_is(caretEnsemble(test2), "caretEnsemble") - testthat::expect_is(caretEnsemble(test3), "caretEnsemble") - - testthat::expect_identical(test2[[1L]]$metric, "ROC") - testthat::expect_identical(test3[[1L]]$metric, "ROC") -}) - -testthat::test_that("Test that caretList preserves user specified error functions", { - test1 <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneLength = 7L, - methodList = c("knn", "rpart", "glm") - ) - - test2 <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneLength = 4L, - methodList = c("knn", "rpart", "glm") - ) - - testthat::expect_identical(test1[[1L]]$metric, "ROC") - testthat::expect_identical(test2[[1L]]$metric, "ROC") - - testthat::expect_identical(nrow(test1[[1L]]$results), 7L) - testthat::expect_gt(nrow(test1[[1L]]$results), nrow(test2[[1L]]$results)) - testthat::expect_identical(nrow(test2[[1L]]$results), 4L) - - myEns2 <- caretEnsemble(test2) - myEns1 <- caretEnsemble(test1) - testthat::expect_is(myEns2, "caretEnsemble") - testthat::expect_is(myEns1, "caretEnsemble") - - test1 <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneLength = 7L, - methodList = c("knn", "rpart", "glm") - ) - - test2 <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneLength = 4L, - methodList = c("knn", "rpart", "glm") - ) - - testthat::expect_identical(test1[[1L]]$metric, "ROC") - testthat::expect_identical(test2[[1L]]$metric, "ROC") - - testthat::expect_identical(nrow(test1[[1L]]$results), 7L) - testthat::expect_gt(nrow(test1[[1L]]$results), nrow(test2[[1L]]$results)) - testthat::expect_identical(nrow(test2[[1L]]$results), 4L) - - - myEns2 <- caretEnsemble(test2) - myEns1 <- caretEnsemble(test1) - - testthat::expect_is(myEns2, "caretEnsemble") - testthat::expect_is(myEns1, "caretEnsemble") + testthat::expect_identical(extractCaretTarget(iris[, 1L:4L], iris[, 5L]), iris[, 5L]) + testthat::expect_identical(extractCaretTarget(Species ~ ., iris), iris[, "Species"]) }) -testthat::test_that("Users can pass a custom tuneList", { - tuneTest <- list( - rpart = caretModelSpec( - method = "rpart", - tuneGrid = data.table::data.table(.cp = c(0.01, 0.001, 0.1, 1.0)) - ), - knn = caretModelSpec( - method = "knn", - tuneLength = 9L - ), - svmRadial = caretModelSpec( - method = "lda2", - tuneLength = 1L - ) - ) - - test2a <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneList = tuneTest - ) - - myEns2a <- caretEnsemble(test2a) - testthat::expect_is(myEns2a, "caretEnsemble") - testthat::expect_is(test2a, "caretList") - testthat::expect_identical(nrow(test2a[[1L]]$results), 4L) - testthat::expect_identical(nrow(test2a[[2L]]$results), 9L) - testthat::expect_identical(nrow(test2a[[3L]]$results), 1L) +testthat::test_that("[.caretList", { + subset_models <- models.class[1L:2L] + testthat::expect_s3_class(subset_models, "caretList") + testthat::expect_length(subset_models, 2L) }) -testthat::context("User tuneTest parameters are respected and model is ensembled") -testthat::test_that("User tuneTest parameters are respected and model is ensembled", { - tuneTest <- list( - nnet = caretModelSpec( - method = "nnet", - tuneLength = 3L, - trace = FALSE, - softmax = FALSE - ) - ) - test <- caretList( - x = train[, -23L], - y = train[, "Class"], - tuneList = tuneTest - ) - ens <- caretEnsemble(test) - testthat::expect_is(ens, "caretEnsemble") - testthat::expect_is(test, "caretList") - testthat::expect_identical(nrow(test[[1L]]$results), 3L * 3L) - testthat::expect_false(test[[1L]]$finalModel$softmax) -}) +testthat::test_that("c.caretList", { + combined_models <- c(models.class, models.class) + testthat::expect_s3_class(combined_models, "caretList") + testthat::expect_length(combined_models, length(models.class) * 2L) -testthat::context("Formula interface for caretList works") -testthat::test_that("User tuneTest parameters are respected and model is ensembled", { - tuneTest <- list( - rpart = list(method = "rpart", tuneLength = 2L), - nnet = list(method = "nnet", tuneLength = 2L, trace = FALSE), - glm = list(method = "glm") - ) - x <- iris[, 1L:3L] - y <- iris[, 4L] - set.seed(42L) - test_default <- testthat::expect_warning( - caretList( - x = x, - y = y, - tuneList = tuneTest - ), "There were missing values in resampled performance measures." - ) - set.seed(42L) - test_flma <- testthat::expect_warning( - caretList( - y ~ ., - data = data.table::data.table(y = y, x), - tuneList = tuneTest - ), "There were missing values in resampled performance measures." - ) - ens_default <- caretEnsemble(test_default) - ens_flma <- caretEnsemble(test_flma) - testthat::expect_is(ens_default, "caretEnsemble") - testthat::expect_is(ens_flma, "caretEnsemble") + combined_models <- c(models.class, models.class[[1L]]) + testthat::expect_s3_class(combined_models, "caretList") + testthat::expect_length(combined_models, length(models.class) + 1L) - testthat::expect_equal(ens_default$RMSE, ens_flma$RMSE, tol = 0.000001) - testthat::expect_equal(ens_default$weights, ens_flma$weights, tol = 0.000001) + testthat::expect_error(c.caretList(list(a = 1L, b = 2L)), "class of modelList1 must be 'caretList' or 'train'") }) -############################################### -testthat::context("Regression models") -############################################### +testthat::test_that("as.caretList", { + # Named + model_list <- list(model1 = models.class[[1L]], model2 = models.class[[2L]]) + caretlist_object <- as.caretList(model_list) + testthat::expect_s3_class(caretlist_object, "caretList") + testthat::expect_length(caretlist_object, 2L) -testthat::test_that("Regression Models", { - test1 <- caretList( - x = train[, c(-23L, -1L)], - y = train[, 1L], - methodList = c("glm", "lm") - ) - test2 <- caretList( - x = train[, c(-23L, -1L)], - y = train[, 1L], - methodList = c("glm", "ppr", "lm") - ) + # Unnamed + model_list <- list(models.class[[1L]], models.class[[2L]]) + caretlist_object <- as.caretList(model_list) + testthat::expect_s3_class(caretlist_object, "caretList") + testthat::expect_length(caretlist_object, 2L) - ens1 <- caretEnsemble(test1) - ens2 <- caretEnsemble(test2) - - testthat::expect_is(test1, "caretList") - testthat::expect_is(test2, "caretList") - - testthat::expect_is(ens1, "caretEnsemble") - testthat::expect_is(ens2, "caretEnsemble") -}) - -testthat::test_that("methodCheck stops for invalid method type", { - testthat::expect_error(methodCheck(list(123L)), "Method \"123\" is invalid.") - testthat::expect_error( - methodCheck(list("invalid_method")), - "The following models are not valid caret models: invalid_method" - ) -}) - -testthat::test_that("as.caretList stops for null object", { + # Error cases testthat::expect_error(as.caretList(NULL), "object is null") -}) - -testthat::test_that("as.caretList.list stops for non-list object", { + testthat::expect_error(as.caretList(1L), "object must be a list") + testthat::expect_error(as.caretList(list(1L)), "object requires all elements of list to be caret models") + testthat::expect_error(as.caretList(list(NULL)), "object requires all elements of list to be caret models") testthat::expect_error(as.caretList.list(1L), "object must be a list of caret models") }) -testthat::test_that("predict.caretList doesn't care about missing training data", { - new_model_list <- lapply(models.class, function(x) { - x$trainingData <- NULL - x - }) - new_model_list <- as.caretList(new_model_list) - pred <- predict.caretList(new_model_list) - testthat::expect_is(pred, "data.table") - testthat::expect_identical(nrow(pred), 150L) - testthat::expect_named(pred, names(new_model_list)) -}) - -testthat::test_that("extractModelName handles custom models correctly", { - mock_model <- list(method = list(method = "custom_method")) - class(mock_model) <- "train" - testthat::expect_identical(extractModelName(mock_model), "custom_method") -}) - -testthat::test_that("extractModelName handles custom models correctly", { - mock_model <- list(method = "custom", modelInfo = list(method = "custom_method")) - class(mock_model) <- "train" - testthat::expect_identical(extractModelName(mock_model), "custom_method") -}) - -testthat::test_that("as.caretList.list fails on NULL object", { - err <- "object requires all elements of list to be caret models" - testthat::expect_error(as.caretList(list(NULL)), err) -}) - -testthat::test_that("predict.caretList works when the progress bar is turned off", { - set.seed(42L) - N <- 100L - noise_level <- 1L / 10L - X <- data.table::data.table( - a = runif(N), - b = runif(N) - ) - y <- 7.5 - 10.0 * X$a + 5.0 * X$b + noise_level * rnorm(N) - models <- caretList( - X, y, - tuneLength = 1L, - methodList = "lm" - ) - pred <- predict(models, X, verbose = FALSE)[["lm"]] - rmse <- sqrt(mean((y - pred)^2L)) - testthat::expect_lt(rmse, noise_level) -}) - -testthat::test_that("caretList handles missing data correctly", { - data(iris) - iris_with_na <- iris - x <- iris_with_na[, 1L:4L] - y <- iris_with_na[, 5L] - x[sample.int(nrow(x), 10L), sample.int(ncol(x), 2L)] <- NA +################################################################ +testthat::context("predict.caretlist") +################################################################ - models <- caretList( - x = x, - y = y, - methodList = "rpart" - ) +testthat::test_that("predict.caretList works for classification and regression", { + class_preds <- predict(models.class, newdata = X.class, excluded_class_id = 0L) + reg_preds <- predict(models.reg, newdata = X.reg) - testthat::expect_s3_class(models, "caretList") - testthat::expect_length(models, 1L) -}) + testthat::expect_is(class_preds, "data.table") + testthat::expect_is(reg_preds, "data.table") + testthat::expect_identical(nrow(class_preds), nrow(X.class)) + testthat::expect_identical(nrow(reg_preds), nrow(X.reg)) -testthat::test_that("caretList handles new factor levels in prediction", { - data(iris) + # Test for handling new factor levels in prediction idx <- seq_len(nrow(iris)) idx_train <- sample(idx, 100L) idx_test <- setdiff(idx, idx_train) @@ -656,79 +135,122 @@ testthat::test_that("caretList handles new factor levels in prediction", { pred <- predict(models, newdata = test_data) testthat::expect_is(pred, "data.table") testthat::expect_identical(nrow(pred), nrow(test_data)) + + # Test verbose option + p <- predict(models, newdata = test_data, verbose = TRUE) + testthat::expect_s3_class(p, "data.table") + testthat::expect_identical(nrow(p), nrow(test_data)) }) -testthat::test_that("caretList handles large number of predictors", { - set.seed(123L) - n <- 100L - p <- 1000L - X <- data.table::data.table(matrix(rnorm(n * p), n, p)) - y <- factor(sample(c("A", "B"), n, replace = TRUE)) +################################################################ +testthat::context("caretList") +################################################################ - models <- caretList( - x = X, - y = y, - methodList = c("glmnet", "rpart") +testthat::test_that("caretList works for various scenarios", { + # Basic classification + test1 <- caretList( + x = train[, -23L], + y = train[, "Class"], + methodList = c("knn", "glm") ) + testthat::expect_is(test1, "caretList") + testthat::expect_is(caretEnsemble(test1), "caretEnsemble") - testthat::expect_s3_class(models, "caretList") - testthat::expect_length(models, 2L) -}) - -testthat::test_that("caretList handles imbalanced data", { - set.seed(123L) - n <- 1000L - X <- data.table::data.table(x1 = rnorm(n), x2 = rnorm(n)) - y <- factor(c(rep("A", 950L), rep("B", 50L))) - - models <- caretList( - x = X, - y = y, - methodList = c("glmnet", "rpart") + # Regression + test_reg <- caretList( + x = train[, c(-23L, -1L)], + y = train[, 1L], + methodList = c("glm", "lm") ) + testthat::expect_is(test_reg, "caretList") + testthat::expect_is(caretEnsemble(test_reg), "caretEnsemble") + # Handling missing data + iris_with_na <- iris + x <- iris_with_na[, 1L:4L] + y <- iris_with_na[, 5L] + x[sample.int(nrow(x), 10L), sample.int(ncol(x), 2L)] <- NA + models <- caretList(x = x, y = y, methodList = "rpart") testthat::expect_s3_class(models, "caretList") - testthat::expect_length(models, 2L) -}) - -testthat::test_that("caretList handles custom performance metrics", { - data(iris) - models <- caretList( - x = iris[, 1L:4L], - y = iris[, 5L], - metric = "default", - methodList = c("rpart", "rf"), + # Handling large number of predictors + models_large <- caretList(x = large_data$X, y = large_data$y, methodList = "rpart") + testthat::expect_s3_class(models_large, "caretList") + testthat::expect_length(models_large, 1L) + + # Handling imbalanced data + imbalanced_y <- factor(c(rep("A", 95L), rep("B", 5L))) + testthat::expect_length(imbalanced_y, nrow(large_data$X)) + models_imbalanced <- caretList( + x = large_data$X, + y = imbalanced_y, + methodList = "rpart", trControl = caret::trainControl( method = "cv", - number = 2L, - summaryFunction = function(data, lev = NULL, model = NULL) c(default = mean(data$obs == data$pred)), - allowParallel = FALSE, - classProbs = TRUE + classProbs = TRUE, + summaryFunction = twoClassSummary, + sampling = "up", + index = caret::createFolds(imbalanced_y, k = 5L, returnTrain = TRUE) ) ) - testthat::expect_s3_class(models, "caretList") - testthat::expect_true(all(vapply(models, function(m) "default" %in% colnames(m$results), logical(1L)))) -}) + testthat::expect_s3_class(models_imbalanced, "caretList") + testthat::expect_length(models_imbalanced, 1L) -############################################### -testthat::context("S3 methods") -############################################### + # Test error cases + testthat::expect_error(caretList(Sepal.Width ~ ., iris), "Please either define a methodList or tuneList") + testthat::expect_warning( + caretList(Sepal.Width ~ ., iris, methodList = c("lm", "lm")), + "Duplicate entries in methodList. Using unique methodList values." + ) -testthat::test_that("plot.caretList", { + # Test continue_on_fail + bad <- list( + bad = caretModelSpec(method = "glm", tuneLength = 1L) + ) + testthat::expect_output( + testthat::expect_warning( + testthat::expect_error( + caretList(iris[, 1L:4L], iris[, 5L], tuneList = bad, continue_on_fail = TRUE), + regexp = "caret:train failed for all models. Please inspect your data." + ), + regexp = "model fit failed for Fold1" + ), + regexp = "Something is wrong; all the Accuracy metric values are missing:" + ) +}) + +# Test plot and summary methods +testthat::test_that("plot.caretList and summary.caretList work", { for (model_list in list(models.reg, models.class)) { plt <- plot(model_list) testthat::expect_is(plt, "ggplot") testthat::expect_identical(nrow(plt$data), 4L) testthat::expect_named(model_list, plt$data$model_name) - } -}) -testthat::test_that("summary.caretList", { - for (model_list in list(models.reg, models.class)) { smry <- testthat::expect_silent(summary(model_list)) for (name in names(model_list)) { testthat::expect_output(print(smry), name) } } }) + +# Test combined regression, binary, multiclass models +testthat::test_that("caretList supports combined regression, binary, multiclass", { + set.seed(42L) + + reg_models <- caretList(Sepal.Length ~ Sepal.Width, iris, methodList = c("glm", "lm")) + bin_models <- caretList(factor(ifelse(Species == "setosa", "Yes", "No")) ~ Sepal.Width, iris, + methodList = c("lda", "rpart") + ) + multi_models <- caretList(Species ~ Sepal.Width, iris, methodList = "rpart") + + all_models <- c(reg_models, bin_models, multi_models) + testthat::expect_s3_class(all_models, "caretList") + + stacked_p <- predict(all_models) + new_p <- predict(all_models, newdata = iris[1L:10L, ]) + testthat::expect_is(stacked_p, "data.table") + testthat::expect_is(new_p, "data.table") + testthat::expect_identical(nrow(stacked_p), nrow(iris)) + testthat::expect_identical(nrow(new_p), 10L) +}) diff --git a/tests/testthat/test-caretPredict.R b/tests/testthat/test-caretPredict.R index 5cf0eb38..2463c580 100644 --- a/tests/testthat/test-caretPredict.R +++ b/tests/testthat/test-caretPredict.R @@ -1,10 +1,28 @@ -data(models.reg) -data(X.reg) -data(Y.reg) +# Setup +utils::data(models.reg) +utils::data(X.reg) +utils::data(Y.reg) +utils::data(models.class) +utils::data(X.class) +utils::data(Y.class) -data(models.class) -data(X.class) -data(Y.class) +set.seed(1234L) + +ens.reg <- caretEnsemble( + models.reg, + trControl = caret::trainControl(method = "cv", number = 2L, savePredictions = "final") +) + +ens.class <- caretEnsemble( + models.class, + metric = "ROC", + trControl = caret::trainControl( + number = 2L, + summaryFunction = caret::twoClassSummary, + classProbs = TRUE, + savePredictions = TRUE + ) +) mod <- caret::train( X.reg, @@ -13,15 +31,162 @@ mod <- caret::train( trControl = caret::trainControl(method = "none") ) -testthat::test_that("Extracting metrics works if there is no SD", { - # In the case of no resampling, metrics will not have an SD to extract +# Helper function for testing +expect_data_table_structure <- function(dt, expected_names) { + testthat::expect_s3_class(dt, "data.table") + testthat::expect_named(dt, expected_names) +} + +############################################################################# +testthat::context("caretPredict and extractMetric") +############################################################################# +testthat::test_that("caretPredict extracts best predictions correctly", { + stacked_preds_class <- caretPredict(models.class[[1L]], excluded_class_id = 0L) + stacked_preds_reg <- caretPredict(models.reg[[1L]]) + + expect_data_table_structure(stacked_preds_class, c("No", "Yes")) + expect_data_table_structure(stacked_preds_reg, "pred") +}) + +testthat::test_that("extractMetric works for different model types", { + # Test for model with no resampling (no SD) metric <- extractMetric(mod) - expect_s3_class(metric, "data.table") - expect_named(metric, c("model_name", "metric", "value", "sd")) - expect_is(metric$model_name, "character") - expect_is(metric$metric, "character") - expect_is(metric$value, "numeric") - expect_is(metric$sd, "numeric") - testthat::expect_true(is.na(metric$value)) - testthat::expect_true(is.na(metric$sd)) + expect_data_table_structure(metric, c("model_name", "metric", "value", "sd")) + testthat::expect_true(is.na(metric$value), is.na(metric$sd)) + + # Test for ensemble models + for (ens in list(ens.class, ens.reg)) { + metrics <- extractMetric(ens) + expect_data_table_structure(metrics, c("model_name", "metric", "value", "sd")) + testthat::expect_named(ens$models, metrics$model_name[-1L]) + } +}) + +############################################################################# +testthat::context("S3 methods and model operations") +############################################################################# +testthat::test_that("c.train on 2 train objects", { + testthat::expect_error(c.train(list()), "class of modelList1 must be 'caretList' or 'train'") + + combined_models <- c(models.class[[1L]], models.class[[1L]]) + testthat::expect_s3_class(combined_models, "caretList") + testthat::expect_length(combined_models, 2L) + testthat::expect_identical(anyDuplicated(names(combined_models)), 0L) + testthat::expect_length(unique(names(combined_models)), 2L) +}) + +testthat::test_that("c.train on a train and a caretList", { + bigList <- c(models.reg[[1L]], models.class) + testthat::expect_is(bigList, "caretList") + testthat::expect_identical(anyDuplicated(names(bigList)), 0L) + testthat::expect_length(unique(names(bigList)), 5L) +}) + +testthat::test_that("extractModelName handles different model types", { + testthat::expect_identical(extractModelName(models.class[[1L]]), "rf") + testthat::expect_identical(extractModelName(models.reg[[1L]]), "rf") + + custom_model <- models.class[[1L]] + custom_model$method <- list(method = "custom_rf") + testthat::expect_identical(extractModelName(custom_model), "custom_rf") + + mock_model <- list(method = list(method = "custom_method")) + class(mock_model) <- "train" + testthat::expect_identical(extractModelName(mock_model), "custom_method") + + mock_model <- list(method = "custom", modelInfo = list(method = "custom_method")) + class(mock_model) <- "train" + testthat::expect_identical(extractModelName(mock_model), "custom_method") +}) + +############################################################################# +testthat::context("isClassifierAndValidate") +############################################################################# +testthat::test_that("isClassifierAndValidate handles various model types", { + models_multi <- caretList( + iris[, 1L:2L], iris[, 5L], + tuneLength = 1L, verbose = FALSE, + methodList = c("rf", "gbm") + ) + models_multi_bin_reg <- c(models_multi, models.class, models.reg) + testthat::expect_is(vapply(models_multi_bin_reg, isClassifierAndValidate, logical(1L)), "logical") + + # Test when predictions are missing + model_list <- models.class + model_list[[1L]]$pred <- NULL + testthat::expect_is(vapply(model_list, isClassifierAndValidate, logical(1L)), "logical") + testthat::expect_equivalent(unique(vapply(model_list, isClassifierAndValidate, logical(1L))), TRUE) + + # Test error cases + model_list <- models.class + model_list[[1L]]$modelInfo$prob <- FALSE + testthat::expect_error( + lapply(model_list, isClassifierAndValidate), + "No probability function found. Re-fit with a method that supports prob." + ) + + model_list <- models.class + model_list[[1L]]$control$classProbs <- FALSE + testthat::expect_error( + lapply(model_list, isClassifierAndValidate, validate_for_stacking = TRUE), + "classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl." + ) + + # Test for non-caretList object + testthat::expect_error( + isClassifierAndValidate(list(model = lm(Y.reg ~ ., data = as.data.frame(X.reg)))), + "is(object, \"train\") is not TRUE", + fixed = TRUE + ) + + # Test for models without savePredictions + model <- models.class[[1L]] + model$control$savePredictions <- NULL + testthat::expect_error( + isClassifierAndValidate(model), + "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions." + ) + model$control$savePredictions <- "BAD_VALUE" + testthat::expect_error( + isClassifierAndValidate(model), + "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions." + ) +}) + +############################################################################# +testthat::context("validateExcludedClass") +############################################################################# +testthat::test_that("validateExcludedClass handles various inputs", { + testthat::expect_error(validateExcludedClass("invalid"), "classification excluded level must be numeric: invalid") + testthat::expect_warning( + testthat::expect_error(validateExcludedClass(Inf), "classification excluded level must be finite: Inf"), + "classification excluded level is not an integer: Inf" + ) + testthat::expect_warning( + testthat::expect_error(validateExcludedClass(-1.0), "classification excluded level must be >= 0: -1"), + "classification excluded level is not an integer:" + ) + testthat::expect_warning(validateExcludedClass(1.1), "classification excluded level is not an integer: 1.1") + testthat::expect_identical(validateExcludedClass(3L), 3L) + + # Edge cases + testthat::expect_identical(validateExcludedClass(0L), 0L) + testthat::expect_identical(validateExcludedClass(1L), 1L) + testthat::expect_identical(validateExcludedClass(4L), 4L) + w <- "classification excluded level is not an integer:" + testthat::expect_warning(testthat::expect_identical(validateExcludedClass(0.0), 0L), w) + testthat::expect_warning(testthat::expect_identical(validateExcludedClass(1.0), 1L), w) + testthat::expect_warning(testthat::expect_identical(validateExcludedClass(4.0), 4L), w) + testthat::expect_error(validateExcludedClass(-1L), "classification excluded level must be >= 0: -1") + + # Additional tests + testthat::expect_warning(validateExcludedClass(NULL), "No excluded_class_id set. Setting to 1L.") + testthat::expect_error( + validateExcludedClass(c(1L, 2L)), + "classification excluded level must have a length of 1: length=2" + ) + testthat::expect_warning( + testthat::expect_error(validateExcludedClass(-0.000001), "classification excluded level must be >= 0: -1e-06"), + "classification excluded level is not an integer" + ) }) diff --git a/tests/testthat/test-caretStack.R b/tests/testthat/test-caretStack.R index ad127f5a..8f8eac2f 100644 --- a/tests/testthat/test-caretStack.R +++ b/tests/testthat/test-caretStack.R @@ -1,85 +1,85 @@ -data(models.reg) -data(X.reg) -data(Y.reg) -data(models.class) -data(X.class) -data(Y.class) - -ens.class <- caretStack( - models.class, - method = "glm", - preProcess = "pca", - trControl = caret::trainControl(method = "cv", number = 2L, savePredictions = "final", classProbs = TRUE) +# Load data and create models +utils::data(models.reg) +utils::data(X.reg) +utils::data(Y.reg) +utils::data(models.class) +utils::data(X.class) +utils::data(Y.class) +utils::data(iris) + +models_multiclass <- caretList( + x = iris[, -5L], + y = iris[, 5L], + methodList = c("rpart", "glmnet") ) -ens.reg <- caretStack( - models.reg, - method = "lm", - preProcess = "pca", - trControl = caret::trainControl(method = "cv", number = 2L, savePredictions = "final", classProbs = FALSE) +control_class <- caret::trainControl( + method = "cv", + number = 2L, + savePredictions = "final", + classProbs = TRUE, + summaryFunction = twoClassSummary, + index = caret::createFolds(Y.class, 2L) ) +ens.class <- caretStack(models.class, preProcess = "pca", method = "glm", metric = "ROC", trControl = control_class) + +control_reg <- caret::trainControl( + method = "cv", + number = 2L, + savePredictions = "final", + classProbs = FALSE, + index = caret::createFolds(Y.reg, 2L) +) +ens.reg <- caretStack(models.reg, preProcess = "pca", method = "lm", trControl = control_reg) + +# Helper functions +expect_all_finite <- function(x) { + testthat::expect_true(all(vapply(x, function(col) all(is.finite(col)), logical(1L)))) +} + +###################################################################### +testthat::context("caretStack") +###################################################################### + +testthat::test_that("caretStack creates valid ensemble models", { + testthat::expect_s3_class(ens.class, "caretStack") + testthat::expect_s3_class(ens.reg, "caretStack") -testthat::context("Does stacking and prediction work?") + testthat::expect_s3_class(summary(ens.class), "summary.caretStack") + testthat::expect_s3_class(plot(ens.class), "ggplot") + testthat::expect_output(print(ens.class), "The following models were ensembled: rf, glm, rpart, treebag") +}) -testthat::test_that("We can make predictions from stacks, including cases where the stacked model has preprocessing", { +testthat::test_that("predict works for classification and regression ensembles", { ens_list <- list(class = ens.class, reg = ens.reg) X_list <- list(class = X.class, reg = X.reg) for (model_name in names(ens_list)) { - # S3 methods ens <- ens_list[[model_name]] - testthat::expect_s3_class(ens, "caretStack") - testthat::expect_s3_class(summary(ens), "summary.caretStack") - testthat::expect_s3_class(plot(ens), "ggplot") - invisible(capture.output(print(ens))) + X <- X_list[[model_name]] - # Predictions param_grid <- expand.grid( se = c(TRUE, FALSE), newdata = c(TRUE, FALSE) ) + for (i in seq_len(nrow(param_grid))) { params <- param_grid[i, ] - X <- X_list[[model_name]] - expected_rows <- nrow(X) - newdata <- NULL - if (params$newdata) { - newdata <- X[1L:50L, ] - expected_rows <- 50L - } + newdata <- if (params$newdata) X[1L:50L, ] else NULL + expected_rows <- if (params$newdata) 50L else nrow(X) + pred <- predict(ens, newdata = newdata, se = params$se) testthat::expect_s3_class(pred, "data.table") testthat::expect_identical(nrow(pred), expected_rows) - testthat::expect_true(all(vapply(pred, is.numeric, logical(1L)))) + expect_all_finite(pred) } } }) -testthat::test_that("For classificaiton, we can predict the class labels", { +testthat::test_that("Classification predictions return correct output", { pred.class <- predict(ens.class, X.class, return_class_only = TRUE) - expect_is(pred.class, "factor") - expect_length(pred.class, nrow(X.class)) -}) - -testthat::test_that("caretStack plots", { - ens.gbm <- caretStack( - models.reg, - method = "gbm", tuneLength = 2L, verbose = FALSE - ) - testthat::expect_s3_class(ens.gbm, "caretStack") - - plt <- plot(ens.gbm) - expect_s3_class(plt, "ggplot") - - dotplot <- lattice::dotplot(ens.gbm, metric = "RMSE") - expect_s3_class(dotplot, "trellis") -}) - -testthat::test_that("Prediction names are correct with SE", { - testthat::expect_named( - predict(ens.reg, X.reg, se = TRUE, excluded_class_id = 0L), - c("pred", "lwr", "upr") - ) + testthat::expect_s3_class(pred.class, "factor") + testthat::expect_length(pred.class, nrow(X.class)) testthat::expect_named( predict(ens.class, X.class, se = TRUE, excluded_class_id = 1L), @@ -92,9 +92,14 @@ testthat::test_that("Prediction names are correct with SE", { ) }) -testthat::test_that("Prediction equivalence", { - # Note that SE is stochastic, since it uses permutation importance +testthat::test_that("Regression predictions return correct output", { + testthat::expect_named( + predict(ens.reg, X.reg, se = TRUE, excluded_class_id = 0L), + c("pred", "lwr", "upr") + ) +}) +testthat::test_that("Predictions are reproducible", { set.seed(42L) p1 <- predict(ens.class, X.class, se = TRUE, level = 0.8) set.seed(42L) @@ -103,153 +108,51 @@ testthat::test_that("Prediction equivalence", { testthat::expect_equivalent(p1, p2) }) -testthat::test_that("Test na.action pass through", { - set.seed(1337L) - - # drop the first model because it does not support na.pass - ens.reg <- caretStack(models.reg[2L:3L], method = "lm") +testthat::test_that("caretStack handles missing data correctly", { + ens.reg.subset <- caretEnsemble::caretStack(models.reg[2L:3L], method = "lm") X_reg_na <- X.reg - # introduce random NA values into a column X_reg_na[sample.int(nrow(X_reg_na), 20L), sample.int(ncol(X_reg_na) - 1L, 1L)] <- NA - pred.reg <- predict(ens.reg, newdata = X_reg_na, na.action = na.pass) + pred.reg <- predict(ens.reg.subset, newdata = X_reg_na) testthat::expect_identical(nrow(pred.reg), nrow(X_reg_na)) - pred.reg <- predict(ens.reg, newdata = X_reg_na) - testthat::expect_false(nrow(pred.reg) != nrow(X_reg_na)) + pred.reg <- predict(ens.reg.subset, newdata = X_reg_na) + testthat::expect_identical(nrow(pred.reg), nrow(X_reg_na)) }) -testthat::test_that("predict.caretStack works correctly if the multiclass excluded level is too high", { - data(iris) - - # Create a caretList - model_list <- caretList( - Species ~ ., - data = iris, - methodList = c("rpart", "rf") - ) - - # Make sure predictions still work if the exlcuded level is too high - meta_model <- caretStack( - model_list, - method = "rpart", - excluded_class_id = 4L - ) +testthat::test_that("caretStack handles multiclass problems", { + meta_model <- caretEnsemble::caretStack(models_multiclass, method = "rpart", excluded_class_id = 4L) pred <- predict(meta_model, newdata = iris) testthat::expect_identical(nrow(pred), 150L) testthat::expect_identical(ncol(pred), 3L) - all_finite <- function(x) all(is.finite(x)) - testthat::expect_true(all(vapply(pred, all_finite, logical(1L)))) + expect_all_finite(pred) }) -testthat::context("caretStack edge cases") - -testthat::test_that("caretStack handles different stacking algorithms", { - for (x in list(list(models.reg, X.reg), list(models.class, X.class))) { - model_list <- x[[1L]] - test_data <- x[[2L]] +testthat::test_that("caretStack works with different stacking algorithms", { + stack_methods <- c("glm", "rf", "gbm", "glmnet") - stack_methods <- c("glm", "rf", "gbm", "glmnet") - - for (method in stack_methods) { - if (method == "gbm") { - stack <- caretStack( - model_list, - method = method, - verbose = FALSE - ) + for (method in stack_methods) { + for (model_list in list(models.reg, models.class)) { + stack <- if (method == "gbm") { + caretEnsemble::caretStack(model_list, method = method, verbose = FALSE) } else { - stack <- caretStack( - model_list, - method = method - ) + caretEnsemble::caretStack(model_list, method = method) } testthat::expect_s3_class(stack, "caretStack") testthat::expect_identical(stack$ens_model$method, method) - predictions <- predict(stack, newdata = test_data) - testthat::expect_identical(nrow(predictions), nrow(test_data)) + predictions <- predict(stack, newdata = if (identical(model_list, models.reg)) X.reg else X.class) + testthat::expect_identical(nrow(predictions), nrow(if (identical(model_list, models.reg)) X.reg else X.class)) } } }) -testthat::test_that("caretStack handles missing data in new data", { - models.class.subset <- models.class[c("rpart", "treebag")] - - stack <- caretStack( - models.class.subset, - method = "rpart" - ) - - test_data_with_na <- X.class - test_data_with_na[1L:5L, 1L] <- NA - - pred <- predict(stack, newdata = test_data_with_na) - testthat::expect_identical(nrow(pred), nrow(test_data_with_na)) -}) - -testthat::test_that("caretStack handles different metrics", { - metrics <- c("ROC", "Sens", "Spec") - for (metric in metrics) { - stack <- caretStack( - models.class, - method = "glm", - metric = metric, - trControl = caret::trainControl( - method = "cv", - number = 3L, - classProbs = TRUE, - summaryFunction = caret::twoClassSummary - ) - ) - testthat::expect_s3_class(stack, "caretStack") - testthat::expect_identical(stack$ens_model$metric, metric) - } -}) - -testthat::test_that("caretStack handles upsampling data", { - data(iris) - train_data <- iris - - imbalanced_data <- rbind( - train_data[train_data$Species == "setosa", ], - train_data[train_data$Species == "versicolor", ][1L:10L, ], - train_data[train_data$Species == "virginica", ][1L:5L, ] - ) - - model_list <- caretList( - x = imbalanced_data[, 1L:4L], - y = imbalanced_data$Species, - methodList = "rpart", - trControl = caret::trainControl( - method = "cv", - number = 3L, - classProbs = TRUE, - sampling = "up", - savePredictions = "final" - ) - ) - - stack <- caretStack( - model_list, - method = "rpart" - ) - - testthat::expect_s3_class(stack, "caretStack") - pred <- predict(stack, newdata = imbalanced_data) - testthat::expect_identical(nrow(pred), nrow(imbalanced_data)) -}) - testthat::test_that("caretStack handles custom preprocessing", { preprocess <- c("center", "scale", "pca") for (model_list in list(models.class, models.reg)) { - stack <- caretStack( - model_list, - method = "glm", - preProcess = preprocess - ) + stack <- caretEnsemble::caretStack(model_list, method = "glm", preProcess = preprocess) testthat::expect_s3_class(stack, "caretStack") testthat::expect_named(stack$ens_model$preProcess$method, c(preprocess, "ignore")) } @@ -261,7 +164,7 @@ testthat::test_that("caretStack handles custom performance function", { } for (model_list in list(models.class, models.reg)) { - stack <- caretStack( + stack <- caretEnsemble::caretStack( model_list, method = "glm", metric = "default", @@ -273,107 +176,167 @@ testthat::test_that("caretStack handles custom performance function", { } }) -testthat::test_that("predict.caretStack works if excluded_class_id is not set", { - ens <- caretStack(models.class) - ens[["excluded_class_id"]] <- NULL - pred <- testthat::expect_warning(predict(ens, X.class), "No excluded_class_id set. Setting to 1L.") - - # Note that we don't exclude the class from the ensemble predictions, but merely from the preprocessing - testthat::expect_s3_class(pred, "data.table") # caret returns data.frame - testthat::expect_identical(nrow(pred), nrow(X.class)) - testthat::expect_identical(ncol(pred), 2L) - testthat::expect_named(pred, c("No", "Yes")) -}) - -testthat::context("Edge cases") - -testthat::test_that("caretStack coerces lists to caretLists", { - model_list <- models.reg - class(model_list) <- "list" - names(model_list) <- NULL - ens <- testthat::expect_warning( - caretStack(model_list), - "Attempting to coerce all.models to a caretList." - ) - testthat::expect_s3_class(ens, "caretStack") - testthat::expect_s3_class(ens$models, "caretList") - testthat::expect_named(ens$models, names(models.reg)) -}) - -testthat::test_that("caretStack fails if new_X is NULL and newY is not and vice versa", { - err <- "Both new_X and new_y must be NULL, or neither." - testthat::expect_error(caretStack(models.reg, new_X = NULL, new_y = Y.reg), err) - testthat::expect_error(caretStack(models.reg, new_X = X.reg, new_y = NULL), err) -}) - -testthat::test_that("caretStack works if both new_X and new_Y are supplied", { +testthat::test_that("caretStack handles new data correctly", { set.seed(42L) N <- 50L idx <- sample.int(nrow(X.reg), N) + stack_class <- caretStack( models.class, + metric = "ROC", + method = "rpart", new_X = X.class[idx, ], new_y = Y.class[idx], - method = "rpart", - # Need probs for stacked preds - trControl = caret::trainControl(method = "cv", number = 2L, savePredictions = "final", classProbs = TRUE) + trControl = control_class ) + stack_reg <- caretStack( models.reg, + method = "glm", new_X = X.reg[idx, ], new_y = Y.reg[idx], - method = "glm", - # Need probs for stacked preds - trControl = caret::trainControl(method = "cv", number = 2L, savePredictions = "final", classProbs = FALSE) + trControl = control_reg ) testthat::expect_s3_class(stack_class, "caretStack") testthat::expect_s3_class(stack_reg, "caretStack") - pred_class_stack <- predict(stack_class) - stack_reg_stack <- predict(stack_reg) - - testthat::expect_s3_class(pred_class_stack, "data.table") - testthat::expect_s3_class(stack_reg_stack, "data.table") - - testthat::expect_identical(nrow(pred_class_stack), N) - testthat::expect_identical(nrow(stack_reg_stack), N) - - testthat::expect_identical(ncol(pred_class_stack), 2L) - testthat::expect_identical(ncol(stack_reg_stack), 1L) - - pred_class <- predict(stack_class, new_X = X.class) - pred_reg <- predict(stack_reg, new_X = X.reg) + pred_class <- predict(stack_class, X.class) + pred_reg <- predict(stack_reg, X.reg) testthat::expect_s3_class(pred_class, "data.table") testthat::expect_s3_class(pred_reg, "data.table") - testthat::expect_identical(nrow(pred_class), N) - testthat::expect_identical(nrow(pred_reg), N) + testthat::expect_identical(nrow(pred_class), nrow(X.class)) + testthat::expect_identical(nrow(pred_reg), nrow(X.reg)) testthat::expect_identical(ncol(pred_class), 2L) testthat::expect_identical(ncol(pred_reg), 1L) }) -testthat::test_that("varImp works in for class and reg in sample", { +testthat::test_that("caretStack coerces lists to caretLists", { + models <- list( + models.class[[1L]], + models.class[[2L]] + ) + testthat::expect_warning( + caretStack(models, method = "glm", tuneLength = 1L), + "Attempting to coerce all.models to a caretList." + ) +}) + +testthat::test_that("caretStack errors if new_X is provided but not new_y", { + testthat::expect_error( + caretStack(models.class, new_X = X.class), + "Both new_X and new_y must be NULL, or neither." + ) +}) + +testthat::test_that("caretStack errors if new_y is provided but not new_X", { + testthat::expect_error( + caretStack(models.class, new_y = Y.class), + "Both new_X and new_y must be NULL, or neither." + ) +}) + +###################################################################### +testthat::context("S3 methods for caretStack") +###################################################################### + +testthat::test_that("print", { for (ens in list(ens.class, ens.reg)) { - imp <- varImp(ens) - expect_is(imp, "numeric") - expect_named(imp, names(ens$models)) - expect_equal(sum(imp), 1.0, tolerance = 1e-6) + testthat::expect_output(print(ens), "The following models were ensembled: rf, glm, rpart, treebag") + testthat::expect_output(print(ens), "150 samples") + testthat::expect_output(print(ens), "4 predictor") } }) -testthat::test_that("varImp works in for class on new data", { - imp <- varImp(ens.class, X.class) - expect_is(imp, "numeric") - expect_named(imp, names(ens.class$models)) - expect_equal(sum(imp), 1.0, tolerance = 1e-6) +testthat::test_that("summary", { + for (ens in list(ens.class, ens.reg)) { + s <- summary(ens) + testthat::expect_s3_class(s, "summary.caretStack") + testthat::expect_output(print(s), "The following models were ensembled: rf, glm, rpart, treebag") + testthat::expect_output(print(s), "Model Importance:") + testthat::expect_output(print(s), "Model accuracy:") + } +}) + +testthat::test_that("plot", { + for (ens in list(ens.class, ens.reg)) { + p <- plot(ens) + testthat::expect_s3_class(p, "ggplot") + } +}) + +testthat::test_that("dotplot", { + for (ens in list(ens.class, ens.reg)) { + p <- lattice::dotplot(ens) + testthat::expect_s3_class(p, "trellis") + } +}) + +testthat::test_that("autoplot", { + for (ens in list(ens.class, ens.reg)) { + p <- ggplot2::autoplot(ens) + testthat::expect_s3_class(p, "ggplot") + } +}) + +###################################################################### +testthat::context("varImp") +###################################################################### + +testthat::test_that("varImp works for classification and regression", { + for (ens in list(ens.class, ens.reg)) { + imp <- caret::varImp(ens) + testthat::expect_type(imp, "double") + testthat::expect_named(imp, names(ens$models)) + testthat::expect_equal(sum(imp), 1.0, tolerance = 1e-6) + + imp_new_data <- caret::varImp(ens, if (identical(ens, ens.class)) X.class else X.reg) + testthat::expect_type(imp_new_data, "double") + testthat::expect_named(imp_new_data, names(ens$models)) + testthat::expect_equal(sum(imp_new_data), 1.0, tolerance = 1e-6) + } }) -testthat::test_that("varImp works in for reg on new data", { - imp <- varImp(ens.reg, X.class) - expect_is(imp, "numeric") - expect_named(imp, names(ens.reg$models)) - expect_equal(sum(imp), 1.0, tolerance = 1e-6) +###################################################################### +testthat::context("wtd.sd") +###################################################################### + +testthat::test_that("wtd.sd calculates weighted standard deviation correctly", { + x <- c(1L, 2L, 3L, 4L, 5L) + w <- c(1L, 1L, 1L, 1L, 1L) + testthat::expect_equal(caretEnsemble::wtd.sd(x, w), stats::sd(x), tolerance = 0.001) + + w_uneven <- c(2L, 1L, 1L, 1L, 1L) + testthat::expect_false(isTRUE(all.equal(caretEnsemble::wtd.sd(x, w_uneven), stats::sd(x)))) + + x_na <- c(1L, 2L, NA, 4L, 5L) + testthat::expect_true(is.na(caretEnsemble::wtd.sd(x_na, w))) + testthat::expect_false(is.na(caretEnsemble::wtd.sd(x_na, w, na.rm = TRUE))) + + testthat::expect_error(caretEnsemble::wtd.sd(x, w[-1L]), "'x' and 'w' must have the same length") + + x3 <- c(10L, 10L, 10L, 20L) + w1 <- c(0.1, 0.1, 0.1, 0.7) + testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1), 5.291503, tolerance = 0.001) + testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1 * 100L), caretEnsemble::wtd.sd(x3, w = w1), tolerance = 0.001) +}) + +###################################################################### +testthat::context("set_excluded_class_id") +###################################################################### + +testthat::test_that("set_excluded_class_id warning if unset", { + old_ensemble <- ens.class + old_ensemble$excluded_class_id <- NULL + + is_class <- isClassifier(old_ensemble) + new_ensemble <- expect_warning( + set_excluded_class_id(old_ensemble, is_class), + "No excluded_class_id set. Setting to 1L." + ) + + testthat::expect_identical(new_ensemble$excluded_class_id, 1L) }) diff --git a/tests/testthat/test-classSelection.R b/tests/testthat/test-classSelection.R index 62c4e7fa..af631e2a 100644 --- a/tests/testthat/test-classSelection.R +++ b/tests/testthat/test-classSelection.R @@ -1,28 +1,21 @@ -testthat::context("Does binary class selection work?") - # Load and prepare data for subsequent tests seed <- 2239L set.seed(seed) -data(models.class) -data(X.class) -data(Y.class) +utils::data(iris) # Create 80/20 train/test split -index <- caret::createDataPartition(Y.class, p = 0.8)[[1L]] -X.train <- X.class[index, ] -X.test <- X.class[-index, ] -Y.train <- Y.class[index] -Y.test <- Y.class[-index] - -############################################################################# -testthat::context("Do classifier predictions use the correct target classes?") -############################################################################# +target_col <- which(names(iris) == "Species") +index <- caret::createDataPartition(iris[, target_col], p = 0.8)[[1L]] +X.train <- iris[index, -target_col] +X.test <- iris[-index, -target_col] +Y.train <- iris[index, target_col] +Y.test <- iris[-index, target_col] runBinaryLevelValidation <- function(Y.train, Y.test, pos.level = 1L) { # Extract levels of response input data Y.levels <- levels(Y.train) testthat::expect_identical(Y.levels, levels(Y.test)) - testthat::expect_length(Y.levels, 2L) + testthat::expect_length(Y.levels, 3L) # Train a caret ensemble model.list <- caretList( @@ -73,50 +66,31 @@ runBinaryLevelValidation <- function(Y.train, Y.test, pos.level = 1L) { # check exists to avoid previous errors where classifer ensemble predictions were # being made using the incorrect level of the response, causing the opposite # class labels to be predicted with new data. - testthat::expect_gt(cmat.pred$overall["Accuracy"], 0.79) + testthat::expect_gt(cmat.pred$overall["Accuracy"], 0.60) # Similar to the above, ensure that probability predictions are working correctly # by checking to see that accuracy is also high for class predictions created # from probabilities - testthat::expect_gt(cmat.cutoff$overall["Accuracy"], 0.79) + testthat::expect_gt(cmat.cutoff$overall["Accuracy"], 0.60) } -testthat::test_that("Ensembled classifiers do not rearrange outcome factor levels", { - # First run the level selection test using the default levels - # of the response (i.e. c('No', 'Yes')) - set.seed(seed) - runBinaryLevelValidation(Y.train, Y.test, pos.level = 1L) +############################################################################# +testthat::context("Do classifier predictions use the correct target classes?") +############################################################################# - # Now reverse the assigment of the response labels as well as - # the levels of the response factor. Reversing the assignment - # is necessary to make sure the expected accuracy numbers are - # the same (i.e. Making a "No" into a "Yes" in the response means - # predictions of the first class will still be as accurate). - # Reversing the level order then ensures that the outcome is not - # releveled at some point by caretEnsemble. - Y.levels <- levels(Y.train) - refactor <- function(d) { - factor( - ifelse(d == Y.levels[1L], Y.levels[2L], Y.levels[1L]), - levels = rev(Y.levels) - ) +testthat::test_that("validateExcludedClass for multiclass", { + for (excluded_class in c(1L, 2L, 3L)) { + testthat::expect_silent(validateExcludedClass(excluded_class)) } - - set.seed(seed) - runBinaryLevelValidation(refactor(Y.train), refactor(Y.test)) + testthat::expect_error(validateExcludedClass("x"), "classification excluded level must be numeric") }) testthat::test_that("Target class selection configuration works", { - # No error - excluded_class <- validateExcludedClass(1L) - excluded_class <- validateExcludedClass(2L) - - # Should error - testthat::expect_error(validateExcludedClass("x"), "classification excluded level must be numeric") - - # Check that we can exclude the first class Y.levels <- levels(Y.train) refactor <- function(d) factor(as.character(d), levels = rev(Y.levels)) set.seed(seed) - runBinaryLevelValidation(refactor(Y.train), refactor(Y.test), pos.level = 1L) + for (pos in c(1L, 2L, 3L)) { + runBinaryLevelValidation(Y.train, Y.test, pos.level = pos) + runBinaryLevelValidation(refactor(Y.train), refactor(Y.test), pos.level = pos) + } }) diff --git a/tests/testthat/test-ensembleMethods.R b/tests/testthat/test-ensembleMethods.R deleted file mode 100644 index cac9430d..00000000 --- a/tests/testthat/test-ensembleMethods.R +++ /dev/null @@ -1,115 +0,0 @@ -# Are tests failing here? -# UPDATE THE FIXTURES! -# make update-test-fixtures - -testthat::context("Does variable importance work?") - -data(models.reg) -data(X.reg) -data(Y.reg) - -data(models.class) -data(X.class) -data(Y.class) - -ens.class <- caretEnsemble( - models.class, - metric = "ROC", - trControl = caret::trainControl( - number = 2L, - summaryFunction = caret::twoClassSummary, - classProbs = TRUE, - savePredictions = TRUE - ) -) -ens.reg <- caretEnsemble(models.reg, trControl = caret::trainControl(number = 2L, savePredictions = TRUE)) - -testthat::test_that("caret::varImp.caretEnsemble", { - set.seed(2239L) - - for (m in list(ens.class, ens.reg)) { - for (s in c(TRUE, FALSE)) { - i <- caret::varImp(m, normalize = s) - testthat::expect_is(i, "numeric") - if (isClassifier(m)) { - len <- length(m$models) * 2L - n <- c(outer(c("rf", "glm", "rpart", "treebag"), c("No", "Yes"), paste, sep = "_")) - n <- matrix(n, ncol = 2L) - n <- c(t(n)) - } else { - len <- length(m$models) - n <- names(m$models) - } - testthat::expect_length(i, len) - testthat::expect_named(i, n) - if (s) { - testthat::expect_true(all(i >= 0.0)) - testthat::expect_true(all(i <= 1.0)) - testthat::expect_equal(sum(i), 1.0, tolerance = 1e-6) - } - } - } -}) - -testthat::test_that("plot.caretEnsemble", { - for (ens in list(ens.class, ens.reg)) { - plt <- plot(ens) - testthat::expect_is(plt, "ggplot") - testthat::expect_identical(nrow(plt$data), 5L) # 4 models, one ensemble - testthat::expect_named(ens$models, plt$data$model_name[-1L]) # First is ensemble - } -}) - -testthat::test_that("ggplot2::autoplot.caretEnsemble", { - for (ens in list(ens.class, ens.reg)) { - plt1 <- ggplot2::autoplot(ens) - plt2 <- ggplot2::autoplot(ens, xvars = c("Petal.Length", "Petal.Width")) - - testthat::expect_is(plt1, "ggplot") - testthat::expect_is(plt2, "ggplot") - - testthat::expect_is(plt1, "patchwork") - testthat::expect_is(plt2, "patchwork") - - train_model <- ens.reg$models[[1L]] - testthat::expect_error(ggplot2::autoplot(train_model), "Objects of class (.*?) are not supported by autoplot") - } -}) - -testthat::test_that("summary.caretEnsemble", { - for (ens in list(ens.class, ens.reg)) { - smry <- testthat::expect_silent(summary(ens.class)) - testthat::expect_output(print(smry), ens.class$ens_model$metric) - for (name in names(ens.class$models)) { - testthat::expect_output(print(smry), name) - } - } -}) - -testthat::test_that("extractModelMetrics", { - for (ens in list(ens.class, ens.reg)) { - metrics <- extractMetric(ens) - testthat::expect_s3_class(metrics, "data.table") - testthat::expect_named(ens$models, metrics$model_name[-1L]) - } -}) - -testthat::test_that("precict.caretEnsemble with and without se and weights", { - for (ens in list(ens.class, ens.reg)) { - is_class <- isClassifier(ens) - for (se in c(FALSE, TRUE)) { - p <- predict( - ens, - newdata = X.reg, - se = se, - excluded_class_id = 1L - ) - expect_s3_class(p, "data.table") - if (se) { - testthat::expect_named(p, c("pred", "lwr", "upr")) - } else { - testthat::expect_named(p, ifelse(is_class, "Yes", "pred")) - } - } - } -}) diff --git a/tests/testthat/test-helper_functions.R b/tests/testthat/test-helper_functions.R deleted file mode 100644 index 894253fb..00000000 --- a/tests/testthat/test-helper_functions.R +++ /dev/null @@ -1,444 +0,0 @@ -######################################################################## -testthat::context("Do the helper functions work for regression objects?") -######################################################################## - -data(models.reg) -data(X.reg) -data(Y.reg) - -data(models.class) -data(X.class) -data(Y.class) - -testthat::test_that("Recycling generates a warning", { - testthat::expect_error( - caretEnsemble::wtd.sd(matrix(1L:10L, ncol = 2L), w = 1L), - "'x' and 'w' must have the same length" - ) -}) - -testthat::test_that("No predictions generates an error", { - models_multi <- caretList( - iris[, 1L:2L], iris[, 5L], - tuneLength = 1L, verbose = FALSE, - methodList = c("rf", "gbm") - ) - testthat::expect_is(vapply(models_multi, isClassifierAndValidate, logical(1L)), "logical") - - models <- caretList( - iris[, 1L:2L], factor(ifelse(iris[, 5L] == "setosa", "Yes", "No")), - tuneLength = 1L, verbose = FALSE, - methodList = c("rf", "gbm") - ) - new_model <- caret::train( - iris[, 1L:2L], factor(ifelse(iris[, 5L] == "setosa", "Yes", "No")), - tuneLength = 1L, - method = "glmnet", - metric = "ROC", - trControl = caret::trainControl( - method = "cv", - number = 2L, - classProbs = TRUE, - summaryFunction = caret::twoClassSummary, - savePredictions = "final" - ) - ) - models2 <- c(new_model, models) - models3 <- c(models, new_model) - testthat::expect_is(vapply(models, isClassifierAndValidate, logical(1L)), "logical") - testthat::expect_is(vapply(models2, isClassifierAndValidate, logical(1L)), "logical") - testthat::expect_is(vapply(models3, isClassifierAndValidate, logical(1L)), "logical") -}) - -testthat::test_that("We can make the stacked predictions matrix", { - out <- predict(models.reg) - testthat::expect_s3_class(out, "data.table") - testthat::expect_identical(dim(out), c(150L, 4L)) - testthat::expect_named(out, c("rf", "glm", "rpart", "treebag")) -}) - -testthat::test_that("We can predict", { - out <- predict(models.reg, newdata = X.reg) - testthat::expect_is(out, "data.table") - testthat::expect_identical(dim(out), c(150L, 4L)) - testthat::expect_named(out, c("rf", "glm", "rpart", "treebag")) -}) - -######################################################################## -testthat::context("Do the helper functions work for classification objects?") -######################################################################## - -testthat::test_that("We can make the stacked predictions matrix", { - out <- predict(models.class) - testthat::expect_s3_class(out, "data.table") - testthat::expect_identical(dim(out), c(150L, 4L * 1L)) # number of models * (number of classes-1) -}) - -testthat::test_that("We can predict", { - out <- predict(models.class, newdata = X.class, excluded_class_id = 0L) - testthat::expect_is(out, "data.table") - testthat::expect_identical(dim(out), c(150L, 4L * 2L)) - model_names <- c("rf", "glm", "rpart", "treebag") - class_names <- c("No", "Yes") - combinations <- expand.grid(class_names, model_names) - testthat::expect_named(out, paste(combinations$Var2, combinations$Var1, sep = "_")) - out2 <- predict(models.reg, newdata = X.reg) - testthat::expect_identical(dim(out2), c(150L, 4L)) - testthat::expect_named(out2, c("rf", "glm", "rpart", "treebag")) -}) - -testthat::test_that("predict results same regardless of verbose option", { - invisible(capture.output({ - testthat::expect_is(predict(models.class, newdata = X.class), "data.table") - out1 <- predict(models.class, newdata = X.class) - out2 <- predict(models.class, verbose = TRUE, newdata = X.class) - testthat::expect_identical(out1, out2) - - testthat::expect_is(predict(models.reg, newdata = X.reg), "data.table") - out1 <- predict(models.reg, newdata = X.reg) - out2 <- predict(models.reg, verbose = TRUE, newdata = X.reg) - testthat::expect_identical(out1, out2) - })) -}) - -testthat::context("Test weighted standard deviations") - -testthat::test_that("wtd.sd applies weights correctly", { - x1 <- c(3L, 5L, 9L, 3L, 4L, 6L, 4L) - x2 <- c(10L, 10L, 20L, 14L, 2L, 2L, 40L) - x3 <- c(10L, 10L, 10L, 20L) - w1 <- c(0.1, 0.1, 0.1, 0.7) - testthat::expect_error(caretEnsemble::wtd.sd(x1), 'argument "w" is missing, with no default') - testthat::expect_false(sd(x1) == caretEnsemble::wtd.sd(x1, w = x2)) - testthat::expect_false(sd(x1) == caretEnsemble::wtd.sd(x1, w = x2)) - testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1), 5.291503, tolerance = 0.001) - testthat::expect_equal(caretEnsemble::wtd.sd(x3, w = w1 * 100L), caretEnsemble::wtd.sd(x3, w = w1), tolerance = 0.001) -}) - -testthat::test_that("wtd.sd handles NA values correctly", { - x1 <- c(10L, 10L, 10L, 20L, NA, NA) - w1 <- c(0.1, 0.1, 0.1, 0.7, NA, NA) - testthat::expect_true(is.na(caretEnsemble::wtd.sd(x1, w = w1))) - testthat::expect_true(is.na(sd(x1))) - testthat::expect_false(is.na(caretEnsemble::wtd.sd(x1, w = w1, na.rm = TRUE))) - testthat::expect_false(is.na(sd(x1, na.rm = TRUE))) - testthat::expect_true(is.na(caretEnsemble::wtd.sd(x1, w = w1))) - testthat::expect_false(is.na(caretEnsemble::wtd.sd(x1, w = w1, na.rm = TRUE))) -}) - -testthat::test_that("caretList supports combined regression, binary, multiclass", { - set.seed(42L) - - # Regression models - reg_models <- caretList( - Sepal.Length ~ Sepal.Width, - iris, - methodList = c("glm", "lm") - ) - testthat::expect_is(predict(reg_models), "data.table") - - # Binary model - bin_models <- caretList( - factor(ifelse(Species == "setosa", "Yes", "No")) ~ Sepal.Width, - iris, - methodList = c("lda", "rpart") - ) - testthat::expect_is(predict(bin_models), "data.table") - - # Multiclass model - multi_models <- caretList( - Species ~ Sepal.Width, - iris, - methodList = "rpart" - ) - testthat::expect_is(predict(multi_models), "data.table") - - # Combine them! - all_models <- c(reg_models, bin_models, multi_models) - testthat::expect_s3_class(all_models, "caretList") - testthat::expect_is(vapply(all_models, isClassifierAndValidate, logical(1L)), "logical") - - # Test preds - stacked_p <- predict(all_models) - new_p <- predict(all_models, newdata = iris[seq_len(10L), ]) - testthat::expect_is(stacked_p, "data.table") - testthat::expect_is(new_p, "data.table") - testthat::expect_identical(nrow(stacked_p), nrow(iris)) - testthat::expect_identical(nrow(new_p), 10L) -}) - -testthat::test_that("isClassifierAndValidate shouldn't care about predictions", { - model_list <- models.class - model_list[[1L]]$pred <- NULL - testthat::expect_is(vapply(model_list, isClassifierAndValidate, logical(1L)), "logical") - testthat::expect_equivalent(unique(vapply(model_list, isClassifierAndValidate, logical(1L))), TRUE) -}) - -testthat::test_that("isClassifierAndValidate stops when a classification model can't predict probabilities", { - model_list <- models.class - model_list[[1L]]$modelInfo$prob <- FALSE - err <- "No probability function found. Re-fit with a method that supports prob." - testthat::expect_error(lapply(model_list, isClassifierAndValidate), err) -}) - -testthat::test_that("isClassifierAndValidate stops when a classification model did not save probs", { - model_list <- models.class - model_list[[1L]]$control$classProbs <- FALSE - err <- "classProbs = FALSE. Re-fit with classProbs = TRUE in trainControl." - testthat::expect_error(lapply(model_list, isClassifierAndValidate, validate_for_stacking = TRUE), err) - testthat::context("Test helper functions for multiclass classification") - - testthat::test_that("Configuration function for excluded level work", { - # Integers work - testthat::expect_identical(validateExcludedClass(0L), 0L) - testthat::expect_identical(validateExcludedClass(1L), 1L) - testthat::expect_identical(validateExcludedClass(4L), 4L) - - # Decimals work with a warning - wrn <- "classification excluded level is not an integer:" - testthat::expect_warning(testthat::expect_identical(validateExcludedClass(0.0), 0L), wrn) - testthat::expect_warning(testthat::expect_identical(validateExcludedClass(1.0), 1L), wrn) - testthat::expect_warning(testthat::expect_identical(validateExcludedClass(4.0), 4L), wrn) - - # Less than 0 will error - testthat::expect_error(validateExcludedClass(-1L), "classification excluded level must be >= 0: -1") - - # Make a model list - data(iris) - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = c("rpart", "glmnet") - ) - - # Stacking with the excluded level should work - invisible(caretStack(model_list, method = "knn", excluded_class_id = 1L)) - - # Stacking with too great of a level should work. No error or warning. - # Should also validate it? - stack <- caretStack(model_list, method = "knn", excluded_class_id = 4L) - invisible(predict(stack, iris[, -5L])) - - # Check if we are actually excluding level 1 (setosa) - classes <- levels(iris[, 5L])[-1L] - models <- c("rpart", "glmnet") - class_model_combinations <- expand.grid(classes, models) - varImp_rownames <- apply(class_model_combinations, 1L, function(x) paste(x[2L], x[1L], sep = "_")) - - model_stack <- caretStack(model_list, method = "knn", excluded_class_id = 1L) - testthat::expect_identical(rownames(caret::varImp(model_stack$ens_model)$importance), varImp_rownames) - }) -}) - -# Tests for validateExcludedClass function -testthat::test_that("validateExcludedClass stops for non-numeric input", { - invalid_input <- "invalid" - err <- "classification excluded level must be numeric: invalid" - testthat::expect_error(validateExcludedClass(invalid_input), err) -}) - -testthat::test_that("validateExcludedClass stops for non-finite input", { - invalid_input <- Inf - err <- "classification excluded level must be finite: Inf" - testthat::expect_warning( - testthat::expect_error(validateExcludedClass(invalid_input), err), - "classification excluded level is not an integer: Inf" - ) -}) - -testthat::test_that("validateExcludedClass stops for non-positive input", { - invalid_input <- -1.0 - err <- "classification excluded level must be >= 0: -1" - wrn <- "classification excluded level is not an integer:" - testthat::expect_warning(testthat::expect_error(validateExcludedClass(invalid_input), err), wrn) -}) - -validated <- testthat::test_that("validateExcludedClass warns for non-integer input", { - testthat::expect_identical( - testthat::expect_warning( - validateExcludedClass(1.1), - "classification excluded level is not an integer: 1.1" - ), 1L - ) -}) - -testthat::test_that("validateExcludedClass passes for valid input", { - valid_input <- 3L - testthat::expect_identical(validateExcludedClass(valid_input), 3L) -}) - -######################################################################## -testthat::context("Helper function edge cases") -######################################################################## - -testthat::test_that("wtd.sd calculates weighted standard deviation correctly", { - x <- c(1L, 2L, 3L, 4L, 5L) - w <- c(1L, 1L, 1L, 1L, 1L) - testthat::expect_equal(wtd.sd(x, w), sd(x), tol = 0.001) - - w <- c(2L, 1L, 1L, 1L, 1L) - testthat::expect_true(wtd.sd(x, w) != sd(x)) - - # Test with NA values - x_na <- c(1L, 2L, NA, 4L, 5L) - testthat::expect_true(is.na(wtd.sd(x_na, w))) - testthat::expect_false(is.na(wtd.sd(x_na, w, na.rm = TRUE))) - - # Test error for mismatched lengths - testthat::expect_error(wtd.sd(x, w[-1L]), "'x' and 'w' must have the same length") -}) - -testthat::test_that("isClassifierAndValidate validates caretList correctly", { - testthat::expect_is(vapply(models.class, isClassifierAndValidate, logical(1L)), "logical") - testthat::expect_is(vapply(models.reg, isClassifierAndValidate, logical(1L)), "logical") - - # Test error for non-caretList object - testthat::expect_error( - isClassifierAndValidate(list(model = lm(Y.reg ~ ., data = as.data.frame(X.reg)))), - "is(object, \"train\") is not TRUE", - fixed = TRUE - ) -}) - -testthat::test_that("isClassifierAndValidate validates model types correctly", { - testthat::expect_is(vapply(models.class, isClassifierAndValidate, logical(1L)), "logical") - testthat::expect_is(vapply(models.reg, isClassifierAndValidate, logical(1L)), "logical") - - # Test error for mixed model types - mixed_list <- c(models.class, models.reg) - testthat::expect_is(vapply(mixed_list, isClassifierAndValidate, logical(1L)), "logical") -}) - -testthat::test_that("Stacked predictions for caret lists works", { - best_preds_class <- predict(models.class) - best_preds_reg <- predict(models.reg) - - testthat::expect_is(best_preds_class, "data.table") - testthat::expect_is(best_preds_reg, "data.table") - - testthat::expect_named(best_preds_class, names(models.class)) - testthat::expect_named(best_preds_reg, names(models.reg)) -}) - -testthat::test_that("Stacked predictions works with different resampling strategies", { - models.class.inconsistent <- models.class - models.class.inconsistent[[1L]]$pred$Resample <- "WEIRD_SAMPLING" - testthat::expect_is(predict(models.class.inconsistent), "data.table") -}) - -testthat::test_that("Stacked predictions works if the row indexes differ", { - models.class.inconsistent <- models.class - models.class.inconsistent[[1L]]$pred$rowIndex <- rev(models.class.inconsistent[[1L]]$pred$rowIndex) - big_preds <- rbind(models.class.inconsistent[[2L]]$pred, models.class.inconsistent[[2L]]$pred) - models.class.inconsistent[[2L]]$pred <- big_preds - testthat::expect_is(predict(models.class.inconsistent), "data.table") -}) - -testthat::test_that("extractModelName extracts model names correctly", { - testthat::expect_identical(extractModelName(models.class[[1L]]), "rf") - testthat::expect_identical(extractModelName(models.reg[[1L]]), "rf") - - # Test custom model - custom_model <- models.class[[1L]] - custom_model$method <- list(method = "custom_rf") - testthat::expect_identical(extractModelName(custom_model), "custom_rf") -}) - -testthat::test_that("isClassifierAndValidate extracts model types correctly", { - testthat::expect_true(unique(vapply(models.class, isClassifierAndValidate, logical(1L)))) - testthat::expect_false(unique(vapply(models.reg, isClassifierAndValidate, logical(1L)))) -}) - -testthat::test_that("caretPredict extracts best predictions correctly", { - stacked_preds_class <- caretPredict(models.class[[1L]], excluded_class_id = 0L) - stacked_preds_reg <- caretPredict(models.reg[[1L]]) - - testthat::expect_s3_class(stacked_preds_class, "data.table") - testthat::expect_s3_class(stacked_preds_reg, "data.table") - - testthat::expect_named(stacked_preds_class, c("No", "Yes")) - testthat::expect_named(stacked_preds_reg, "pred") -}) - -testthat::test_that("Stacked predictions creates prediction-observation data correctly", { - stacked_preds_class <- predict(models.class) - stacked_preds_reg <- predict(models.reg) - - testthat::expect_s3_class(stacked_preds_class, "data.table") - testthat::expect_s3_class(stacked_preds_reg, "data.table") - - testthat::expect_identical(ncol(stacked_preds_class), length(models.class)) - testthat::expect_identical(ncol(stacked_preds_reg), length(models.reg)) - - testthat::expect_named(stacked_preds_class, names(models.class)) - testthat::expect_named(stacked_preds_reg, names(stacked_preds_reg)) - - testthat::expect_identical(nrow(stacked_preds_class), 150L) - testthat::expect_identical(nrow(stacked_preds_reg), 150L) -}) - -testthat::test_that("Stacked predictions works on new model types", { - # Note that new model types would have to return a single column called 'pred' - models.class.new <- models.reg - for (idx in seq_along(models.class.new)) { - models.class.new[[idx]]$modelType <- "TimeSeries" - } - preds <- predict(models.class.new) - testthat::expect_s3_class(preds, "data.table") -}) - -testthat::test_that("validateExcludedClass validates excluded level correctly", { - testthat::expect_warning(validateExcludedClass(NULL), "No excluded_class_id set. Setting to 1L.") - testthat::expect_error( - validateExcludedClass(c(1L, 2L)), - "classification excluded level must have a length of 1: length=2" - ) - testthat::expect_error(validateExcludedClass("a"), "classification excluded level must be numeric: a") - testthat::expect_error(validateExcludedClass(-1L), "classification excluded level must be >= 0: -1") - testthat::expect_warning( - testthat::expect_error(validateExcludedClass(-0.000001), "classification excluded level must be >= 0: -1e-06"), - "classification excluded level is not an integer" - ) - testthat::expect_warning( - testthat::expect_error(validateExcludedClass(Inf), "classification excluded level must be finite: Inf"), - "classification excluded level is not an integer" - ) - testthat::expect_warning(validateExcludedClass(1.5), "classification excluded level is not an integer: 1.5") - txt <- "classification excluded level is not an integer: 2" - testthat::expect_warning( - testthat::expect_identical( - validateExcludedClass(2.0), 2L - ), txt, - "classification excluded level is not an integer" - ) -}) - -testthat::test_that("validateExcludedClass validates excluded level correctly", { - testthat::expect_warning(validateExcludedClass(NULL), "No excluded_class_id set. Setting to 1L.") - testthat::expect_error( - validateExcludedClass(c(1L, 2L)), - "classification excluded level must have a length of 1: length=2" - ) - testthat::expect_error(validateExcludedClass("a"), "classification excluded level must be numeric: a") - testthat::expect_error(validateExcludedClass(-1L), "classification excluded level must be >= 0: -1") - testthat::expect_warning( - testthat::expect_error(validateExcludedClass(-0.000001), "classification excluded level must be >= 0: -1e-06"), - "classification excluded level is not an integer" - ) - testthat::expect_warning( - testthat::expect_error(validateExcludedClass(Inf), "classification excluded level must be finite: Inf"), - "classification excluded level is not an integer" - ) - testthat::expect_warning(validateExcludedClass(1.5), "classification excluded level is not an integer: 1.5") - txt <- "classification excluded level is not an integer: 2" - testthat::expect_warning(testthat::expect_identical(validateExcludedClass(2.0), 2L), txt) -}) - -testthat::test_that("isClassifierAndValidate fails for models without object$control$savePredictions", { - model <- models.class[[1L]] - model$control$savePredictions <- NULL - err <- "Must have savePredictions = 'all', 'final', or TRUE in trainControl to do stacked predictions." - testthat::expect_error(isClassifierAndValidate(model), err) - model$control$savePredictions <- "BAD_VALUE" - testthat::expect_error(isClassifierAndValidate(model), err) -}) diff --git a/tests/testthat/test-multiclass.R b/tests/testthat/test-multiclass.R index 8a825633..54d4a5f0 100644 --- a/tests/testthat/test-multiclass.R +++ b/tests/testthat/test-multiclass.R @@ -1,16 +1,23 @@ -############################################################################# -testthat::context("caretList and caretStack work for multiclass problems") -############################################################################# -testthat::test_that("We can predict with caretList and caretStack multiclass problems", { - data(iris) - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = c("glmnet", "rpart") +utils::data(iris) +utils::data(Boston, package = "MASS") + +model_list <- caretList( + iris[, -5L], + iris[, 5L], + tuneLength = 1L, + methodList = c("glmnet", "rpart"), + tuneList = list( + nnet = caretModelSpec(method = "nnet", trace = FALSE) ) +) +###################################################################### +testthat::context("caretList and caretStack work for multiclass problems") +###################################################################### + +testthat::test_that("We can predict with caretList and caretStack for multiclass problems", { p <- predict(model_list, newdata = iris[, -5L]) - testthat::expect_is(p, "data.table") + testthat::expect_s3_class(p, "data.table") testthat::expect_identical(nrow(p), nrow(iris)) ens <- caretStack(model_list, method = "rpart") @@ -25,20 +32,9 @@ testthat::test_that("We can predict with caretList and caretStack multiclass pro }) testthat::test_that("Columns for caretList predictions are correct and ordered", { - data(iris) - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = c("glmnet", "rpart"), - tuneList = list( - nnet = caretModelSpec(method = "nnet", trace = FALSE) - ) - ) - num_methods <- length(model_list) num_classes <- length(unique(iris$Species)) - # Check the number of rows and columns is correct p <- predict(model_list, newdata = iris[, -5L], excluded_class_id = 0L) testthat::expect_identical(dim(p), c(nrow(iris), num_methods * num_classes)) @@ -47,48 +43,32 @@ testthat::test_that("Columns for caretList predictions are correct and ordered", class_method_combinations <- expand.grid(classes, methods) ordered_colnames <- apply(class_method_combinations, 1L, function(x) paste(x[2L], x[1L], sep = "_")) - # Check the names of the columns are correct testthat::expect_true(all(colnames(p) %in% ordered_colnames)) - - # Check that the columns are ordered correctly testthat::expect_named(p, ordered_colnames) }) testthat::test_that("Columns for caretStack are correct", { - data(iris) - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = "rpart", - tuneList = list( - nnet = caretModelSpec(method = "nnet", trace = FALSE) - ) - ) - model_stack <- caretStack(model_list, method = "knn") num_classes <- length(unique(iris$Species)) + classes <- levels(iris$Species) - # Check the number of rows and columns is correct p_raw <- predict(model_stack, newdata = iris[, -5L]) testthat::expect_identical(nrow(p_raw), nrow(iris)) + p_prob <- predict(model_stack, newdata = iris[, -5L]) testthat::expect_identical(dim(p_prob), c(nrow(iris), num_classes)) - - classes <- levels(iris$Species) - - # Check that the columns are ordered correctly testthat::expect_named(p_prob, classes) }) testthat::test_that("Periods are supported in method and class names in caretList and caretStack", { - data(iris) - # Rename values and levels to have underscores - levels(iris[, 5L]) <- c("setosa_1", "versicolor_2", "virginica_3") - iris[, 5L] <- factor(iris[, 5L]) + iris_mod <- iris + levels(iris_mod[, 5L]) <- c("setosa_1", "versicolor_2", "virginica_3") + iris_mod[, 5L] <- factor(iris_mod[, 5L]) + model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], + x = iris_mod[, -5L], + y = iris_mod[, 5L], methodList = c("glmnet", "rpart"), tuneList = list( nnet_1 = caretModelSpec( @@ -105,10 +85,9 @@ testthat::test_that("Periods are supported in method and class names in caretLis ) methods <- names(model_list) - classes <- levels(iris[, 5L]) - - p <- predict(model_list, newdata = iris[, -5L], excluded_class_id = 0L) + classes <- levels(iris_mod[, 5L]) + p <- predict(model_list, newdata = iris_mod[, -5L], excluded_class_id = 0L) class_method_combinations <- expand.grid(classes, methods) ordered_colnames <- apply(class_method_combinations, 1L, function(x) paste(x[2L], x[1L], sep = "_")) testthat::expect_named(p, ordered_colnames) @@ -116,58 +95,32 @@ testthat::test_that("Periods are supported in method and class names in caretLis model_stack <- caretStack(model_list, method = "knn", trControl = trainControl( savePredictions = "final", classProbs = TRUE )) - p_prob <- predict(model_stack, newdata = iris[, -5L]) + p_prob <- predict(model_stack, newdata = iris_mod[, -5L]) testthat::expect_named(p_prob, classes) - p_raw <- predict(model_stack, newdata = iris[, -5L]) + p_raw <- predict(model_stack, newdata = iris_mod[, -5L]) testthat::expect_named(p_raw, classes) }) testthat::test_that("We can make a confusion matrix", { - data(iris) - - set.seed(42L) - n <- nrow(iris) - train_indices <- sample.int(n, n * 0.8) - train_data <- iris[train_indices, ] - test_data <- iris[-train_indices, ] - model_list <- caretList( - x = train_data[, -5L], - y = train_data[, 5L], - methodList = c("glmnet", "rpart"), - tuneList = list( - nnet = caretModelSpec(method = "nnet", trace = FALSE) - ) - ) - model_stack <- caretStack(model_list, method = "knn") - # Make a confusion matrix - predictions <- predict(model_stack, newdata = test_data[, -5L]) - classes <- apply(predictions, 1L, function(x) names(x)[which.max(x)]) - classes <- factor(classes, levels = levels(test_data[, 5L])) - cm <- confusionMatrix(classes, test_data[, 5L]) - testthat::expect_is(cm, "confusionMatrix") + predictions <- predict(model_stack, newdata = iris[, -5L], return_class_only = TRUE) + cm <- caret::confusionMatrix(predictions, iris[, 5L]) - # Check dims + testthat::expect_s3_class(cm, "confusionMatrix") testthat::expect_identical(dim(cm$table), c(3L, 3L)) - # Accuracy should be greater than 0.9 - testthat::expect_gt(cm$overall["Accuracy"], 0.9) + testthat::expect_gt(cm$overall["Accuracy"], 0.95) # In sample accuracy should be high }) testthat::test_that("caretList and caretStack handle imbalanced multiclass data", { set.seed(123L) n <- 1000L - X <- data.table::data.table(x1 = rnorm(n), x2 = rnorm(n)) + X <- data.table::data.table(x1 = stats::rnorm(n), x2 = stats::rnorm(n)) y <- factor(c(rep("A", 700L), rep("B", 200L), rep("C", 100L))) - model_list <- caretList( - x = X, - y = y, - methodList = c("rpart", "glmnet") - ) - + model_list <- caretList(X, y, methodList = "rpart") testthat::expect_s3_class(model_list, "caretList") - testthat::expect_length(model_list, 2L) + testthat::expect_length(model_list, 1L) stack <- caretStack(model_list, method = "glmnet") testthat::expect_s3_class(stack, "caretStack") @@ -179,15 +132,10 @@ testthat::test_that("caretList and caretStack handle imbalanced multiclass data" testthat::test_that("caretList and caretStack handle a large number of classes", { set.seed(123L) n <- 1000L - X <- data.table::data.table(x1 = rnorm(n), x2 = rnorm(n)) + X <- data.table::data.table(x1 = stats::rnorm(n), x2 = stats::rnorm(n)) y <- factor(sample(paste0("Class", 1L:100L), n, replace = TRUE)) - model_list <- caretList( - x = X, - y = y, - methodList = "rpart" - ) - + model_list <- caretList(X, y, methodList = "rpart") testthat::expect_s3_class(model_list, "caretList") stack <- caretStack(model_list, method = "rpart") @@ -198,16 +146,10 @@ testthat::test_that("caretList and caretStack handle a large number of classes", }) testthat::test_that("caretList and caretStack handle ordinal multiclass data", { - data(Boston, package = "MASS") Boston$chas <- as.factor(Boston$chas) Boston$rad <- factor(paste0("rad_", Boston$rad), ordered = TRUE) - model_list <- caretList( - rad ~ ., - data = Boston, - methodList = c("rpart", "glmnet") - ) - + model_list <- caretList(rad ~ ., data = Boston, methodList = c("rpart", "glmnet")) testthat::expect_s3_class(model_list, "caretList") stack <- caretStack(model_list, method = "rpart") @@ -216,18 +158,10 @@ testthat::test_that("caretList and caretStack handle ordinal multiclass data", { preds <- predict(stack, newdata = Boston) testthat::expect_s3_class(preds, "data.table") testthat::expect_named(preds, levels(Boston$rad)) - testthat::expect_equal(rowSums(preds), rep(1.0, nrow(Boston)), tol = 0.0001) + testthat::expect_equal(rowSums(preds), rep(1.0, nrow(Boston)), tolerance = 1e-4) }) testthat::test_that("caretList and caretStack produce consistent probability predictions", { - data(iris) - - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = c("rpart", "glmnet") - ) - stack <- caretStack(model_list, method = "rpart") prob_preds <- predict(stack, newdata = iris[, -5L]) @@ -238,7 +172,7 @@ testthat::test_that("caretList and caretStack produce consistent probability pre }) testthat::test_that("caretList and caretStack handle new levels in prediction data", { - data(iris) + set.seed(123L) idx <- seq_len(nrow(iris)) idx_train <- sample(idx, 120L) idx_test <- setdiff(idx, idx_train) @@ -247,32 +181,9 @@ testthat::test_that("caretList and caretStack handle new levels in prediction da test_data$Species <- factor(as.character(test_data$Species), levels = c(levels(iris$Species), "NewSpecies")) test_data$Species[1L] <- "NewSpecies" - model_list <- caretList( - x = train_data[, -5L], - y = train_data[, 5L], - methodList = c("rf", "rpart") - ) - + model_list <- caretList(train_data[, -5L], train_data[, 5L], methodList = c("rf", "rpart")) stack <- caretStack(model_list, method = "rpart") preds <- predict(stack, newdata = test_data) - testthat::expect_true(all(levels(preds) %in% levels(train_data$Species))) -}) - -testthat::test_that("caretList and caretStack produce consistent probability predictions", { - data(iris) - - model_list <- caretList( - x = iris[, -5L], - y = iris[, 5L], - methodList = c("rpart", "glmnet") - ) - - stack <- caretStack(model_list, method = "rpart") - - prob_preds <- predict(stack, newdata = iris[, -5L]) - testthat::expect_identical(nrow(prob_preds), nrow(iris)) - testthat::expect_identical(ncol(prob_preds), nlevels(iris$Species)) - testthat::expect_true(all(rowSums(prob_preds) >= 0.99)) - testthat::expect_true(all(rowSums(prob_preds) <= 1.01)) + testthat::expect_true(all(colnames(preds) %in% levels(train_data$Species))) }) diff --git a/tests/testthat/test-permutationImportance.R b/tests/testthat/test-permutationImportance.R index d6211c23..2aa772c4 100644 --- a/tests/testthat/test-permutationImportance.R +++ b/tests/testthat/test-permutationImportance.R @@ -1,21 +1,20 @@ -data(models.class) -data(models.reg) -data(iris) +# Helper functions +utils::data(models.class) +utils::data(models.reg) +utils::data(iris) -# Helper function to create a simple dataset create_dataset <- function(n = 200L, p = 5L, classification = TRUE) { set.seed(42L) - X <- data.table::data.table(matrix(rnorm(n * p), ncol = p)) + X <- data.table::data.table(matrix(stats::rnorm(n * p), ncol = p)) data.table::setnames(X, paste0("x", seq_len(p))) if (classification) { y <- factor(ifelse(rowSums(X) > 0L, "A", "B")) } else { - y <- rowSums(X) + rnorm(n) + y <- rowSums(X) + stats::rnorm(n) } list(X = X, y = y) } -# Helper function to train a model train_model <- function(x, y, method = "rpart", ...) { set.seed(1234L) caret::train( @@ -27,7 +26,6 @@ train_model <- function(x, y, method = "rpart", ...) { ) } -# Helper function to check test results check_importance_scores <- function( imp, expected_names = paste0("x", seq_len(5L)), @@ -40,24 +38,37 @@ check_importance_scores <- function( testthat::expect_equal(sum(imp), 1L, tolerance = 1e-6) } -testthat::test_that("isClassifier works for train models", { +###################################################################### +testthat::context("isClassifier function") +###################################################################### + +testthat::test_that("isClassifier works for train models and caretStacks models", { testthat::expect_true(isClassifier(models.class[[1L]])) testthat::expect_false(isClassifier(models.reg[[1L]])) -}) -testthat::test_that("isClassifier works for caretStacks models", { - ens_class <- caretEnsemble(models.class) - ens_reg <- caretEnsemble(models.reg) + ens_class <- caretEnsemble::caretEnsemble(models.class) + ens_reg <- caretEnsemble::caretEnsemble(models.reg) testthat::expect_true(isClassifier(ens_class)) testthat::expect_false(isClassifier(ens_reg)) }) -testthat::test_that("permutationImportance works for regression", { - dt <- create_dataset(classification = FALSE) - model <- train_model(dt[["X"]], dt[["y"]]) - imp <- permutationImportance(model, dt[["X"]]) - check_importance_scores(imp) +###################################################################### +testthat::context("permutationImportance function") +###################################################################### + +testthat::test_that("permutationImportance works for regression and classification", { + # Regression + dt_reg <- create_dataset(classification = FALSE) + model_reg <- train_model(dt_reg[["X"]], dt_reg[["y"]]) + imp_reg <- permutationImportance(model_reg, dt_reg[["X"]]) + check_importance_scores(imp_reg) + + # Classification + dt_class <- create_dataset(classification = TRUE) + model_class <- train_model(dt_class[["X"]], dt_class[["y"]]) + imp_class <- permutationImportance(model_class, dt_class[["X"]]) + check_importance_scores(imp_class) }) testthat::test_that("permutationImportance works for multiclass classification", { @@ -69,13 +80,13 @@ testthat::test_that("permutationImportance works for multiclass classification", x3 = stats::rnorm(n) ) coef_matrix <- matrix(c( - 1.0, -0.5, 0.2, # coefficients for class A - -0.5, 1.0, 0.2, # coefficients for class B - 0.2, 0.2, 1.0 # coefficients for class C + 1.0, -0.5, 0.2, + -0.5, 1.0, 0.2, + 0.2, 0.2, 1.0 ), nrow = 3L, byrow = TRUE) linear_combinations <- as.matrix(x) %*% t(coef_matrix) - linear_combinations <- linear_combinations + matrix(stats::rnorm(n * 3.0, sd = 0.1), nrow = n) + linear_combinations <- linear_combinations + matrix(stats::rnorm(n * 3L, sd = 0.1), nrow = n) probabilities <- exp(linear_combinations) / rowSums(exp(linear_combinations)) y <- factor(apply(probabilities, 1L, function(prob) sample(c("A", "B", "C"), 1L, prob = prob))) @@ -84,174 +95,120 @@ testthat::test_that("permutationImportance works for multiclass classification", check_importance_scores(imp, c("x1", "x2", "x3")) }) -testthat::test_that("permutationImportance works with a single feature unimportant feature", { +testthat::test_that("permutationImportance works with single feature cases", { n <- 100L - x <- data.table::data.table(x1 = stats::rnorm(n)) - y <- factor(sample(c("A", "B"), n, replace = TRUE)) - model <- train_model(x, y) - imp <- permutationImportance(model, x) - check_importance_scores(imp, "x1") -}) - -testthat::test_that("permutationImportance works with a single important feature", { - set.seed(1234L) - - make_var <- function(n) scale(stats::rnorm(n), center = TRUE, scale = TRUE)[, 1L] - - n <- 1000L - x <- data.table::data.table( - x1 = make_var(n), - x2 = make_var(n), - x3 = make_var(n) - ) - - cf_set <- c(0L, 1L, 5L, 10L) - all_cfs <- expand.grid( - c(0L, 1L), - cf_set, - cf_set, - cf_set - ) - - evaluate_model <- function(cf, do_class) { - cf <- unname(unlist(cf)) - y <- (cbind(1L, as.matrix(x)) %*% cf)[, 1L] - if (do_class) { - classes <- c("A", "B") - y <- factor(ifelse(y > 0L, classes[1L], classes[2L]), levels = classes) - if (length(unique(y)) == 1L) { - return(NULL) - } - } - model <- suppressWarnings(train_model(x, y, method = "glm")) - imp <- permutationImportance(model, x) - check_importance_scores(imp, c("x1", "x2", "x3")) - - glm_imp <- normalize_to_one(abs(coef(model$finalModel))[-1L]) - cf_norm <- normalize_to_one(cf[-1L]) - testthat::expect_equivalent(glm_imp, cf_norm, tolerance = 0.1) - if (!do_class || cf[[1L]] == 0.0) { - testthat::expect_equivalent(imp, cf_norm, tolerance = 0.1) - } - } - - for (do_class in c(FALSE, TRUE)) { - apply(all_cfs, 1L, evaluate_model, do_class) - } + # Unimportant feature + x_unimp <- data.table::data.table(x1 = stats::rnorm(n)) + y_unimp <- factor(sample(c("A", "B"), n, replace = TRUE)) + model_unimp <- train_model(x_unimp, y_unimp) + imp_unimp <- permutationImportance(model_unimp, x_unimp) + check_importance_scores(imp_unimp, "x1") + + # Important feature + x_imp <- data.table::data.table(x1 = stats::rnorm(n)) + y_imp <- x_imp$x1 + stats::rnorm(n, sd = 0.1) + model_imp <- train_model(x_imp, y_imp, method = "lm") + imp_imp <- permutationImportance(model_imp, x_imp) + check_importance_scores(imp_imp, "x1") + testthat::expect_gt(imp_imp["x1"], 0.9) }) -testthat::test_that("permutationImportance works a single, contant, unimportant feature", { +testthat::test_that("permutationImportance works with constant features", { n <- 100L - x <- data.table::data.table( - x1 = rep(1L, n), - x2 = stats::rnorm(n) - ) - y <- stats::rnorm(n) - model <- train_model(x, y) - imp <- permutationImportance(model, x) - check_importance_scores(imp, c("x1", "x2")) - testthat::expect_lte(imp["x1"], imp["x2"]) -}) -testthat::test_that("permutationImportance works a single, contant, important feature - aka intercept only", { - n <- 100L - x <- data.table::data.table( - x1 = rep(1L, n), - x2 = stats::rnorm(n) - ) - y <- x$x1 + stats::rnorm(n) / 10L - model <- train_model(x, y) - imp <- permutationImportance(model, x) - check_importance_scores(imp, c("x1", "x2")) - testthat::expect_lte(imp["x1"], imp["x2"]) + # Constant unimportant feature + x_const_unimp <- data.table::data.table(x1 = rep(1L, n), x2 = stats::rnorm(n)) + y_const_unimp <- stats::rnorm(n) + model_const_unimp <- train_model(x_const_unimp, y_const_unimp) + imp_const_unimp <- permutationImportance(model_const_unimp, x_const_unimp) + check_importance_scores(imp_const_unimp, c("x1", "x2")) + testthat::expect_lte(imp_const_unimp["x1"], imp_const_unimp["x2"]) + + # Constant important feature (intercept only) + x_const_imp <- data.table::data.table(x1 = rep(1L, n), x2 = stats::rnorm(n)) + y_const_imp <- x_const_imp$x1 + stats::rnorm(n, sd = 0.1) + model_const_imp <- train_model(x_const_imp, y_const_imp) + imp_const_imp <- permutationImportance(model_const_imp, x_const_imp) + check_importance_scores(imp_const_imp, c("x1", "x2")) + testthat::expect_lte(imp_const_imp["x2"], imp_const_imp["x1"]) }) testthat::test_that("permutationImportance works with perfect predictor", { n <- 100L - x <- data.table::data.table( - x1 = stats::rnorm(n), - x2 = stats::rnorm(n) - ) + x <- data.table::data.table(x1 = stats::rnorm(n), x2 = stats::rnorm(n)) y <- x$x1 model <- train_model(x, y, method = "lm") imp <- permutationImportance(model, x) check_importance_scores(imp, c("x1", "x2")) - testthat::expect_gt(imp["x1"], imp["x2"]) + testthat::expect_gt(imp["x1"], 0.9) + testthat::expect_lt(imp["x2"], 0.1) }) -testthat::test_that("permutationImportance works for multiclass classification and various edge cases", { +testthat::test_that("permutationImportance works for multiclass classification with iris dataset", { model <- train_model(iris[, -5L], iris$Species, method = "rpart") imp <- permutationImportance(model, iris[, -5L]) check_importance_scores(imp, names(iris[, -5L])) }) +###################################################################### testthat::context("permutationImportance edge cases") -testthat::test_that("permutationImportance normalizes to uniform distribution for all zero importances", { - n <- 100L - x <- data.table::data.table(x1 = rep(0L, n), x2 = rep(0L, n), x3 = rep(0L, n)) - y <- rep(0L, n) - model <- train_model(x, y, method = "lm") - imp <- permutationImportance(model, x) - check_importance_scores(imp, names(x)) - testthat::expect_equivalent(imp, normalize_to_one(rep(0L, length(imp))), tolerance = 1e-6) -}) +###################################################################### -testthat::test_that("permutationImportance assigns full importance to perfect predictor", { - set.seed(1234L) +testthat::test_that("permutationImportance handles various edge cases", { n <- 100L vars <- 25L - x <- data.table::data.table( - matrix(rnorm(n * vars), nrow = n, ncol = vars) - ) - data.table::setnames(x, paste0("x", seq_len(vars))) - y <- x$x1 - model <- train_model(x, y, method = "lm") - imp <- permutationImportance(model, x) - check_importance_scores(imp, names(x)) - testthat::expect_equal(imp[["x1"]], 1L, tol = 1e-8) - testthat::expect_equal(sum(imp[-1L]), 0L, tol = 1e-8) -}) -testthat::test_that("permutationImportance handles highly collinear features", { - set.seed(5678L) - n <- 100L - x <- data.table::data.table( - x1 = rnorm(n), - x2 = rnorm(n) + # All zero importances + x_zero <- data.table::data.table(matrix(0L, nrow = n, ncol = 3L)) + y_zero <- rep(0L, n) + model_zero <- train_model(x_zero, y_zero, method = "lm") + imp_zero <- permutationImportance(model_zero, x_zero) + check_importance_scores(imp_zero, names(x_zero)) + testthat::expect_equivalent(imp_zero, normalize_to_one(rep(0L, length(imp_zero))), tolerance = 1e-6) + + # Perfect predictor among many variables + x_perfect <- data.table::data.table(matrix(stats::rnorm(n * vars), nrow = n, ncol = vars)) + data.table::setnames(x_perfect, paste0("x", seq_len(vars))) + y_perfect <- x_perfect$x1 + model_perfect <- train_model(x_perfect, y_perfect, method = "lm") + imp_perfect <- permutationImportance(model_perfect, x_perfect) + check_importance_scores(imp_perfect, names(x_perfect)) + testthat::expect_equal(imp_perfect[["x1"]], 1L, tol = 1e-8) + testthat::expect_equal(sum(imp_perfect[-1L]), 0L, tol = 1e-8) + + # Highly collinear features + x_collinear <- data.table::data.table( + x1 = stats::rnorm(n), + x2 = stats::rnorm(n) ) - x$x3 <- x$x1 + rnorm(n, sd = 0.01) - y <- x$x1 + x$x2 - model <- train_model(x, y, method = "lm") - imp <- permutationImportance(model, x) - check_importance_scores(imp, names(x)) - testthat::expect_equal(imp[["x3"]], 0L, tol = 1e-6) -}) - -testthat::test_that("permutationImportance works with very small dataset", { - set.seed(9876L) - n <- 5L - x <- data.table::data.table( - x1 = rnorm(n), - x2 = rnorm(n), - x3 = rnorm(n) + x_collinear$x3 <- x_collinear$x1 + stats::rnorm(n, sd = 0.01) + y_collinear <- x_collinear$x1 + x_collinear$x2 + model_collinear <- train_model(x_collinear, y_collinear, method = "lm") + imp_collinear <- permutationImportance(model_collinear, x_collinear) + check_importance_scores(imp_collinear, names(x_collinear)) + testthat::expect_lt(imp_collinear[["x3"]], 0.1) + + # Very small dataset + x_small <- data.table::data.table( + x1 = stats::rnorm(5L), + x2 = stats::rnorm(5L), + x3 = stats::rnorm(5L) ) - y <- x$x1 + rnorm(n) - model <- train_model(x, y, method = "lm") - imp <- permutationImportance(model, x) - check_importance_scores(imp, names(x)) -}) + y_small <- x_small$x1 + stats::rnorm(5L) + model_small <- train_model(x_small, y_small, method = "lm") + imp_small <- permutationImportance(model_small, x_small) + check_importance_scores(imp_small, names(x_small)) -testthat::test_that("permutationImportance handles identical features", { - n <- 100L - x <- data.table::data.table( - x1 = rnorm(n), - x2 = rnorm(n) + # Identical features + x_identical <- data.table::data.table( + x1 = stats::rnorm(n), + x2 = stats::rnorm(n) ) - x$x3 <- x$x1 - y <- x$x1 + x$x2 + rnorm(n) - model <- train_model(x, y, method = "glmnet") - imp <- permutationImportance(model, x) - check_importance_scores(imp, names(x)) - testthat::expect_equal(imp[["x1"]], imp[["x3"]], tol = 1e-1) + x_identical$x3 <- x_identical$x1 + y_identical <- x_identical$x1 + x_identical$x2 + stats::rnorm(n) + model_identical <- train_model(x_identical, y_identical, method = "glmnet") + imp_identical <- permutationImportance(model_identical, x_identical) + check_importance_scores(imp_identical, names(x_identical)) + testthat::expect_equal(imp_identical[["x1"]], imp_identical[["x3"]], tol = 1e-1) })