From c3053e6de65158d331b3965478f7ddc688e2bef5 Mon Sep 17 00:00:00 2001 From: Niek Den Teuling Date: Tue, 14 Mar 2023 18:57:15 +0100 Subject: [PATCH] Implemented gridsearch initialization for lcmm methods (#126) --- R/methodLcmmGMM.R | 89 ++++++++++++++++++++++++++++++++++---- man/lcMethodLcmmGBTM.Rd | 11 +++-- man/lcMethodLcmmGMM.Rd | 20 ++++++--- tests/testthat/test-lcmm.R | 11 ++++- 4 files changed, 112 insertions(+), 19 deletions(-) diff --git a/R/methodLcmmGMM.R b/R/methodLcmmGMM.R index 167d5211..5e2ee789 100644 --- a/R/methodLcmmGMM.R +++ b/R/methodLcmmGMM.R @@ -32,10 +32,11 @@ setValidity('lcMethodLcmmGMM', function(object) { #' @param time The name of the time variable. #' @param id The name of the trajectory identifier variable. This replaces the `subject` argument of [lcmm::hlme]. #' @param init Alternative for the `B` argument of [lcmm::hlme], for initializing the hlme fitting procedure. -#' If `"lme.random"` (default): random initialization through a standard linear mixed model. -#' Assigns a fitted standard linear mixed model enclosed in a call to random() to the `B` argument. -#' If `"lme"`, fits a standard linear mixed model and passes this to the `B` argument. -#' If `NULL` or `"default"`, the default [lcmm::hlme] input for `B` is used. +#' Options: +#' * `"lme.random"` (default): random initialization through a standard linear mixed model. Assigns a fitted standard linear mixed model enclosed in a call to random() to the `B` argument. +#' * `"lme"`, fits a standard linear mixed model and passes this to the `B` argument. +#' * `"gridsearch"`, a gridsearch is used with initialization from `"lme.random"`, following the approach used by [lcmm::gridsearch]. To use this initalization, specify arguments `gridsearch.maxiter` (max number of iterations during search), `gridsearch.rep` (number of fits during search), and `gridsearch.parallel` (whether to enable [parallel computation][latrend-parallel]). +#' * `NULL` or `"default"`, the default [lcmm::hlme] input for `B` is used. #' #' The argument is ignored if the `B` argument is specified, or `nClusters = 1`. #' @@ -51,19 +52,24 @@ setValidity('lcMethodLcmmGMM', function(object) { #' mixture = ~ Time, #' random = ~ 1, #' id = "Id", -#' time = "Time", , +#' time = "Time", #' nClusters = 2 #' ) #' gmm <- latrend(method, data = latrendData) #' summary(gmm) #' +#' # define method with gridsearch #' method <- lcMethodLcmmGMM( #' fixed = Y ~ Time, #' mixture = ~ Time, #' random = ~ Time, #' id = "Id", #' time = "Time", -#' nClusters = 3 +#' nClusters = 3, +#' init = "gridsearch", +#' gridsearch.maxiter = 10, +#' gridsearch.rep = 50, +#' gridsearch.parallel = TRUE #' ) #' } #' @family lcMethod implementations @@ -168,7 +174,7 @@ gmm_prepare = function(method, data, envir, verbose, ...) { } if (hasName(method, 'init') && method$nClusters > 1) { - init = match.arg(method$init, c('default', 'lme', 'lme.random')) + init = match.arg(method$init, c('default', 'lme', 'lme.random', 'gridsearch')) if (init == 'default') { init = 'lme' } @@ -191,6 +197,14 @@ gmm_prepare = function(method, data, envir, verbose, ...) { args1$classmb = NULL prepEnv$lme = do.call(lcmm::hlme, args1) args$B = quote(random(lme)) + }, + gridsearch = { + cat(verbose, 'Fitting standard linear mixed model for gridsearch initialization...') + args1 = args + args1$ng = 1 + args1$mixture = NULL + args1$classmb = NULL + prepEnv$lme = do.call(lcmm::hlme, args1) } ) } @@ -222,9 +236,68 @@ gmm_fit = function(method, data, envir, verbose, ...) { model } + +gmm_gridsearch = function(method, data, envir, verbose, ...) { + assert_that( + is.count(method$gridsearch.maxiter), + is.count(method$gridsearch.rep), + is.flag(method$gridsearch.parallel), + hasName(envir, 'lme'), + inherits(envir$lme, 'hlme') + ) + + args = envir$args + gridArgs = args + gridArgs$maxiter = method$gridsearch.maxiter + rep = method$gridsearch.rep + .latrend.lme = envir$lme + `%infix%` = ifelse(method$gridsearch.parallel, `%dopar%`, `%do%`) + + # Conduct gridsearch + timing = .enterTimed(verbose, sprintf('Gridsearch with %d repetitions...', rep)) + gridModels = foreach(k = seq_len(rep)) %infix% { + cat( + verbose, + sprintf('Gridsearch fit %d/%d (%g%%)', k, rep, round(k / rep * 100)) + ) + e = environment() + assign('minit', .latrend.lme, envir = e) + gridArgs$B = substitute(random(minit), env = e) + gridModel = do.call(lcmm::hlme, gridArgs) + gridModel + } + .exitTimed(timing) + + # determine the best candidate solution + gridLogLiks = vapply(gridModels, function(x) x$loglik, FUN.VALUE = 0) + assert_that( + any(is.finite(gridLogLiks)), + msg = 'Failed to obtain a valid fit during gridsearch. Try more reps or higher maxiter?' + ) + iBest = which.max(gridLogLiks) + args$B = gridModels[[iBest]]$best + + # fit the final model + timing = .enterTimed(verbose, 'Final model optimization...') + model = do.call(lcmm::hlme, args) + .exitTimed(timing, msg = 'Done with final model optimization (%s)') + + # + model$fixed = args$fixed + model$mixture = args$mixture + model$random = args$random + model$mb = envir$classmb + model +} + + #' @rdname interface-lcmm setMethod('fit', 'lcMethodLcmmGMM', function(method, data, envir, verbose, ...) { - model = gmm_fit(method, data, envir, verbose, ...) + if (method$init == 'gridsearch') { + model = gmm_gridsearch(method, data, envir, verbose, ...) + } else { + model = gmm_fit(method, data, envir, verbose, ...) + } new( 'lcModelLcmmGMM', diff --git a/man/lcMethodLcmmGBTM.Rd b/man/lcMethodLcmmGBTM.Rd index 6040ec4f..34bd234f 100644 --- a/man/lcMethodLcmmGBTM.Rd +++ b/man/lcMethodLcmmGBTM.Rd @@ -29,10 +29,13 @@ lcMethodLcmmGBTM( \item{nClusters}{The number of clusters to fit. This replaces the \code{ng} argument of \link[lcmm:hlme]{lcmm::hlme}.} \item{init}{Alternative for the \code{B} argument of \link[lcmm:hlme]{lcmm::hlme}, for initializing the hlme fitting procedure. -If \code{"lme.random"} (default): random initialization through a standard linear mixed model. -Assigns a fitted standard linear mixed model enclosed in a call to random() to the \code{B} argument. -If \code{"lme"}, fits a standard linear mixed model and passes this to the \code{B} argument. -If \code{NULL} or \code{"default"}, the default \link[lcmm:hlme]{lcmm::hlme} input for \code{B} is used. +Options: +\itemize{ +\item \code{"lme.random"} (default): random initialization through a standard linear mixed model. Assigns a fitted standard linear mixed model enclosed in a call to random() to the \code{B} argument. +\item \code{"lme"}, fits a standard linear mixed model and passes this to the \code{B} argument. +\item \code{"gridsearch"}, a gridsearch is used with initialization from \code{"lme.random"}, following the approach used by \link[lcmm:gridsearch]{lcmm::gridsearch}. To use this initalization, specify arguments \code{gridsearch.maxiter} (max number of iterations during search), \code{gridsearch.rep} (number of fits during search), and \code{gridsearch.parallel} (whether to enable \link[=latrend-parallel]{parallel computation}). +\item \code{NULL} or \code{"default"}, the default \link[lcmm:hlme]{lcmm::hlme} input for \code{B} is used. +} The argument is ignored if the \code{B} argument is specified, or \code{nClusters = 1}.} diff --git a/man/lcMethodLcmmGMM.Rd b/man/lcMethodLcmmGMM.Rd index a00235f8..ee1c1d6c 100644 --- a/man/lcMethodLcmmGMM.Rd +++ b/man/lcMethodLcmmGMM.Rd @@ -30,10 +30,13 @@ lcMethodLcmmGMM( \item{id}{The name of the trajectory identifier variable. This replaces the \code{subject} argument of \link[lcmm:hlme]{lcmm::hlme}.} \item{init}{Alternative for the \code{B} argument of \link[lcmm:hlme]{lcmm::hlme}, for initializing the hlme fitting procedure. -If \code{"lme.random"} (default): random initialization through a standard linear mixed model. -Assigns a fitted standard linear mixed model enclosed in a call to random() to the \code{B} argument. -If \code{"lme"}, fits a standard linear mixed model and passes this to the \code{B} argument. -If \code{NULL} or \code{"default"}, the default \link[lcmm:hlme]{lcmm::hlme} input for \code{B} is used. +Options: +\itemize{ +\item \code{"lme.random"} (default): random initialization through a standard linear mixed model. Assigns a fitted standard linear mixed model enclosed in a call to random() to the \code{B} argument. +\item \code{"lme"}, fits a standard linear mixed model and passes this to the \code{B} argument. +\item \code{"gridsearch"}, a gridsearch is used with initialization from \code{"lme.random"}, following the approach used by \link[lcmm:gridsearch]{lcmm::gridsearch}. To use this initalization, specify arguments \code{gridsearch.maxiter} (max number of iterations during search), \code{gridsearch.rep} (number of fits during search), and \code{gridsearch.parallel} (whether to enable \link[=latrend-parallel]{parallel computation}). +\item \code{NULL} or \code{"default"}, the default \link[lcmm:hlme]{lcmm::hlme} input for \code{B} is used. +} The argument is ignored if the \code{B} argument is specified, or \code{nClusters = 1}.} @@ -54,19 +57,24 @@ if (rlang::is_installed("lcmm")) { mixture = ~ Time, random = ~ 1, id = "Id", - time = "Time", , + time = "Time", nClusters = 2 ) gmm <- latrend(method, data = latrendData) summary(gmm) + # define method with gridsearch method <- lcMethodLcmmGMM( fixed = Y ~ Time, mixture = ~ Time, random = ~ Time, id = "Id", time = "Time", - nClusters = 3 + nClusters = 3, + init = "gridsearch", + gridsearch.maxiter = 10, + gridsearch.rep = 50, + gridsearch.parallel = TRUE ) } } diff --git a/tests/testthat/test-lcmm.R b/tests/testthat/test-lcmm.R index 99cc6a23..bef7de4b 100644 --- a/tests/testthat/test-lcmm.R +++ b/tests/testthat/test-lcmm.R @@ -1,4 +1,3 @@ -context('LCMM models') skip_if_not_installed('lcmm') skip_on_cran() rngReset() @@ -28,6 +27,9 @@ make.gmm = function(id, time, response, ..., init = 'default') { mc$time = time mc$maxiter = 10 mc$seed = 1 + mc$gridsearch.rep = 2 + mc$gridsearch.maxiter = 5 + mc$gridsearch.parallel = FALSE do.call(lcMethodLcmmGMM, as.list(mc)[-1]) %>% evaluate() } @@ -54,6 +56,13 @@ test_that('gmm with init=lme.random', { expect_true(is.lcModel(model)) }) +test_that('gmm with init=gridsearch', { + skip_on_cran() + method = make.gmm(id = 'Traj', time = 'Assessment', response = 'Value', init = 'gridsearch') + model = latrend(method, testLongData) + expect_true(is.lcModel(model)) +}) + test_that('gmm with NA covariate', { expect_true({ suppressWarnings({