Skip to content

Commit

Permalink
Custom models test (#310)
Browse files Browse the repository at this point in the history
* start on custom test:

* refactor tuneList later

* rebuild
  • Loading branch information
zachmayer authored Aug 10, 2024
1 parent 81a6d8e commit e9664fe
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
1 change: 1 addition & 0 deletions R/greedyOpt.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ predict.greedyMSE <- function(object, newdata, return_labels = FALSE, ...) {
greedyMSE_caret <- function() {
list(
label = "Greedy Mean Squared Error Optimizer",
method = "greedyMSE",
library = NULL,
loop = NULL,
type = c("Regression", "Classification"),
Expand Down
Binary file modified coverage.rds
Binary file not shown.
41 changes: 41 additions & 0 deletions tests/testthat/test-caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,44 @@ testthat::test_that("caretList supports combined regression, binary, multiclass"
testthat::expect_identical(nrow(stacked_p), nrow(iris))
testthat::expect_identical(nrow(new_p), 10L)
})

testthat::test_that("caretList supports custom models", {
set.seed(42L)

# Use the custom greedyMSE model
custom_list <- list(
custom.mse = caretModelSpec(method = greedyMSE_caret(), tuneLength = 1L)
)

# Fit it reg/bin/multi (it supports all 3!)
reg_models <- caretList(Sepal.Length ~ Sepal.Width, iris, tuneList = custom_list)
bin_models <- caretList(factor(ifelse(Species == "setosa", "Y", "N")) ~ Sepal.Width, iris, tuneList = custom_list)
multi_models <- caretList(Species ~ Sepal.Width, iris, tuneList = custom_list)

# Check the fit
all_models <- c(reg_models, bin_models, multi_models)
testthat::expect_s3_class(all_models, "caretList")

# Check predictions
stacked_p <- predict(all_models)
new_p <- predict(all_models, newdata = iris[1L:10L, ])
testthat::expect_is(stacked_p, "data.table")
testthat::expect_is(new_p, "data.table")
testthat::expect_identical(nrow(stacked_p), nrow(iris))
testthat::expect_identical(nrow(new_p), 10L)

# Check we can stack it
# Note that caretStack with method=greedyMSE_caret()
# is what caretEnsemble does under the hood
ens <- caretStack(
all_models,
method = greedyMSE_caret(),
trControl = trainControl(method = "cv", number = 2L, savePredictions = "final")
)
stacked_p <- predict(ens)
new_p <- predict(ens, newdata = iris[1L:10L, ])
testthat::expect_is(stacked_p, "data.table")
testthat::expect_is(new_p, "data.table")
testthat::expect_identical(nrow(stacked_p), nrow(iris))
testthat::expect_identical(nrow(new_p), 10L)
})

0 comments on commit e9664fe

Please sign in to comment.