Skip to content

Commit

Permalink
[324] Fix rev dep (#325)
Browse files Browse the repository at this point in the history
* handle array/matrix case

* usethis

* nbrbrbr

* just a regular test

* ok the test script works

* script works

* rebuild

* rebuild

* add test case

* shrink it

* add failing test

* fix failing test

* add seed to readme lol

* re add readme workflow, with a seed it should pass now

* ok different platforms = different rf results.
  • Loading branch information
zachmayer authored Aug 14, 2024
1 parent f8df26b commit eaee3bb
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 18 deletions.
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@ Suggests:
caTools,
covr,
devtools,
earth,
gbm,
glmnet,
klaR,
knitr,
lintr,
mgcv,
mlbench,
nnet,
pkgdown,
randomForest,
rmarkdown,
rhub,
rpart,
spelling,
testthat,
usethis,
pkgdown,
rhub
usethis
Imports:
caret,
data.table,
Expand Down
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ help:
@echo " readme Build readme"
@echo " check-win Run R CMD on the winbuilder service from CRAN"
@echo " check-rhub Run R CMD on the rhub service"
@echo " check-many-preds Check that caretList can predict on ~200 caret models"
@echo " release Release to CRAN"
@echo " preview-site Preview pkgdown site"
@echo " dev-guide Open the R package development guide"
Expand Down Expand Up @@ -135,6 +136,10 @@ preview-site:
Rscript -e "pkgdown::build_site()"
open docs/index.html

.PHONY: check-many-preds
check-many-preds:
Rscript inst/data-raw/test-all_models.R

.PHONY: check-win
check-win:
rm -rf lib/
Expand All @@ -146,7 +151,7 @@ check-rhub:
Rscript -e "rhub::rhub_check(platform='linux')"

.PHONY: release
release: check-rhub check-win
release: check-many-preds check-rhub check-win
R --no-save --quiet --interactive
devtools::release()

Expand All @@ -158,6 +163,7 @@ dev-guide:
clean:
rm -rf *.Rcheck
rm -f *.tar.gz
rm -f *.Rout
rm -rf man/
rm -f README.md
rm -f coverage.rds
Expand Down
12 changes: 10 additions & 2 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {

# Otherwise, predict on newdata
} else {
if (any(object$modelInfo$library %in% c("neuralnet", "klaR"))) {
newdata <- as.matrix(newdata) # I hate some of these packages
}
if (is_class) {
pred <- caret::predict.train(object, type = "prob", newdata = newdata, ...)
pred <- stats::predict(object, type = "prob", newdata = newdata, ...)
stopifnot(is.data.frame(pred))
} else {
pred <- caret::predict.train(object, type = "raw", newdata = newdata, ...)
pred <- stats::predict(object, type = "raw", newdata = newdata, ...)
stopifnot(is.numeric(pred))
if (!is.vector(pred)) {
pred <- as.vector(pred) # Backwards compatability with older earth and caret::train models
}
stopifnot(
is.vector(pred),
is.numeric(pred),
Expand Down
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ stack them with another caret model.
First, use caretList to fit many models to the same data:

``` r
set.seed(42L)
data(diamonds, package = "ggplot2")
dat <- data.table::data.table(diamonds)
dat <- dat[sample.int(nrow(diamonds), 500L), ]
Expand All @@ -48,8 +49,8 @@ print(summary(models))
#> Model accuracy:
#> model_name metric value sd
#> <char> <char> <num> <num>
#> 1: rf RMSE 1110.199 114.5286
#> 2: glmnet RMSE 1256.668 100.4436
#> 1: rf RMSE 1076.492 215.4737
#> 2: glmnet RMSE 1142.082 105.6022
```

Then, use caretEnsemble to make a greedy ensemble of these models
Expand All @@ -68,17 +69,17 @@ print(greedy_stack)
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1015.885 0.9364999 572.8984
#> 969.2517 0.9406218 557.1987
#>
#> Tuning parameter 'max_iter' was held constant at a value of 100
#>
#> Final model:
#> Greedy MSE
#> RMSE: 1031.297
#> RMSE: 989.2085
#> Weights:
#> [,1]
#> rf 0.63
#> glmnet 0.37
#> rf 0.55
#> glmnet 0.45
```

You can also use caretStack to make a non-linear ensemble
Expand All @@ -97,8 +98,8 @@ print(rf_stack)
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1020.387 0.9363342 527.8525
#> RMSE Rsquared MAE
#> 1081.425 0.930012 540.3294
#>
#> Tuning parameter 'mtry' was held constant at a value of 2
#>
Expand All @@ -110,8 +111,8 @@ print(rf_stack)
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 1065781
#> % Var explained: 93.33
#> Mean of squared residuals: 925377
#> % Var explained: 93.95
```

Use autoplot from ggplot2 to plot ensemble diagnostics:
Expand Down
1 change: 1 addition & 0 deletions README.rmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Use `caretList` to fit multiple models, and then use `caretStack` to stack them

First, use caretList to fit many models to the same data:
```{r}
set.seed(42L)
data(diamonds, package = "ggplot2")
dat <- data.table::data.table(diamonds)
dat <- dat[sample.int(nrow(diamonds), 500L), ]
Expand Down
4 changes: 2 additions & 2 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ rpart
savePredictions
scalability
scikit
setosa
SDs
setosa
trainControl
travis
tuneGrid
Expand All @@ -71,4 +71,4 @@ varImp
vecstack
verions
xvars
yhat
yhat
115 changes: 115 additions & 0 deletions inst/data-raw/build_backwards_compatability_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# This script is a little big sorry.
# We're using a 3rd party dataset from a package
# It depends on caretEnsemble, and I broke it with the 4.0.0 pre-release
# This script isolates the bad, saved model in that package
# and then removes all the parts of that model that aren't needed to make predictions
# this gives us a minimal test case for the backwards compatability issue
# This script shouldn't ever need to get run again, just use the old saved
# caretlist_with_bad_earth_model.rds file in data in testthat in the tests folder.
# This script is for posterity.

# Note this is not in our depends or suggests. Also note new version may not have the bug.
devtools::install_version("LDLcalc", version = "2.1", repos = "http://cran.us.r-project.org")
devtools::load_all()

# Load the data and fit the model
data(SampleData, package = "LDLcalc")
ldl_model <- LDLcalc:::LDL_ML_train_StackingAlgorithm(SampleData) # nolint undesirable_operator_linter
testthat::expect_s3_class(ldl_model$stackModel, "caretStack")
testthat::expect_s3_class(ldl_model$stackModel$models, "caretList")

# Make a caretList with just the bad model
caretlist_with_old_earth_model <- ldl_model$stackModel$models["earth"]

# Function to test the error and warnings after removing a specific part
test_error <- function(obj, path, SampleData) {
modified_obj <- obj
eval(parse(text = paste0("modified_obj", path, " <- NULL")))

wrns <- NULL

# Capture both errors and warnings
result <- tryCatch(
{
withCallingHandlers(
{
predict(modified_obj, SampleData)
},
warning = function(w) {
wrns <<- c(wrns, conditionMessage(w)) # nolint undesirable_operator_linter
invokeRestart("muffleWarning")
}
)
list(error = NULL, wrns = wrns)
},
error = function(e) {
list(error = e$message, wrns = wrns)
}
)

result
}

# Function to iteratively prune the object
prune_list_iterative <- function(obj, SampleData) { # nolint cyclocomp_linter
the_stack <- list(list(obj = obj, path = ""))
pruned_obj <- obj

while (length(the_stack) > 0L) {
# Pop the last element from the the_stack
current <- the_stack[[length(the_stack)]]
the_stack <- the_stack[-length(the_stack)]

if (is.list(current$obj)) {
keys <- names(current$obj)
for (key in keys) {
current_path <- paste0(current$path, "$", key)

# Test by removing the current element
result <- test_error(pruned_obj, current_path, SampleData)

# Determine if we should keep or remove the element
if ((!is.null(result$error) && result$error != "is.vector(pred) is not TRUE") || !is.null(result$wrns)) {
# If error changes, goes away, or a warning appears, keep the element
the_stack <- c(the_stack, list(list(obj = current$obj[[key]], path = current_path)))
} else {
# If error remains the same and no wrns, remove the element
eval(parse(text = paste0("pruned_obj", current_path, " <- NULL")))
}
}
}
}
pruned_obj
}

# Start the pruning process
pruned_caretlist <- prune_list_iterative(caretlist_with_old_earth_model, SampleData)

# Prune attributes
attr(pruned_caretlist$earth$terms, ".Environment") <- NULL
attr(pruned_caretlist$earth$terms, "dimnames") <- NULL
attr(pruned_caretlist$earth$terms, "term.labels") <- NULL
attr(pruned_caretlist$earth$terms, "order") <- NULL
attr(pruned_caretlist$earth$terms, "intercept") <- NULL
attr(pruned_caretlist$earth$terms, "response") <- NULL
attr(pruned_caretlist$earth$terms, "predvars") <- NULL
attr(pruned_caretlist$earth$terms, "dataClasses") <- NULL

# Test the final pruned object
# Note that once the bug is fixed, this will no longer fail
# this requires version 4.0.0 of caretEnsemble, prior to the PR fixing the prediciton issue
# https://github.com/zachmayer/caretEnsemble/issues/324
testthat::expect_error(
predict(pruned_caretlist, SampleData),
"is.vector(pred) is not TRUE",
fixed = TRUE
)

# Save
saveRDS(
pruned_caretlist,
file.path("tests", "testthat", "data", "caretlist_with_bad_earth_model.rds"),
ascii = FALSE,
version = 3L,
compress = "xz"
)
106 changes: 106 additions & 0 deletions inst/data-raw/test-all_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# This test takes a few minutes and needs to install and load a lot of packages
# I don't want to make it a dependency for the package or even for PR tests
# But I do want to run this every release to make sure that the models
# we can run predict correctly.

devtools::load_all()

very_quiet <- function(expr) {
testthat::expect_output(suppressWarnings(suppressMessages(expr)))
}

#################################################################
# Setup data
#################################################################
set.seed(42L)
nrows <- 10L
ncols <- 2L

X <- matrix(stats::rnorm(nrows * ncols), ncol = ncols)
colnames(X) <- paste0("X", 1L:ncols)

y <- X[, 1L] + X[, 2L] + stats::rnorm(nrows) / 10.0
y_bin <- factor(ifelse(y > median(y), "yes", "no"))

all_models <- data.table::data.table(caret::modelLookup())
all_models <- unique(all_models[, c("model", "forReg", "probModel")])

java_models <- c(
"gbm_h2o",
"glmnet_h2o",
"bartMachine",
"M5",
"M5Rules",
"J48",
"JRip",
"LMT",
"PART",
"OneR",
"evtree"
)

#################################################################
# Reg
#################################################################

# From https://github.com/zachmayer/caretEnsemble/issues/324
# Problem models:
# bam - array
# blackboost - matrix, array
# dnn - matrix, array
# earth - matrix, array
# gam - array
# gamboost - matrix, array
# glmboost - matrix, array
# pcaNNet - matrix, array
# rvmLinear - matrix, array
# rvmRadial - matrix, array
# spls - matrix, array
# xyf - matrix, array
reg_models <- sort(unique(all_models[which(forReg), ][["model"]]))
reg_models <- setdiff(reg_models, c( # Can't install or too slow
"elm", "extraTrees", "foba", "logicBag", "mlpSGD", "mxnet",
"mxnetAdam", "nodeHarvest", "relaxo",
java_models
))

#################################################################
# Class
#################################################################

# Problem models: None!
bin_models <- sort(unique(all_models[which(probModel), ][["model"]]))
bin_models <- setdiff(bin_models, c( # Can't install or too slow
"gaussprLinear", "adaboost", "amdai", "chaid", "extraTrees",
"gpls", "logicBag", "mlpSGD", "mxnet", "mxnetAdam", "nodeHarvest",
"ORFlog", "ORFpls", "ORFridge", "ORFsvm", "rrlda", "vbmpRadial",
java_models
))

#################################################################
# Tests
#################################################################

testthat::test_that("Most caret models can predict", {
# Fit the big caret lists
models_reg <- very_quiet(caretList(X, y, methodList = reg_models, tuneLength = 1L, continue_on_fail = TRUE))
models_bin <- very_quiet(caretList(X, y_bin, methodList = bin_models, tuneLength = 1L, continue_on_fail = TRUE))
all_models <- c(models_reg, models_bin)
testthat::expect_gt(length(all_models), 200L) # About 100 each of class/reg

# Make sure we can predict
pred <- very_quiet(predict(all_models, head(X, 5L)))
testthat::expect_identical(nrow(pred), 5L)
testthat::expect_identical(ncol(pred), length(all_models))
testthat::expect_true(all(unlist(lapply(pred, is.finite))))

# Make sure we can stacked predict
# Some of these stupid models predict Infs lol, so whatever.
# I guess beware of what models you ensemble.
# The bagEarth models are bad, as is rvmPoly and some others.
# These are stacked preds btw, so probably it indicates a fit failure
# on one fold. Many ensemble models can handle Nans, but we'll see.
pred_stack <- suppressWarnings(suppressMessages(predict(all_models)))
testthat::expect_identical(nrow(pred_stack), nrow(X))
testthat::expect_identical(ncol(pred_stack), length(all_models))
})
Binary file modified man/figures/README-greedy-stack-6-plot-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-5-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading

0 comments on commit eaee3bb

Please sign in to comment.