Skip to content

Commit

Permalink
Implemented gridsearch initialization for lcmm methods (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
niekdt committed Mar 14, 2023
1 parent 0131063 commit c3053e6
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 19 deletions.
89 changes: 81 additions & 8 deletions R/methodLcmmGMM.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ setValidity('lcMethodLcmmGMM', function(object) {
#' @param time The name of the time variable.
#' @param id The name of the trajectory identifier variable. This replaces the `subject` argument of [lcmm::hlme].
#' @param init Alternative for the `B` argument of [lcmm::hlme], for initializing the hlme fitting procedure.
#' If `"lme.random"` (default): random initialization through a standard linear mixed model.
#' Assigns a fitted standard linear mixed model enclosed in a call to random() to the `B` argument.
#' If `"lme"`, fits a standard linear mixed model and passes this to the `B` argument.
#' If `NULL` or `"default"`, the default [lcmm::hlme] input for `B` is used.
#' Options:
#' * `"lme.random"` (default): random initialization through a standard linear mixed model. Assigns a fitted standard linear mixed model enclosed in a call to random() to the `B` argument.
#' * `"lme"`, fits a standard linear mixed model and passes this to the `B` argument.
#' * `"gridsearch"`, a gridsearch is used with initialization from `"lme.random"`, following the approach used by [lcmm::gridsearch]. To use this initalization, specify arguments `gridsearch.maxiter` (max number of iterations during search), `gridsearch.rep` (number of fits during search), and `gridsearch.parallel` (whether to enable [parallel computation][latrend-parallel]).
#' * `NULL` or `"default"`, the default [lcmm::hlme] input for `B` is used.
#'
#' The argument is ignored if the `B` argument is specified, or `nClusters = 1`.
#'
Expand All @@ -51,19 +52,24 @@ setValidity('lcMethodLcmmGMM', function(object) {
#' mixture = ~ Time,
#' random = ~ 1,
#' id = "Id",
#' time = "Time", ,
#' time = "Time",
#' nClusters = 2
#' )
#' gmm <- latrend(method, data = latrendData)
#' summary(gmm)
#'
#' # define method with gridsearch
#' method <- lcMethodLcmmGMM(
#' fixed = Y ~ Time,
#' mixture = ~ Time,
#' random = ~ Time,
#' id = "Id",
#' time = "Time",
#' nClusters = 3
#' nClusters = 3,
#' init = "gridsearch",
#' gridsearch.maxiter = 10,
#' gridsearch.rep = 50,
#' gridsearch.parallel = TRUE
#' )
#' }
#' @family lcMethod implementations
Expand Down Expand Up @@ -168,7 +174,7 @@ gmm_prepare = function(method, data, envir, verbose, ...) {
}

if (hasName(method, 'init') && method$nClusters > 1) {
init = match.arg(method$init, c('default', 'lme', 'lme.random'))
init = match.arg(method$init, c('default', 'lme', 'lme.random', 'gridsearch'))
if (init == 'default') {
init = 'lme'
}
Expand All @@ -191,6 +197,14 @@ gmm_prepare = function(method, data, envir, verbose, ...) {
args1$classmb = NULL
prepEnv$lme = do.call(lcmm::hlme, args1)
args$B = quote(random(lme))
},
gridsearch = {
cat(verbose, 'Fitting standard linear mixed model for gridsearch initialization...')
args1 = args
args1$ng = 1
args1$mixture = NULL
args1$classmb = NULL
prepEnv$lme = do.call(lcmm::hlme, args1)
}
)
}
Expand Down Expand Up @@ -222,9 +236,68 @@ gmm_fit = function(method, data, envir, verbose, ...) {
model
}


gmm_gridsearch = function(method, data, envir, verbose, ...) {
assert_that(
is.count(method$gridsearch.maxiter),
is.count(method$gridsearch.rep),
is.flag(method$gridsearch.parallel),
hasName(envir, 'lme'),
inherits(envir$lme, 'hlme')
)

args = envir$args
gridArgs = args
gridArgs$maxiter = method$gridsearch.maxiter
rep = method$gridsearch.rep
.latrend.lme = envir$lme
`%infix%` = ifelse(method$gridsearch.parallel, `%dopar%`, `%do%`)

# Conduct gridsearch
timing = .enterTimed(verbose, sprintf('Gridsearch with %d repetitions...', rep))
gridModels = foreach(k = seq_len(rep)) %infix% {
cat(
verbose,
sprintf('Gridsearch fit %d/%d (%g%%)', k, rep, round(k / rep * 100))
)
e = environment()
assign('minit', .latrend.lme, envir = e)
gridArgs$B = substitute(random(minit), env = e)
gridModel = do.call(lcmm::hlme, gridArgs)
gridModel
}
.exitTimed(timing)

# determine the best candidate solution
gridLogLiks = vapply(gridModels, function(x) x$loglik, FUN.VALUE = 0)
assert_that(
any(is.finite(gridLogLiks)),
msg = 'Failed to obtain a valid fit during gridsearch. Try more reps or higher maxiter?'
)
iBest = which.max(gridLogLiks)
args$B = gridModels[[iBest]]$best

# fit the final model
timing = .enterTimed(verbose, 'Final model optimization...')
model = do.call(lcmm::hlme, args)
.exitTimed(timing, msg = 'Done with final model optimization (%s)')

#
model$fixed = args$fixed
model$mixture = args$mixture
model$random = args$random
model$mb = envir$classmb
model
}


#' @rdname interface-lcmm
setMethod('fit', 'lcMethodLcmmGMM', function(method, data, envir, verbose, ...) {
model = gmm_fit(method, data, envir, verbose, ...)
if (method$init == 'gridsearch') {
model = gmm_gridsearch(method, data, envir, verbose, ...)
} else {
model = gmm_fit(method, data, envir, verbose, ...)
}

new(
'lcModelLcmmGMM',
Expand Down
11 changes: 7 additions & 4 deletions man/lcMethodLcmmGBTM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 14 additions & 6 deletions man/lcMethodLcmmGMM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion tests/testthat/test-lcmm.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
context('LCMM models')
skip_if_not_installed('lcmm')
skip_on_cran()
rngReset()
Expand Down Expand Up @@ -28,6 +27,9 @@ make.gmm = function(id, time, response, ..., init = 'default') {
mc$time = time
mc$maxiter = 10
mc$seed = 1
mc$gridsearch.rep = 2
mc$gridsearch.maxiter = 5
mc$gridsearch.parallel = FALSE

do.call(lcMethodLcmmGMM, as.list(mc)[-1]) %>% evaluate()
}
Expand All @@ -54,6 +56,13 @@ test_that('gmm with init=lme.random', {
expect_true(is.lcModel(model))
})

test_that('gmm with init=gridsearch', {
skip_on_cran()
method = make.gmm(id = 'Traj', time = 'Assessment', response = 'Value', init = 'gridsearch')
model = latrend(method, testLongData)
expect_true(is.lcModel(model))
})

test_that('gmm with NA covariate', {
expect_true({
suppressWarnings({
Expand Down

0 comments on commit c3053e6

Please sign in to comment.