diff --git a/DESCRIPTION b/DESCRIPTION index 202a1bf1..70e44dbb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: ingredients Title: Effects and Importances of Model Ingredients -Version: 0.3.2 +Version: 0.3.3 Authors@R: c(person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0001-8423-1823")), diff --git a/NAMESPACE b/NAMESPACE index c2df90fc..6c1e6de4 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -44,7 +44,7 @@ export(partial_dependency) export(plotD3) export(select_neighbours) export(select_sample) -export(show_aggreagated_profiles) +export(show_aggregated_profiles) export(show_observations) export(show_profiles) export(show_residuals) diff --git a/NEWS.md b/NEWS.md index 15619d35..8283b44a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +ingredients 0.3.3 +---------------------------------------------------------------- +* `show_profiles` and `show_residuals` functions extend Ceteris Paribus Plots. +* `show_aggreagated_profiles` is renamed to `show_aggregated_profiles` +* centering of ggplot2 title + ingredients 0.3.2 ---------------------------------------------------------------- * added new functions `describe()` and `print.ceteris_paribus_descriptions()` for text based descriptions of Ceteris Paribus explainers diff --git a/R/aggreagate_profiles.R b/R/aggreagate_profiles.R index a2e475aa..9c2a9b56 100644 --- a/R/aggreagate_profiles.R +++ b/R/aggreagate_profiles.R @@ -46,7 +46,7 @@ #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + #' show_rugs(cp_rf, variables = "age", color = "red") + -#' show_aggreagated_profiles(pdp_rf, size = 3, color = "_label_") +#' show_aggregated_profiles(pdp_rf, size = 3, color = "_label_") #' } #' @export aggregate_profiles <- function(x, ..., diff --git a/R/cluster_profiles.R b/R/cluster_profiles.R index 81ff7ea6..734a3031 100644 --- a/R/cluster_profiles.R +++ b/R/cluster_profiles.R @@ -51,10 +51,10 @@ #' head(clust_rf) #' #' plot(clust_rf, color = "_label_") + -#' show_aggreagated_profiles(pdp_rf, color = "black", size = 3) +#' show_aggregated_profiles(pdp_rf, color = "black", size = 3) #' #' plot(cp_rf, color = "grey", variables = "age") + -#' show_aggreagated_profiles(clust_rf, color = "_label_", size = 2) +#' show_aggregated_profiles(clust_rf, color = "_label_", size = 2) #' #' clust_rf <- cluster_profiles(cp_rf, k = 3, center = TRUE, variables = "age") #' head(clust_rf) diff --git a/R/plot_aggregated_ceteris_paribus_explainer.R b/R/plot_aggregated_ceteris_paribus_explainer.R index d7341aae..9e02c0bc 100644 --- a/R/plot_aggregated_ceteris_paribus_explainer.R +++ b/R/plot_aggregated_ceteris_paribus_explainer.R @@ -1,6 +1,6 @@ #' Adds a Layer with Aggregated Profiles #' -#' Function 'show_aggreagated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. +#' Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. #' #' @param x a ceteris paribus explainer produced with function `ceteris_paribus()` #' @param ... other explainers that shall be plotted together @@ -62,7 +62,7 @@ #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + #' show_rugs(cp_rf, variables = "age", color = "red") + -#' show_aggreagated_profiles(pdp_rf_p, size = 2) +#' show_aggregated_profiles(pdp_rf_p, size = 2) #' #' } #' @export diff --git a/R/show_aggregated_profiles.R b/R/show_aggregated_profiles.R index 0893d87c..ddc4f1c6 100644 --- a/R/show_aggregated_profiles.R +++ b/R/show_aggregated_profiles.R @@ -1,6 +1,6 @@ #' Adds a Layer with Aggregated Profiles #' -#' Function 'show_aggreagated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. +#' Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. #' #' @param x a ceteris paribus explainer produced with function `ceteris_paribus()` #' @param ... other explainers that shall be plotted together @@ -27,7 +27,7 @@ #' pdp_rf <- aggregate_profiles(cp_rf, variables = "age") #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + -#' show_aggreagated_profiles(pdp_rf, size = 3) +#' show_aggregated_profiles(pdp_rf, size = 3) #' #' \donttest{ #' library("randomForest") @@ -48,13 +48,13 @@ #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + #' show_rugs(cp_rf, variables = "age", color = "red") + -#' show_aggreagated_profiles(pdp_rf, size = 3) +#' show_aggregated_profiles(pdp_rf, size = 3) #' #' plot(pdp_rf, variables = "age", color = "grey") #' #' } #' @export -show_aggreagated_profiles <- function(x, ..., +show_aggregated_profiles <- function(x, ..., size = 0.5, alpha = 1, color = "#371ea3", diff --git a/R/show_rugs.R b/R/show_rugs.R index b7b13f4d..880640fd 100644 --- a/R/show_rugs.R +++ b/R/show_rugs.R @@ -30,7 +30,7 @@ #' pdp_rf <- aggregate_profiles(cp_rf, variables = "age") #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + -#' show_aggreagated_profiles(pdp_rf, size = 3) +#' show_aggregated_profiles(pdp_rf, size = 3) #' #' \donttest{ #' library("randomForest") @@ -51,7 +51,7 @@ #' plot(cp_rf, variables = "age") + #' show_observations(cp_rf, variables = "age") + #' show_rugs(cp_rf, variables = "age", color = "red") + -#' show_aggreagated_profiles(pdp_rf, size = 3) +#' show_aggregated_profiles(pdp_rf, size = 3) #' #' plot(pdp_rf, variables = "age", color = "grey") #' diff --git a/R/theme_drwhy.R b/R/theme_drwhy.R index 6c6cc1d1..54cf9a48 100644 --- a/R/theme_drwhy.R +++ b/R/theme_drwhy.R @@ -15,8 +15,8 @@ theme_drwhy <- function() { axis.line.y = element_line(color = "white"), axis.ticks.y = element_line(color = "white"), #axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1), - plot.title = element_text(color = "#371ea3", size = 16), - plot.subtitle = element_text(color = "#371ea3", size = 14), + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", size = 14, hjust = 0), axis.title = element_text(color = "#371ea3"), axis.text = element_text(color = "#371ea3", size = 10), strip.text = element_text(color = "#371ea3", size = 12, hjust = 0, margin = margin(0, 0, 1, 0)), @@ -36,8 +36,8 @@ theme_drwhy_vertical <- function() { legend.direction = "horizontal", legend.position = "top", axis.line.x = element_line(color = "white"), axis.ticks.x = element_line(color = "white"), - plot.title = element_text(color = "#371ea3", size = 16), - plot.subtitle = element_text(color = "#371ea3", size = 14), + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", size = 14, hjust = 0), #axis.line = element_line(color = "#371ea3", size = 0.5, linetype = 1), axis.title = element_text(color = "#371ea3"), axis.text = element_text(color = "#371ea3", size = 10), @@ -56,8 +56,8 @@ theme_drwhy_blank <- function() { panel.border = element_blank(), strip.background = element_blank(), plot.background = element_blank(), complete = TRUE, legend.direction = "horizontal", legend.position = "top", - plot.title = element_text(color = "#371ea3", size = 16), - plot.subtitle = element_text(color = "#371ea3", size = 14), + plot.title = element_text(color = "#371ea3", size = 16, hjust = 0), + plot.subtitle = element_text(color = "#371ea3", size = 14, hjust = 0), axis.line.x = element_line(color = "white"), axis.ticks.x = element_line(color = "white"), axis.title = element_text(color = "#371ea3"), diff --git a/docs/articles/vignette_titanic.html b/docs/articles/vignette_titanic.html index 76a34f72..c9f330bd 100644 --- a/docs/articles/vignette_titanic.html +++ b/docs/articles/vignette_titanic.html @@ -76,7 +76,7 @@

Survival on the RMS Titanic

Przemyslaw Biecek

-

2019-04-14

+

2019-04-24

@@ -116,8 +116,8 @@

#> Number of trees: 500 #> No. of variables tried at each split: 2 #> -#> Mean of squared residuals: 0.143218 -#> % Var explained: 34.66 +#> Mean of squared residuals: 0.1428399 +#> % Var explained: 34.83

@@ -138,18 +138,18 @@

fi_rf <- feature_importance(explain_titanic_rf) head(fi_rf)

#>       variable dropout_loss            label
-#> 1 _full_model_    0.3338510 Random Forest v7
-#> 2      country    0.3338510 Random Forest v7
-#> 3        parch    0.3447046 Random Forest v7
-#> 4        sibsp    0.3452359 Random Forest v7
-#> 5     embarked    0.3507052 Random Forest v7
-#> 6         fare    0.3740003 Random Forest v7
+#> 1 _full_model_ 0.3329331 Random Forest v7 +#> 2 country 0.3329331 Random Forest v7 +#> 3 parch 0.3441589 Random Forest v7 +#> 4 sibsp 0.3452181 Random Forest v7 +#> 5 embarked 0.3508888 Random Forest v7 +#> 6 fare 0.3744391 Random Forest v7
plot(fi_rf)

library("r2d3")
 plotD3(fi_rf)
-
- +
+

@@ -166,12 +166,12 @@

pp_age  <- partial_dependency(explain_titanic_rf, variables =  c("age", "fare"))
 head(pp_age)
#>   _vname_          _label_       _x_    _yhat_ _ids_
-#> 1    fare Random Forest v7 0.0000000 0.3419235     0
-#> 2     age Random Forest v7 0.1666667 0.5430471     0
-#> 3     age Random Forest v7 2.0000000 0.5745798     0
-#> 4     age Random Forest v7 4.0000000 0.5812799     0
-#> 5    fare Random Forest v7 6.1904000 0.3271178     0
-#> 6     age Random Forest v7 7.0000000 0.5388531     0
+#> 1 fare Random Forest v7 0.0000000 0.3060236 0 +#> 2 age Random Forest v7 0.1666667 0.5124979 0 +#> 3 age Random Forest v7 2.0000000 0.5525708 0 +#> 4 age Random Forest v7 4.0000000 0.5472667 0 +#> 5 fare Random Forest v7 6.1904000 0.2901396 0 +#> 6 age Random Forest v7 7.0000000 0.5094431 0

@@ -226,14 +226,14 @@

clust_rf <- cluster_profiles(sp_rf, k = 3) head(clust_rf)
#>   _vname_            _label_       _x_ _cluster_    _yhat_ _ids_
-#> 1    fare Random Forest v7_1 0.0000000         1 0.7424063     0
-#> 2   sibsp Random Forest v7_1 0.0000000         1 0.8697368     0
-#> 3   parch Random Forest v7_1 0.0000000         1 0.8240121     0
-#> 4     age Random Forest v7_1 0.1666667         1 0.8003449     0
-#> 5   parch Random Forest v7_1 0.2800000         1 0.8242621     0
-#> 6   sibsp Random Forest v7_1 1.0000000         1 0.8182423     0
+#> 1 fare Random Forest v7_1 0.0000000 1 0.7411514 0 +#> 2 sibsp Random Forest v7_1 0.0000000 1 0.8946084 0 +#> 3 parch Random Forest v7_1 0.0000000 1 0.8490757 0 +#> 4 age Random Forest v7_1 0.1666667 1 0.8119124 0 +#> 5 parch Random Forest v7_1 0.2800000 1 0.8493557 0 +#> 6 sibsp Random Forest v7_1 1.0000000 1 0.8359760 0
plot(sp_rf, alpha = 0.1) +
-  show_aggreagated_profiles(clust_rf, color = "_label_", size = 2)
+ show_aggregated_profiles(clust_rf, color = "_label_", size = 2)

@@ -255,7 +255,7 @@

#> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: -#> [1] r2d3_0.2.3 ingredients_0.3.2 randomForest_4.6-14 +#> [1] r2d3_0.2.3 ingredients_0.3.3 randomForest_4.6-14 #> [4] DALEX_0.3.1 #> #> loaded via a namespace (and not attached): @@ -263,7 +263,7 @@

#> [5] tools_3.5.0 digest_0.6.18 jsonlite_1.6 evaluate_0.13 #> [9] memoise_1.1.0 tibble_2.1.1 gtable_0.3.0 pkgconfig_2.0.2 #> [13] rlang_0.3.4 rstudioapi_0.10 commonmark_1.5 yaml_2.2.0 -#> [17] pkgdown_1.0.0 xfun_0.5 stringr_1.4.0 dplyr_0.8.0.1 +#> [17] pkgdown_1.0.0 xfun_0.6 stringr_1.4.0 dplyr_0.8.0.1 #> [21] roxygen2_6.1.1 xml2_1.2.0 knitr_1.22 htmlwidgets_1.3 #> [25] desc_1.2.0 fs_1.2.6 rprojroot_1.3-2 grid_3.5.0 #> [29] tidyselect_0.2.5 glue_1.3.1 R6_2.4.0 rmarkdown_1.12 diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-10-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-10-1.png index 3fd3c71a..ad0c4537 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-10-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-10-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-4-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-4-1.png index 5bd9db7d..b1c9cf18 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-4-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-4-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-5-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-5-1.png index ab6eeead..923ce464 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-5-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-5-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-6-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-6-1.png index e879de67..dec8dabc 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-6-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-6-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-7-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-7-1.png index a540255a..6e374175 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-7-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-7-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-8-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-8-1.png index 01755e0e..99760d7f 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-8-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-8-1.png differ diff --git a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-9-1.png b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-9-1.png index ad8314a4..eeffe1c4 100644 Binary files a/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-9-1.png and b/docs/articles/vignette_titanic_files/figure-html/unnamed-chunk-9-1.png differ diff --git a/docs/news/index.html b/docs/news/index.html index cc0ef6d6..4be7a9b2 100644 --- a/docs/news/index.html +++ b/docs/news/index.html @@ -105,6 +105,17 @@

Change log

+
+

+ingredients 0.3.3

+
    +
  • +show_profiles and show_residuals functions extend Ceteris Paribus Plots.
  • +
  • +show_aggreagated_profiles is renamed to show_aggregated_profiles +
  • +
+

ingredients 0.3.2

@@ -163,6 +174,7 @@

Contents

#> Error in plot(cp_rf, variables = "age"): object 'cp_rf' not found

plot(clust_rf, color = "_label_") + - show_aggreagated_profiles(pdp_rf, color = "black", size = 3)
+ show_aggregated_profiles(pdp_rf, color = "black", size = 3)
plot(cp_rf, color = "grey", variables = "age") + - show_aggreagated_profiles(clust_rf, color = "_label_", size = 2)
+ show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
clust_rf <- cluster_profiles(cp_rf, k = 3, center = TRUE, variables = "age") head(clust_rf)
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_ #> 1 age Random Forest v7_1 0.1666667 1 0.7170289 0 diff --git a/docs/reference/feature_importance-5.png b/docs/reference/feature_importance-5.png index e56d8466..5ff9b608 100644 Binary files a/docs/reference/feature_importance-5.png and b/docs/reference/feature_importance-5.png differ diff --git a/docs/reference/feature_importance.html b/docs/reference/feature_importance.html index 7ccde646..decb2a14 100644 --- a/docs/reference/feature_importance.html +++ b/docs/reference/feature_importance.html @@ -219,8 +219,7 @@

Examp #> 4 salary 0.4149351 lm #> 5 gender 0.4151134 lm #> 6 evaluation 0.4291616 lm

plot(vd_glm)
-library("xgboost") -model_martix_train <- model.matrix(status == "fired" ~ . -1, HR) +library("xgboost")
#> Warning: package ‘xgboost’ was built under R version 3.5.2
model_martix_train <- model.matrix(status == "fired" ~ . -1, HR) data_train <- xgb.DMatrix(model_martix_train, label = HR$status == "fired") param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2, objective = "binary:logistic", eval_metric = "auc") @@ -231,10 +230,10 @@

