Skip to content

Commit

Permalink
all but vignette
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Apr 24, 2020
1 parent 9d60a4b commit 86859a5
Show file tree
Hide file tree
Showing 40 changed files with 1,160 additions and 70 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
*.bak

.Rproj.user
inst/doc
14 changes: 10 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Package: adjustr
Encoding: UTF-8
Type: Package
Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling
Version: 0.0.0.9000
Authors@R: person("Cory", "McCartan", email = "cmccartan@g.harvard.edu",
Expand All @@ -12,22 +14,26 @@ License: BSD_3_clause + file LICENSE
Depends: R (>= 3.6.0)
Imports:
tibble,
tidyselect,
dplyr,
purrr,
methods,
utils,
stats,
rlang,
rstan,
ggplot2,
stringr,
dparser,
ggplot2,
loo
Suggests:
tidyr,
extraDistr,
tidyr,
testthat,
covr
Encoding: UTF-8
covr,
knitr,
rmarkdown
URL: https://corymccartan.github.io/adjustr/
LazyData: true
RoxygenNote: 7.1.0
VignetteBuilder: knitr
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(arrange,adjustr_spec)
S3method(as.data.frame,adjustr_spec)
S3method(length,adjustr_spec)
S3method(plot,adjustr_weighted)
S3method(print,adjustr_spec)
S3method(pull,adjustr_weighted)
S3method(rename,adjustr_spec)
Expand Down
34 changes: 29 additions & 5 deletions R/adjust_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,25 @@
#' method. The returned object also includes the model sample draws, in the
#' \code{draws} attribute.
#'
#' @examples \dontrun{
#' model_data = list(
#' J = 8,
#' y = c(28, 8, -3, 7, -1, 1, 18, 12),
#' sigma = c(15, 10, 16, 11, 9, 11, 10, 18)
#' )
#'
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
#' adjust_weights(spec, eightschools_m)
#' adjust_weights(spec, eightschools_m, keep_bad=TRUE)
#'
#' spec = make_spec(y ~ student_t(df, theta, sigma), df=1:10)
#' adjust_weights(spec, eightschools_m, data=model_data)
#' # will throw an error because `y` and `sigma` aren't provided
#' adjust_weights(spec, eightschools_m)
#' }
#'
#' @export
adjust_weights = function(spec, object, data=NULL, keep_bad=F) {
adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) {
# CHECK ARGUMENTS
object = get_fit_obj(object)
model_code = object@stanmodel@model_code
Expand Down Expand Up @@ -59,12 +76,12 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=F) {
psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff))
pareto_k = loo::pareto_k_values(psis_wgt)
if (all(psis_wgt$log_weights == psis_wgt$log_weights[1])) {
warning("New specification equal to old specification.", call.=F)
warning("New specification equal to old specification.", call.=FALSE)
pareto_k = -Inf
}

