Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
macartan committed Apr 19, 2024
1 parent c08ff17 commit 5d5289c
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 36 deletions.
3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(print,model_query)
S3method(print,nodal_types)
S3method(print,nodal_types_query)
S3method(print,nodes)
S3method(print,parameter_mapping)
S3method(print,parameter_matrix)
S3method(print,parameters)
S3method(print,parameters_df)
Expand All @@ -20,7 +21,7 @@ S3method(print,posterior_event_probabilities)
S3method(print,stan_summary)
S3method(print,statement)
S3method(print,summary.causal_model)
S3method(print,type_posterior)
S3method(print,type_distribution)
S3method(print,type_prior)
S3method(summary,causal_model)
export(collapse_data)
Expand Down
12 changes: 6 additions & 6 deletions R/grab.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' \item \code{"causal_statement"} a character. Statement describing causal relations using dagitty syntax,
#' \item \code{"dag"} A data frame with columns ‘parent’ and ‘children’ indicating how nodes relate to each other,
#' \item \code{"nodes"} A list containing the nodes in the model,
#' \item \code{"parents"} a table listing nodes, whether they are root nodes or not, and the number and names of parents they have,
#' \item \code{"parents_df"} a table listing nodes, whether they are root nodes or not, and the number and names of parents they have,
#' \item \code{"parameters_df"} a data frame containing parameter information,
#' \item \code{"causal_types"} a data frame listing causal types and the nodal types that produce them,
#' \item \code{"causal_types_interpretation"} a key to interpreting types; see \code{"?interpret_type"} for options,
Expand All @@ -34,7 +34,7 @@
#' \item \code{"stan_fit"} the stanfit object generated by Stan,
#' \item \code{"stan_summary"} a summary of the stanfit object generated by Stan,
#' \item \code{"type_prior"} a matrix of type probabilities using priors,
#' \item \code{"type_posterior"} a matrix of type probabilities using posteriors,
#' \item \code{"type_distribution"} a matrix of type probabilities using posteriors,
#' }
#' @param ... Other arguments passed to helper \code{"get_*"} functions.
#' @return Objects from a \code{causal_model} as specified.
Expand All @@ -52,7 +52,7 @@
#' grab(model, object = "causal_statement")
#' grab(model, object = "dag")
#' grab(model, object = "nodes")
#' grab(model, object = "parents")
#' grab(model, object = "parents_df")
#' grab(model, object = "parameters_df")
#' grab(model, object = "causal_types")
#' grab(model, object = "causal_types_interpretation")
Expand All @@ -73,7 +73,7 @@
#' grab(model, object = "stan_fit")
#' grab(model, object = "stan_summary")
#' grab(model, object = "type_prior")
#' grab(model, object = "type_posterior")
#' grab(model, object = "type_distribution")
#'
#' # Example of arguments passed on to helpers
#' grab(model,
Expand All @@ -88,7 +88,7 @@ grab <- function(model, object = NULL, ...) {
causal_statement = model$statement,
dag = model$dag,
nodes = model$nodes,
parents = model$parents_df,
parents_df = model$parents_df,
parameters_df = model$parameters_df,
causal_types = get_causal_types(model),
causal_types_interpretation = interpret_type(model, ...),
Expand Down Expand Up @@ -144,7 +144,7 @@ grab <- function(model, object = NULL, ...) {
model$stan_objects$stan_summary
},
type_prior = get_type_prob_multiple(model, using = "priors") |> t(),
type_posterior =
type_distribution =
if (is.null(model$stan_objects$type_distribution)) {
stop(
"Model does not contain type_distribution; update model with type_distribution = TRUE"
Expand Down
25 changes: 21 additions & 4 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ print.statement <- function(x, ...) {
return(invisible(x))
}

#' Print a short summary for a causal_model nodes
#' Print a short summary for causal_model nodes
#'
#' print method for class \code{nodes}.
#'
Expand Down Expand Up @@ -335,6 +335,23 @@ print.type_prior <- function(x, ...) {
return(invisible(x))
}

#' Print a short summary for paramater mapping matrix
#'
#' print method for class \code{parameter_mapping}.
#'
#' @param x An object of \code{parameter_mapping} class.
#' @param ... Further arguments passed to or from other methods.
#'
#' @export
print.parameter_mapping <- function(x, ...) {
cat("\nParameter mapping matrix: \n\n")
cat("Maps from parameters to data types, with \n")
cat("possibly multiple columns for each data type \n")
cat("in cases with confounding. \n\n")
print(data.frame(x))
cat("\n")
return(invisible(x))
}


#' Print a short summary for stan fit
Expand Down Expand Up @@ -418,15 +435,15 @@ print.event_probabilities <- function(x, ...) {

#' Print a short summary for causal-type posterior distributions
#'
#' print method for class \code{type_posterior}.
#' print method for class \code{type_distribution}.
#'
#' @param x An object of \code{type_posterior} class, which is a sub-object of
#' @param x An object of \code{type_distribution} class, which is a sub-object of
#' an object of the \code{causal_model} class produced using
#' \code{get_type_prob_multiple}.
#' @param ... Further arguments passed to or from other methods.
#'
#' @export
print.type_posterior <- function(x, ...) {
print.type_distribution <- function(x, ...) {
cat("Posterior draws of causal types (transformed parameters)")
cat(paste("\nDimensions:", dim(x)[1], "rows (draws) by", dim(x)[2], "cols (types) \n\n", sep = " "))
cat("Summary: \n\n")
Expand Down
3 changes: 2 additions & 1 deletion R/parmap.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ make_parmap <- function(model, A = NULL, P = NULL) {
map <- diag(ncol(type_matrix))
rownames(map) <- colnames(map) <- colnames(A)
attr(type_matrix, "map") <- map
class(type_matrix) <- c("parameter_mapping", class(type_matrix))
return(type_matrix)
}

Expand Down Expand Up @@ -49,8 +50,8 @@ make_parmap <- function(model, A = NULL, P = NULL) {

colnames(type_matrix) <- .type_matrix$d

# type_matrix <- type_matrix[,match(colnames(type_matrix), colnames(A))]
attr(type_matrix, "map") <- data_to_data(type_matrix, A)
class(type_matrix) <- c("parameter_mapping", class(type_matrix))
type_matrix
}

Expand Down
17 changes: 12 additions & 5 deletions R/set_parameter_matrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,21 @@ set_parameter_matrix <- function(model, P = NULL) {
print.parameter_matrix <- function(x, ...) {
cat(paste0("\nRows are parameters, grouped in parameter sets"))
cat(paste0("\n\nColumns are causal types"))
cat(paste0("\n\nCell entries indicate whether a parameter probability is",
"used\nin the calculation of causal type probability\n\n"))
cat(
paste0(
"\n\nCell entries indicate whether a parameter probability is",
"used\nin the calculation of causal type probability\n\n"
)
)

param_set <- attr(x, "param_set")
class(x) <- "data.frame"
print(x)
cat("\n \n param_set (P)\n ")
cat(paste0(param_set, collapse = " "))
cat("\n")
if (!is.null(attr(x, "param_set"))) {
param_set <- attr(x, "param_set")
cat("\n param_set (P)\n ")
cat(paste0(param_set, collapse = " "))
}
return(invisible(x))
}

Expand Down
2 changes: 1 addition & 1 deletion R/update_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ update_model <- function(model,
model$stan_objects$type_distribution <- extract(newfit, pars = "types")$types

colnames(model$stan_objects$type_distribution) <- colnames(stan_data$P)
class(model$stan_objects$type_distribution) <- c("type_posterior", "matrix", "array")
class(model$stan_objects$type_distribution) <- c("type_distribution", "matrix", "array")
}

# Retain event (pre-censoring) probabilities
Expand Down
8 changes: 4 additions & 4 deletions man/grab.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/print.nodes.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/print.parameter_mapping.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/print.type_posterior.Rd → man/print.type_distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tests/testthat/test_grab.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ testthat::test_that(
"causal_statement",
"dag",
"nodes",
"parents",
"parents_df",
"parameters_df",
"causal_types",
"causal_types_interpretation",
Expand All @@ -33,7 +33,7 @@ testthat::test_that(
"stan_fit",
"stan_summary",
"type_prior",
"type_posterior"
"type_distribution"
)

classes <- c(
Expand All @@ -50,7 +50,7 @@ testthat::test_that(
"matrix",
"parameters",
"character",
"matrix",
"parameter_mapping",
"parameter_matrix",
"numeric",
"parameters_prior",
Expand All @@ -61,7 +61,7 @@ testthat::test_that(
"stanfit",
"stan_summary",
"type_prior",
"type_posterior"
"type_distribution"
)


Expand All @@ -79,7 +79,7 @@ testthat::test_that(

# Proper dimensions
expect_equal(grab(model, "type_prior") |> dim(), c(4000, 8))
expect_equal(grab(model, "type_posterior") |> dim(), c(4000, 8))
expect_equal(grab(model, "type_distribution") |> dim(), c(4000, 8))
expect_equal(grab(model, "posterior_distribution") |> dim(), c(4000, 6))
expect_equal(grab(model, "prior_distribution") |> dim(), c(4000, 6))

Expand All @@ -98,7 +98,7 @@ testthat::test_that(
# Print methods
out <- capture.output(print(grab(model, object = "nodes")))
expect_true(any(grepl("Nodes:", out)))
out <- capture.output(print(grab(model, object = "parents")))
out <- capture.output(print(grab(model, object = "parents_df")))
expect_true(any(grepl("parents", out)))
out <- capture.output(print(grab(model, object = "parameters_df")))
expect_false(any(grepl("first 10 rows:", out)))
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ testthat::test_that(
out <- capture.output(print(grab(model, object = "nodes")))
expect_true(any(grepl("Nodes:", out)))

out <- capture.output(print(grab(model, object = "parents")))
out <- capture.output(print(grab(model, object = "parents_df")))
expect_true(any(grepl("parents", out)))

out <- capture.output(print(grab(model, object = "parameters_df")))
Expand Down Expand Up @@ -65,7 +65,7 @@ testthat::test_that(
out <- capture.output(print(grab(model, object = "event_probabilities")))
expect_true(any(grepl("event_probs", out)))

out <- capture.output(print(grab(model, object = "type_posterior")))
out <- capture.output(print(grab(model, object = "type_distribution")))
expect_true(any(grepl("Posterior draws", out)))

out <- capture.output(print( query_model(model, "Y[X=1] - Y[X = 0]", using = "parameters")) )
Expand Down

0 comments on commit 5d5289c

Please sign in to comment.