Skip to content

Commit

Permalink
fix: make dynamic data.table evaluation more robust throughout (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdt committed Apr 3, 2024
1 parent c8f2358 commit 6b2cd97
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 20 deletions.
15 changes: 10 additions & 5 deletions R/assert.R
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ are_trajectories_length = function(data, min = 1, id, time) {
)

data = as.data.table(data)
all(data[, uniqueN(get(time)), by = c(id)]$V1 >= min)
all(data[, uniqueN(get(..time)), by = c(id)]$V1 >= min)
}

assertthat::on_failure(are_trajectories_length) = function(call, env) {
Expand All @@ -370,7 +370,7 @@ assertthat::on_failure(are_trajectories_length) = function(call, env) {
time = eval(call$time, env)
min = eval(call$min, env)

dtTraj = data[, .(Moments = uniqueN(get(time))), by = c(id)] %>%
dtTraj = data[, .(Moments = uniqueN(get(..time))), by = c(id)] %>%
.[Moments < min]

sprintf(
Expand All @@ -391,7 +391,12 @@ are_trajectories_equal_length = function(data, id, time) {
data = as.data.table(data)
nTimes = uniqueN(data[[time]])

all(data[, .N == nTimes && uniqueN(get(time)) == nTimes, by = c(id)]$V1)
all(
data[,
.N == nTimes && uniqueN(get(..time)) == nTimes,
by = c(id)
]$V1
)
}

assertthat::on_failure(are_trajectories_equal_length) = function(call, env) {
Expand All @@ -401,7 +406,7 @@ assertthat::on_failure(are_trajectories_equal_length) = function(call, env) {

nTimes = uniqueN(data[[time]])
# check for trajectories with multiple observations at the same moment in time
dtMult = data[, .(HasMult = anyDuplicated(get(time))), by = c(id)] %>%
dtMult = data[, .(HasMult = anyDuplicated(get(..time))), by = c(id)] %>%
.[HasMult == TRUE]

if (any(dtMult$HasMult)) {
Expand Down Expand Up @@ -437,7 +442,7 @@ assertthat::on_failure(have_trajectories_noNA) = function(call, env) {
id = eval(call$id, env)
response = eval(call$response, env)

dtMissing = data[, .(NaCount = sum(is.na(get(response)))), by = c(id)] %>%
dtMissing = data[, .(NaCount = sum(is.na(get(..response)))), by = c(id)] %>%
.[NaCount > 0]

sprintf(
Expand Down
2 changes: 1 addition & 1 deletion R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ setGeneric('trajectories', function(
keepIds = as.data.table(data)[, .(AllNA = all(is.na(get(..response)))), by = c(id)] %>%
.[AllNA == FALSE, get(..id)]

data = as.data.table(data)[get(id) %in% keepIds]
data = as.data.table(data)[get(..id) %in% keepIds]

if (is.factor(data)) {
data[[id]] = droplevels(data[[id]], exclude = NULL)
Expand Down
4 changes: 3 additions & 1 deletion R/latrend.R
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,9 @@ testFold = function(data, fold, id, folds, seed) {
all(ids %in% data[[id]])
)

newdata = data[data[[id]] %in% ids, ]
rowIds = data[[id]]

newdata = data[rowIds %in% ids, ]
if (is.factor(data[[id]])) {
newdata[[id]] = droplevels(newdata[[id]])
}
Expand Down
2 changes: 1 addition & 1 deletion R/methodLcmmGMM.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ gmm_prepare = function(method, data, envir, verbose, ...) {
# Check & process data
id = idVariable(method)
trainData = as.data.table(data) %>%
.[, c(id) := factor(get(id)) %>% as.integer()]
.[, c(id) := factor(get(..id)) %>% as.integer()]

# Create argument list
args = as.list(method, args = lcmm::hlme)
Expand Down
2 changes: 1 addition & 1 deletion R/methodStratify.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ setMethod('fit', 'lcMethodStratify', function(method, data, envir, verbose, ...)
# determine output order
ids = make.ids(data[[id]])

out[match(ids, get(id)), Cluster]
out[match(ids, .id), Cluster, env = list(.id = id)]
}


Expand Down
3 changes: 2 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -1696,8 +1696,9 @@ setMethod('trajectories', 'lcModel', function(object, ...) {
id = idVariable(object)
time = timeVariable(object)
res = responseVariable(object)
columns = c(id, time, res) # needed because of strange dynamic evaluation by data.table

trajdata = subset(data, select = c(id, time, res))
trajdata = subset(data, select = columns)

trajectories(trajdata, id = id, time = time, response = res, ...)
})
Expand Down
13 changes: 6 additions & 7 deletions R/modelApprox.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ setClass('lcApproxModel', contains = 'lcModel')
#' @rdname lcApproxModel-class
#' @inheritParams fitted.lcModel
fitted.lcApproxModel = function(object, ..., clusters = trajectoryAssignments(object)) {
newdata = subset(
model.data(object),
select = c(idVariable(object), timeVariable(object), responseVariable(object))
)
columns = c(idVariable(object), timeVariable(object), responseVariable(object))
newdata = subset(model.data(object), select = columns)
pred = predict(object, newdata = newdata, useCluster = FALSE)
transformFitted(pred, model = object, clusters = clusters)
}
Expand All @@ -40,7 +38,7 @@ setMethod('predictForCluster', 'lcApproxModel',

# check if we need to do any interpolation
if (all(newtimes %in% clusTimes)) {
pred = clusTrajs[match(newtimes, get(time)), get(resp)]
pred = clusTrajs[match(newtimes, .time), .resp, env = list(.time = time, .resp = resp)]
return(pred)
}

Expand All @@ -55,8 +53,9 @@ setMethod('predictForCluster', 'lcApproxModel',
}

dtpred = clusTrajs[,
lapply(.SD, function(y) approxFun(x = get(time), y = y, xout = newtimes)$y),
keyby = Cluster, .SDcols = -c(time)
lapply(.SD, function(y) approxFun(x = .time, y = y, xout = newtimes)$y),
keyby = Cluster, .SDcols = -c(time),
env = list(.time = time)
]

dtpred[[resp]]
Expand Down
5 changes: 3 additions & 2 deletions R/modelPartition.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ setMethod('clusterTrajectories', 'lcModelPartition',

# compute cluster trajectories at all moments in time
clusTrajs = data[,
.(Value = center(get(response))),
keyby = .(Cluster = rowClusters, Time = get(time))
.(Value = center(get(..response))),
keyby = .(Cluster = rowClusters, Time = .time),
env = list(.time = time)
]

if (uniqueN(trajClusters) < nClusters(object)) {
Expand Down
2 changes: 1 addition & 1 deletion R/trajectories.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ setMethod('plotClusterTrajectories', 'data.frame', function(
)

clusTrajData = as.data.table(object) %>%
.[, .(Value = center(get(response))), keyby = c(cluster, time)] %>%
.[, .(Value = center(get(..response))), keyby = c(cluster, time)] %>%
setnames('Value', response)

clusterNames = as.character(unique(clusTrajData[[cluster]]))
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/setup-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ testLongData = generateLongData(
.[, Cluster := Class] %>%
.[, Traj := factor(Traj)] %>%
.[]

# set up capture functions to test for wrong column name handling
response = function(...) stop('response column mishandling')
cluster = function(...) stop('cluster column mishandling')
id = function(...) stop('id column mishandling')

Time = function(...) stop('Time name evaluation')
Value = function(...) stop('Value name evaluation')
Assessment = function(...) stop('Assessment name evaluation')
Cluster = function(...) stop('Cluster name evaluation')
Traj = function(...) stop('Traj name evaluation')
10 changes: 10 additions & 0 deletions tests/testthat/test-cluslong.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,13 @@ test_that('trajectory length warning', {
expect_warning(latrend(mTest, data = testLongData), regexp = 'warnTrajectoryLength')
options(latrend.warnTrajectoryLength = 0)
})


test_that('"time" column', {
timeData = copy(testLongData)
setnames(timeData, 'Assessment', 'time')
method = lcMethodTestKML(time = 'time')
model = latrend(method, data = timeData)
expect_is(model, 'lcModel')
})

7 changes: 7 additions & 0 deletions tests/testthat/test-method.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ test_that('unevaluated values', {
expect_error(m[['missing', eval = FALSE]])
})

test_that('time function confusion', {
m = new('lcMethodTest', time = 'time')
expect_is(m[['time', eval = FALSE]], 'character')
expect_is(m[['time', eval = TRUE]], 'character')
expect_equal(m$time, 'time')
})

test_that('dependency function evaluation', {
method = lcMethodTestLMKM(fun = mean)
expect_is(method$fun, 'function')
Expand Down

0 comments on commit 6b2cd97

Please sign in to comment.