Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does the r session break when I try to make a prediction with a lightGBM model saved as ".rds"? #33

Open
rafzamb opened this issue Nov 17, 2020 · 2 comments

Comments

@rafzamb
Copy link

rafzamb commented Nov 17, 2020

After fitting a lightGBM model with tidymodels and treesnip, I can take the fitted workflow and make predictions on new data without any problems. However, after saving the adjusted model in ".rds" format, closing the session and loading the ".rds" model in a new session, when I try to generate a prediction the R session breaks.

This only happens with the lightGBM model, for any other type of model this inconvenience does not happen. Here is a reproducible example:

The lightGBM model was installed as follows

PKG_URL <- "https://github.com/microsoft/LightGBM/releases/download/v3.0.0/lightgbm-3.0.0-r-cran.tar.gz"
remotes::install_url(PKG_URL)
library(dplyr)
library(parsnip)
library(rsample)
library(yardstick)
library(recipes)
library(workflows)
library(dials)
library(tune)
library(treesnip)

data = bind_rows(iris, iris, iris, iris, iris, iris, iris)

set.seed(2)
initial_split <- initial_split(data, p = 0.75)
train <- training(initial_split)
test  <- testing(initial_split)
initial_split
#> <Analysis/Assess/Total>
#> <788/262/1050>

recipe <- recipe(Sepal.Length ~ ., data = data) %>%
  step_dummy(all_nominal(), -all_outcomes())

model <- boost_tree(
  mtry = 3, 
  trees = 1000, 
  min_n = tune(), 
  tree_depth = tune(),
  loss_reduction = tune(), 
  learn_rate = tune(), 
  sample_size = 0.75
) %>% 
  set_mode("regression") %>%
  set_engine("lightgbm")

wf <- workflow() %>% 
  add_model(model) %>% 
  add_recipe(recipe)

wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_dummy()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = 3
#>   trees = 1000
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#>   loss_reduction = tune()
#>   sample_size = 0.75
#> 
#> Computational engine: lightgbm

# resamples
resamples <- vfold_cv(train, v = 3)

# grid
grid <- parameters(model) %>% 
  finalize(train) %>% 
  grid_random(size = 10)

head(grid)
#> # A tibble: 6 x 4
#>   min_n tree_depth   learn_rate loss_reduction
#>   <int>      <int>        <dbl>          <dbl>
#> 1     2          4 0.000282          0.0000402
#> 2    13         10 0.00333          13.0      
#> 3    32         11 0.000000585       0.000106 
#> 4    32          7 0.000258          0.163    
#> 5    31         13 0.0000000881      0.000479 
#> 6    19         14 0.000000167       0.00174


# grid search
tune_grid <- wf %>%
  tune_grid(
    resamples = resamples,
    grid = grid,
    control = control_grid(verbose = FALSE),
    metrics = metric_set(rmse)
  )


# select best hiperparameter found
best_params <- select_best(tune_grid, "rmse")
wf <- wf %>% finalize_workflow(best_params)

wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_dummy()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = 3
#>   trees = 1000
#>   min_n = 13
#>   tree_depth = 10
#>   learn_rate = 0.00333377440294304
#>   loss_reduction = 13.0320661814971
#>   sample_size = 0.75
#> 
#> Computational engine: lightgbm

# last fit
last_fit <- last_fit(wf,initial_split)

# metrics
collect_metrics(last_fit)
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard       0.380
#> 2 rsq     standard       0.837

# fit to predict new data
model_fit <-  fit(wf, data)
#> [LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000020 seconds.
#> You can set `force_row_wise=true` to remove the overhead.
#> And if memory is not enough, you can set `force_col_wise=true`.
#> [LightGBM] [Info] Total Bins 95
#> [LightGBM] [Info] Number of data points in the train set: 1050, number of used features: 5
#> [LightGBM] [Info] Start training from score 5.843333
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain, best gain: -inf
.................................................................................
predicciones = predict(model_fit, iris)