Examp head(vd_xgb)

#> variable dropout_loss label #> 1 _full_model_ 0.3142729 xgboost #> 2 gendermale 0.3142729 xgboost -#> 3 evaluation 0.3372445 xgboost -#> 4 genderfemale 0.3909165 xgboost -#> 5 age 0.4038176 xgboost -#> 6 salary 0.4158528 xgboost
plot(vd_xgb, vd_glm)
+#> 3 evaluation 0.3382139 xgboost +#> 4 genderfemale 0.3916174 xgboost +#> 5 age 0.3979185 xgboost +#> 6 salary 0.4117663 xgboost
plot(vd_xgb, vd_glm)
+#> no 1321 97 0.06840621 +#> yes 305 376 0.44787078
explain_titanic_rf <- explain(model_titanic_rf, data = titanic[,-9], y = titanic$survived, diff --git a/docs/reference/plot.aggregated_profiles_explainer-1.png b/docs/reference/plot.aggregated_profiles_explainer-1.png index a29f9a0d..6f82bbc9 100644 Binary files a/docs/reference/plot.aggregated_profiles_explainer-1.png and b/docs/reference/plot.aggregated_profiles_explainer-1.png differ diff --git a/docs/reference/plot.aggregated_profiles_explainer-2.png b/docs/reference/plot.aggregated_profiles_explainer-2.png index 50f0cb71..140d5403 100644 Binary files a/docs/reference/plot.aggregated_profiles_explainer-2.png and b/docs/reference/plot.aggregated_profiles_explainer-2.png differ diff --git a/docs/reference/plot.aggregated_profiles_explainer-3.png b/docs/reference/plot.aggregated_profiles_explainer-3.png index 02194bb4..632c1f0a 100644 Binary files a/docs/reference/plot.aggregated_profiles_explainer-3.png and b/docs/reference/plot.aggregated_profiles_explainer-3.png differ diff --git a/docs/reference/plot.aggregated_profiles_explainer-4.png b/docs/reference/plot.aggregated_profiles_explainer-4.png index 53703459..b455f896 100644 Binary files a/docs/reference/plot.aggregated_profiles_explainer-4.png and b/docs/reference/plot.aggregated_profiles_explainer-4.png differ diff --git a/docs/reference/plot.aggregated_profiles_explainer.html b/docs/reference/plot.aggregated_profiles_explainer.html index 0b68f89e..3a7d863d 100644 --- a/docs/reference/plot.aggregated_profiles_explainer.html +++ b/docs/reference/plot.aggregated_profiles_explainer.html @@ -104,7 +104,7 @@

Adds a Layer with Aggregated Profiles

-

Function 'show_aggreagated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'.

+

Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'.

# S3 method for aggregated_profiles_explainer
@@ -172,12 +172,12 @@ 

Examp pdp_rf_a<- accumulated_dependency(explain_titanic_glm, N = 50) pdp_rf_a$`_label_` <- "RF_accumulated" head(pdp_rf_p)
#> _vname_ _label_ _x_ _yhat_ _ids_ -#> 1 fare RF_partial 0.0000000 0.2874121 0 -#> 2 sibsp RF_partial 0.0000000 0.3200546 0 -#> 3 parch RF_partial 0.0000000 0.3200546 0 -#> 4 age RF_partial 0.1666667 0.3693770 0 -#> 5 parch RF_partial 0.2800000 0.3200546 0 -#> 6 sibsp RF_partial 1.0000000 0.3200546 0
plot(pdp_rf_p, pdp_rf_l, pdp_rf_a, color = "_label_")
+#> 1 fare RF_partial 0.0000000 0.3138426 0 +#> 2 sibsp RF_partial 0.0000000 0.3297662 0 +#> 3 parch RF_partial 0.0000000 0.3297662 0 +#> 4 age RF_partial 0.1666667 0.3715088 0 +#> 5 parch RF_partial 0.2800000 0.3297662 0 +#> 6 sibsp RF_partial 1.0000000 0.3297662 0
plot(pdp_rf_p, pdp_rf_l, pdp_rf_a, color = "_label_")
library("randomForest") titanic <- na.omit(titanic) model_titanic_rf <- randomForest(survived == "yes" ~ gender + age + class + embarked + @@ -188,7 +188,7 @@

Examp #> Number of trees: 500 #> No. of variables tried at each split: 2 #> -#> Mean of squared residuals: 0.1428164 +#> Mean of squared residuals: 0.1428201 #> % Var explained: 34.84

explain_titanic_rf <- explain(model_titanic_rf, data = titanic[,-9], @@ -206,12 +206,12 @@

Examp #> 1594 male 44 victualling crew Southampton England 0.0000 0 0 #> 1594.1 female 44 victualling crew Southampton England 0.0000 0 0 #> _yhat_ _vname_ _ids_ _label_ -#> 1960 0.1583068 gender 1960 Random Forest v7 -#> 1960.1 0.7717861 gender 1960 Random Forest v7 -#> 883 0.1193948 gender 883 Random Forest v7 -#> 883.1 0.4775326 gender 883 Random Forest v7 -#> 1594 0.1414053 gender 1594 Random Forest v7 -#> 1594.1 0.8851591 gender 1594 Random Forest v7 +#> 1960 0.1588032 gender 1960 Random Forest v7 +#> 1960.1 0.7758566 gender 1960 Random Forest v7 +#> 883 0.1140935 gender 883 Random Forest v7 +#> 883.1 0.4709515 gender 883 Random Forest v7 +#> 1594 0.1430898 gender 1594 Random Forest v7 +#> 1594.1 0.8663689 gender 1594 Random Forest v7 #> #> #> Top observations: @@ -223,12 +223,12 @@

Examp #> 1445 male 20 victualling crew Southampton England 0.0000 0 0 #> 1432 male 33 victualling crew Belfast England 0.0000 0 0 #> _yhat_ _label_ _ids_ -#> 1960 0.77178609 Random Forest v7 1 -#> 883 0.11939482 Random Forest v7 2 -#> 1594 0.14140527 Random Forest v7 3 -#> 1131 0.07144135 Random Forest v7 4 -#> 1445 0.22359315 Random Forest v7 5 -#> 1432 0.17579153 Random Forest v7 6

+#> 1960 0.77585659 Random Forest v7 1 +#> 883 0.11409353 Random Forest v7 2 +#> 1594 0.14308976 Random Forest v7 3 +#> 1131 0.07799722 Random Forest v7 4 +#> 1445 0.21975678 Random Forest v7 5 +#> 1432 0.16557374 Random Forest v7 6
pdp_rf_p <- aggregate_profiles(cp_rf, variables = "age", type = "partial") pdp_rf_p$`_label_` <- "RF_partial" pdp_rf_c <- aggregate_profiles(cp_rf, variables = "age", type = "conditional") @@ -236,16 +236,16 @@

Examp pdp_rf_a <- aggregate_profiles(cp_rf, variables = "age", type = "accumulated") pdp_rf_a$`_label_` <- "RF_accumulated" head(pdp_rf_p)

#> _vname_ _label_ _x_ _yhat_ _ids_ -#> 1 age RF_partial 0.1666667 0.5421567 0 -#> 2 age RF_partial 2.0000000 0.5772933 0 -#> 3 age RF_partial 4.0000000 0.5815934 0 -#> 4 age RF_partial 7.0000000 0.5527687 0 -#> 5 age RF_partial 9.0000000 0.5497517 0 -#> 6 age RF_partial 13.0000000 0.5047613 0
plot(pdp_rf_p)
plot(pdp_rf_p, pdp_rf_c, pdp_rf_a)
+#> 1 age RF_partial 0.1666667 0.5304977 0 +#> 2 age RF_partial 2.0000000 0.5554616 0 +#> 3 age RF_partial 4.0000000 0.5539585 0 +#> 4 age RF_partial 7.0000000 0.5210260 0 +#> 5 age RF_partial 9.0000000 0.5207317 0 +#> 6 age RF_partial 13.0000000 0.4811878 0
plot(pdp_rf_p)
plot(pdp_rf_p, pdp_rf_c, pdp_rf_a)
plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf_p, size = 2)
+ show_aggregated_profiles(pdp_rf_p, size = 2)

#> variable dropout_loss label #> 1 _full_model_ 0.3142729 xgboost #> 2 gendermale 0.3142729 xgboost -#> 3 evaluation 0.3401847 xgboost -#> 4 genderfemale 0.3937621 xgboost -#> 5 age 0.4006677 xgboost -#> 6 salary 0.4146221 xgboost
+#> 3 evaluation 0.3375758 xgboost +#> 4 genderfemale 0.3876988 xgboost +#> 5 age 0.3971998 xgboost +#> 6 salary 0.4113673 xgboost
plot(vd_glm, vd_xgb, bar_width = 5)
diff --git a/docs/reference/plotD3.html b/docs/reference/plotD3.html index 836a05a0..4a0c079a 100644 --- a/docs/reference/plotD3.html +++ b/docs/reference/plotD3.html @@ -157,20 +157,23 @@

Value

Examples

library("DALEX") library("ingredients") -library("caret") - +library("caret")
#> Loading required package: lattice
#> Loading required package: ggplot2
#> Warning: package ‘ggplot2’ was built under R version 3.5.2
#> RStudio Community is a great place to get help: +#> https://community.rstudio.com/c/tidyverse.
#> +#> Attaching package: ‘ggplot2’
#> The following object is masked from ‘package:randomForest’: +#> +#> margin
rf_model <- train(m2.price~., data = apartments, method="rf", ntree = 100) explainer_rf <- explain(rf_model, data = apartments_test[,2:6], y = apartments_test$m2.price, label="rf") fi_rf <- feature_importance(explainer_rf, loss_function = loss_root_mean_square) head(fi_rf)
#> variable dropout_loss label -#> 1 _full_model_ 157.8401 rf -#> 2 no.rooms 165.1408 rf -#> 3 construction.year 369.3096 rf -#> 4 floor 418.9614 rf -#> 5 surface 602.1783 rf -#> 6 district 961.8930 rf
plotD3(fi_rf) +#> 1 _full_model_ 159.6260 rf +#> 2 no.rooms 169.6482 rf +#> 3 construction.year 368.1726 rf +#> 4 floor 417.9736 rf +#> 5 surface 592.5975 rf +#> 6 district 960.6579 rf
plotD3(fi_rf) svm_model <- train(m2.price~., data = apartments, method="svmLinear") explainer_svm <- explain(svm_model, data = apartments_test[,2:6], @@ -179,11 +182,11 @@

Examp head(fi_svm)

#> variable dropout_loss label #> 1 _full_model_ 301.5889 svm -#> 2 construction.year 301.5907 svm -#> 3 no.rooms 317.8417 svm -#> 4 floor 502.6867 svm -#> 5 surface 615.7759 svm -#> 6 district 1008.2428 svm
plotD3(fi_rf, fi_svm) +#> 2 construction.year 301.6034 svm +#> 3 no.rooms 316.6314 svm +#> 4 floor 498.8596 svm +#> 5 surface 610.8330 svm +#> 6 district 995.0613 svm
plotD3(fi_rf, fi_svm) plotD3(fi_rf, fi_svm, split = "feature") diff --git a/docs/reference/show_aggregated_profiles-1.png b/docs/reference/show_aggregated_profiles-1.png new file mode 100644 index 00000000..c6dba876 Binary files /dev/null and b/docs/reference/show_aggregated_profiles-1.png differ diff --git a/docs/reference/show_aggregated_profiles-2.png b/docs/reference/show_aggregated_profiles-2.png new file mode 100644 index 00000000..f958f17d Binary files /dev/null and b/docs/reference/show_aggregated_profiles-2.png differ diff --git a/docs/reference/show_aggregated_profiles-3.png b/docs/reference/show_aggregated_profiles-3.png new file mode 100644 index 00000000..dba088b5 Binary files /dev/null and b/docs/reference/show_aggregated_profiles-3.png differ diff --git a/docs/reference/show_aggregated_profiles.html b/docs/reference/show_aggregated_profiles.html new file mode 100644 index 00000000..c04eb88f --- /dev/null +++ b/docs/reference/show_aggregated_profiles.html @@ -0,0 +1,259 @@ + + + + + + + + +Adds a Layer with Aggregated Profiles — show_aggregated_profiles • ingredients + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + +
+ +
+
+ + + +

Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'.

+ + +
show_aggregated_profiles(x, ..., size = 0.5, alpha = 1,
+  color = "#371ea3", variables = NULL)
+ +

Arguments

+ + + + + + + + + + + + + + + + + + + + + + + + + + +
x

a ceteris paribus explainer produced with function `ceteris_paribus()`

...

other explainers that shall be plotted together

size

a numeric. Size of lines to be plotted

alpha

a numeric between 0 and 1. Opacity of lines

color

a character. Either name of a color or name of a variable that should be used for coloring

variables

if not NULL then only `variables` will be presented

+ +

Value

+ +

a ggplot2 layer

+ + +

Examples

+
library("DALEX") +# Toy examples, because CRAN angels ask for them +titanic <- na.omit(titanic) +selected_passangers <- select_sample(titanic, n = 100) + +model_titanic_glm <- glm(survived == "yes" ~ gender + age + fare, + data = titanic, family = "binomial") + +explain_titanic_glm <- explain(model_titanic_glm, + data = titanic[,-9], + y = titanic$survived == "yes") + +cp_rf <- ceteris_paribus(explain_titanic_glm, selected_passangers) +pdp_rf <- aggregate_profiles(cp_rf, variables = "age") +plot(cp_rf, variables = "age") + + show_observations(cp_rf, variables = "age") + + show_aggregated_profiles(pdp_rf, size = 3)
+
library("randomForest") + model_titanic_rf <- randomForest(survived ~ gender + age + class + embarked + + fare + sibsp + parch, data = titanic) + model_titanic_rf
#> +#> Call: +#> randomForest(formula = survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic) +#> Type of random forest: classification +#> Number of trees: 500 +#> No. of variables tried at each split: 2 +#> +#> OOB estimate of error rate: 19.2% +#> Confusion matrix: +#> no yes class.error +#> no 1317 101 0.07122708 +#> yes 302 379 0.44346549
+ explain_titanic_rf <- explain(model_titanic_rf, + data = titanic[,-9], + y = titanic$survived) + +cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers) +cp_rf
#> Top profiles : +#> gender age class embarked country fare sibsp parch +#> 1960 male 36 victualling crew Southampton England 0.0000 0 0 +#> 1960.1 female 36 victualling crew Southampton England 0.0000 0 0 +#> 883 male 21 3rd Southampton Sweden 7.1701 0 0 +#> 883.1 female 21 3rd Southampton Sweden 7.1701 0 0 +#> 1594 male 44 victualling crew Southampton England 0.0000 0 0 +#> 1594.1 female 44 victualling crew Southampton England 0.0000 0 0 +#> _yhat_ _vname_ _ids_ _label_ +#> 1960 0.002 gender 1960 randomForest +#> 1960.1 0.862 gender 1960 randomForest +#> 883 0.002 gender 883 randomForest +#> 883.1 0.538 gender 883 randomForest +#> 1594 0.002 gender 1594 randomForest +#> 1594.1 0.950 gender 1594 randomForest +#> +#> +#> Top observations: +#> gender age class embarked country fare sibsp parch _yhat_ +#> 1960 female 36 victualling crew Southampton England 0.0000 0 0 0.862 +#> 883 male 21 3rd Southampton Sweden 7.1701 0 0 0.002 +#> 1594 male 44 victualling crew Southampton England 0.0000 0 0 0.002 +#> 1131 male 37 3rd Southampton Croatia 8.1303 0 0 0.004 +#> 1445 male 20 victualling crew Southampton England 0.0000 0 0 0.002 +#> 1432 male 33 victualling crew Belfast England 0.0000 0 0 0.002 +#> _label_ _ids_ +#> 1960 randomForest 1 +#> 883 randomForest 2 +#> 1594 randomForest 3 +#> 1131 randomForest 4 +#> 1445 randomForest 5 +#> 1432 randomForest 6
+pdp_rf <- aggregate_profiles(cp_rf, variables = "age") +head(pdp_rf)
#> _vname_ _label_ _x_ _yhat_ _ids_ +#> 1 age randomForest 0.1666667 0.51946 0 +#> 2 age randomForest 2.0000000 0.56250 0 +#> 3 age randomForest 4.0000000 0.56690 0 +#> 4 age randomForest 7.0000000 0.51762 0 +#> 5 age randomForest 9.0000000 0.51634 0 +#> 6 age randomForest 13.0000000 0.44654 0
+plot(cp_rf, variables = "age") + + show_observations(cp_rf, variables = "age") + + show_rugs(cp_rf, variables = "age", color = "red") + + show_aggregated_profiles(pdp_rf, size = 3)
+plot(pdp_rf, variables = "age", color = "grey")
+
+
+ +
+ +
+ + +
+

Site built with pkgdown.

+
+ +
+
+ + + + diff --git a/docs/reference/show_residuals.html b/docs/reference/show_residuals.html index 23b00bed..ffb49ff1 100644 --- a/docs/reference/show_residuals.html +++ b/docs/reference/show_residuals.html @@ -185,7 +185,7 @@

Examp variable_splits = list(age = seq(0,70, length.out = 1000))) plot(cp_johny, variables = "age", size = 1.5, color = "#8bdcbe") + - show_profiles(cp_neighbours, variables = "age", color = "#ceced9") + + show_profiles(cp_neighbours, variables = "age", color = "#ceced9") + show_observations(cp_johny, variables = "age", size = 5, color = "#371ea3") + show_residuals(cp_neighbours, variables = "age")

diff --git a/docs/reference/show_rugs.html b/docs/reference/show_rugs.html index 12b93194..6b1b2e49 100644 --- a/docs/reference/show_rugs.html +++ b/docs/reference/show_rugs.html @@ -170,7 +170,7 @@

Examp pdp_rf <- aggregate_profiles(cp_rf, variables = "age") plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + - show_aggreagated_profiles(pdp_rf, size = 3)
+ show_aggregated_profiles(pdp_rf, size = 3)
library("randomForest") model_titanic_rf <- randomForest(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic) @@ -234,7 +234,7 @@

Examp plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf, size = 3)

+ show_aggregated_profiles(pdp_rf, size = 3)
plot(pdp_rf, variables = "age", color = "grey")
diff --git a/man/aggregate_profiles.Rd b/man/aggregate_profiles.Rd index 47537447..5b53a19b 100644 --- a/man/aggregate_profiles.Rd +++ b/man/aggregate_profiles.Rd @@ -59,7 +59,7 @@ head(pdp_rf) plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf, size = 3, color = "_label_") + show_aggregated_profiles(pdp_rf, size = 3, color = "_label_") } } \references{ diff --git a/man/cluster_profiles.Rd b/man/cluster_profiles.Rd index 1d238ece..270d7d95 100644 --- a/man/cluster_profiles.Rd +++ b/man/cluster_profiles.Rd @@ -66,10 +66,10 @@ clust_rf <- cluster_profiles(cp_rf, k = 3, variables = "age") head(clust_rf) plot(clust_rf, color = "_label_") + - show_aggreagated_profiles(pdp_rf, color = "black", size = 3) + show_aggregated_profiles(pdp_rf, color = "black", size = 3) plot(cp_rf, color = "grey", variables = "age") + - show_aggreagated_profiles(clust_rf, color = "_label_", size = 2) + show_aggregated_profiles(clust_rf, color = "_label_", size = 2) clust_rf <- cluster_profiles(cp_rf, k = 3, center = TRUE, variables = "age") head(clust_rf) diff --git a/man/plot.aggregated_profiles_explainer.Rd b/man/plot.aggregated_profiles_explainer.Rd index a611e090..cc30a202 100644 --- a/man/plot.aggregated_profiles_explainer.Rd +++ b/man/plot.aggregated_profiles_explainer.Rd @@ -27,7 +27,7 @@ a ggplot2 layer } \description{ -Function 'show_aggreagated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. +Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. } \examples{ library("DALEX") @@ -78,7 +78,7 @@ plot(pdp_rf_p, pdp_rf_c, pdp_rf_a) plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf_p, size = 2) + show_aggregated_profiles(pdp_rf_p, size = 2) } } diff --git a/man/show_aggreagated_profiles.Rd b/man/show_aggregated_profiles.Rd similarity index 86% rename from man/show_aggreagated_profiles.Rd rename to man/show_aggregated_profiles.Rd index 324d6add..0646b4b9 100644 --- a/man/show_aggreagated_profiles.Rd +++ b/man/show_aggregated_profiles.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/show_aggregated_profiles.R -\name{show_aggreagated_profiles} -\alias{show_aggreagated_profiles} +\name{show_aggregated_profiles} +\alias{show_aggregated_profiles} \title{Adds a Layer with Aggregated Profiles} \usage{ -show_aggreagated_profiles(x, ..., size = 0.5, alpha = 1, +show_aggregated_profiles(x, ..., size = 0.5, alpha = 1, color = "#371ea3", variables = NULL) } \arguments{ @@ -24,7 +24,7 @@ show_aggreagated_profiles(x, ..., size = 0.5, alpha = 1, a ggplot2 layer } \description{ -Function 'show_aggreagated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. +Function 'show_aggregated_profiles' adds a layer to a plot created with 'plot.ceteris_paribus_explainer'. } \examples{ library("DALEX") @@ -43,7 +43,7 @@ cp_rf <- ceteris_paribus(explain_titanic_glm, selected_passangers) pdp_rf <- aggregate_profiles(cp_rf, variables = "age") plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + - show_aggreagated_profiles(pdp_rf, size = 3) + show_aggregated_profiles(pdp_rf, size = 3) \donttest{ library("randomForest") @@ -64,7 +64,7 @@ head(pdp_rf) plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf, size = 3) + show_aggregated_profiles(pdp_rf, size = 3) plot(pdp_rf, variables = "age", color = "grey") diff --git a/man/show_rugs.Rd b/man/show_rugs.Rd index aee68dce..646b8e79 100644 --- a/man/show_rugs.Rd +++ b/man/show_rugs.Rd @@ -48,7 +48,7 @@ cp_rf <- ceteris_paribus(explain_titanic_glm, selected_passangers) pdp_rf <- aggregate_profiles(cp_rf, variables = "age") plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + - show_aggreagated_profiles(pdp_rf, size = 3) + show_aggregated_profiles(pdp_rf, size = 3) \donttest{ library("randomForest") @@ -69,7 +69,7 @@ head(pdp_rf) plot(cp_rf, variables = "age") + show_observations(cp_rf, variables = "age") + show_rugs(cp_rf, variables = "age", color = "red") + - show_aggreagated_profiles(pdp_rf, size = 3) + show_aggregated_profiles(pdp_rf, size = 3) plot(pdp_rf, variables = "age", color = "grey") diff --git a/tests/testthat/test_cluster_profiles.R b/tests/testthat/test_cluster_profiles.R index 4b9bf5a8..78ce7d5e 100644 --- a/tests/testthat/test_cluster_profiles.R +++ b/tests/testthat/test_cluster_profiles.R @@ -25,10 +25,10 @@ test_that("plot cluster_profiles",{ expect_true("aggregated_profiles_explainer" %in% class(clust_rf)) pl1 <- plot(clust_rf, color = "_label_") + - show_aggreagated_profiles(pdp_rf, color = "black", size = 3) + show_aggregated_profiles(pdp_rf, color = "black", size = 3) pl2 <- plot(cp_rf, color = "grey", variables = "Age") + - show_aggreagated_profiles(clust_rf, color = "_label_", size = 2) + show_aggregated_profiles(clust_rf, color = "_label_", size = 2) pl3 <- plot(cp_rf, variables = "Embarked", only_numerical = FALSE) diff --git a/tests/testthat/test_single_variable.R b/tests/testthat/test_single_variable.R index b3bae038..ad5cfdec 100644 --- a/tests/testthat/test_single_variable.R +++ b/tests/testthat/test_single_variable.R @@ -9,7 +9,7 @@ test_that("test plot",{ pl <- plot(cp_rf, variables = "Age") + show_observations(cp_rf, variables = "Age") + show_rugs(cp_rf, variables = "Age", color = "red") + - show_aggreagated_profiles(pdp_rf, size = 2) + show_aggregated_profiles(pdp_rf, size = 2) expect_true("gg" %in% class(pl)) }) diff --git a/vignettes/vignette_titanic.Rmd b/vignettes/vignette_titanic.Rmd index 41b6a0bf..cc3f50ad 100644 --- a/vignettes/vignette_titanic.Rmd +++ b/vignettes/vignette_titanic.Rmd @@ -142,7 +142,7 @@ sp_rf <- ceteris_paribus(explain_titanic_rf, passangers) clust_rf <- cluster_profiles(sp_rf, k = 3) head(clust_rf) plot(sp_rf, alpha = 0.1) + - show_aggreagated_profiles(clust_rf, color = "_label_", size = 2) + show_aggregated_profiles(clust_rf, color = "_label_", size = 2) ```