diff --git a/R/translate.R b/R/translate.R index 4732dab76..e096b5dfe 100644 --- a/R/translate.R +++ b/R/translate.R @@ -106,25 +106,21 @@ get_model_spec <- function(model, mode, engine) { env_obj <- grep(model, env_obj, value = TRUE) res <- list() - res$libs <- - rlang::env_get(m_env, paste0(model, "_pkgs")) %>% - dplyr::filter(engine == !!engine) %>% - purrr::pluck("pkg") %>% - purrr::pluck(1) - - res$fit <- - rlang::env_get(m_env, paste0(model, "_fit")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::pull(value) %>% - purrr::pluck(1) - - pred_code <- - rlang::env_get(m_env, paste0(model, "_predict")) %>% - dplyr::filter(mode == !!mode & engine == !!engine) %>% - dplyr::select(-engine, -mode) - - res$pred <- pred_code[["value"]] - names(res$pred) <- pred_code$type + + libs <- rlang::env_get(m_env, paste0(model, "_pkgs")) + libs <- vctrs::vec_slice(libs$pkg, libs$engine == engine) + res$libs <- if (length(libs) > 0) {libs[[1]]} else {NULL} + + fits <- rlang::env_get(m_env, paste0(model, "_fit")) + fits <- vctrs::vec_slice(fits$value, fits$mode == mode & fits$engine == engine) + res$fit <- if (length(fits) > 0) {fits[[1]]} else {NULL} + + preds <- rlang::env_get(m_env, paste0(model, "_predict")) + where <- preds$mode == mode & preds$engine == engine + types <- vctrs::vec_slice(preds$type, where) + values <- vctrs::vec_slice(preds$value, where) + names(values) <- types + res$pred <- values res } diff --git a/tests/testthat/test_translate.R b/tests/testthat/test_translate.R index 24a50bc5f..d4c5bc913 100644 --- a/tests/testthat/test_translate.R +++ b/tests/testthat/test_translate.R @@ -309,4 +309,38 @@ test_that("translate tuning paramter names", { expect_snapshot_error(.model_param_name_key(1)) }) +# ------------------------------------------------------------------------------ + +test_that("get_model_spec helper", { + mod1 <- get_model_spec("linear_reg", "regression", "lm") + + expect_type(mod1, "list") + + expect_type(mod1$libs, "character") + expect_length(mod1$libs, 1) + expect_equal(mod1$libs, "stats") + + expect_type(mod1$fit, "list") + expect_length(mod1$fit, 4) + expect_equal(names(mod1$fit), c("interface", "protect", "func", "defaults")) + expect_type(mod1$pred, "list") + expect_length(mod1$pred, 4) + expect_equal(names(mod1$pred), c("numeric", "conf_int", "pred_int", "raw")) + + expect_type(mod1$pred$numeric, "list") + expect_length(mod1$pred$numeric, 4) + expect_equal(names(mod1$pred$numeric), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$conf_int, "list") + expect_length(mod1$pred$conf_int, 4) + expect_equal(names(mod1$pred$conf_int), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$pred_int, "list") + expect_length(mod1$pred$pred_int, 4) + expect_equal(names(mod1$pred$pred_int), c("pre", "post", "func", "args")) + + expect_type(mod1$pred$raw, "list") + expect_length(mod1$pred$raw, 4) + expect_equal(names(mod1$pred$raw), c("pre", "post", "func", "args")) +})