Skip to content

Commit

Permalink
superlearner impl. and new ml_model wrapper: predictor*
Browse files Browse the repository at this point in the history
  • Loading branch information
kkholst committed Nov 23, 2024
1 parent 581e7d5 commit d59bb2b
Show file tree
Hide file tree
Showing 19 changed files with 954 additions and 252 deletions.
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: targeted
Type: Package
Title: Targeted Inference
Version: 0.6
Date: 2024-04-15
Date: 2024-10-30
Authors@R: c(person(given = "Klaus K.",
family = "Holst",
role = c("aut", "cre"),
Expand All @@ -26,19 +26,22 @@ Depends:
lava (>= 1.8.0)
Imports:
data.table,
digest,
futile.logger,
future.apply,
glmnet,
optimx,
progressr,
methods,
mets,
R6,
Rcpp (>= 1.0.0),
rlang,
survival
Suggests:
e1071,
grf,
hal9001,
mgcv,
nnls,
testthat (>= 0.11),
rmarkdown,
scatterplot3d,
Expand Down
26 changes: 22 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,21 @@ S3method(print,cross_validated)
S3method(print,summary.ate.targeted)
S3method(print,summary.riskreg.targeted)
S3method(print,summary.targeted)
S3method(print,superlearner)
S3method(print,targeted)
S3method(score,cross_validated)
S3method(score,predictor_sl)
S3method(specials,design)
S3method(summary,ate.targeted)
S3method(summary,cross_validated)
S3method(summary,design)
S3method(summary,predictor_sl)
S3method(summary,riskreg.targeted)
S3method(summary,targeted)
S3method(update,design)
S3method(vcov,targeted)
S3method(weights,design)
S3method(weights,predictor_sl)
export(ML)
export(NB)
export(NB2)
Expand All @@ -52,6 +57,20 @@ export(ml_model)
export(nondom)
export(offsets)
export(pava)
export(predictor)
export(predictor_gam)
export(predictor_glm)
export(predictor_glmnet)
export(predictor_grf)
export(predictor_grf_binary)
export(predictor_hal)
export(predictor_isoreg)
export(predictor_sl)
export(predictor_xgboost)
export(predictor_xgboost_binary)
export(predictor_xgboost_count)
export(predictor_xgboost_cox)
export(predictor_xgboost_multiclass)
export(riskreg)
export(riskreg_cens)
export(riskreg_fit)
Expand All @@ -66,10 +85,6 @@ import(methods)
importFrom(R6,R6Class)
importFrom(data.table,data.table)
importFrom(data.table,is.data.table)
importFrom(digest,sha1)
importFrom(futile.logger,flog.debug)
importFrom(futile.logger,flog.info)
importFrom(futile.logger,flog.warn)
importFrom(grDevices,nclass.Sturges)
importFrom(graphics,abline)
importFrom(graphics,lines)
Expand All @@ -80,7 +95,10 @@ importFrom(lava,Inverse)
importFrom(lava,estimate)
importFrom(lava,getoutcome)
importFrom(lava,na.pass0)
importFrom(lava,score)
importFrom(optimx,optimx)
importFrom(rlang,call_match)
importFrom(rlang,hash)
importFrom(stats,.getXlevels)
importFrom(stats,approxfun)
importFrom(stats,as.formula)
Expand Down
96 changes: 71 additions & 25 deletions R/cate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ate_if_fold <- function(fold, data,
} else {
dtrain <- data[-fold, ]
deval <- data[fold, ]
}
}

pmod <- propensity_model$estimate(dtrain)
X <- deval
Expand All @@ -18,7 +18,6 @@ ate_if_fold <- function(fold, data,
tmp <- response_model$estimate(dtrain)
X[, treatment] <- level
}

A <- propensity_model$response(deval)
Y <- response_model$response(deval, na.action=lava::na.pass0)
pr <- propensity_model$predict(newdata = deval)
Expand Down Expand Up @@ -48,8 +47,15 @@ cate_fold1 <- function(fold, data, score, cate_des) {
lm.fit(y = y, x = x)$coef
}

##' Conditional Average Treatment Effect estimation via Double Machine Learning
##' Conditional Average Treatment Effect estimation with cross-fitting.
##'
##' We have observed data \eqn{(Y,A,W)} where \eqn{Y} is the response variable,
##' \eqn{A} the binary treatment, and \eqn{W} covariates. We further let \eqn{V}
##' be a subset of the covariates. Define the conditional potential mean outcome
##' \deqn{\psi_{a}(P)(V) = E_{P}[E_{P}(Y\mid A=a, W)|V]} and let \eqn{m(V;
##' \beta)} denote a parametric working model, then the target parameter is the
##' mean-squared error \deqn{\beta(P) = \operatorname{argmin}_{\beta}
##' E_{P}[\{\Psi_{1}(P)(V)-\Psi_{0}(P)(V)\} - m(V; \beta)]^{2}}
##' @title Conditional Average Treatment Effect estimation
##' @param response_model formula or ml_model object (formula => glm)
##' @param propensity_model formula or ml_model object (formula => glm)
Expand All @@ -66,25 +72,42 @@ cate_fold1 <- function(fold, data, score, cate_des) {
##' @param ... additional arguments to future.apply::future_mapply
##' @return cate.targeted object
##' @author Klaus Kähler Holst, Andreas Nordland
##' @references Mark J. van der Laan (2006) Statistical Inference for Variable
##' Importance, The International Journal of Biostatistics.
##' @examples
##' sim1 <- function(n=1e4,
##' seed=NULL,
##' return_model=FALSE, ...) {
##' suppressPackageStartupMessages(require("lava"))
##' if (!is.null(seed)) set.seed(seed)
##' m <- lava::lvm()
##' lava::regression(m, ~a) <- function(z1,z2,z3,z4,z5)
##' cos(z1)+sin(z1*z2)+z3+z4+z5^2
##' lava::regression(m, ~u) <- function(a,z1,z2,z3,z4,z5)
##' (z1+z2+z3)*a + z1+z2+z3 + a
##' lava::distribution(m, ~a) <- lava::binomial.lvm()
##' if (return_model) return(m)
##' lava::sim(m, n, p=par)
##' sim1 <- function(n=1000, ...) {
##' w1 <- rnorm(n)
##' w2 <- rnorm(n)
##' a <- rbinom(n, 1, expit(-1 + w1))
##' y <- cos(w1) + w2*a + 0.2*w2^2 + a + rnorm(n)
##' data.frame(y, a, w1, w2)
##' }
##'
##' d <- sim1(5000)
##' ## ATE
##' cate(cate_model=~1,
##' response_model=y~a*(w1+w2),
##' propensity_model=a~w1+w2,
##' data=d)
##' ## CATE
##' cate(cate_model=~1+w2,
##' response_model=y~a*(w1+w2),
##' propensity_model=a~w1+w2,
##' data=d)
##'
##' \dontrun{ ## superlearner example
##' mod1 <- list(
##' glm=predictor_glm(y~w1+w2),
##' gam=predictor_gam(y~s(w1) + s(w2))
##' )
##' s1 <- predictor_sl(mod1, nfolds=5)
##' cate(cate_model=~1,
##' response_model=s1,
##' propensity_model=predictor_glm(a~w1+w2, family=binomial),
##' data=d,
##' stratify=TRUE)
##' }
##'
##' d <- sim1(200)
##' e <- cate(a ~ z1+z2+z3, response=u~., data=d)
##' e
##' @export
cate <- function(response_model,
propensity_model,
Expand Down Expand Up @@ -125,14 +148,14 @@ cate <- function(response_model,
if (inherits(response_model, "formula")) {
response_model <- ML(response_model)
}
response_var <- lava::getoutcome(response_model$formula, data=data)

if (length(contrast) > 2) {
stop("Expected contrast vector of length 1 or 2.")
}
propensity_outcome <- function(treatment_level)
paste0("I(", treatment_var, "==", treatment_level, ")")
if (missing(propensity_model)) {
response_var <- lava::getoutcome(response_model$formula, data=data)
newf <- reformulate(
paste0(" . - ", response_var),
response=propensity_outcome(contrast[1])
Expand Down Expand Up @@ -213,13 +236,21 @@ cate <- function(response_model,
)
}

scores <- adj <- list()
qval <- pval <- scores <- adj <- list()
for (i in contrast) {
ii <- which(fargs[, 2] == i)
scores <- c(
scores,
list(unlist(lapply(ii, function(x) val[[x]]$IC))[idx])
)
qval <- c(
qval,
list(unlist(lapply(ii, function(x) val[[x]]$qmod))[idx])
)
pval <- c(
pval,
list(unlist(lapply(ii, function(x) val[[x]]$pmod))[idx])
)
if (!is.null(val[[1]]$adj)) {
A <- lapply(ii, function(x) {
val[[x]]$adj
Expand All @@ -228,8 +259,10 @@ cate <- function(response_model,
}
}
names(scores) <- contrast
names(qval) <- contrast
names(pval) <- contrast
if (length(adj) > 0) names(adj) <- contrast
list(scores = scores, adj = adj)
list(scores = scores, adj = adj, qval = qval, pval = pval)
}

mc <- !missing(mc.cores)
Expand All @@ -241,27 +274,37 @@ cate <- function(response_model,
}
if (mc) {
val <- parallel::mclapply(1:rep, f,
mc.cores = mc.cores, pb=pb, ...)
mc.cores = mc.cores, pb = pb, ...
)
} else {
val <- future.apply::future_lapply(
1:rep, f, pb=pb, future.seed = TRUE, ...)
1:rep, f,
pb = pb, future.seed = TRUE, ...
)
}
} else {
val <- list(calculate_scores())
val <- list(calculate_scores())
}

pval <- val[[1]]$pval
qval <- val[[1]]$qval
scores <- val[[1]]$scores
adj <- val[[1]]$adj
if (rep > 1) {
for (i in 2:rep) {
for (j in seq_len(length(scores))) {
scores[[j]] <- scores[[j]] + val[[i]]$scores[[j]]
qval[[j]] <- qval[[j]] + val[[i]]$qval[[j]]
pval[[j]] <- pval[[j]] + val[[i]]$pval[[j]]
if (length(adj) > 0) {
adj[[j]] <- adj[[j]] + val[[j]]$adj[[j]]
}
}
}
for (j in seq_len(length(scores))) {
scores[[j]] <- scores[[j]] / rep
qval[[j]] <- qval[[j]] / rep
pval[[j]] <- pval[[j]] / rep
if (length(adj) > 0) {
adj[[j]] <- adj[[j]] / rep
}
Expand Down Expand Up @@ -321,6 +364,9 @@ cate <- function(response_model,

res <- list(scores=scores, cate_des=desA,
coef=est,
response_model = response_model,
propensity_model = propensity_model,
pval = pval, qval = qval,
potential.outcomes=potential.outcomes,
call=cl,
estimate=estimate)
Expand Down
Loading

0 comments on commit d59bb2b

Please sign in to comment.