Skip to content

Commit

Permalink
feat: add result object
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed May 31, 2024
1 parent f0b1098 commit a50d040
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 8 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Imports:
lgr,
mlr3misc (>= 0.15.0.9000),
paradox (>= 1.0.0),
R6
R6,
stabm
Suggests:
e1071,
genalg,
Expand All @@ -55,6 +56,7 @@ Collate:
'AutoFSelector.R'
'CallbackBatchFSelect.R'
'ContextBatchFSelect.R'
'EnsembleFSResult.R'
'FSelectInstanceBatchSingleCrit.R'
'FSelectInstanceBatchMultiCrit.R'
'mlr_fselectors.R'
Expand Down
75 changes: 75 additions & 0 deletions R/EnsembleFSResult.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#' @title Ensemble Feature Selection Result
#'
#' @description
#' The `EnsembleFSResult` stores the results of the ensemble feature selection.
#' The function [ensemble_fselect()] returns an object of this class.
#'
#' @examples
#' \donttest{
#' efsr = ensemble_fselect(
#' fselector = fs("rfe", n_features = 2, feature_fraction = 0.8),
#' task = tsk("sonar"),
#' learners = lrns(c("classif.rpart", "classif.featureless")),
#' init_resampling = rsmp("subsampling", repeats = 2),
#' inner_resampling = rsmp("cv", folds = 3),
#' measure = msr("classif.ce"),
#' terminator = trm("none")
#' )
#'
#' # contains the benchmark result
#' efsr$benchmark_result
#'
#' # contains the selected features for each iteration
#' efsr$grid
#'
#' # returns the stability of the selected features
#' efsr$stability(stability_measure = "jaccard")
#' }
EnsembleFSResult = R6Class("EnsembleFSResult",
public = list(

#' @field benchmark_result (`BenchmarkResult`)\cr
#' The benchmark result object.
benchmark_result = NULL,

#' @field grid (`data.table`)\cr
#' The grid of feature selection results.
grid = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param benchmark_result (`BenchmarkResult`)\cr
#' The benchmark result object.
#' @param grid (`data.table`)\cr
#' The grid of feature selection results.
initialize = function(benchmark_result, grid) {
self$benchmark_result = assert_benchmark_result(benchmark_result)
self$grid = assert_data_table(grid)
},

#' @description
#' Returns the feature ranking.
feature_ranking = function() {

},

#' @description
#' Calculates the stability of the selected features with the `stabm` package.
#'
#' @param stability_measure (`character(1)`)\cr
#' The stability measure to be used.
#' One of the measures returned by [stabm::listStabilityMeasures()] in lower case.
#' Default is `"jaccard"`.
#' @param ... (`any`)\cr
#' Additional arguments passed to the stability measure function.
stability = function(stability_measure = "jaccard", ...) {
funs = stabm::listStabilityMeasures()$Name
keys = tolower(gsub("stability", "", funs))
assert_choice(stability_measure, choices = keys)

fun = get(funs[which(stability_measure == keys)], envir = asNamespace("stabm"))
fun(self$grid$features, ...)
}
)
)
2 changes: 1 addition & 1 deletion R/ensemble_fselect.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ ensemble_fselect = function(
set(grid, j = "importance", value = imp_scores)
}

grid
EnsembleFSResult$new(bmr, grid)
}
123 changes: 123 additions & 0 deletions man/EnsembleFSResult.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 21 additions & 6 deletions tests/testthat/test_ensemble_fselect.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("ensemble feature selection works", {
res = ensemble_fselect(
efsr = ensemble_fselect(
fselector = fs("rfe", n_features = 2, feature_fraction = 0.8),
task = tsk("sonar"),
learners = lrns(c("classif.rpart", "classif.featureless")),
Expand All @@ -9,11 +9,26 @@ test_that("ensemble feature selection works", {
terminator = trm("none")
)

expect_data_table(res, nrows = 4)
expect_list(res$features, any.missing = FALSE, len = 4)
expect_vector(res$n_features, size = 4)
expect_vector(res$classif.ce, size = 4)
expect_list(res$importance, any.missing = FALSE, len = 4)
expect_data_table(efsr$grid, nrows = 4)
expect_list(efsr$grid$features, any.missing = FALSE, len = 4)
expect_vector(efsr$grid$n_features, size = 4)
expect_vector(efsr$grid$classif.ce, size = 4)
expect_list(efsr$grid$importance, any.missing = FALSE, len = 4)
expect_benchmark_result(efsr$benchmark_result)
})

test_that("stability method works", {
efsr = ensemble_fselect(
fselector = fs("rfe", n_features = 2, feature_fraction = 0.8),
task = tsk("sonar"),
learners = lrns(c("classif.rpart", "classif.featureless")),
init_resampling = rsmp("subsampling", repeats = 2),
inner_resampling = rsmp("cv", folds = 3),
measure = msr("classif.ce"),
terminator = trm("none")
)

expect_number(efsr$stability(stability_measure = "jaccard"))
})


0 comments on commit a50d040

Please sign in to comment.