list(
weights = loo::weights.importance_sampling(psis_wgt, log=F),
weights = loo::weights.importance_sampling(psis_wgt, log=FALSE),
pareto_k = pareto_k
)
})
Expand All @@ -84,7 +101,7 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=F) {

# Generic methods
is.adjustr_weighted = function(x) inherits(x, "adjustr_weighted")
#' Extract Weights From an \code{adjustr_spec_weighted} Object
#' Extract Weights From an \code{adjustr_weighted} Object
#'
#' This function modifies the default behavior of \code{dplyr::pull} to extract
#' the \code{.weights} column.
Expand All @@ -96,10 +113,11 @@ is.adjustr_weighted = function(x) inherits(x, "adjustr_weighted")
#'
#' @export
pull.adjustr_weighted = function(.data, var=".weights") {
var = tidyselect::vars_pull(names(.data), !!enquo(var))
if (nrow(.data) == 1 && var == ".weights") {
.data$.weights[[1]]
} else {
NextMethod(.data, var=var)
.data[[var]]
}
}

Expand All @@ -113,6 +131,12 @@ pull.adjustr_weighted = function(.data, var=".weights") {
#'
#' @return Invisbly returns a list of sampling formulas.
#'
#' @examples \dontrun{
#' extract_samp_stmts(eightschools_m)
#' #> Sampling statements for model 2c8d1d8a30137533422c438f23b83428:
#' #> parameter eta ~ std_normal()
#' #> data y ~ normal(theta, sigma)
#' }
#' @export
extract_samp_stmts = function(object) {
model_code = get_fit_obj(object)@stanmodel@model_code
Expand Down
4 changes: 3 additions & 1 deletion R/adjustr-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' \item \code{\link{make_spec}}
#' \item \code{\link{adjust_weights}}
#' \item \code{\link{summarize}}
#' \item \code{\link{plot}}
#' }
#'
#' @importFrom methods is
Expand All @@ -35,6 +36,7 @@ pkg_env = new_environment()
# create the Stan parser
tryCatch(get_parser(), error = function(e) {})

utils::globalVariables(c("name", "pos", "value"))
utils::globalVariables(c("name", "pos", "value", ".y", ".y_ol", ".y_ou",
".y_il", ".y_iu", ".y_med"))
} # nocov end
#> NULL
2 changes: 1 addition & 1 deletion R/logprob.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ get_base_data = function(object, samps, parsed_vars, data, extra_names=NULL) {
if (!all(found)) stop(paste(vars_indata[!found], collapse=", "), " not found")
# combine draws and data
base_data = append(
map(vars_indraws, ~ rstan::extract(object, ., permuted=F)) %>%
map(vars_indraws, ~ rstan::extract(object, ., permuted=FALSE)) %>%
set_names(vars_indraws),
map(vars_indata, ~ reshape_data(data[[.]])) %>%
set_names(vars_indata),
Expand Down
22 changes: 20 additions & 2 deletions R/make_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@
#' \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are
#' supported and operate on the underlying table of specification parameters.
#'
#' @examples
#' make_spec(eta ~ cauchy(0, 1))
#'
#' make_spec(eta ~ student_t(df, 0, 1), df=1:10)
#'
#' params = tidyr::crossing(df=1:10, infl=c(1, 1.5, 2))
#' make_spec(eta ~ student_t(df, 0, 1),
#' y ~ normal(theta, infl*sigma),
#' params)
#'
#' @export
make_spec = function(...) {
args = dots_list(..., .check_assign=T)
Expand Down Expand Up @@ -151,6 +161,14 @@ as.data.frame.adjustr_spec = function(x, ...) {
#' @param ... additional arguments to underlying method
#' @param .preserve as in \code{filter} and \code{slice}
#' @name dplyr.adjustr_spec
#'
#' @examples \dontrun{
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
#'
#' arrange(spec, desc(df))
#' slice(spec, 4:7)
#' filter(spec, df == 2)
#' }
NULL
# dplyr generics
dplyr_handler = function(dplyr_func, x, ...) {
Expand All @@ -165,7 +183,7 @@ dplyr_handler = function(dplyr_func, x, ...) {

# no @export because R CMD CHECK didn't like it
#' @rdname dplyr.adjustr_spec
filter.adjustr_spec = function(.data, ..., .preserve=F) {
filter.adjustr_spec = function(.data, ..., .preserve=FALSE) {
dplyr_handler(dplyr::filter, .data, ..., .preserve=.preserve)
}
#' @rdname dplyr.adjustr_spec
Expand All @@ -185,7 +203,7 @@ select.adjustr_spec = function(.data, ...) {
}
#' @rdname dplyr.adjustr_spec
#' @export
slice.adjustr_spec = function(.data, ..., .preserve=F) {
slice.adjustr_spec = function(.data, ..., .preserve=FALSE) {
dplyr_handler(dplyr::slice, .data, ..., .preserve=.preserve)
}

104 changes: 101 additions & 3 deletions R/use_weights.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
#' containing the sampled indices. If any weights are \code{NA}, the indices
#' will also be \code{NA}.
#'
#' @examples \dontrun{
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
#' adjusted = adjust_weights(spec, eightschools_m)
#'
#' get_resampling_idxs(adjusted)
#' get_resampling_idxs(adjusted, frac=0.1, replace=FALSE)
#' }
#'
#' @export
get_resampling_idxs = function(x, frac=1, replace=T) {
if (frac < 0) stop("`frac` parameter must be nonnegative")
Expand Down Expand Up @@ -48,7 +56,10 @@ get_resampling_idxs = function(x, frac=1, replace=T) {
#' posterior distribution of eight alternative specification. For example,
#' a value of \code{mean(theta)} will compute the posterior mean of
#' \code{theta} for each alternative specification.
#' @param .resampling Wether to compute summary statistics by first resampling
#'
#' The arguments in \code{...} are automatically quoted and evaluated in the
#' context of \code{.data}. They support unquoting and splicing.
#' @param .resampling Whether to compute summary statistics by first resampling
#' the data according to the weights. Defaults to \code{FALSE}, but will be
#' used for any summary statistic that is not \code{mean}, \code{var} or
#' \code{sd}.
Expand All @@ -58,9 +69,24 @@ get_resampling_idxs = function(x, frac=1, replace=T) {
#' @return An \code{adjustr_weighted} object, wth the new columns specified in
#' \code{...} added.
#'
#' @examples \dontrun{
#' model_data = list(
#' J = 8,
#' y = c(28, 8, -3, 7, -1, 1, 18, 12),
#' sigma = c(15, 10, 16, 11, 9, 11, 10, 18)
#' )
#'
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
#' adjusted = adjust_weights(spec, eightschools_m)
#'
#' summarize(adjusted, mean(mu), var(mu))
#' summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data)
#' summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95)))
#' }
#'
#' @rdname summarize.adjustr_weighted
#' @export
summarise.adjustr_weighted = function(.data, ..., .resampling=F, .model_data=NULL) {
summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data=NULL) {
stopifnot(is.adjustr_weighted(.data)) # just in case called manually
args = enexprs(...)

Expand Down Expand Up @@ -89,7 +115,7 @@ summarise.adjustr_weighted = function(.data, ..., .resampling=F, .model_data=NUL
expr = expr_deparse(call_args(call)[[1]])
expr = stringr::str_replace_all(expr, "\\[(\\d)", "[,\\1")
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])mean\\(", "rowMeans(")
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])sum\\(", "rowSum(")
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])sum\\(", "rowSums(")
computed = as.array(eval_tidy(parse_expr(expr), data))
if (length(dim(computed)) == 1) dim(computed) = c(dim(computed), 1)

Expand Down Expand Up @@ -124,3 +150,75 @@ funs_env = new_environment(list(
var = wtd_array_var,
sd = wtd_array_sd
))


#' Plot Posterior Quantities of Interest Under Alternative Model Specifications
#'
#' Uses weights computed in \code{\link{adjust_weights}} to plot posterior
#' quantities of interest versus
#'
#' @param x An \code{adjustr_weighted} object.
#' @param by The x-axis variable, which is usually one of the specification
#' parameters. Can be set to \code{1} if there is only one specification.
#' Automatically quoted and evaluated in the context of \code{x}.
#' @param post The posterior quantity of interest, to be computed for each
#' resampled draw of each specificaiton. Should evaluate to a single number
#' for each draw. Automatically quoted and evaluated in the context of \code{x}.
#' @param only_mean Whether to only plot the posterior mean. May be more stable.
#' @param ci_level The inner credible interval to plot. Central
#' 100*ci_level% intervals are computed from the quantiles of the resampled
#' posterior draws.
#' @param outer_level The outer credible interval to plot.
#' @param ... Ignored.
#'
#' @return A \code{\link[ggplot2]{ggplot}} object which can be further
#' customized with the \strong{ggplot2} package.
#'
#' @examples \dontrun{
#' spec = make_spec(eta ~ student_t(df, 0, scale),
#' df=1:10, scale=seq(2, 1, -1/9))
#' adjusted = adjust_weights(spec, eightschools_m)
#'
#' plot(adjusted, df, theta[1])
#' plot(adjusted, df, mu, only_mean=TRUE)
#' plot(adjusted, scale, tau)
#' }
#'
#' @export
plot.adjustr_weighted = function(x, by, post, only_mean=FALSE, ci_level=0.8,
outer_level=0.95, ...) {
if (!requireNamespace("ggplot2", quietly=TRUE)) { # nocov start
stop("Package `ggplot2` must be installed to plot posterior quantities of interest.")
} # nocov end
if (ci_level > outer_level) stop("`ci_level` should be less than `outer_level`")

post = enexpr(post)
if (!only_mean) {
outer = (1 - outer_level) / 2
inner = (1 - ci_level) / 2
q_probs = c(outer, inner, 0.5, 1-inner, 1-outer)
sum_arg = quo(stats::quantile(!!post, probs = !!q_probs))

summarise.adjustr_weighted(x, .y = !!sum_arg) %>%
rowwise() %>%
mutate(.y_ol = .y[1],
.y_il = .y[2],
.y_med = .y[3],
.y_iu = .y[4],
.y_ou = .y[5]) %>%
ggplot(aes({{ by }}, .y_med)) +
geom_ribbon(aes(ymin=.y_ol, ymax=.y_ou), alpha=0.4) +
geom_ribbon(aes(ymin=.y_il, ymax=.y_iu), alpha=0.5) +
geom_line() +
geom_point(size=3) +
theme_minimal() +
labs(y= expr_name(post))
} else {
summarise.adjustr_weighted(x, .y = mean(!!post)) %>%
ggplot(aes({{ by }}, .y)) +
geom_line() +
geom_point(size=3) +
theme_minimal() +
labs(y = expr_name(post))
}
}
9 changes: 5 additions & 4 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ template:
navbar:
title: "adjustr"
left:
#- text: "Vignettes"
# href: articles/
- text: "Vignettes"
href: articles/index.html
- text: "Functions"
href: reference/
href: reference/index.html
- text: "Other Packages"
menu:
- text: "rstan"
Expand Down Expand Up @@ -72,7 +72,8 @@ reference:
contents:
- make_spec
- adjust_weights
- summarise.adjustr_weighted
- summarize.adjustr_weighted
- plot.adjustr_weighted
- title: "Helper Functions"
desc: >
Various helper functions for examining a model or building sampling
Expand Down
5 changes: 4 additions & 1 deletion docs/404.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion docs/LICENSE-text.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 86859a5

Please sign in to comment.