-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* handle array/matrix case * usethis * nbrbrbr * just a regular test * ok the test script works * script works * rebuild * rebuild * add test case * shrink it * add failing test * fix failing test * add seed to readme lol * re add readme workflow, with a seed it should pass now * ok different platforms = different rf results.
- Loading branch information
Showing
13 changed files
with
316 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# This script is a little big sorry. | ||
# We're using a 3rd party dataset from a package | ||
# It depends on caretEnsemble, and I broke it with the 4.0.0 pre-release | ||
# This script isolates the bad, saved model in that package | ||
# and then removes all the parts of that model that aren't needed to make predictions | ||
# this gives us a minimal test case for the backwards compatability issue | ||
# This script shouldn't ever need to get run again, just use the old saved | ||
# caretlist_with_bad_earth_model.rds file in data in testthat in the tests folder. | ||
# This script is for posterity. | ||
|
||
# Note this is not in our depends or suggests. Also note new version may not have the bug. | ||
devtools::install_version("LDLcalc", version = "2.1", repos = "http://cran.us.r-project.org") | ||
devtools::load_all() | ||
|
||
# Load the data and fit the model | ||
data(SampleData, package = "LDLcalc") | ||
ldl_model <- LDLcalc:::LDL_ML_train_StackingAlgorithm(SampleData) # nolint undesirable_operator_linter | ||
testthat::expect_s3_class(ldl_model$stackModel, "caretStack") | ||
testthat::expect_s3_class(ldl_model$stackModel$models, "caretList") | ||
|
||
# Make a caretList with just the bad model | ||
caretlist_with_old_earth_model <- ldl_model$stackModel$models["earth"] | ||
|
||
# Function to test the error and warnings after removing a specific part | ||
test_error <- function(obj, path, SampleData) { | ||
modified_obj <- obj | ||
eval(parse(text = paste0("modified_obj", path, " <- NULL"))) | ||
|
||
wrns <- NULL | ||
|
||
# Capture both errors and warnings | ||
result <- tryCatch( | ||
{ | ||
withCallingHandlers( | ||
{ | ||
predict(modified_obj, SampleData) | ||
}, | ||
warning = function(w) { | ||
wrns <<- c(wrns, conditionMessage(w)) # nolint undesirable_operator_linter | ||
invokeRestart("muffleWarning") | ||
} | ||
) | ||
list(error = NULL, wrns = wrns) | ||
}, | ||
error = function(e) { | ||
list(error = e$message, wrns = wrns) | ||
} | ||
) | ||
|
||
result | ||
} | ||
|
||
# Function to iteratively prune the object | ||
prune_list_iterative <- function(obj, SampleData) { # nolint cyclocomp_linter | ||
the_stack <- list(list(obj = obj, path = "")) | ||
pruned_obj <- obj | ||
|
||
while (length(the_stack) > 0L) { | ||
# Pop the last element from the the_stack | ||
current <- the_stack[[length(the_stack)]] | ||
the_stack <- the_stack[-length(the_stack)] | ||
|
||
if (is.list(current$obj)) { | ||
keys <- names(current$obj) | ||
for (key in keys) { | ||
current_path <- paste0(current$path, "$", key) | ||
|
||
# Test by removing the current element | ||
result <- test_error(pruned_obj, current_path, SampleData) | ||
|
||
# Determine if we should keep or remove the element | ||
if ((!is.null(result$error) && result$error != "is.vector(pred) is not TRUE") || !is.null(result$wrns)) { | ||
# If error changes, goes away, or a warning appears, keep the element | ||
the_stack <- c(the_stack, list(list(obj = current$obj[[key]], path = current_path))) | ||
} else { | ||
# If error remains the same and no wrns, remove the element | ||
eval(parse(text = paste0("pruned_obj", current_path, " <- NULL"))) | ||
} | ||
} | ||
} | ||
} | ||
pruned_obj | ||
} | ||
|
||
# Start the pruning process | ||
pruned_caretlist <- prune_list_iterative(caretlist_with_old_earth_model, SampleData) | ||
|
||
# Prune attributes | ||
attr(pruned_caretlist$earth$terms, ".Environment") <- NULL | ||
attr(pruned_caretlist$earth$terms, "dimnames") <- NULL | ||
attr(pruned_caretlist$earth$terms, "term.labels") <- NULL | ||
attr(pruned_caretlist$earth$terms, "order") <- NULL | ||
attr(pruned_caretlist$earth$terms, "intercept") <- NULL | ||
attr(pruned_caretlist$earth$terms, "response") <- NULL | ||
attr(pruned_caretlist$earth$terms, "predvars") <- NULL | ||
attr(pruned_caretlist$earth$terms, "dataClasses") <- NULL | ||
|
||
# Test the final pruned object | ||
# Note that once the bug is fixed, this will no longer fail | ||
# this requires version 4.0.0 of caretEnsemble, prior to the PR fixing the prediciton issue | ||
# https://github.com/zachmayer/caretEnsemble/issues/324 | ||
testthat::expect_error( | ||
predict(pruned_caretlist, SampleData), | ||
"is.vector(pred) is not TRUE", | ||
fixed = TRUE | ||
) | ||
|
||
# Save | ||
saveRDS( | ||
pruned_caretlist, | ||
file.path("tests", "testthat", "data", "caretlist_with_bad_earth_model.rds"), | ||
ascii = FALSE, | ||
version = 3L, | ||
compress = "xz" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# This test takes a few minutes and needs to install and load a lot of packages | ||
# I don't want to make it a dependency for the package or even for PR tests | ||
# But I do want to run this every release to make sure that the models | ||
# we can run predict correctly. | ||
|
||
devtools::load_all() | ||
|
||
very_quiet <- function(expr) { | ||
testthat::expect_output(suppressWarnings(suppressMessages(expr))) | ||
} | ||
|
||
################################################################# | ||
# Setup data | ||
################################################################# | ||
set.seed(42L) | ||
nrows <- 10L | ||
ncols <- 2L | ||
|
||
X <- matrix(stats::rnorm(nrows * ncols), ncol = ncols) | ||
colnames(X) <- paste0("X", 1L:ncols) | ||
|
||
y <- X[, 1L] + X[, 2L] + stats::rnorm(nrows) / 10.0 | ||
y_bin <- factor(ifelse(y > median(y), "yes", "no")) | ||
|
||
all_models <- data.table::data.table(caret::modelLookup()) | ||
all_models <- unique(all_models[, c("model", "forReg", "probModel")]) | ||
|
||
java_models <- c( | ||
"gbm_h2o", | ||
"glmnet_h2o", | ||
"bartMachine", | ||
"M5", | ||
"M5Rules", | ||
"J48", | ||
"JRip", | ||
"LMT", | ||
"PART", | ||
"OneR", | ||
"evtree" | ||
) | ||
|
||
################################################################# | ||
# Reg | ||
################################################################# | ||
|
||
# From https://github.com/zachmayer/caretEnsemble/issues/324 | ||
# Problem models: | ||
# bam - array | ||
# blackboost - matrix, array | ||
# dnn - matrix, array | ||
# earth - matrix, array | ||
# gam - array | ||
# gamboost - matrix, array | ||
# glmboost - matrix, array | ||
# pcaNNet - matrix, array | ||
# rvmLinear - matrix, array | ||
# rvmRadial - matrix, array | ||
# spls - matrix, array | ||
# xyf - matrix, array | ||
reg_models <- sort(unique(all_models[which(forReg), ][["model"]])) | ||
reg_models <- setdiff(reg_models, c( # Can't install or too slow | ||
"elm", "extraTrees", "foba", "logicBag", "mlpSGD", "mxnet", | ||
"mxnetAdam", "nodeHarvest", "relaxo", | ||
java_models | ||
)) | ||
|
||
################################################################# | ||
# Class | ||
################################################################# | ||
|
||
# Problem models: None! | ||
bin_models <- sort(unique(all_models[which(probModel), ][["model"]])) | ||
bin_models <- setdiff(bin_models, c( # Can't install or too slow | ||
"gaussprLinear", "adaboost", "amdai", "chaid", "extraTrees", | ||
"gpls", "logicBag", "mlpSGD", "mxnet", "mxnetAdam", "nodeHarvest", | ||
"ORFlog", "ORFpls", "ORFridge", "ORFsvm", "rrlda", "vbmpRadial", | ||
java_models | ||
)) | ||
|
||
################################################################# | ||
# Tests | ||
################################################################# | ||
|
||
testthat::test_that("Most caret models can predict", { | ||
# Fit the big caret lists | ||
models_reg <- very_quiet(caretList(X, y, methodList = reg_models, tuneLength = 1L, continue_on_fail = TRUE)) | ||
models_bin <- very_quiet(caretList(X, y_bin, methodList = bin_models, tuneLength = 1L, continue_on_fail = TRUE)) | ||
all_models <- c(models_reg, models_bin) | ||
testthat::expect_gt(length(all_models), 200L) # About 100 each of class/reg | ||
|
||
# Make sure we can predict | ||
pred <- very_quiet(predict(all_models, head(X, 5L))) | ||
testthat::expect_identical(nrow(pred), 5L) | ||
testthat::expect_identical(ncol(pred), length(all_models)) | ||
testthat::expect_true(all(unlist(lapply(pred, is.finite)))) | ||
|
||
# Make sure we can stacked predict | ||
# Some of these stupid models predict Infs lol, so whatever. | ||
# I guess beware of what models you ensemble. | ||
# The bagEarth models are bad, as is rvmPoly and some others. | ||
# These are stacked preds btw, so probably it indicates a fit failure | ||
# on one fold. Many ensemble models can handle Nans, but we'll see. | ||
pred_stack <- suppressWarnings(suppressMessages(predict(all_models))) | ||
testthat::expect_identical(nrow(pred_stack), nrow(X)) | ||
testthat::expect_identical(ncol(pred_stack), length(all_models)) | ||
}) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Oops, something went wrong.