Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with predicting after butchering bart #262

Closed
juliasilge opened this issue Aug 14, 2023 · 7 comments · Fixed by #263
Closed

Problem with predicting after butchering bart #262

juliasilge opened this issue Aug 14, 2023 · 7 comments · Fixed by #263
Labels
bug an unexpected problem or unintended behavior

Comments

@juliasilge
Copy link
Member

In rstudio/vetiver-r#234 @JamesHWade reported a problem predicting with bart() used via tidymodels:

library(tidymodels)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels_prefer()
data(ames)

set.seed(502)
ames_split <- ames |> 
    mutate(Sale_Price = log10(Sale_Price)) |> 
    initial_split(prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test  <- testing(ames_split)

ames_rec <- 
    recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
               Latitude + Longitude, data = ames_train) |>
    step_log(Gr_Liv_Area, base = 10) |> 
    step_other(Neighborhood, threshold = 0.01) |> 
    step_dummy(all_nominal_predictors()) |> 
    step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) |> 
    step_ns(Latitude, Longitude, deg_free = 20)


bart_model <- bart(mode = "regression")
bart_wflow <- workflow(ames_rec, bart_model)
fitted <- bart_wflow |> last_fit(ames_split) |> extract_workflow()
predict(fitted, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.12

butchered <- butcher::butcher(fitted)
predict(butchered,  new_data = slice_sample(ames_test, n = 1))
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length

Created on 2023-08-14 with reprex v2.0.2

This reprex is probably too big right now for us to find the real problem but it's a place to start.

@juliasilge juliasilge added the bug an unexpected problem or unintended behavior label Aug 14, 2023
@JamesHWade
Copy link

JamesHWade commented Aug 14, 2023

I'm going in the wrong direction for simplifying the reprex but the issues seems to be in axe_fitted().

library(tidymodels)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels_prefer()
data(ames)

set.seed(502)
ames_split <- ames |> 
  mutate(Sale_Price = log10(Sale_Price)) |> 
  initial_split(prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test  <- testing(ames_split)

ames_rec <- 
  recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
           Latitude + Longitude, data = ames_train) |>
  step_log(Gr_Liv_Area, base = 10) |> 
  step_other(Neighborhood, threshold = 0.01) |> 
  step_dummy(all_nominal_predictors()) |> 
  step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) |> 
  step_ns(Latitude, Longitude, deg_free = 20)


