Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 20, 2024
1 parent 41028cd commit adc95f8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
2 changes: 0 additions & 2 deletions R/mlr_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ load_callback_holdout_task = function() {
man = "mlr3::mlr3.holdout_task",

on_resample_before_predict = function(callback, context) {
assert_task(callback$state$task)

pred = context$learner$predict(callback$state$task)
context$data_extra = list(prediction_holdout = pred)
}
Expand Down
26 changes: 26 additions & 0 deletions tests/testthat/test_CallbackResample.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,29 @@ test_that("data_extra is null", {
expect_names(names(tab), disjunct.from = "data_extra")
})

test_that("learner cloning in workhorse is passed to context", {
task = tsk("pima")
learner = lrn("classif.rpart")
resampling = rsmp("holdout")

callback = callback_resample("test",
on_resample_begin = function(callback, context) {
callback$state$address_1 = data.table::address(context$learner)
},

on_resample_before_train = function(callback, context) {
callback$state$address_2 = data.table::address(context$learner)
},

on_resample_end = function(callback, context) {
context$data_extra = list(
address_1 = callback$state$address_1,
address_2 = callback$state$address_2
)
}
)

rr = resample(task, learner, resampling, callbacks = callback)

expect_true(rr$data_extra[[1]]$address_1 != rr$data_extra[[1]]$address_2)
})

0 comments on commit adc95f8

Please sign in to comment.