Skip to content

Commit

Permalink
register internal s3 methods (doesn't actually export them)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgabry committed Jan 22, 2024
1 parent e4908c2 commit 3a25167
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
15 changes: 15 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,31 @@

S3method("[",neff_ratio)
S3method("[",rhat)
S3method(apply_transformations,array)
S3method(apply_transformations,matrix)
S3method(diagnostic_factor,neff_ratio)
S3method(diagnostic_factor,rhat)
S3method(log_posterior,CmdStanMCMC)
S3method(log_posterior,stanfit)
S3method(log_posterior,stanreg)
S3method(melt_mcmc,matrix)
S3method(melt_mcmc,mcmc_array)
S3method(neff_ratio,CmdStanMCMC)
S3method(neff_ratio,stanfit)
S3method(neff_ratio,stanreg)
S3method(num_chains,data.frame)
S3method(num_chains,mcmc_array)
S3method(num_iters,data.frame)
S3method(num_iters,mcmc_array)
S3method(num_params,data.frame)
S3method(num_params,mcmc_array)
S3method(nuts_params,CmdStanMCMC)
S3method(nuts_params,list)
S3method(nuts_params,stanfit)
S3method(nuts_params,stanreg)
S3method(parameter_names,array)
S3method(parameter_names,default)
S3method(parameter_names,matrix)
S3method(plot,bayesplot_grid)
S3method(plot,bayesplot_scheme)
S3method(pp_check,default)
Expand Down
17 changes: 17 additions & 0 deletions R/helpers-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ select_parameters <-
#' @return A molten data frame.
#'
melt_mcmc <- function(x, ...) UseMethod("melt_mcmc")

#' @export
melt_mcmc.mcmc_array <- function(x,
varnames =
c("Iteration", "Chain", "Parameter"),
Expand All @@ -144,6 +146,7 @@ melt_mcmc.mcmc_array <- function(x,
}

# If all chains are already merged
#' @export
melt_mcmc.matrix <- function(x,
varnames = c("Draw", "Parameter"),
value.name = "Value",
Expand Down Expand Up @@ -305,13 +308,17 @@ chain_list2array <- function(x) {

# Get parameter names from a 3-D array
parameter_names <- function(x) UseMethod("parameter_names")

#' @export
parameter_names.array <- function(x) {
stopifnot(is_3d_array(x))
dimnames(x)[[3]] %||% abort("No parameter names found.")
}
#' @export
parameter_names.default <- function(x) {
colnames(x) %||% abort("No parameter names found.")
}
#' @export
parameter_names.matrix <- function(x) {
colnames(x) %||% abort("No parameter names found.")
}
Expand Down Expand Up @@ -391,6 +398,8 @@ validate_transformations <-
apply_transformations <- function(x, ...) {
UseMethod("apply_transformations")
}

#' @export
apply_transformations.matrix <- function(x, ..., transformations = list()) {
pars <- colnames(x)
x_transforms <- validate_transformations(transformations, pars)
Expand All @@ -400,6 +409,8 @@ apply_transformations.matrix <- function(x, ..., transformations = list()) {

x
}

#' @export
apply_transformations.array <- function(x, ..., transformations = list()) {
stopifnot(length(dim(x)) == 3)
pars <- dimnames(x)[[3]]
Expand Down Expand Up @@ -437,17 +448,23 @@ num_chains <- function(x, ...) UseMethod("num_chains")
num_iters <- function(x, ...) UseMethod("num_iters")
num_params <- function(x, ...) UseMethod("num_params")

#' @export
num_params.mcmc_array <- function(x, ...) dim(x)[3]
#' @export
num_chains.mcmc_array <- function(x, ...) dim(x)[2]
#' @export
num_iters.mcmc_array <- function(x, ...) dim(x)[1]
#' @export
num_params.data.frame <- function(x, ...) {
stopifnot("Parameter" %in% colnames(x))
length(unique(x$Parameter))
}
#' @export
num_chains.data.frame <- function(x, ...) {
stopifnot("Chain" %in% colnames(x))
length(unique(x$Chain))
}
#' @export
num_iters.data.frame <- function(x, ...) {
cols <- colnames(x)
stopifnot("Iteration" %in% cols || "Draws" %in% cols)
Expand Down
2 changes: 2 additions & 0 deletions R/mcmc-diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,14 @@ diagnostic_factor <- function(x, ...) {
UseMethod("diagnostic_factor")
}

#' @export
diagnostic_factor.rhat <- function(x, ..., breaks = c(1.05, 1.1)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
ordered_result = FALSE)
}

#' @export
diagnostic_factor.neff_ratio <- function(x, ..., breaks = c(0.1, 0.5)) {
cut(x, breaks = c(-Inf, breaks, Inf),
labels = c("low", "ok", "high"),
Expand Down
2 changes: 1 addition & 1 deletion R/ppc-discrete.R
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ ppc_bars_data <-
#' @param y,yrep,group User's already validated `y`, `yrep`, and (if applicable)
#' `group` arguments.
#' @param prob,freq User's `prob` and `freq` arguments.
#' @importFrom dplyr "%>%" ungroup count arrange mutate summarise across full_join rename all_of
#' @importFrom dplyr %>% ungroup count arrange mutate summarise across full_join rename all_of
.ppc_bars_data <- function(y, yrep, group = NULL, prob = 0.9, freq = TRUE) {
alpha <- (1 - prob) / 2
probs <- sort(c(alpha, 0.5, 1 - alpha))
Expand Down

0 comments on commit 3a25167

Please sign in to comment.