Skip to content

Commit

Permalink
change stability args to a list
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Jun 21, 2024
1 parent 4fde94b commit 33b8621
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
#' The stability measure to be used.
#' One of the measures returned by [stabm::listStabilityMeasures()] in lower case.
#' Default is `"jaccard"`.
#' @param ... (`any`)\cr
#' @param stability_args (`list`)\cr
#' Additional arguments passed to the stability measure function.
#' @param global (`logical(1)`)\cr
#' Whether to calculate the stability globally or for each learner.
Expand All @@ -167,7 +167,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
#'
#' @return A `numeric()` value representing the stability of the selected features.
#' Or a `numeric()` vector with the stability of the selected features for each learner.
stability = function(stability_measure = "jaccard", ..., global = TRUE, reset_cache = FALSE) {
stability = function(stability_measure = "jaccard", stability_args = NULL, global = TRUE, reset_cache = FALSE) {
funs = stabm::listStabilityMeasures()$Name
keys = tolower(gsub("stability", "", funs))
assert_choice(stability_measure, choices = keys)
Expand All @@ -179,7 +179,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",
}

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))
private$.stability_global[[stability_measure]] = fun(private$.result$features, ...)
private$.stability_global[[stability_measure]] = invoke(fun, features = private$.result$features, .args = stability_args)
private$.stability_global[[stability_measure]]
} else {
# cached results
Expand All @@ -189,7 +189,7 @@ EnsembleFSResult = R6Class("EnsembleFSResult",

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))

tab = private$.result[, list(score = fun(.SD$features, ...)), by = learner_id]
tab = private$.result[, list(score = invoke(fun, features = .SD$features, .args = stability_args)), by = learner_id]
private$.stability_learner[[stability_measure]] = set_names(tab$score, tab$learner_id)
private$.stability_learner[[stability_measure]]
}
Expand Down
6 changes: 3 additions & 3 deletions man/ensemble_fs_result.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 33b8621

Please sign in to comment.