diff --git a/NEWS.md b/NEWS.md index 8d11f80..367d4d3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,11 @@ # mlr3batchmark (development version) -* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`, useful for deleting model data to avoid running out of memory, https://github.com/mlr-org/mlr3batchmark/issues/18 Thanks to Toby Dylan Hocking @tdhock for the PR. -* docs: A warning is now given when the loaded mlr3 version differs from the -mlr3 version stored in the trained learners -* Support marshaling +* feat: The design of `batchmark()` can now include parameter settings. +* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`. +Useful for deleting model data to avoid running out of memory. +Thanks to Toby Dylan Hocking @tdhock for the PR (https://github.com/mlr-org/mlr3batchmark/issues/18). +* docs: A warning is now given when the loaded mlr3 version differs from the mlr3 version stored in the trained learners +* feat: support marshaling # mlr3batchmark 0.1.1 diff --git a/R/assertions.R b/R/assertions.R new file mode 100644 index 0000000..439bc7a --- /dev/null +++ b/R/assertions.R @@ -0,0 +1,12 @@ +assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) { + assert_list(x, len = n_learners, .var.name = .var.name) + + ok = every(x, function(x) { + test_list(x) && every(x, test_list, names = "unique", null.ok = TRUE) + }) + + if (!ok) { + stopf("'%s' must be a three-time nested list and the most inner list must be named", .var.name) + } + invisible(x) +} diff --git a/R/batchmark.R b/R/batchmark.R index 0062ddf..c86f5fe 100644 --- a/R/batchmark.R +++ b/R/batchmark.R @@ -35,7 +35,7 @@ #' reduceResultsBatchmark(reg = reg) batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultRegistry()) { design = as.data.table(assert_data_frame(design, min.rows = 1L)) - assert_names(names(design), permutation.of = c("task", "learner", "resampling")) + assert_names(names(design), must.include = c("task", "learner", "resampling")) assert_flag(store_models) batchtools::assertRegistry(reg, class = "ExperimentRegistry", writeable = TRUE, sync = TRUE, running.ok = FALSE) @@ -53,10 +53,22 @@ batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultR batchtools::addAlgorithm("run_learner", fun = run_learner, reg = reg) } - # group per problem to speed up addExperiments() + # set hashes set(design, j = "task_hash", value = map_chr(design$task, "hash")) set(design, j = "learner_hash", value = map_chr(design$learner, "hash")) set(design, j = "resampling_hash", value = map_chr(design$resampling, "hash")) + + # expand with param values + if (is.null(design$param_values)) { + design$param_values = list() + } else { + design$param_values = list(assert_param_values(design$param_values, n_learners = length(design$learner))) + task = learner = resampling = NULL + design = design[, list(task, learner, resampling, param_values = unlist(get("param_values"), recursive = FALSE)), by = c("learner_hash", "task_hash", "resampling_hash")] + } + design[, "param_values_hash" := map(get("param_values"), calculate_hash)] + + # group per problem to speed up addExperiments() design[, "group" := .GRP, by = c("task_hash", "resampling_hash")] groups = unique(design$group) @@ -85,13 +97,23 @@ batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultR exports = c(exports, learner_hashes[i]) } + param_values_hashes = tab$param_values_hash + for (i in which(param_values_hashes %nin% exports)) { + batchtools::batchExport(export = set_names(list(tab$param_values[[i]]), param_values_hashes[i]), reg = reg) + exports = c(exports, param_values_hashes[i]) + } + prob_design = data.table( - task_hash = task_hash, task_id = task$id, - resampling_hash = resampling_hash, resampling_id = resampling$id + task_hash = task_hash, + task_id = task$id, + resampling_hash = resampling_hash, + resampling_id = resampling$id ) algo_design = data.table( - learner_hash = learner_hashes, learner_id = map_chr(tab$learner, "id"), + learner_hash = learner_hashes, + learner_id = map_chr(tab$learner, "id"), + param_values_hash = param_values_hashes, store_models = store_models ) diff --git a/R/reduceResultsBatchmark.R b/R/reduceResultsBatchmark.R index 441c1f2..f7f68bf 100644 --- a/R/reduceResultsBatchmark.R +++ b/R/reduceResultsBatchmark.R @@ -1,8 +1,7 @@ #' @title Collect Results from batchmark #' #' @description -#' Collect the results from jobs defined via [batchmark()] and combine them into -#' a [mlr3::BenchmarkResult]. +#' Collect the results from jobs defined via [batchmark()] and combine them into a [mlr3::BenchmarkResult]. #' #' Note that `ids` defaults to finished jobs (as reported by [batchtools::findDone()]). #' If a job threw an error, is expired or is still running, it will be ignored with this default. diff --git a/R/worker.R b/R/worker.R index 613587c..ac9442c 100644 --- a/R/worker.R +++ b/R/worker.R @@ -1,7 +1,10 @@ -run_learner = function(job, data, learner_hash, store_models, ...) { +run_learner = function(job, data, learner_hash, param_values_hash, store_models, ...) { workhorse = utils::getFromNamespace("workhorse", ns = asNamespace("mlr3")) resampling = get(job$prob.pars$resampling_hash, envir = .GlobalEnv) learner = get(learner_hash, envir = .GlobalEnv) + param_values = get(param_values_hash, envir = .GlobalEnv) + + if (!is.null(param_values)) learner$param_set$set_values(.values = param_values) workhorse( iteration = job$repl, diff --git a/man/reduceResultsBatchmark.Rd b/man/reduceResultsBatchmark.Rd index 7eb04e2..42e53fa 100644 --- a/man/reduceResultsBatchmark.Rd +++ b/man/reduceResultsBatchmark.Rd @@ -46,8 +46,7 @@ If you want to ensure that all learners are in marshaled form, you need to call \link[mlr3:BenchmarkResult]{mlr3::BenchmarkResult}. } \description{ -Collect the results from jobs defined via \code{\link[=batchmark]{batchmark()}} and combine them into -a \link[mlr3:BenchmarkResult]{mlr3::BenchmarkResult}. +Collect the results from jobs defined via \code{\link[=batchmark]{batchmark()}} and combine them into a \link[mlr3:BenchmarkResult]{mlr3::BenchmarkResult}. Note that \code{ids} defaults to finished jobs (as reported by \code{\link[batchtools:findJobs]{batchtools::findDone()}}). If a job threw an error, is expired or is still running, it will be ignored with this default. diff --git a/tests/testthat/test_batchmark.R b/tests/testthat/test_batchmark.R index b00e11e..746ce69 100644 --- a/tests/testthat/test_batchmark.R +++ b/tests/testthat/test_batchmark.R @@ -97,3 +97,32 @@ test_that("marshaling", { expect_true(bmr_marshaled$resample_result(1)$learners[[1]]$marshaled) expect_false(bmr_unmarshaled$resample_result(1)$learners[[1]]$marshaled) }) + +test_that("adding parameter values works", { + tasks = tsks(c("iris", "spam")) + resamplings = list(rsmp("cv", folds = 3)$instantiate(tasks[[1]])) + learners = lrns("classif.debug") + + design = data.table( + task = tasks, + learner = learners, + resampling = resamplings, + param_values = list(list(list(x = 1), list(x = 0.5)))) + + reg = batchtools::makeExperimentRegistry(NA, make.default = FALSE) + + ids = batchmark(design, reg = reg) + expect_data_table(ids, ncol = 1L, nrows = 12L) + ids = batchtools::submitJobs(reg = reg) + batchtools::waitForJobs(reg = reg) + expect_data_table(ids, nrows = 12) + + logs = batchtools::getErrorMessages(reg = reg) + expect_data_table(logs, nrows = 0L) + results = reduceResultsBatchmark(reg = reg) + expect_is(results, "BenchmarkResult") + expect_benchmark_result(results) + expect_data_table(as.data.table(results), nrow = 12L) + expect_equal(results$learners$learner[[1]]$param_set$values$x, 1) + expect_equal(results$learners$learner[[2]]$param_set$values$x, 0.5) +})