Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: start remote workers with mirai #276

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ Suggests:
GenSA,
irace (>= 4.0.0),
knitr,
mirai,
nloptr,
progressr,
processx,
redux,
testthat (>= 3.0.0),
rush (>= 0.1.2)
Remotes:
mlr-org/rush@mirai
Config/testthat/edition: 3
Config/testthat/parallel: false
Encoding: UTF-8
Expand Down
10 changes: 9 additions & 1 deletion R/Objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ Objective = R6Class("Objective",
man = function(rhs) {
assert_ro_binding(rhs)
private$.man
},

#' @field packages (`character()`)\cr
#' Set of required packages.
packages = function(rhs) {
assert_ro_binding(rhs)
private$.packages
}
),

Expand Down Expand Up @@ -211,6 +218,7 @@ Objective = R6Class("Objective",
},

.label = NULL,
.man = NULL
.man = NULL,
.packages = NULL
)
)
28 changes: 17 additions & 11 deletions R/OptimizerAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ OptimizerAsync = R6Class("OptimizerAsync",
#' @keywords internal
#' @export
optimize_async_default = function(instance, optimizer, design = NULL, n_workers = NULL) {
assert_class(instance, "OptimInstanceAsync")
assert_class(optimizer, "OptimizerAsync")
assert_data_table(design, null.ok = TRUE)

instance$archive$start_time = Sys.time()
Expand All @@ -81,19 +79,17 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
# run .optimize() on workers
rush = instance$rush

# FIXME: How to pass globals and packages?
if (rush$n_pre_workers) {
# start remote workers
if (requireNamespace("mirai") && mirai::daemons()$connections) {
# remote workers
lg$info("Starting to optimize %i parameter(s) with '%s' and '%s' on %i remote worker(s)",
instance$search_space$length,
optimizer$format(),
instance$terminator$format(with_params = TRUE),
rush$n_pre_workers
)
rush::rush_config()$n_workers)

rush$start_remote_workers(
worker_loop = bbotk_worker_loop,
packages = c(optimizer$packages, "bbotk"), # add packages from objective
packages = c(optimizer$packages, instance$objective$packages, "bbotk"),
optimizer = optimizer,
instance = instance)
} else if (rush::rush_available()) {
Expand All @@ -107,18 +103,25 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers

rush$start_local_workers(
worker_loop = bbotk_worker_loop,
packages = c(optimizer$packages, "bbotk"), # add packages from objective
packages = c(optimizer$packages, instance$objective$packages, "bbotk"),
optimizer = optimizer,
instance = instance)
} else {
stop("No rush plan available to start local workers and no pre-started remote workers found. See `?rush::rush_plan()`.")
stop("No rush plan available to start local workers and `mirai::daemons()` found. See `?rush::rush_plan()`.")
}
}

n_running_workers = 0
# wait until optimization is finished
# check terminated workers when the terminator is "none"
while(TRUE) {
Sys.sleep(1)

if (rush$n_running_workers > n_running_workers) {
n_running_workers = rush$n_running_workers
lg$info("%i of %i worker(s) started", n_running_workers, rush::rush_config()$n_workers)
}

instance$rush$print_log()

# fetch new results for printing
Expand All @@ -133,7 +136,10 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
}

if (instance$is_terminated) break
if (instance$rush$all_workers_terminated) break
if (instance$rush$all_workers_terminated) {
lg$info("All workers have terminated.")
break
}
}

# assign result
Expand Down
3 changes: 3 additions & 0 deletions man/Objective.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions tests/testthat/test_OptimInstanceAsyncSingleCrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ test_that("reconnect method works", {
skip_if_not_installed("rush")
flush_redis()

on.exit({
file.remove("instance.rds")
})

rush::rush_plan(n_workers = 2)

instance = oi_async(
Expand All @@ -94,8 +90,9 @@ test_that("reconnect method works", {
optimizer = opt("async_random_search")
optimizer$optimize(instance)

saveRDS(instance, file = "instance.rds")
instance = readRDS("instance.rds")
file = tempfile(fileext = ".rds")
saveRDS(instance, file = file )
instance = readRDS(file)

instance$reconnect()

Expand Down
26 changes: 7 additions & 19 deletions tests/testthat/test_OptimizerAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_that("OptimizerAsync starts local workers", {
instance = oi_async(
objective = OBJ_2D,
search_space = PS_2D,
terminator = trm("evals", n_evals = 5L),
terminator = trm("evals", n_evals = 50L),
)

optimizer = opt("async_random_search")
Expand All @@ -22,41 +22,29 @@ test_that("OptimizerAsync starts local workers", {

test_that("OptimizerAsync starts remote workers", {
skip_on_cran()
skip_if_not_installed("rush")
skip_if_not_installed("processx")
skip_if_not_installed(c("rush", "mirai"))
flush_redis()
library(processx)

rush = rsh(network_id = "test_rush")
expect_snapshot(rush$create_worker_script())

px = process$new("Rscript",
args = c("-e", 'rush::start_worker(network_id = "test_rush", remote = TRUE, url = "redis://127.0.0.1:6379", scheme = "redis", host = "127.0.0.1", port = "6379")'),
supervise = TRUE,
stderr = "|", stdout = "|")
library(rush)

on.exit({
px$kill()
}, add = TRUE)
mirai::daemons(2)

Sys.sleep(5)
rush_plan(n_workers = 2)

instance = oi_async(
objective = OBJ_2D,
search_space = PS_2D,
terminator = trm("evals", n_evals = 5L),
rush = rush
)

optimizer = opt("async_random_search")
optimizer$optimize(instance)

expect_data_table(instance$rush$worker_info, nrows = 1)
expect_true(instance$rush$worker_info$remote)
expect_data_table(instance$rush$worker_info, nrows = 2)

expect_rush_reset(instance$rush)
})


test_that("OptimizerAsync assigns result", {
skip_on_cran()
skip_if_not_installed("rush")
Expand Down
Loading