bart_model <- bart(mode = "regression")
bart_wflow <- workflow(ames_rec, bart_model)
fitted <- bart_wflow |> last_fit(ames_split) |> extract_workflow()
predict(fitted, new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.12
butchered <- butcher::butcher(fitted, verbose = TRUE)
#> ✔ Memory released: "20.51 MB"
x <- fitted
old <- fitted
x <- butcher::axe_call(x, verbose = TRUE)
#> ✖ No memory released. Do not butcher.
#> ✖ No memory released. Do not butcher.
predict(x,  new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.07
x <- butcher::axe_ctrl(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
predict(x,  new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.07
x <- butcher::axe_data(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
predict(x,  new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.53
x <- butcher::axe_env(x, verbose = TRUE)
#> ✔ Memory released: "0 B"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "320 B"
#> ✔ Memory released: "208 B"
predict(x,  new_data = slice_sample(ames_test, n = 1))
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.17
x <- butcher::axe_fitted(x, verbose = TRUE)
#> ✔ Memory released: "19.05 MB"
#> ✖ Disabled: `print()` and `summary()`
#> ✔ Memory released: "19.05 MB"
#> ✔ Memory released: "93.88 kB"
predict(x,  new_data = slice_sample(ames_test, n = 1))
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length

Created on 2023-08-14 with reprex v2.0.2

I noticed that if I extract the fit engine, it works. That makes me think parsnip::bart() might be another place to look. Maybe related to post_dist <- predict(obj$fit, new_data, type = "ppd") inside of dbart_predict_calc() (link to parsnip file).

library(tidymodels)
tidymodels_prefer()
data(ames)

set.seed(502)
ames_split <- ames |> 
  mutate(Sale_Price = log10(Sale_Price)) |> 
  initial_split(prop = 0.80, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test  <- testing(ames_split)

ames_rec <- 
  recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
           Latitude + Longitude, data = ames_train) |>
  step_log(Gr_Liv_Area, base = 10) |> 
  step_other(Neighborhood, threshold = 0.01) |> 
  step_dummy(all_nominal_predictors()) |> 
  step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) |> 
  step_ns(Latitude, Longitude, deg_free = 20)

bart_model <- bart(mode = "regression")
bart_wflow <- workflow(ames_rec, bart_model)
fitted <-  
  bart_wflow |> last_fit(ames_split) |> extract_workflow() |> extract_fit_engine() 
sample_data <-  ames_rec |> prep() |> bake(new_data = slice_sample(ames_test, n = 1))
predict(fitted, newdata = sample_data) |> median()
#> [1] 5.117399

Created on 2023-08-14 with reprex v2.0.2

@juliasilge
Copy link
Member Author

juliasilge commented Aug 15, 2023

Smaller reprex:

library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()

bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars) 
predict(fitted, new_data = mtcars[11,])
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  18.6

butchered <- butcher::butcher(fitted)
predict(butchered,  new_data = mtcars[11,])
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length

Created on 2023-08-14 with reprex v2.0.2

@juliasilge
Copy link
Member Author

We can butcher and still predict on the underlying model but not the parsnip model:

library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()

bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars) 
predict(fitted, new_data = mtcars[11,])
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  18.5
predict(fitted$fit, mtcars[11,]) |> head() ## spits out a lot
#>          [,1]
#> [1,] 19.71279
#> [2,] 18.52702
#> [3,] 17.92176
#> [4,] 17.84751
#> [5,] 19.16888
#> [6,] 15.65998

## we can butcher the underlying model
butchered <- butcher::butcher(fitted$fit)
predict(butchered, mtcars[11,]) |> head() ## spits out a lot
#>          [,1]
#> [1,] 19.71279
#> [2,] 18.52702
#> [3,] 17.92176
#> [4,] 17.84751
#> [5,] 19.16888
#> [6,] 15.65998

## doesn't work for the parsnip model
butchered <- butcher::butcher(fitted)
predict(butchered, mtcars[11,])
#> Error in apply(post_dist, 2, mean, na.rm = TRUE): dim(X) must have a positive length

Created on 2023-08-14 with reprex v2.0.2

@juliasilge
Copy link
Member Author

You are right @JamesHWade about post_dist <- predict(obj$fit, new_data, type = "ppd"); the error happens here:
https://github.com/tidymodels/parsnip/blob/e6cd72fb3aab56d3dce3ffbd361ffd509fe36718/R/bart.R#L199

library(parsnip)
## note that `bart()` has a pretty nasty namespace collision:
tidymodels::tidymodels_prefer()

bart_model <- bart(mode = "regression")
fitted <- bart_model |> fit(mpg ~ ., mtcars) 
butchered <- butcher::butcher(fitted)

## Nothing returned if butchered:
predict(fitted$fit, mtcars[11,], type = "ppd") |> head() ## lots of output
#>          [,1]
#> [1,] 17.75746
#> [2,] 21.09958
#> [3,] 19.05330
#> [4,] 19.24808
#> [5,] 20.92507
#> [6,] 15.06747
predict(butchered$fit, mtcars[11,], type = "ppd")
#> numeric(0)

Created on 2023-08-14 with reprex v2.0.2

@juliasilge
Copy link
Member Author

A new version of butcher with this bug fix is now on CRAN @JamesHWade. Thanks again for the report! 🙌

@JamesHWade
Copy link

That was fast! Thank you!!!

@github-actions
Copy link

github-actions bot commented Sep 7, 2023

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Sep 7, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants