Skip to content
/ wand Public

`wand` fits `torch` powered GAMs in a `tidymodels` friendly framework.

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md
Notifications You must be signed in to change notification settings

ben-e/wand

Repository files navigation

wand

R-CMD-check Codecov test coverage

wand is an R package which implements generalized additive models (GAMs) using torch as a backend, allowing users to fit semi-interpretable models using tools from the deep learning world. In addition to its GAM implementation, wand also provides tools for model interpretation. This package gets its name from the type of network it implements: wide and deep networks combine linear features (wide) with feature embeddings (deep).

wand draws heavy inspiration from the mgcv and brulee packages.

Installation

The development version of can be installed from GitHub with:

# install.packages("remotes")
remotes::install_github("ben-e/wand")

Note that wand is still in active development, and should not be considered stable.

Example

Setup:

suppressPackageStartupMessages({
  # Modeling
  library(wand)
  library(parsnip)
  library(recipes)
  library(workflows)
  library(yardstick)
  # Data manipulation
  library(dplyr)
  # Data viz
  library(ggplot2)
  library(ggraph)
})

set.seed(4242)

# Generate some synthetic data with a square shape and an uncorrelated feature
df <- expand.grid(x = seq(-2, 2, 0.1), 
                  y = seq(-2, 2, 0.1)) %>% 
  mutate(z = rnorm(n()),
         class = factor((abs(x) > 1 | abs(y) > 1), 
                        levels = c(T, F), 
                        labels = c("out", "in")))

The primary model fitting function, wand::wand has formula, x/y, and recipe user interfaces, and compatibility with the tidymodels ecosystem via the nn_additive_model specification and wand engine.

Using wand alone:

# Fit a model with one linear term and one smooth interaction
wand_fit <- wand(class ~ z + s_mlp(x, y),
                 data = df)

predict(wand_fit, df, type = "prob") %>% 
  bind_cols(df) %>% 
  roc_auc(class, .pred_out)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.918

Using wand with the tidymodels ecosystem:

wand_recipe <- recipe(class ~ x + y + z, 
                      data = df)

wand_model_spec <- nn_additive_mod("classification") %>% 
  set_engine(engine = "wand", 
             # note that all recipe steps are carried out before smoothing
             smooth_specs = list(xy = s_mlp(x, y)))

wand_wf <- workflow() %>% 
  add_recipe(wand_recipe) %>% 
  add_model(wand_model_spec)

wand_wf_fit <- fit(wand_wf, df)

predict(wand_wf_fit, df, type = "prob") %>% 
  bind_cols(df) %>% 
  roc_auc(class, .pred_out)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.914

The wand package also includes a few convenience functions for inspecting the fitted models.

First, we can take a look at the model graph to understand how data flows through the model. Note that wand only supplies a function to build a graph, the user can then use any tbl_graph or igraph compatible plotting method; I’ll use ggraph for this example.

wand_wf_fit %>% 
  extract_fit_engine() %>% 
  build_wand_graph() %>% 
  ggraph(layout = 'kk') +
  geom_edge_bend(aes(start_cap = label_rect(node1.name),
                     end_cap = label_rect(node2.name)),
                 arrow = arrow(length = unit(2, 'mm'))) +
  geom_node_label(aes(label = name))

Next, we should take a look at the model’s training loss to look for any hints of overfitting.

wand_wf_fit %>% 
  extract_fit_engine() %>% 
  wand_plot_loss()

Now we can actually inspect the results of model training by looking at coefficients for linear terms.

wand_wf_fit %>% 
  extract_fit_engine() %>% 
  coef()
#>                     out           in
#> (Intercept) -0.06551073 -0.003939368
#> z            0.10245648  0.002335455

Finally, for smooth terms, we can plot the actual smooth functions. In this case, the only smooth is two dimensional so we will plot a surface. The wand_plot_smooths function returns a list with an entry for each smooth, in this case we’re only interested in the first and only plot.

smooth_contours <- wand_wf_fit %>% 
  extract_fit_engine() %>%
  wand_plot_smooths(df)

smooth_contours[[1]] +
  annotate("rect", xmin = -1, xmax = 1, ymin = -1, ymax = 1,
           fill = alpha("grey", 0), colour = "black", 
           linetype = "dashed") +
  coord_fixed()

Feature Roadmap

Models

  • Implement linear regression and classification modules with torch.
  • Implement multilayer perceptron module with torch.
  • Implement wide and deep module with torch.
  • Explore implementation of other deep module architectures (RNN, CNN, transformer) with a focus on fitting mixed-type data (e.g. tabular + image data).
  • Add L1/L2 penalties to optimization, see brulee.
  • Explore use of MC dropout to generate prediction intervals.

User Interface

  • Implement formula, x/y, and recipe user interfaces via hardhat.
  • Add parsnip and workflows compatibility.
  • Add specification functions for smooths, similar to mgcv::s, for all interfaces.
  • Explore the use of constructor functions for new smooth specification functions, allowing users to specify their own smoothers using their own torch::nn_module’s.
  • Explore the use of a smooth function registry, similar to how parsnip tracks models.

Tuning

  • Add tune compatibility for model training parameters.
  • Add tune compatibility for output layer regularization, similar to brulee.
  • Explore possibility of “compiling” models such that parameters for all smooths are tune compatible, similar to the skorch Python package’s bridge between PyTorch and scikit-learn.

Interpretability Tools

  • Implement coef method to extract linear coefficients.
  • Implement summary method.
  • Implement plot methods for plotting training curves and smooth functions (ala mgcv).
  • Implement graph/module plot to show how features flow through modules.
  • Add compatibility with the marginaleffects package.
  • Add compatibility with the ggeffects package.

Documentation

  • Design hex.
  • Build pkgdown website.
  • Write vignette: how does wand compare to mgcv.
  • Write vignette: wand workflow guide.
  • Write vignette: a deep dive into wide and deep neural networks.
  • Document functions and add examples.
  • Add tests for each user interface.
  • Add tests for internal functions.
  • Ongoing Add more tests :)

Related Work

The ideas underpinning this package are not original, in particular this package draws from:

  • The original Wide & Deep Learning for Recommender Systems paper.
  • The formula interface and model specification options available in {mgcv}.
  • The interface between the {torch} package and {tidymodels} used by the {brulee} package.

wand is, of course, not the only implementation of wide and deep networks:

About

`wand` fits `torch` powered GAMs in a `tidymodels` friendly framework.

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE
MIT
LICENSE.md

Stars

Watchers

Forks

Releases

No releases published

Languages