head(predicciones)
#> # A tibble: 6 x 1
#>   .pred
#>   <dbl>
#> 1  5.13
#> 2  5.12
#> 3  5.12
#> 4  5.12
#> 5  5.13
#> 6  5.25

# save model
saveRDS(model_fit, "model_fit.rds")

After saving the model, I close the session and in a new session load the model.

model <- readRDS("model_fit.rds")

predicciones = predict(model, iris)

When I try to generate the prediction the r session breaks. An alternative that works mostly is to pull the workflow, extract the fit and save with the model's own method, however I lose all the workflow stored in the work_flow. I will be attentive to any help or suggestion.

pull_lightgbm = pull_workflow_fit(model_fit)


library(lightgbm)

lgb.save(pull_lightgbm$fit, "lightgbm.model")

model = lgb.load("lightgbm.model")
sessionInfo()
#> R version 4.0.3 (2020-10-10)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] treesnip_0.1.0.9000 tune_0.1.1          dials_0.0.9        
#>  [4] scales_1.1.1        workflows_0.2.1     recipes_0.1.14     
#>  [7] yardstick_0.0.7     rsample_0.0.8       parsnip_0.1.4      
#> [10] dplyr_1.0.2        
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.5         lubridate_1.7.9    lattice_0.20-41    tidyr_1.1.2       
#>  [5] listenv_0.8.0      class_7.3-17       assertthat_0.2.1   digest_0.6.27     
#>  [9] ipred_0.9-9        foreach_1.5.1      parallelly_1.21.0  R6_2.5.0          
#> [13] plyr_1.8.6         evaluate_0.14      ggplot2_3.3.2      highr_0.8         
#> [17] pillar_1.4.6       rlang_0.4.8        DiceDesign_1.8-1   furrr_0.2.1       
#> [21] rpart_4.1-15       Matrix_1.2-18      rmarkdown_2.5      splines_4.0.3     
#> [25] gower_0.2.2        stringr_1.4.0      munsell_0.5.0      compiler_4.0.3    
#> [29] xfun_0.19          pkgconfig_2.0.3    globals_0.13.1     htmltools_0.5.0   
#> [33] nnet_7.3-14        tidyselect_1.1.0   tibble_3.0.4       prodlim_2019.11.13
#> [37] codetools_0.2-16   GPfit_1.0-8        fansi_0.4.1        future_1.20.1     
#> [41] crayon_1.3.4       withr_2.3.0        MASS_7.3-53        grid_4.0.3        
#> [45] gtable_0.3.0       lifecycle_0.2.0    magrittr_1.5       pROC_1.16.2       
#> [49] cli_2.1.0          stringi_1.5.3      timeDate_3043.102  ellipsis_0.3.1    
#> [53] lhs_1.1.1          generics_0.1.0     vctrs_0.3.4        lava_1.6.8.1      
#> [57] iterators_1.0.13   tools_4.0.3        glue_1.4.2         purrr_0.3.4       
#> [61] parallel_4.0.3     survival_3.2-7     yaml_2.2.1         colorspace_1.4-1  
#> [65] knitr_1.30

Created on 2020-11-16 by the reprex package (v0.3.0)

@Athospd
Copy link
Member

Athospd commented Nov 17, 2020

thank you @rafzamb , I was able to reproduce the crash. It seems that some object from lightgbm is lost when the session closes.
One workaround would be save both workflow and lgb model (as you did) and then mount it back like:

model <- readRDS("model_fit.rds")
model_lgb <- lightgbm::lgb.load("lightgbm.model")
model$fit$fit$fit <- model_lgb

It is obviously not ideal. But at the same time it sounds pretty odd to consider saveRDS() to perform anything else that it is supposed to do (like store a side file such as a lgb.booster). We have to think about a good way to solve this issue!

@rafzamb
Copy link
Author

rafzamb commented Nov 19, 2020

Perfect @Athospd , thank you very much for the answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants