diff --git a/R/nearmiss.R b/R/nearmiss.R index e305ae8..ef587e3 100644 --- a/R/nearmiss.R +++ b/R/nearmiss.R @@ -162,9 +162,11 @@ bake.step_nearmiss <- function(object, new_data, ...) { with_seed( seed = object$seed, code = { + original_levels <- levels(new_data[[object$column]]) new_data <- nearmiss(new_data, object$column, k = object$neighbors, under_ratio = object$under_ratio) + new_data[[object$column]] <- factor(new_data[[object$column]], levels = original_levels) } ) diff --git a/R/rose.R b/R/rose.R index 54a72fd..bb5cfd6 100644 --- a/R/rose.R +++ b/R/rose.R @@ -186,11 +186,13 @@ bake.step_rose <- function(object, new_data, ...) { with_seed( seed = object$seed, code = { + original_levels <- levels(new_data[[object$column]]) new_data <- ROSE(string2formula(object$column), new_data, N = majority_size * object$over_ratio, p = object$minority_prop, hmult.majo = object$majority_smoothness, hmult.mino = object$minority_smoothness)$data + new_data[[object$column]] <- factor(new_data[[object$column]], levels = original_levels) } ) diff --git a/R/tomek.R b/R/tomek.R index 8983925..5360e2a 100644 --- a/R/tomek.R +++ b/R/tomek.R @@ -147,10 +147,10 @@ response_0_1 <- function(x) { ifelse(x == names(sort(table(x)))[1], 1, 0) } # Turns 0-1 coded variable back into factor variable -response_0_1_to_org <- function(old, new) { +response_0_1_to_org <- function(old, new, levels) { ref <- names(sort(table(old))) names(ref) <- c("1", "0") - factor(unname(ref[as.character(new)])) + factor(unname(ref[as.character(new)]), levels = levels) } #' @importFrom tibble as_tibble tibble @@ -165,6 +165,7 @@ bake.step_tomek <- function(object, new_data, ...) { with_seed( seed = object$seed, code = { + original_levels <- levels(new_data[[object$column]]) tomek_data <- ubTomek(X = select(new_data, -!!object$column), Y = response_0_1(new_data[[object$column]]), verbose = FALSE) @@ -174,7 +175,7 @@ bake.step_tomek <- function(object, new_data, ...) { new_data0 <- mutate( tomek_data$X, !!object$column := response_0_1_to_org(new_data[[object$column]], - tomek_data$Y) + tomek_data$Y, levels = original_levels) ) as_tibble(new_data0[names(new_data)])