-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem with predicting after butchering bart #262
Comments
I'm going in the wrong direction for simplifying the reprex but the issues seems to be in library(tidymodels)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels_prefer()
data(ames)
set.seed(502)
ames_split <- ames |>
mutate(Sale_Price = log10(Sale_Price)) |>
initial_split(prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
ames_rec <-
recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type +
Latitude + Longitude, data = ames_train) |>
step_log(Gr_Liv_Area, base = 10) |>
step_other(Neighborhood, threshold = 0.01) |>
step_dummy(all_nominal_predictors()) |>
step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) |>
step_ns(Latitude, Longitude, deg_free = 20)
bart_model <- bart(mode = "regression")
bart_wflow <- workflow(ames_rec, bart_model)
fitted <- bart_wflow |> last_fit(ames_split) |> extract_workflow()
predict(fitted, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.12
butchered <- butcher::butcher(fitted, verbose = TRUE)
#> ✔ Memory released: "20.51 MB"
x <- fitted
old <- fitted
x <- butcher::axe_call(x, verbose = TRUE)
#> ✖ No memory released. Do not butcher.
#> ✖ No memory released. Do not butcher.
predict(x, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.07
x <- butcher::axe_ctrl(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
predict(x, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.07
x <- butcher::axe_data(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
predict(x, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.53
x <- butcher::axe_env(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "320 B"
#> ✔ Memory released: "208 B"
predict(x, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.17
x <- butcher::axe_fitted(x, verbose = TRUE)
#> ✔ Memory released: "19.05 MB"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "19.05 MB"
#> ✔ Memory released: "93.88 kB"
predict(x, new_data = slice_sample(ames_test, n = 1))
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length Created on 2023-08-14 with reprex v2.0.2 I noticed that if I extract the fit engine, it works. That makes me think library(tidymodels)
tidymodels_prefer()
data(ames)
set.seed(502)
ames_split <- ames |>
mutate(Sale_Price = log10(Sale_Price)) |>
initial_split(prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
ames_rec <-
recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type +
Latitude + Longitude, data = ames_train) |>
step_log(Gr_Liv_Area, base = 10) |>
step_other(Neighborhood, threshold = 0.01) |>
step_dummy(all_nominal_predictors()) |>
step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) |>
step_ns(Latitude, Longitude, deg_free = 20)
bart_model <- bart(mode = "regression")
bart_wflow <- workflow(ames_rec, bart_model)
fitted <-
bart_wflow |> last_fit(ames_split) |> extract_workflow() |> extract_fit_engine()
sample_data <- ames_rec |> prep() |> bake(new_data = slice_sample(ames_test, n = 1))
predict(fitted, newdata = sample_data) |> median()
#> [1] 5.117399 Created on 2023-08-14 with reprex v2.0.2 |
Smaller reprex: library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()
bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars)
predict(fitted, new_data = mtcars[11,])
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 18.6
butchered <- butcher::butcher(fitted)
predict(butchered, new_data = mtcars[11,])
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length Created on 2023-08-14 with reprex v2.0.2 |
We can butcher and still predict on the underlying model but not the parsnip model: library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()
bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars)
predict(fitted, new_data = mtcars[11,])
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 18.5
predict(fitted$fit, mtcars[11,]) |> head() ## spits out a lot
#> [,1]
#> [1,] 19.71279
#> [2,] 18.52702
#> [3,] 17.92176
#> [4,] 17.84751
#> [5,] 19.16888
#> [6,] 15.65998
## we can butcher the underlying model
butchered <- butcher::butcher(fitted$fit)
predict(butchered, mtcars[11,]) |> head() ## spits out a lot
#> [,1]
#> [1,] 19.71279
#> [2,] 18.52702
#> [3,] 17.92176
#> [4,] 17.84751
#> [5,] 19.16888
#> [6,] 15.65998
## doesn't work for the parsnip model
butchered <- butcher::butcher(fitted)
predict(butchered, mtcars[11,])
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length Created on 2023-08-14 with reprex v2.0.2 |
You are right @JamesHWade about library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()
bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars)
butchered <- butcher::butcher(fitted)
## Nothing returned if butchered:
predict(fitted$fit, mtcars[11,], type = "ppd") |> head() ## lots of output
#> [,1]
#> [1,] 17.75746
#> [2,] 21.09958
#> [3,] 19.05330
#> [4,] 19.24808
#> [5,] 20.92507
#> [6,] 15.06747
predict(butchered$fit, mtcars[11,], type = "ppd")
#> numeric(0) Created on 2023-08-14 with reprex v2.0.2 |
A new version of butcher with this bug fix is now on CRAN @JamesHWade. Thanks again for the report! 🙌 |
That was fast! Thank you!!! |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
In rstudio/vetiver-r#234 @JamesHWade reported a problem predicting with
bart()
used via tidymodels:Created on 2023-08-14 with reprex v2.0.2
This reprex is probably too big right now for us to find the real problem but it's a place to start.
The text was updated successfully, but these errors were encountered: