Skip to content

Commit

Permalink
fix conditional treatment
Browse files Browse the repository at this point in the history
  • Loading branch information
emitanaka committed Sep 16, 2023
1 parent 4b71fc9 commit a453bb6
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 13 deletions.
31 changes: 31 additions & 0 deletions R/graph-input.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,34 @@ graph_input.nest_lvls <- function(input, prov, name, class, ...) {
}
}
}



graph_input.cond_lvls <- function(input, prov, name, class, ...) {
parent <- input %@% "keyname"
cross_parents <- input %@% "parents"
clabels <- input %@% "labels"
attrs <- NULL # attributes(input)
prov$append_fct_nodes(name = name, role = class)
idp <- prov$fct_id(name = parent)
idv <- prov$fct_id(name = name)
prov$append_fct_edges(from = idp, to = idv, type = "nest")
plevels <- rep(names(input), lengths(input))
clevels <- unname(unlist(input))
pids <- prov$lvl_id(value = plevels, fid = idp)
## unique(clevels) is the only part that's different to nest_lvls
prov$append_lvl_nodes(value = unique(clevels), fid = idv)
vids <- prov$lvl_id(value = clevels, fid = idv)
prov$append_lvl_edges(from = pids, to = vids)

if(!is_null(cross_parents)) {
cross_df <- do.call("rbind", cross_parents[names(input)])
cross_parent_names <- colnames(cross_df)
for(across in cross_parent_names) {
prov$append_fct_edges(from = prov$fct_id(name = across), to = idv, type = "cross")
cpids <- prov$lvl_id(value = cross_df[[across]])
prov$append_lvl_edges(from = cpids, to = vids)
}
}
}

2 changes: 1 addition & 1 deletion R/nest.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ conditioned_on <- function(x, ...) {
attributes(child_levels) <- c(attributes(child_levels),
list(parents = attr(args, "parents"),
labels = child_levels))
class(child_levels) <- c("nest_lvls", class(child_levels))
class(child_levels) <- c("cond_lvls", class(child_levels))
return(child_levels)


Expand Down
63 changes: 51 additions & 12 deletions R/provenance.R
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,10 @@ Provenance <- R6::R6Class("Provenance",
for(i in seq(ncomp)) {
if(i == 1L) {
sub_graph <- self$graph_subset(id = scomps[[i]], include = "self")
out <- private$build_subtable(sub_graph, return = return)
trts_tbl <- as.data.frame(out)
colnames(trts_tbl) <- names(out)
trts_tbl <- private$build_condtable(sub_graph, return = return)
} else {
sub_graph <- self$graph_subset(id = scomps[[i]], include = "self")
out <- private$build_subtable(sub_graph, return = return)
new_trts_tbl <- as.data.frame(out)
colnames(new_trts_tbl) <- names(out)
new_trts_tbl <- private$build_condtable(sub_graph, return = return)
xtabs <- expand.grid(old = seq(nrow(trts_tbl)), new = seq(nrow(new_trts_tbl)))
trts_tbl <- cbind(trts_tbl[xtabs[[1]], , drop = FALSE], new_trts_tbl[xtabs[[2]], , drop = FALSE])
}
Expand Down Expand Up @@ -746,13 +742,16 @@ Provenance <- R6::R6Class("Provenance",
#' @description
#' Get the level edges by factor
#' @param from,to The factor id.
lvl_mapping = function(from, to) {
lvl_mapping = function(from, to, return = c("vector", "table")) {
return <- match.arg(return)
lnodes <- self$lvl_nodes
nodes_from <- lnodes[[as.character(from)]]$id
nodes_to <- lnodes[[as.character(to)]]$id
ledges <- self$lvl_edges
map <- subset(ledges, from %in% nodes_from & to %in% nodes_to)
setNames(map$to, map$from)
map <- subset(ledges, from %in% nodes_from & to %in% nodes_to,
select = c(from, to))
if(return=="vector") return(setNames(map$to, map$from))
map
},


Expand Down Expand Up @@ -1032,20 +1031,60 @@ Provenance <- R6::R6Class("Provenance",

#' Given a particular DAG, return a topological order
#' Remember that there could be more than one order.
graph_reverse_topological_order = function(graph) {
graph_topological_order = function(graph, reverse = TRUE) {
fnodes <- graph$factors$nodes
lnodes <- graph$levels$nodes
fedges <- graph$factors$edges
ledges <- graph$levels$edges
fnodes$parent <- map_int(fnodes$id, function(id) sum(fedges$to %in% id))
fnodes$child <- map_int(fnodes$id, function(id) sum(fedges$from %in% id))
fnodes$nlevels <- map_int(fnodes$id, function(id) nrow(lnodes[[id]]))
fnodes <- fnodes[order(fnodes$child, -fnodes$nlevels), ]
if(reverse) {
fnodes <- fnodes[order(fnodes$child, -fnodes$nlevels), ]
} else {
fnodes <- fnodes[order(-fnodes$child, -fnodes$nlevels), ]
}
new_edibble_graph(fnodes = fnodes, lnodes = lnodes, fedges = fedges, ledges = ledges)
},

build_condtable = function(subgraph, return) {
top_graph <- private$graph_topological_order(subgraph, reverse = FALSE)
sub_fnodes <- top_graph$factors$nodes
sub_fedges <- top_graph$factors$edges
sub_lnodes <- top_graph$levels$nodes
sub_ledges <- top_graph$levels$edges

out <- list()
if(nrow(sub_fnodes) == 1) {
iunit <- sub_fnodes$id[1]
out[[as.character(iunit)]] <- self$lvl_id(fid = iunit)
} else {
for(irow in 2:nrow(sub_fnodes)) {
iunit <- sub_fnodes$id[irow]
if(sub_fnodes$parent[irow] == 0) {
abort("This factor has no parents. This shouldn't happen.")
} else {
parent_id <- sub_fedges$from[sub_fedges$to == iunit]
map_tbl <- self$lvl_mapping(parent_id, iunit, return = "table")
colnames(map_tbl) <- c(parent_id, iunit)
if(as.character(parent_id) %in% names(out)) {
trts_tbl <- do.call(tibble::tibble, out)
map_tbl <- merge(trts_tbl, map_tbl)
}
for(anm in names(map_tbl)) {
out[[anm]] <- map_tbl[[anm]]
}
}
}
}
ret <- switch(return,
id = out,
value = self$fct_levels_id_to_value(out))
do.call(tibble::tibble, ret)
},

build_subtable = function(subgraph, return) {
top_graph <- private$graph_reverse_topological_order(subgraph)
top_graph <- private$graph_topological_order(subgraph, reverse = TRUE)
sub_fnodes <- top_graph$factors$nodes
sub_fedges <- top_graph$factors$edges
sub_lnodes <- top_graph$levels$nodes
Expand Down
53 changes: 53 additions & 0 deletions tests/testthat/test-nest.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,58 @@ test_that("conditioning-structure", {
serve_table()

count_by(cond2, trt1, trt2, trt3)

# FIXME
vitexp2 <- design("Vitamin experiment with control") %>%
set_trts(vitamin = c("control", "B", "C"),
dose = conditioned_on(vitamin,
"control" ~ "0",
c("B", "C") ~ c("0.5", "1", "2"))) %>%
trts_table()

vitexp3 <- design("Vitamin experiment with control") %>%
set_trts(vitamin = c("control", "B", "C"),
dose = conditioned_on(vitamin,
"control" ~ "0",
c("B", "C") ~ c("0.5", "1", "2")),
trt = conditioned_on(vitamin,
c("control", "B") ~ c("a", "b"),
"C" ~ "c"),
vac = conditioned_on(trt,
"a" ~ c("I", "II"),
c("b", "c") ~ "I")) %>%
trts_table()

vitexp4 <- design("Vitamin experiment with control") %>%
set_trts(vitamin = c("control", "B", "C"),
dose = conditioned_on(vitamin,
"control" ~ "0",
c("B", "C") ~ c("0.5", "1", "2")),
trt = conditioned_on(vitamin,
c("control", "B") ~ c("a", "b"),
"C" ~ "c"),
vac = conditioned_on(trt,
"a" ~ c("I", "II"),
c("b", "c") ~ "I"),
test = c("alpha", "beta")) %>%
trts_table()


vitexp5 <- design("Vitamin experiment with control") %>%
set_trts(vitamin = c("control", "B", "C"),
dose = conditioned_on(vitamin,
"control" ~ "0",
c("B", "C") ~ c("0.5", "1", "2")),
trt = conditioned_on(vitamin,
c("control", "B") ~ c("a", "b"),
"C" ~ "c"),
vac = conditioned_on(trt,
"a" ~ c("I", "II"),
c("b", "c") ~ "I"),
test = c("alpha", "beta"),
test2 = conditioned_on(test,
"alpha" ~ "AA",
"beta" ~ "BB")) %>%
trts_table()
})

0 comments on commit a453bb6

Please sign in to comment.