Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Nov 21, 2024
1 parent af87ce7 commit 6a97bf0
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 133 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ RoxygenNote: 7.3.2
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
'CallbackWorkhorse.R'
'ContextWorkhorse.R'
'warn_deprecated.R'
'DataBackend.R'
'DataBackendCbind.R'
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled)
S3method(unmarshal_model,default)
S3method(unmarshal_model,learner_state_marshaled)
export(BenchmarkResult)
export(CallbackWorkhorse)
export(ContextWorkhorse)
export(DataBackend)
export(DataBackendDataTable)
export(DataBackendMatrix)
Expand Down Expand Up @@ -214,9 +216,12 @@ export(assert_row_ids)
export(assert_task)
export(assert_tasks)
export(assert_validate)
export(assert_workhorse_callback)
export(assert_workhorse_callbacks)
export(auto_convert)
export(benchmark)
export(benchmark_grid)
export(callback_workhorse)
export(check_prediction_data)
export(col_info)
export(convert_task)
Expand Down
113 changes: 0 additions & 113 deletions R/CallbackResample.R

This file was deleted.

93 changes: 93 additions & 0 deletions R/CallbackWorkhorse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#' @title Create Workhorse Callback
#'
#' @description
#' Callbacks allow to customize the behavior of processes in mlr3.
#'
#' @export
CallbackWorkhorse= R6Class("CallbackWorkhorse",
inherit = Callback,
public = list(

on_workhorse_before_train = NULL,

on_workhorse_before_predict = NULL,

on_workhorse_before_result = NULL
)
)

#' @title Create Workhorse Callback
#'
#' @description
#' Function to create a [CallbackWorkhorse].
#'
#' ```
#' Start Workhorse
#' - on_workhorse_before_train
#' - on_workhorse_before_predict
#' - on_workhorse_before_result
#' End Tuning
#' ```
#
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
#' Workhorse callbacks access [ContextWorkhorse].
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
#' @param label (`character(1)`)\cr
#' Label for the new instance.
#' @param man (`character(1)`)\cr
#' String in the format `[pkg]::[topic]` pointing to a manual page for this object.
#' The referenced help package can be opened via method `$help()`.
#'
#' @export
#' @inherit CallbackWorkhorse examples
callback_workhorse = function(
id,
label = NA_character_,
man = NA_character_,
on_workhorse_before_train = NULL,
on_workhorse_before_predict = NULL,
on_workhorse_before_result = NULL
) {
stages = discard(set_names(list(
on_workhorse_before_train,
on_workhorse_before_predict,
on_workhorse_before_result),
c(
"on_workhorse_before_train",
"on_workhorse_before_predict",
"on_workhorse_before_result"
)), is.null)

walk(stages, function(stage) assert_function(stage, args = c("callback", "context")))
callback = CallbackWorkhorse$new(id, label, man)
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
}

#' @title Assertions for Callbacks
#'
#' @description
#' Assertions for [CallbackWorkhorse] class.
#'
#' @param callback ([CallbackWorkhorse]).
#' @param null_ok (`logical(1)`)\cr
#' If `TRUE`, `NULL` is allowed.
#'
#' @return [CallbackWorkhorse | List of [CallbackWorkhorse]s.
#' @export
assert_workhorse_callback = function(callback, null_ok = FALSE) {
if (null_ok && is.null(callback)) return(invisible(NULL))
assert_class(callback, "CallbackWorkhorse")
invisible(callback)
}

#' @export
#' @param callbacks (list of [CallbackWorkhorse]).
#' @rdname assert_workhorse_callback
assert_workhorse_callbacks = function(callbacks) {
invisible(lapply(callbacks, assert_workhorse_callback))
}
7 changes: 2 additions & 5 deletions R/ContextResample.R → R/ContextWorkhorse.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
#' See the section on active bindings for a list of modifiable objects.
#' See [callback_batch_tuning()] for a list of stages that access `ContextBatchTuning`.
#'
#' @template param_inst_batch
#' @template param_tuner
#'
#' @export
ContextResample = R6Class("ContextResample",
ContextWorkhorse = R6Class("ContextResample",
inherit = Context,
public = list(
data = NULL
env = NULL
)
)
11 changes: 4 additions & 7 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ resample = function(
unmarshal = TRUE,
callbacks = NULL
) {
callbacks = assert_resample_callbacks(as_callbacks(callbacks))
context = ContextResample$new("resample")
callbacks = assert_workhorse_callbacks(as_callbacks(callbacks))

assert_subset(clone, c("task", "learner", "resampling"))
task = assert_task(as_task(task, clone = "task" %in% clone))
Expand Down Expand Up @@ -129,10 +128,10 @@ resample = function(
}

res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode,
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal)
MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks)
)

context$data = data.table(
data = data.table(
task = list(task),
learner = grid$learner,
learner_state = map(res, "learner_state"),
Expand All @@ -144,9 +143,7 @@ resample = function(
learner_hash = map_chr(res, "learner_hash")
)

call_back("on_resample_before_result_data", callbacks, context)

result_data = ResultData$new(context$data, store_backends = store_backends)
result_data = ResultData$new(data, store_backends = store_backends)

# the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have
# to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where
Expand Down
15 changes: 11 additions & 4 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,11 @@ workhorse = function(
mode = "train",
is_sequential = TRUE,
unmarshal = TRUE,
callback = NULL,
callbacks = NULL
) {
context = ContextWorkhorse$new("workhorse")
context$env = environment()

if (!is.null(pb)) {
pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration))
}
Expand Down Expand Up @@ -319,6 +322,9 @@ workhorse = function(
validate = get0("validate", learner)

test_set = if (identical(validate, "test")) sets$test

call_back("on_workhorse_before_train", callbacks, context)

train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode)
learner = train_result$learner

Expand All @@ -337,6 +343,9 @@ workhorse = function(

pdatas = Map(function(set, row_ids, task) {
lg$debug("Creating Prediction for predict set '%s'", set)

call_back("on_workhorse_before_predict", callbacks, context)

learner_predict(learner, task, row_ids)
}, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks)

Expand All @@ -345,9 +354,7 @@ workhorse = function(
}
pdatas = discard(pdatas, is.null)

if (!is.null(callback)) {
learner_state = c(learner_state, assert_list(callback(learner$model)))
}
call_back("on_workhorse_before_result", callbacks, context)

# set the model slot after prediction so it can be sent back to the main process
process_model_after_predict(
Expand Down
46 changes: 46 additions & 0 deletions man/CallbackWorkhorse.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6a97bf0

Please sign in to comment.