wand
is an R package which implements generalized additive models
(GAMs) using torch
as a backend, allowing users to fit
semi-interpretable models using tools from the deep learning world. In
addition to its GAM implementation, wand
also provides tools for model
interpretation. This package gets its name from the type of network it
implements: wide and deep networks combine linear features
(wide) with feature embeddings (deep).
wand
draws heavy inspiration from the mgcv
and brulee
packages.
The development version of can be installed from GitHub with:
# install.packages("remotes")
remotes::install_github("ben-e/wand")
Note that wand
is still in active development, and should not be
considered stable.
Setup:
suppressPackageStartupMessages({
# Modeling
library(wand)
library(parsnip)
library(recipes)
library(workflows)
library(yardstick)
# Data manipulation
library(dplyr)
# Data viz
library(ggplot2)
library(ggraph)
})
set.seed(4242)
# Generate some synthetic data with a square shape and an uncorrelated feature
df <- expand.grid(x = seq(-2, 2, 0.1),
y = seq(-2, 2, 0.1)) %>%
mutate(z = rnorm(n()),
class = factor((abs(x) > 1 | abs(y) > 1),
levels = c(T, F),
labels = c("out", "in")))
The primary model fitting function, wand::wand
has formula, x/y, and
recipe user interfaces, and compatibility with the tidymodels
ecosystem via the nn_additive_model
specification and wand
engine.
Using wand
alone:
# Fit a model with one linear term and one smooth interaction
wand_fit <- wand(class ~ z + s_mlp(x, y),
data = df)
predict(wand_fit, df, type = "prob") %>%
bind_cols(df) %>%
roc_auc(class, .pred_out)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.918
Using wand
with the tidymodels
ecosystem:
wand_recipe <- recipe(class ~ x + y + z,
data = df)
wand_model_spec <- nn_additive_mod("classification") %>%
set_engine(engine = "wand",
# note that all recipe steps are carried out before smoothing
smooth_specs = list(xy = s_mlp(x, y)))
wand_wf <- workflow() %>%
add_recipe(wand_recipe) %>%
add_model(wand_model_spec)
wand_wf_fit <- fit(wand_wf, df)
predict(wand_wf_fit, df, type = "prob") %>%
bind_cols(df) %>%
roc_auc(class, .pred_out)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.914
The wand
package also includes a few convenience functions for
inspecting the fitted models.
First, we can take a look at the model graph to understand how data
flows through the model. Note that wand
only supplies a function to
build a graph, the user can then use any tbl_graph
or igraph
compatible plotting method; I’ll use ggraph
for this example.
wand_wf_fit %>%
extract_fit_engine() %>%
build_wand_graph() %>%
ggraph(layout = 'kk') +
geom_edge_bend(aes(start_cap = label_rect(node1.name),
end_cap = label_rect(node2.name)),
arrow = arrow(length = unit(2, 'mm'))) +
geom_node_label(aes(label = name))
Next, we should take a look at the model’s training loss to look for any hints of overfitting.
wand_wf_fit %>%
extract_fit_engine() %>%
wand_plot_loss()
Now we can actually inspect the results of model training by looking at coefficients for linear terms.
wand_wf_fit %>%
extract_fit_engine() %>%
coef()
#> out in
#> (Intercept) -0.06551073 -0.003939368
#> z 0.10245648 0.002335455
Finally, for smooth terms, we can plot the actual smooth functions. In
this case, the only smooth is two dimensional so we will plot a surface.
The wand_plot_smooths
function returns a list with an entry for each
smooth, in this case we’re only interested in the first and only plot.
smooth_contours <- wand_wf_fit %>%
extract_fit_engine() %>%
wand_plot_smooths(df)
smooth_contours[[1]] +
annotate("rect", xmin = -1, xmax = 1, ymin = -1, ymax = 1,
fill = alpha("grey", 0), colour = "black",
linetype = "dashed") +
coord_fixed()
- Implement
linear
regression and classification modules withtorch
. - Implement multilayer perceptron module with
torch
. - Implement wide and deep module with
torch
. - Explore implementation of other deep module architectures (RNN, CNN, transformer) with a focus on fitting mixed-type data (e.g. tabular + image data).
- Add L1/L2 penalties to optimization, see
brulee
. - Explore use of MC dropout to generate prediction intervals.
- Implement formula, x/y, and recipe user interfaces via
hardhat
. - Add
parsnip
andworkflows
compatibility. - Add specification functions for smooths, similar to
mgcv::s
, for all interfaces. - Explore the use of constructor functions for new smooth
specification functions, allowing users to specify their own
smoothers using their own
torch::nn_module
’s. - Explore the use of a smooth function registry, similar to how
parsnip
tracks models.
- Add
tune
compatibility for model training parameters. - Add
tune
compatibility for output layer regularization, similar tobrulee
. - Explore possibility of “compiling” models such that parameters
for all smooths are
tune
compatible, similar to theskorch
Python package’s bridge between PyTorch and scikit-learn.
- Implement
coef
method to extract linear coefficients. - Implement
summary
method. - Implement
plot
methods for plotting training curves and smooth functions (alamgcv
). - Implement graph/module plot to show how features flow through modules.
- Add compatibility with the
marginaleffects
package. - Add compatibility with the
ggeffects
package.
- Design hex.
- Build
pkgdown
website. - Write vignette: how does
wand
compare tomgcv
. - Write vignette:
wand
workflow guide. - Write vignette: a deep dive into wide and deep neural networks.
- Document functions and add examples.
- Add tests for each user interface.
- Add tests for internal functions.
- Ongoing Add more tests :)
The ideas underpinning this package are not original, in particular this package draws from:
- The original Wide & Deep Learning for Recommender Systems paper.
- The formula interface and model specification options available in {mgcv}.
- The interface between the {torch} package and {tidymodels} used by the {brulee} package.
wand
is, of course, not the only implementation of wide and deep
networks: