Skip to content

Commit

Permalink
Better handling of formula specials (design).
Browse files Browse the repository at this point in the history
- New methods `weights.design`, `specials.design`, `offsets.design`
- cate also returns expected potential outcomes and influence functions
- `ml_model` now updates environment of the prediction-method (if the method was refering to 'self' it was still using the old version). - -
- Documentation update
- scoring method only switches to log-score+brier score when the response is a factor. The model-scoring function (cv argument modelscore) automatically gets 'weights' appended to the formal-arguments.
  • Loading branch information
kkholst committed Apr 15, 2024
1 parent 5a20dc7 commit 6b162af
Showing 16 changed files with 254 additions and 147 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: targeted
Type: Package
Title: Targeted Inference
Version: 0.5
Date: 2024-02-22
Version: 0.6
Date: 2024-04-15
Authors@R: c(person(given = "Klaus K.",
family = "Holst",
role = c("aut", "cre"),
@@ -23,7 +23,7 @@ Description: Various methods for targeted and semiparametric inference including
linear model parameters (Vansteelandt et al. (2022) <doi:10.1111/rssb.12504>).
Depends:
R (>= 4.0),
lava (>= 1.7.0)
lava (>= 1.8.0)
Imports:
data.table,
digest,
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ S3method(coef,targeted)
S3method(estimate,ml_model)
S3method(logLik,targeted)
S3method(model.matrix,design)
S3method(offsets,design)
S3method(plot,calibration)
S3method(predict,NB)
S3method(predict,NB2)
@@ -20,13 +21,15 @@ S3method(print,summary.ate.targeted)
S3method(print,summary.riskreg.targeted)
S3method(print,summary.targeted)
S3method(print,targeted)
S3method(specials,design)
S3method(summary,ate.targeted)
S3method(summary,cross_validated)
S3method(summary,design)
S3method(summary,riskreg.targeted)
S3method(summary,targeted)
S3method(update,design)
S3method(vcov,targeted)
S3method(weights,design)
export(ML)
export(NB)
export(NB2)
@@ -47,6 +50,7 @@ export(expand.list)
export(isoregw)
export(ml_model)
export(nondom)
export(offsets)
export(pava)
export(riskreg)
export(riskreg_cens)
@@ -55,6 +59,7 @@ export(riskreg_mle)
export(scoring)
export(softmax)
export(solve_ode)
export(specials)
export(specify_ode)
import(Rcpp)
import(methods)
2 changes: 1 addition & 1 deletion R/RATE.R
Original file line number Diff line number Diff line change
@@ -190,7 +190,7 @@ RATE.surv <- function(response, post.treatment, treatment, censoring,
attr(surv.response, "type") == "right", # only allows right censoring
attr(surv.censoring, "type") == "right", # only allows right censoring
all(surv.response[,1] == surv.censoring[ ,1]), # time must be equal
all(order(surv.response[,1]) == (1:nrow(data))) # data must be ordered by time and have no missing values
all(order(surv.response[,1]) == (seq_len(nrow(data)))) # data must be ordered by time and have no missing values
)
rm(surv.response, surv.censoring)

67 changes: 34 additions & 33 deletions R/cate.R
Original file line number Diff line number Diff line change
@@ -31,8 +31,8 @@ ate_if_fold <- function(fold, data,

cate_fold1 <- function(fold, data, score, treatment_des) {
y <- score[fold]
x <- update(treatment_des, data[fold, , drop=FALSE])$x
lm.fit(y=y, x=x)$coef
x <- update(treatment_des, data[fold, , drop = FALSE])$x
lm.fit(y = y, x = x)$coef
}

##' Conditional Average Treatment Effect estimation via Double Machine Learning
@@ -225,40 +225,42 @@ cate <- function(treatment,
}

score_fold <- function(fold,
data,
propensity_model,
response_model,
importance_model,
treatment, level) {
dtrain <- data[-fold,]
deval <- data[fold,]
data,
propensity_model,
response_model,
importance_model,
treatment, level) {
dtrain <- data[-fold, ]
deval <- data[fold, ]

# training
tmp <- propensity_model$estimate(dtrain)
tmp <- response_model$estimate(dtrain)
A <- propensity_model$response(dtrain)
Y <- response_model$response(dtrain)
X <- dtrain
X[, treatment] <- level
pr <- propensity_model$predict(newdata=dtrain)
if (NCOL(pr)>1)
pr <- pr[,2]
eY <- response_model$predict(newdata=X)
D <- A/pr*(Y-eY) + eY
pr <- propensity_model$predict(newdata = dtrain)
if (NCOL(pr) > 1) {
pr <- pr[, 2]
}
eY <- response_model$predict(newdata = X)
D <- A / pr * (Y - eY) + eY
X[["D_"]] <- D
tmp <- importance_model$estimate(data = X)

# evaluation
# evaluation
A <- propensity_model$response(deval)
Y <- response_model$response(deval)
X <- deval
X[, treatment] <- level
pr <- propensity_model$predict(newdata=deval)
if (NCOL(pr)>1)
pr <- pr[,2]
eY <- response_model$predict(newdata=X)
D <- A/pr*(Y-eY) + eY
II <- importance_model$predict(newdata=X)
pr <- propensity_model$predict(newdata = deval)
if (NCOL(pr) > 1) {
pr <- pr[, 2]
}
eY <- response_model$predict(newdata = X)
D <- A / pr * (Y - eY) + eY
II <- importance_model$predict(newdata = X)

return(list(II = II, D = D))
}
@@ -315,7 +317,7 @@ crr <- function(treatment,
data,
nfolds=5,
type="dml1",
...){
...) {
cl <- match.call()
if (is.character(treatment)) {
treatment <- as.formula(paste0(treatment, "~", 1))
@@ -326,7 +328,7 @@ crr <- function(treatment,
}
if (length(contrast)!=2)
stop("Expected contrast vector of length 2.")

response_var <- lava::getoutcome(response_model$formula, data=data)
treatment_var <- lava::getoutcome(treatment)
treatment_f <- function(treatment_level, x=paste0(".-", response_var))
@@ -338,18 +340,18 @@ crr <- function(treatment,
importance_formula <- update(treatment, D_~.)
importance_model <- SL(importance_formula, ...)
}

n <- nrow(data)
folds <- split(sample(1:n, n), rep(1:nfolds, length.out = n))
folds <- lapply(folds, sort)

ff <- Reduce(c, folds)
idx <- order(ff)

# D_a = I(A=a)/P(A=a|W)[Y - E[Y|A=a, W]] + E[Y|A=a, W], a = {1,0}
D <- list()
# II = E[E[Y|A=a, W]|V] = E[D_a|V], a = {1,0}
II <- list()
II <- list()
pb <- progressr::progressor(steps = length(contrast)*nfolds)
for (i in seq_along(contrast)) {
a <- contrast[i]
@@ -373,11 +375,11 @@ crr <- function(treatment,
}
names(D) <- contrast
names(II) <- contrast

score <- D[[1]]*II[[2]] - D[[2]]*II[[1]]
score <- score + II[[1]] * II[[2]]
score <- score * II[[2]]^(-2)

if (type=="dml1") {
est1 <- lapply(folds, function(x) cate_fold1(x,
data = data,
@@ -388,14 +390,13 @@ crr <- function(treatment,
est <- coef(lm(score ~ -1+desA$x))
}
names(est) <- names(desA$x)

M1 <- desA$x
C <- -n^(-1) * crossprod(M1)
IF <- -solve(C) %*% t(M1 * as.vector(score - M1 %*% est))
IF <- t(IF)

estimate <- estimate(coef=est, IC=IF)

res <- list(folds=folds,
score=score,
treatment_des=desA,
51 changes: 40 additions & 11 deletions R/design.R
Original file line number Diff line number Diff line change
@@ -10,15 +10,18 @@
##' @author Klaus Kähler Holst
##' @export
design <- function(formula, data, intercept=FALSE,
rm_envir=FALSE, ...) {
tt <- terms(formula, data=data)
rm_envir=FALSE, ..., specials = c("weights", "offset")) {
tt <- terms(formula, data = data)
if (!intercept)
attr(tt, "intercept") <- 0
mf <- model.frame(tt, data=data, ...)
x_levels <- .getXlevels(tt, mf)
x <- model.matrix(mf, data=data)
y <- model.response(mf, type="any")
specials <- names(substitute(list(...)))[-1]
specials <- union(
specials,
names(substitute(list(...)))[-1]
)
specials_list <- c()
if (length(specials)>0) {
for (s in specials) {
@@ -38,19 +41,23 @@ design <- function(formula, data, intercept=FALSE,
}

##' @export
update.design <- function(object, data=NULL, ...) {
update.design <- function(object, data = NULL, ...,
specials = c("weights", "offset")) {
if (is.null(data)) data <- object$data
mf <- with(object, model.frame(terms, data=data, ...,
xlev = xlevels,
drop.unused.levels=FALSE))
mf <- model.frame(object$terms, data=data, ...,
xlev = object$xlevels,
drop.unused.levels=FALSE)
x <- model.matrix(mf, data=data, ..., xlev = object$xlevels)
object[["y"]] <- NULL
for (s in object$specials) {
object[[s]] <- NULL
}
specials <- names(substitute(list(...)))[-1]
specials2 <- names(substitute(list(...)))[-1]
for (s in specials2) {
object[[s]] <- eval(substitute(model.extract(mf, s), list(s = s)))
}
for (s in specials) {
object[[s]] <- eval(substitute(model.extract(mf, s), list(s=s)))
object[[s]] <- do.call(model.extract, list(mf, s))
}
object$specials <- specials
object$x <- x
@@ -61,12 +68,34 @@ update.design <- function(object, data=NULL, ...) {
model.matrix.design <- function(object, drop.intercept = FALSE, ...) {
if (drop.intercept) {
intercept <- which(attr(object$x, "assign") == 0)
if (length(intercept)>0)
return(object$x[, -intercept, drop=FALSE])
if (length(intercept) > 0) {
return(object$x[, -intercept, drop = FALSE])
}
}
return(object$x)
}

##' @export
weights.design <- function(object, ...) {
specials(object, "weights")
}

##' @export
offsets <- function(object, ...) UseMethod("offsets")

##' @export
offsets.design <- function(object, ...) {
specials(object, "offset")
}

##' @export
specials <- function(object, ...) UseMethod("specials")

##' @export
specials.design <- function(object, which, ...) {
return(object[[which]])
}

##' @export
summary.design <- function(object, ...) {
object$x <- object$x[0, ]
36 changes: 24 additions & 12 deletions R/ml_model.R
Original file line number Diff line number Diff line change
@@ -58,7 +58,8 @@ ml_model <- R6::R6Class("ml_model",
estimate,
predict=stats::predict,
predict.args=NULL,
info=NULL, specials,
info=NULL,
specials = c(),
response.arg="y",
x.arg="x",
...) {
@@ -72,7 +73,8 @@ ml_model <- R6::R6Class("ml_model",
if (!("..." %in% formalArgs(estimate))) {
formals(estimate) <- c(formals(estimate), alist(... = ))
}
des.args <- lapply(substitute(specials), function(x) x)[-1]
## des.args <- lapply(substitute(specials), function(x) x)[-1]
des.args <- list(specials = specials)
fit_formula <- "formula"%in%formalArgs(estimate)
fit_response_arg <- response.arg %in% formalArgs(estimate)
fit_x_arg <- x.arg%in%formalArgs(estimate)
@@ -82,7 +84,6 @@ ml_model <- R6::R6Class("ml_model",
## if (!fit_x_arg && !("data"%in%formalArgs(estimate)))
## stop("Estimation method must have an argument 'x' or 'data'")


self$args <- dots
no_formula <- is.null(formula)
if (no_formula) {
@@ -128,10 +129,16 @@ ml_model <- R6::R6Class("ml_model",
}
private$predfun <- function(object, data, ...) {
if (fit_formula || no_formula) {
args <- c(list(object, newdata=data), predict.args, list(...))
args <- c(list(object, newdata = data), predict.args, list(...))
} else {
x <- model.matrix(update(attr(object, "design"), data))
args <- c(list(object, newdata=x), predict.args, list(...))
args <- list(...)
des <- update(attr(object, "design"), data)
for (s in des$specials) {
if (is.null(args[[s]])) args[[s]] <- des[[s]]
}
args <- c(list(object,
newdata = model.matrix(des)
), predict.args, args)
}
return(do.call(private$init.predict, args))
}
@@ -154,7 +161,7 @@ ml_model <- R6::R6Class("ml_model",
estimate = function(data, ..., store=TRUE) {
res <- private$fitfun(data, ...)
if (store) private$fitted <- res
invisible(res)
return(invisible(res))
},

##' @description
@@ -292,7 +299,7 @@ predict.ml_model <- function(object, ...) {
##' @param ... additional arguments to model object
##' @details
##' model 'sl' (SuperLearner::SuperLearner)
##' args: SL.library, cvControl, f<aamily, method
##' args: SL.library, cvControl, family, method
##' example:
##'
##' model 'grf' (grf::regression_forest)
@@ -306,7 +313,6 @@ predict.ml_model <- function(object, ...) {
##' model 'glm'
##' args: family, weights, offset, ...
##'
##'
ML <- function(formula, model="glm", ...) {
model <- tolower(model)
dots <- list(...)
@@ -425,10 +431,16 @@ ML <- function(formula, model="glm", ...) {
## glm, default
m <- ml_model$new(formula, info = "glm", ...,
estimate = function(formula, data, ...) {
stats::glm(formula, data=data, ...)
stats::glm(formula,
data = data,
...
)
},
predict = function(object, newdata) {
stats::predict(object, newdata = newdata, type = "response")
predict = function(object, newdata, ...) {
stats::predict(object,
newdata = newdata,
type = "response"
)
}
)
return(m)
Loading

0 comments on commit 6b162af

Please sign in to comment.