From a453bb6edb8316ae12ea2d04d441b6da1f14d966 Mon Sep 17 00:00:00 2001 From: Emi Tanaka Date: Sat, 16 Sep 2023 18:57:10 +1000 Subject: [PATCH] fix conditional treatment --- R/graph-input.R | 31 +++++++++++++++++++ R/nest.R | 2 +- R/provenance.R | 63 ++++++++++++++++++++++++++++++-------- tests/testthat/test-nest.R | 53 ++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 13 deletions(-) diff --git a/R/graph-input.R b/R/graph-input.R index 7467ce29..d8234122 100644 --- a/R/graph-input.R +++ b/R/graph-input.R @@ -92,3 +92,34 @@ graph_input.nest_lvls <- function(input, prov, name, class, ...) { } } } + + + +graph_input.cond_lvls <- function(input, prov, name, class, ...) { + parent <- input %@% "keyname" + cross_parents <- input %@% "parents" + clabels <- input %@% "labels" + attrs <- NULL # attributes(input) + prov$append_fct_nodes(name = name, role = class) + idp <- prov$fct_id(name = parent) + idv <- prov$fct_id(name = name) + prov$append_fct_edges(from = idp, to = idv, type = "nest") + plevels <- rep(names(input), lengths(input)) + clevels <- unname(unlist(input)) + pids <- prov$lvl_id(value = plevels, fid = idp) + ## unique(clevels) is the only part that's different to nest_lvls + prov$append_lvl_nodes(value = unique(clevels), fid = idv) + vids <- prov$lvl_id(value = clevels, fid = idv) + prov$append_lvl_edges(from = pids, to = vids) + + if(!is_null(cross_parents)) { + cross_df <- do.call("rbind", cross_parents[names(input)]) + cross_parent_names <- colnames(cross_df) + for(across in cross_parent_names) { + prov$append_fct_edges(from = prov$fct_id(name = across), to = idv, type = "cross") + cpids <- prov$lvl_id(value = cross_df[[across]]) + prov$append_lvl_edges(from = cpids, to = vids) + } + } +} + diff --git a/R/nest.R b/R/nest.R index a1500e5d..87bfe94a 100644 --- a/R/nest.R +++ b/R/nest.R @@ -140,7 +140,7 @@ conditioned_on <- function(x, ...) { attributes(child_levels) <- c(attributes(child_levels), list(parents = attr(args, "parents"), labels = child_levels)) - class(child_levels) <- c("nest_lvls", class(child_levels)) + class(child_levels) <- c("cond_lvls", class(child_levels)) return(child_levels) diff --git a/R/provenance.R b/R/provenance.R index 989aef6a..9db74189 100644 --- a/R/provenance.R +++ b/R/provenance.R @@ -607,14 +607,10 @@ Provenance <- R6::R6Class("Provenance", for(i in seq(ncomp)) { if(i == 1L) { sub_graph <- self$graph_subset(id = scomps[[i]], include = "self") - out <- private$build_subtable(sub_graph, return = return) - trts_tbl <- as.data.frame(out) - colnames(trts_tbl) <- names(out) + trts_tbl <- private$build_condtable(sub_graph, return = return) } else { sub_graph <- self$graph_subset(id = scomps[[i]], include = "self") - out <- private$build_subtable(sub_graph, return = return) - new_trts_tbl <- as.data.frame(out) - colnames(new_trts_tbl) <- names(out) + new_trts_tbl <- private$build_condtable(sub_graph, return = return) xtabs <- expand.grid(old = seq(nrow(trts_tbl)), new = seq(nrow(new_trts_tbl))) trts_tbl <- cbind(trts_tbl[xtabs[[1]], , drop = FALSE], new_trts_tbl[xtabs[[2]], , drop = FALSE]) } @@ -746,13 +742,16 @@ Provenance <- R6::R6Class("Provenance", #' @description #' Get the level edges by factor #' @param from,to The factor id. - lvl_mapping = function(from, to) { + lvl_mapping = function(from, to, return = c("vector", "table")) { + return <- match.arg(return) lnodes <- self$lvl_nodes nodes_from <- lnodes[[as.character(from)]]$id nodes_to <- lnodes[[as.character(to)]]$id ledges <- self$lvl_edges - map <- subset(ledges, from %in% nodes_from & to %in% nodes_to) - setNames(map$to, map$from) + map <- subset(ledges, from %in% nodes_from & to %in% nodes_to, + select = c(from, to)) + if(return=="vector") return(setNames(map$to, map$from)) + map }, @@ -1032,7 +1031,7 @@ Provenance <- R6::R6Class("Provenance", #' Given a particular DAG, return a topological order #' Remember that there could be more than one order. - graph_reverse_topological_order = function(graph) { + graph_topological_order = function(graph, reverse = TRUE) { fnodes <- graph$factors$nodes lnodes <- graph$levels$nodes fedges <- graph$factors$edges @@ -1040,12 +1039,52 @@ Provenance <- R6::R6Class("Provenance", fnodes$parent <- map_int(fnodes$id, function(id) sum(fedges$to %in% id)) fnodes$child <- map_int(fnodes$id, function(id) sum(fedges$from %in% id)) fnodes$nlevels <- map_int(fnodes$id, function(id) nrow(lnodes[[id]])) - fnodes <- fnodes[order(fnodes$child, -fnodes$nlevels), ] + if(reverse) { + fnodes <- fnodes[order(fnodes$child, -fnodes$nlevels), ] + } else { + fnodes <- fnodes[order(-fnodes$child, -fnodes$nlevels), ] + } new_edibble_graph(fnodes = fnodes, lnodes = lnodes, fedges = fedges, ledges = ledges) }, + build_condtable = function(subgraph, return) { + top_graph <- private$graph_topological_order(subgraph, reverse = FALSE) + sub_fnodes <- top_graph$factors$nodes + sub_fedges <- top_graph$factors$edges + sub_lnodes <- top_graph$levels$nodes + sub_ledges <- top_graph$levels$edges + + out <- list() + if(nrow(sub_fnodes) == 1) { + iunit <- sub_fnodes$id[1] + out[[as.character(iunit)]] <- self$lvl_id(fid = iunit) + } else { + for(irow in 2:nrow(sub_fnodes)) { + iunit <- sub_fnodes$id[irow] + if(sub_fnodes$parent[irow] == 0) { + abort("This factor has no parents. This shouldn't happen.") + } else { + parent_id <- sub_fedges$from[sub_fedges$to == iunit] + map_tbl <- self$lvl_mapping(parent_id, iunit, return = "table") + colnames(map_tbl) <- c(parent_id, iunit) + if(as.character(parent_id) %in% names(out)) { + trts_tbl <- do.call(tibble::tibble, out) + map_tbl <- merge(trts_tbl, map_tbl) + } + for(anm in names(map_tbl)) { + out[[anm]] <- map_tbl[[anm]] + } + } + } + } + ret <- switch(return, + id = out, + value = self$fct_levels_id_to_value(out)) + do.call(tibble::tibble, ret) + }, + build_subtable = function(subgraph, return) { - top_graph <- private$graph_reverse_topological_order(subgraph) + top_graph <- private$graph_topological_order(subgraph, reverse = TRUE) sub_fnodes <- top_graph$factors$nodes sub_fedges <- top_graph$factors$edges sub_lnodes <- top_graph$levels$nodes diff --git a/tests/testthat/test-nest.R b/tests/testthat/test-nest.R index d1759005..b5848934 100644 --- a/tests/testthat/test-nest.R +++ b/tests/testthat/test-nest.R @@ -82,5 +82,58 @@ test_that("conditioning-structure", { serve_table() count_by(cond2, trt1, trt2, trt3) + + # FIXME + vitexp2 <- design("Vitamin experiment with control") %>% + set_trts(vitamin = c("control", "B", "C"), + dose = conditioned_on(vitamin, + "control" ~ "0", + c("B", "C") ~ c("0.5", "1", "2"))) %>% + trts_table() + + vitexp3 <- design("Vitamin experiment with control") %>% + set_trts(vitamin = c("control", "B", "C"), + dose = conditioned_on(vitamin, + "control" ~ "0", + c("B", "C") ~ c("0.5", "1", "2")), + trt = conditioned_on(vitamin, + c("control", "B") ~ c("a", "b"), + "C" ~ "c"), + vac = conditioned_on(trt, + "a" ~ c("I", "II"), + c("b", "c") ~ "I")) %>% + trts_table() + + vitexp4 <- design("Vitamin experiment with control") %>% + set_trts(vitamin = c("control", "B", "C"), + dose = conditioned_on(vitamin, + "control" ~ "0", + c("B", "C") ~ c("0.5", "1", "2")), + trt = conditioned_on(vitamin, + c("control", "B") ~ c("a", "b"), + "C" ~ "c"), + vac = conditioned_on(trt, + "a" ~ c("I", "II"), + c("b", "c") ~ "I"), + test = c("alpha", "beta")) %>% + trts_table() + + + vitexp5 <- design("Vitamin experiment with control") %>% + set_trts(vitamin = c("control", "B", "C"), + dose = conditioned_on(vitamin, + "control" ~ "0", + c("B", "C") ~ c("0.5", "1", "2")), + trt = conditioned_on(vitamin, + c("control", "B") ~ c("a", "b"), + "C" ~ "c"), + vac = conditioned_on(trt, + "a" ~ c("I", "II"), + c("b", "c") ~ "I"), + test = c("alpha", "beta"), + test2 = conditioned_on(test, + "alpha" ~ "AA", + "beta" ~ "BB")) %>% + trts_table() })