diff --git a/.github/workflows/recheck.yml b/.github/workflows/recheck.yml new file mode 100644 index 00000000..274a0a2e --- /dev/null +++ b/.github/workflows/recheck.yml @@ -0,0 +1,18 @@ +on: + workflow_dispatch: + inputs: + which: + type: choice + description: Which dependents to check + options: + - strong + - most + +name: Reverse dependency check + +jobs: + revdep_check: + name: Reverse check ${{ inputs.which }} dependents + uses: r-devel/recheck/.github/workflows/recheck.yml@v1 + with: + which: ${{ inputs.which }} \ No newline at end of file diff --git a/.lintr b/.lintr index 1ceb393c..93be52f9 100644 --- a/.lintr +++ b/.lintr @@ -1,36 +1,31 @@ linters: linters_with_tags( tags = c( + "best_practices", + "common_mistakes", + "consistency", + "correctness", "default", - "best_practices", - "common_mistakes", - "correctness", - "executing" - ), + "efficiency", + "executing", + "package_development", + "pkg_testthat", + "readability", + "regex", + "robustness", + "style", + "tidy_design" + ), return_linter(), object_overwrite_linter(), - unnecessary_concatenation_linter(), - cyclocomp_linter(complexity_limit = 19), # TODO small refactor - implicit_integer_linter = NULL, # TODO refactor - library_call_linter = NULL, # TODO refactor + object_length_linter(), + line_length_linter(120), + todo_comment_linter = NULL, # TODO refactor implicit_assignment_linter = NULL, # TODO refactor - length_levels_linter = NULL, # TODO refactor - any_duplicated_linter = NULL, # TODO refactor - expect_length_linter = NULL, # TODO refactor - expect_comparison_linter = NULL, # TODO refactor - expect_null_linter = NULL, # TODO refactor - expect_not_linter = NULL, # TODO refactor - expect_named_linter = NULL, # TODO refactor - expect_type_linter = NULL, # TODO refactor - expect_s3_class_linter = NULL, # TODO refactor - fixed_regex_linter = NULL, # TODO refactor - redundant_ifelse_linter = NULL, #TODO refactor + expect_identical_linter = NULL, # TODO big refactor undesirable_function_linter = NULL, # TODO big refactor - line_length_linter(305), # TODO change to 120 - big refactor - object_length_linter(length = 31), # May or may not refactor nzchar_linter = NULL, # Whats the point of this lintr? condition_message_linter = NULL, # Whats the point of this lintr? condition_call_linter = NULL, # Whats the point of this lintr? paste_linter = NULL, # Probably never gonna remove this one object_name_linter = NULL # Probably never gonna remove this one ) - diff --git a/R/S3GenericExtenstions.R b/R/S3GenericExtenstions.R index 16341d19..d096a5da 100644 --- a/R/S3GenericExtenstions.R +++ b/R/S3GenericExtenstions.R @@ -35,14 +35,15 @@ #' c.caretList <- function(...) { new_model_list <- unlist(lapply(list(...), function(x) { - if (!inherits(x, "caretList")) { - if (!inherits(x, "train")) stop("class of modelList1 must be 'caretList' or 'train'") - - # assuming this is a single train object + if (inherits(x, "caretList")) { + x + } else if (inherits(x, "train")) { x <- list(x) - names(x) <- x[[1]]$method + names(x) <- x[[1L]]$method + x + } else { + stop("class of modelList1 must be 'caretList' or 'train'") } - x }), recursive = FALSE) # Make sure names are unique @@ -81,15 +82,15 @@ c.caretList <- function(...) { #' c.train <- function(...) { new_model_list <- unlist(lapply(list(...), function(x) { - if (!inherits(x, "caretList")) { - if (!inherits(x, "train")) stop("class of modelList1 must be 'caretList' or 'train'") - - # assuming this is a single train object + if (inherits(x, "caretList")) { + x + } else if (inherits(x, "train")) { x <- list(x) - names(x) <- x[[1]]$method + names(x) <- x[[1L]]$method x + } else { + stop("class of modelList1 must be 'caretList' or 'train'") } - x }), recursive = FALSE) # Make sure names are unique diff --git a/R/caretEnsemble.R b/R/caretEnsemble.R index ff7334e8..0737e971 100644 --- a/R/caretEnsemble.R +++ b/R/caretEnsemble.R @@ -64,8 +64,8 @@ summary.caretEnsemble <- function(object, ...) { val <- getMetric.train(object$ens_model) cat(paste0("The following models were ensembled: ", types, " \n")) cat("They were weighted: \n") - cat(paste0(paste0(round(wghts, 4), collapse = " "), "\n")) - cat(paste0("The resulting ", metric, " is: ", round(val, 4), "\n")) + cat(paste0(paste0(round(wghts, 4L), collapse = " "), "\n")) + cat(paste0("The resulting ", metric, " is: ", round(val, 4L), "\n")) # Add code to compare ensemble to individual models cat(paste0("The fit for each individual model on the ", metric, " is: \n")) @@ -96,7 +96,7 @@ extractModRes <- function(ensemble) { ), stringsAsFactors = FALSE ) - names(modRes)[2:3] <- c(metric, paste0(metric, "SD")) + names(modRes)[2L:3L] <- c(metric, paste0(metric, "SD")) modRes } @@ -163,12 +163,12 @@ varImp.caretEnsemble <- function(object, ...) { # Convert to data.frame dat <- varImpFrame(coef_importance) - dat[is.na(dat)] <- 0 + dat[is.na(dat)] <- 0L names(dat) <- make.names(names(coef_importance)) # Scale the importances - norm_to_100 <- function(d) d / sum(d) * 100 - dat <- apply(dat, 2, norm_to_100) + norm_to_100 <- function(d) d / sum(d) * 100.0 + dat <- apply(dat, 2L, norm_to_100) # Calculate overall importance model_weights <- coef(object$ens_model$finalModel) @@ -179,7 +179,7 @@ varImp.caretEnsemble <- function(object, ...) { names(model_weights) <- names(object$models) model_weights <- model_weights[names(model_weights) %in% names(coef_importance)] model_weights <- abs(model_weights) - overall <- norm_to_100(apply(dat, 1, weighted.mean, w = model_weights)) + overall <- norm_to_100(apply(dat, 1L, weighted.mean, w = model_weights)) dat <- data.frame(overall = overall, dat) # Order by overall importance @@ -191,7 +191,7 @@ varImp.caretEnsemble <- function(object, ...) { #' @keywords internal # This function only gets called once, in varImp.caretEnsemble clean_varImp <- function(x) { - names(x$importance)[1] <- "Overall" + names(x$importance)[1L] <- "Overall" x$importance <- x$importance[, "Overall", drop = FALSE] x$importance } @@ -206,18 +206,18 @@ varImpFrame <- function(x) { dat$id <- row.names(dat) dat$model <- sub("\\.[^\n]*", "", dat$id) dat$var <- sub("^[^.]*", "", dat$id) - dat$var <- substr(dat$var, 2, nchar(dat$var)) + dat$var <- substr(dat$var, 2L, nchar(dat$var)) # Parse intercept variables - dat$var[grep("Inter", dat$var)] <- "Intercept" + dat$var[grep("Inter", dat$var, fixed = TRUE)] <- "Intercept" dat$id <- NULL row.names(dat) <- NULL dat <- reshape(dat, direction = "wide", v.names = "Overall", idvar = "var", timevar = "model" ) - row.names(dat) <- dat[, 1] - dat[, -1] + row.names(dat) <- dat[, 1L] + dat[, -1L] } #' @title Plot Diagnostics for an caretEnsemble Object @@ -252,9 +252,9 @@ plot.caretEnsemble <- function(x, ...) { theme_bw() + labs(x = "Individual Model Method", y = metricLab) - if (nrow(x$error) > 0) { + if (nrow(x$error) > 0L) { plt <- plt + - geom_hline(linetype = 2, linewidth = 0.2, yintercept = min(x$error[[metricLab]]), color = I("red")) + geom_hline(linetype = 2L, linewidth = 0.2, yintercept = min(x$error[[metricLab]]), color = I("red")) } plt } @@ -274,7 +274,7 @@ extractPredObsResid <- function(object, show_class_id = 2L) { obs <- predobs$obs id <- predobs$rowIndex if (type == "Regression") { - pred <- pred[[1]] + pred <- pred[[1L]] } else { show_class <- levels(object)[show_class_id] pred <- pred[[show_class]] @@ -304,6 +304,7 @@ extractPredObsResid <- function(object, show_class_id = 2L) { #' @importFrom gridExtra grid.arrange #' @export #' @examples +#' \dontrun{ #' set.seed(42) #' data(models.reg) #' ens <- caretEnsemble( @@ -313,6 +314,7 @@ extractPredObsResid <- function(object, show_class_id = 2L) { #' ) #' ) #' suppressWarnings(autoplot(ens)) +#' } autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ...) { stopifnot(is(object, "caretEnsemble")) ensemble_data <- extractPredObsResid(object$ens_model, show_class_id = show_class_id) @@ -361,11 +363,11 @@ autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ... ymin = .data[["ymin"]], ymax = .data[["ymax"]] )) + - ggplot2::geom_point(size = I(3), alpha = I(0.8)) + + ggplot2::geom_point(size = I(3L), alpha = I(0.8)) + ggplot2::theme_bw() + ggplot2::geom_smooth( method = "lm", se = FALSE, - linewidth = I(1.1), color = I("red"), linetype = 2 + linewidth = I(1.1), color = I("red"), linetype = 2L ) + ggplot2::labs( x = "Fitted Values", y = "Range of Resid.", @@ -377,24 +379,24 @@ autoplot.caretEnsemble <- function(object, xvars = NULL, show_class_id = 2L, ... if (is.null(xvars)) { xvars <- names(x_data) xvars <- setdiff(xvars, c(".outcome", ".weights", "(Intercept)")) - xvars <- sample(xvars, 2) + xvars <- sample(xvars, 2L) } data.table::set(x_data, j = "id", value = seq_len(nrow(x_data))) plotdf <- merge(ensemble_data, x_data, by = "id") g5 <- ggplot2::ggplot(plotdf, ggplot2::aes(.data[[xvars[1]]], .data[["resid"]])) + ggplot2::geom_point() + ggplot2::geom_smooth(se = FALSE) + - ggplot2::scale_x_continuous(xvars[1]) + + ggplot2::scale_x_continuous(xvars[1L]) + ggplot2::scale_y_continuous("Residuals") + - ggplot2::labs(title = paste0("Residuals Against ", xvars[1])) + + ggplot2::labs(title = paste0("Residuals Against ", xvars[1L])) + ggplot2::theme_bw() - g6 <- ggplot2::ggplot(plotdf, ggplot2::aes(.data[[xvars[2]]], .data[["resid"]])) + + g6 <- ggplot2::ggplot(plotdf, ggplot2::aes(.data[[xvars[2L]]], .data[["resid"]])) + ggplot2::geom_point() + ggplot2::geom_smooth(se = FALSE) + ggplot2::scale_x_continuous(xvars[2]) + ggplot2::scale_y_continuous("Residuals") + - ggplot2::labs(title = paste0("Residuals Against ", xvars[2])) + + ggplot2::labs(title = paste0("Residuals Against ", xvars[2L])) + ggplot2::theme_bw() # nolint end: object_usage_linter - suppressMessages(gridExtra::grid.arrange(g1, g2, g3, g4, g5, g6, ncol = 2)) + suppressMessages(gridExtra::grid.arrange(g1, g2, g3, g4, g5, g6, ncol = 2L)) } diff --git a/R/caretList.R b/R/caretList.R index fc5058e0..4058ba94 100644 --- a/R/caretList.R +++ b/R/caretList.R @@ -1,5 +1,6 @@ #' @title Generate a specification for fitting a caret model -#' @description A caret model specification consists of 2 parts: a model (as a string) and the arguments to the train call for fitting that model +#' @description A caret model specification consists of 2 parts: a model (as a string) and +#' the arguments to the train call for fitting that model #' @param method the modeling method to pass to caret::train #' @param ... Other arguments that will eventually be passed to caret::train #' @export @@ -12,7 +13,8 @@ caretModelSpec <- function(method = "rf", ...) { } #' @title Check that the tuning parameters list supplied by the user is valid -#' @description This function makes sure the tuning parameters passed by the user are valid and have the proper naming, etc. +#' @description This function makes sure the tuning parameters passed by the user +#' are valid and have the proper naming, etc. #' @param x a list of user-supplied tuning parameters and methods #' @return NULL #' @export @@ -40,7 +42,8 @@ tuneCheck <- function(x) { } #' @title Check that the methods supplied by the user are valid caret methods -#' @description This function uses modelLookup from caret to ensure the list of methods supplied by the user are all models caret can fit. +#' @description This function uses modelLookup from caret to ensure the list of +#' methods supplied by the user are all models caret can fit. #' @param x a list of user-supplied tuning parameters and methods #' @importFrom caret modelLookup #' @return NULL @@ -53,9 +56,9 @@ methodCheck <- function(x) { models <- lapply(x, function(m) { if (is.list(m)) { checkCustomModel(m) - data.frame(type = "custom", model = m$method) + data.frame(type = "custom", model = m$method, stringsAsFactors = FALSE) } else if (is.character(m)) { - data.frame(type = "native", model = m) + data.frame(type = "native", model = m, stringsAsFactors = FALSE) } else { stop(paste0( "Method \"", m, "\" is invalid. Methods must either be character names ", @@ -70,7 +73,7 @@ methodCheck <- function(x) { native_models <- subset(models, get("type") == "native")$model bad_models <- setdiff(native_models, supported_models) - if (length(bad_models) > 0) { + if (length(bad_models) > 0L) { msg <- paste(bad_models, collapse = ", ") stop(paste("The following models are not valid caret models:", msg)) } @@ -79,13 +82,16 @@ methodCheck <- function(x) { } #' @title Check that the trainControl object supplied by the user is valid and has defined re-sampling indexes. -#' @description This function checks the user-supplied trainControl object and makes sure it has all the required fields. If the resampling indexes are missing, it adds them to the model. If savePredictions=FALSE or "none", this function sets it to "final". +#' @description This function checks the user-supplied trainControl object and makes +#' sure it has all the required fields. +#' If the resampling indexes are missing, it adds them to the model. +#' If savePredictions=FALSE or "none", this function sets it to "final". #' @param x a trainControl object. #' @param y the target for the model. Used to determine resampling indexes. #' @importFrom caret createResample createFolds createMultiFolds createDataPartition #' @return NULL trControlCheck <- function(x, y) { - if (!length(x$savePredictions) == 1) { + if (length(x$savePredictions) != 1L) { stop("Please pass exactly 1 argument to savePredictions, e.g. savePredictions='final'") } @@ -95,9 +101,12 @@ trControlCheck <- function(x, y) { } if (is.null(x$index)) { - warning("indexes not defined in trControl. Attempting to set them ourselves, so each model in the ensemble will have the same resampling indexes.") + # So each model in the ensemble will have the same resampling indexes + warning("indexes not defined in trControl. Attempting to set them ourselves.") if (x$method == "none") { - stop("Models that aren't resampled cannot be ensembled. All good ensemble methods rely on out-of sample data. If you really need to ensemble without re-sampling, try the median or mean of the model's predictions.") + # All good ensemble methods rely on out-of sample data. + # If you really need to ensemble without re-sampling, try the median or mean of the model's predictions. + stop("Models that aren't resampled cannot be ensembled.") } else if (x$method == "boot" || x$method == "adaptive_boot") { x$index <- createResample(y, times = x$number, list = TRUE) } else if (x$method == "cv" || x$method == "adaptive_cv") { @@ -110,17 +119,18 @@ trControlCheck <- function(x, y) { times = x$number, p = 0.5, list = TRUE, - groups = min(5, length(y)) + groups = min(5L, length(y)) ) } else { - stop(paste0("caretList does not currently know how to handle cross-validation method='", x$method, "'. Please specify trControl$index manually")) + stop(paste0("caretList can't handle cv method='", x$method, "'. Please specify trControl$index manually")) } } x } #' @title Extracts the target variable from a set of arguments headed to the caret::train function. -#' @description This function extracts the y variable from a set of arguments headed to a caret::train model. Since there are 2 methods to call caret::train, this function also has 2 methods. +#' @description This function extracts the y variable from a set of arguments headed to a caret::train model. +#' Since there are 2 methods to call caret::train, this function also has 2 methods. #' @param ... a set of arguments, as in the caret::train function extractCaretTarget <- function(...) { UseMethod("extractCaretTarget") @@ -128,7 +138,8 @@ extractCaretTarget <- function(...) { #' @title Extracts the target variable from a set of arguments headed to the caret::train.default function. #' @description This function extracts the y variable from a set of arguments headed to a caret::train.default model. -#' @param x an object where samples are in rows and features are in columns. This could be a simple matrix, data frame or other type (e.g. sparse matrix). See Details below. +#' @param x an object where samples are in rows and features are in columns. This could be a simple matrix, data frame +#' or other type (e.g. sparse matrix). See Details below. #' @param y a numeric or factor vector containing the outcome for each sample. #' @param ... ignored #' @method extractCaretTarget default @@ -153,11 +164,18 @@ extractCaretTarget.formula <- function(form, data, ...) { #' Build a list of train objects suitable for ensembling using the \code{\link{caretEnsemble}} #' function. #' -#' @param ... arguments to pass to \code{\link[caret]{train}}. These arguments will determine which train method gets dispatched. -#' @param trControl a \code{\link[caret]{trainControl}} object. We are going to intercept this object check that it has the "index" slot defined, and define the indexes if they are not. -#' @param methodList optional, a character vector of caret models to ensemble. One of methodList or tuneList must be specified. -#' @param tuneList optional, a NAMED list of caretModelSpec objects. This much more flexible than methodList and allows the specification of model-specific parameters (e.g. passing trace=FALSE to nnet) -#' @param continue_on_fail, logical, should a valid caretList be returned that excludes models that fail, default is FALSE +#' @param ... arguments to pass to \code{\link[caret]{train}}. +#' These arguments will determine which train method gets dispatched. +#' @param trControl a \code{\link[caret]{trainControl}} object. +#' We are going to intercept this object check that it has the +#' "index" slot defined, and define the indexes if they are not. +#' @param methodList optional, a character vector of caret models to ensemble. +#' One of methodList or tuneList must be specified. +#' @param tuneList optional, a NAMED list of caretModelSpec objects. +#' This much more flexible than methodList and allows the +#' specification of model-specific parameters (e.g. passing trace=FALSE to nnet) +#' @param continue_on_fail, logical, should a valid caretList be returned that +#' excludes models that fail, default is FALSE #' @return A list of \code{\link[caret]{train}} objects. If the model fails to build, #' it is dropped from the list. #' @importFrom caret trainControl train @@ -194,7 +212,7 @@ caretList <- function( if (is.null(tuneList) && is.null(methodList)) { stop("Please either define a methodList or tuneList") } - if (!is.null(methodList) && any(duplicated(methodList))) { + if (!is.null(methodList) && anyDuplicated(methodList) > 0L) { warning("Duplicate entries in methodList. Using unqiue methodList values.") methodList <- unique(methodList) } @@ -233,7 +251,7 @@ caretList <- function( nulls <- sapply(modelList, is.null) modelList <- modelList[!nulls] - if (length(modelList) == 0) { + if (length(modelList) == 0L) { stop("caret:train failed for all models. Please inspect your data.") } class(modelList) <- c("caretList", "list") @@ -338,7 +356,7 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_ if (is.null(newdata)) { train_data_nulls <- sapply((object), function(x) is.null(x[["trainingData"]])) if (any(train_data_nulls)) { - stop("newdata is NULL and trainingData is NULL for some models. Please pass newdata or retrain with returnData=TRUE.") + stop("newdata is NULL and trainingData is NULL for some models. Use newdata or retrain with returnData=TRUE.") } } diff --git a/R/caretStack.R b/R/caretStack.R index 8039cb41..aa44d2bf 100644 --- a/R/caretStack.R +++ b/R/caretStack.R @@ -6,10 +6,13 @@ #' @details Check the models, and make a matrix of obs and preds #' #' @param all.models a list of caret models to ensemble. -#' @param excluded_class_id The integer level to exclude from binary classification or multiclass problems. If 0, will include all levels. +#' @param excluded_class_id The integer level to exclude from binary classification or multiclass problems. +#' If 0, will include all levels. #' @param ... additional arguments to pass to the optimization function #' @return S3 caretStack object -#' @references Caruana, R., Niculescu-Mizil, A., Crew, G., & Ksikes, A. (2004). Ensemble Selection from Libraries of Models. \url{https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf} +#' @references Caruana, R., Niculescu-Mizil, A., Crew, G., & Ksikes, A. (2004). +#' Ensemble Selection from Libraries of Models. +#' \url{https://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf} #' @export #' @examples #' \dontrun{ @@ -91,7 +94,12 @@ predict.caretStack <- function( warning("No excluded_class_id set. Setting to 1L.") } - preds <- predict(object$models, newdata = newdata, verbose = verbose, excluded_class_id = object[["excluded_class_id"]]) + preds <- predict( + object$models, + newdata = newdata, + verbose = verbose, + excluded_class_id = object[["excluded_class_id"]] + ) meta_preds <- predict(object$ens_model, newdata = preds, type = type, ...) if (se || return_weights) { @@ -100,13 +108,13 @@ predict.caretStack <- function( model_methods <- colnames(preds) model_weights <- lapply(model_weights, function(class_weights) { # ensure that we have a numeric vector - class_weights <- ifelse(is.finite(class_weights), class_weights, 0) + class_weights <- ifelse(is.finite(class_weights), class_weights, 0L) # normalize weights class_weights <- class_weights / sum(class_weights) names(class_weights) <- row.names(imp) # set 0 weights for methods that are not present in varImp for (m in setdiff(model_methods, names(class_weights))) { - class_weights[m] <- 0 + class_weights[m] <- 0L } class_weights }) @@ -121,7 +129,7 @@ predict.caretStack <- function( overall_weights <- model_weights$Overall[model_methods] # Use overall weights to calculate standard error in regression estimations - std_error <- apply(preds, 1, wtd.sd, w = overall_weights, na.rm = TRUE) + std_error <- apply(preds, 1L, wtd.sd, w = overall_weights, na.rm = TRUE) std_error <- qnorm(level) * std_error out <- data.frame( fit = meta_preds, @@ -216,8 +224,9 @@ plot.caretStack <- function(x, ...) { } #' @title Comparison dotplot for a caretStack object -#' @description This is a function to make a dotplot from a caretStack. It uses dotplot from the caret package on all the models in the ensemble, excluding the final ensemble model. -#' At the moment, this function only works if the ensembling model has the same number of resamples as the component models. +#' @description This is a function to make a dotplot from a caretStack. It uses dotplot from the +#' caret package on all the models in the ensemble, excluding the final ensemble model.At the moment, +#' this function only works if the ensembling model has the same number of resamples as the component models. #' @param x An object of class caretStack #' @param ... passed to dotplot #' @importFrom lattice dotplot diff --git a/R/helper_functions.R b/R/helper_functions.R index e00b9f8e..a934ecde 100644 --- a/R/helper_functions.R +++ b/R/helper_functions.R @@ -14,7 +14,8 @@ utils::globalVariables(c(".SD", ".data")) # Disables warnings from R CMD CHECk, ##################################################### #' @title Validate the excluded class -#' @description Helper function to ensure that the excluded level for classification is an integer. Set to 0L to exclude no class. +#' @description Helper function to ensure that the excluded level for classification is an integer. +#' Set to 0L to exclude no class. #' @param arg The value to check #' @return integer validateExcludedClass <- function(arg) { @@ -29,7 +30,7 @@ validateExcludedClass <- function(arg) { "classification excluded level must be numeric: ", arg )) } - if (length(arg) != 1) { + if (length(arg) != 1L) { stop(paste0( "classification excluded level must have a length of 1: length=", length(arg) )) @@ -52,7 +53,7 @@ validateExcludedClass <- function(arg) { "classification excluded level must be finite: ", arg )) } - if (out < 0) { + if (out < 0L) { stop(paste0( "classification excluded level must be >= 0: ", arg )) @@ -67,10 +68,9 @@ validateExcludedClass <- function(arg) { #' @param all_classes a character vector of all classes #' @param excluded_class_id an integer indicating the class to exclude dropExcludedClass <- function(x, all_classes, excluded_class_id) { - stopifnot(is(x, "data.table")) - stopifnot(is.character(all_classes)) + stopifnot(is(x, "data.table"), is.character(all_classes)) excluded_class_id <- validateExcludedClass(excluded_class_id) - if (length(all_classes) > 1) { + if (length(all_classes) > 1L) { excluded_class <- all_classes[excluded_class_id] # Note that if excluded_class_id is 0, no class will be excludede classes_included <- setdiff(all_classes, excluded_class) x <- x[, classes_included, drop = FALSE, with = FALSE] @@ -106,8 +106,10 @@ extractModelName <- function(x) { #' @importFrom data.table data.table setorderv set extractBestPreds <- function(x) { # Checks - stopifnot(is(x, "train")) - stopifnot(x$control$savePredictions %in% c("all", "final", TRUE)) + stopifnot( + is(x, "train"), + x$control$savePredictions %in% c("all", "final", TRUE) + ) # Extract the best tune and the pred data a <- data.table::data.table(x$bestTune, key = names(x$bestTune)) @@ -141,7 +143,7 @@ extractModelType <- function(list_of_models) { # TODO: Maybe in the future we can combine reg and class models # Also, this check is redundant, but I think that is ok - stopifnot(length(type) == 1) + stopifnot(length(type) == 1L) if (!type %in% c("Classification", "Regression")) { stop(paste("Unknown model type:", type)) } @@ -154,14 +156,14 @@ extractModelType <- function(list_of_models) { extractObsLevels <- function(list_of_models) { all_levels <- lapply(list_of_models, levels) all_levels <- unique(all_levels) - stopifnot(length(all_levels) == 1) - all_levels <- all_levels[[1]] + stopifnot(length(all_levels) == 1L) + all_levels <- all_levels[[1L]] all_levels } #' @title Extract the best predictions (and observeds) from a list of train objects -#' @description Extract predictions (and observeds) for the best tune from a list of caret models. This function extracts -#' the raw preds from regression models and the class probs from classification models. +#' @description Extract predictions (and observeds) for the best tune from a list of caret models. +#' This function extracts the raw preds from regression models and the class probs from classification models. #' Note that it extract preds and obs in one go, rather than separately. This is because caret can save the internal #' preds/obs from all resamples rather than just the final. So we subset the internal pred/obs to just the best tuning #' (from caret) and return the pred and obs for that tune. @@ -200,16 +202,18 @@ extractBestPredsAndObs <- function(list_of_models, excluded_class_id = 1L) { preds <- data.table::as.data.table(preds) # Return - # TODO: make this a data.table - # TODO: make Classifciaiton pull from each sub-model - # TODO: aggregate by row index and sort by row inde3x - # TODO: merge with all possible IDs, warn on NAs and fill with 0 - # TODO: allow different models, different methods, different resamples, different types. Only require a common set of rows + # TODO: + # - make this a data.table + # - make Classifciaiton pull from each sub-model + # - aggregate by row index and sort by row inde3x + # - merge with all possible IDs, warn on NAs and fill with 0 + # - allow different models, different methods, different resamples, different types. + # - Only require a common set of rows out <- list( preds = preds, - obs = preds_and_obs[[1]][["obs"]], - rowIndex = preds_and_obs[[1]][["rowIndex"]], - Resample = preds_and_obs[[1]][["Resample"]], + obs = preds_and_obs[[1L]][["obs"]], + rowIndex = preds_and_obs[[1L]][["rowIndex"]], + Resample = preds_and_obs[[1L]][["Resample"]], type = type ) invisible(gc(reset = TRUE)) @@ -243,8 +247,10 @@ checkCustomModel <- function(x) { #' @param list_of_models a list of caret models to check check_caretList_classes <- function(list_of_models) { # Check that we have a list of train models - stopifnot(is(list_of_models, "caretList")) - stopifnot(sapply(list_of_models, is, "train")) + stopifnot( + is(list_of_models, "caretList"), + sapply(list_of_models, is, "train") + ) invisible(NULL) } @@ -260,7 +266,7 @@ check_caretList_model_types <- function(list_of_models) { for (model in list_of_models) { unique_obs <- unique(model$pred$obs) if (is.null(unique_obs)) { - stop("No predictions saved by train. Please re-run models with trainControl set with savePredictions = 'final'.") + stop("No predictions saved by train. Please re-run models with trainControl savePredictions = 'final'") } } } @@ -273,7 +279,10 @@ check_caretList_model_types <- function(list_of_models) { if (!all(classProbs)) { bad_models <- names(list_of_models)[!classProbs] bad_models <- paste(bad_models, collapse = ", ") - stop("Some models were fit with no class probabilities. Please re-fit them with trainControl, classProbs = TRUE: ", bad_models) + stop( + "Some models were fit with no class probabilities. Please re-fit them with trainControl, classProbs = TRUE: ", + bad_models + ) } } invisible(NULL) @@ -287,7 +296,7 @@ check_bestpreds_resamples <- function(modelLibrary) { resamples <- lapply(modelLibrary, function(x) x[["Resample"]]) names(resamples) <- names(modelLibrary) check <- length(unique(resamples)) - if (check != 1) { + if (check != 1L) { stop("Component models do not have the same re-sampling strategies") } invisible(NULL) @@ -301,7 +310,7 @@ check_bestpreds_indexes <- function(modelLibrary) { rows <- lapply(modelLibrary, function(x) x[["rowIndex"]]) names(rows) <- names(modelLibrary) check <- length(unique(rows)) - if (check != 1) { + if (check != 1L) { stop("Re-sampled predictions from each component model do not use the same rowIndexes from the origial dataset") } invisible(NULL) @@ -315,8 +324,8 @@ check_bestpreds_obs <- function(modelLibrary) { obs <- lapply(modelLibrary, function(x) x[["obs"]]) names(obs) <- names(modelLibrary) check <- length(unique(obs)) - if (check != 1) { - stop("Observed values for each component model are not the same. Please re-train the models with the same Y variable") + if (check != 1L) { + stop("Observed values for each component model are not the same. Re-train the models with the same Y variable") } invisible(NULL) } @@ -332,11 +341,11 @@ check_bestpreds_preds <- function(modelLibrary) { clases <- sapply(pred, class) if (is.matrix(clases)) { - clases <- apply(clases, 2, paste, collapse = " ") + clases <- apply(clases, 2L, paste, collapse = " ") } classes <- unique(clases) check <- length(classes) - if (check != 1) { + if (check != 1L) { stop( paste0( "Component models do not all have the same type of predicitons. Predictions are a mix of ", @@ -354,10 +363,10 @@ check_bestpreds_preds <- function(modelLibrary) { #' @param list_of_models a list of caret models to check #' @export check_binary_classification <- function(list_of_models) { - if (is.list(list_of_models) && length(list_of_models) > 1) { + if (is.list(list_of_models) && length(list_of_models) > 1L) { lapply(list_of_models, function(x) { # avoid regression models - if (is(x, "train") && !is.null(x$pred$obs) && is.factor(x$pred$obs) && length(levels(x$pred$obs)) > 2) { + if (is(x, "train") && !is.null(x$pred$obs) && is.factor(x$pred$obs) && nlevels(x$pred$obs) > 2L) { stop("caretEnsemble only supports binary classification problems") } }) @@ -377,13 +386,12 @@ check_binary_classification <- function(list_of_models) { #' @export # https://stats.stackexchange.com/a/61285 wtd.sd <- function(x, w, na.rm = FALSE) { - stopifnot(is.numeric(x)) - stopifnot(is.numeric(w)) + stopifnot(is.numeric(x), is.numeric(w)) xWbar <- weighted.mean(x, w, na.rm = na.rm) w <- w / mean(w, na.rm = na.rm) - variance <- sum((w * (x - xWbar)^2) / (sum(w, na.rm = na.rm) - 1), na.rm = na.rm) + variance <- sum((w * (x - xWbar)^2L) / (sum(w, na.rm = na.rm) - 1L), na.rm = na.rm) out <- sqrt(variance) out diff --git a/cobertura.xml b/cobertura.xml index f54ae90d..1d4a1c61 100644 --- a/cobertura.xml +++ b/cobertura.xml @@ -1,6 +1,6 @@ - + /Users/zach/source/caretEnsemble @@ -441,513 +441,522 @@ - + - - - - - - - - - - - + + + + + + + + + + + - - - - - + + - - - - - + + + + + - - - + + + + + - - - - + + + + + - - - - - - + + - - - - - - - - - - + + + + + - - - - - - - + + + + + + + + + + + + + + + + - + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + + + + + + + + + - + - - - + + + - + - - - - - - - - - + + + + + + + + + - - + + - - - - - - - - - - - - - - - - - - + + + + + + + + - - - + + + + + + + + + + + + + - - + - - - - - - - - - - - - - - - + + + + + + + + + + + + + - - - - - + + + + + - - - + + + + + - - - - - - - - - - + + + + + + + - - - - - - - - - - + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + + + + + + + + + + + - - - + + - - + + + - - - - - - - - - - - - + + + + + + + + + + + + - - + - - - - - - - - - - - + + + + + + + + + + + - - - - - + + + + + + + + + + + - + - + - - - - + + + + - + - - + + - - - + + - - - + + + - - - - - - - - - - - + + + + + + + + + + + + - - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + - - + - - + + - - + + - - - - - - - - + + + + + + + + - - + + - + + - @@ -970,200 +979,205 @@ - - - - - - - - + + + + + + + + + + - - - - - - + + + + + + - - + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - - - - - + + + + + + - - - + + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - - - + + + + + + - - - - - - + + + + + + - - - - - - + + + + + + - - - - - - - - - - - + + + - - - + + + + + + + + + + + - - - - - - - + + + + + + + - - - - - - - + + + + + + - - + - - + + - - + + - - - - - - - - + + + + + + + + - - + + - - + + @@ -1178,121 +1192,127 @@ - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + - - - - + + + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + @@ -1301,58 +1321,68 @@ - + + + - + - - - + + + + - - + + + - + - - - + + + + - + + + - + - - - - + + + + - + + + - + - - - + + + + diff --git a/coverage-report.html b/coverage-report.html index 61a8dce3..4c46b6ae 100644 --- a/coverage-report.html +++ b/coverage-report.html @@ -107,257 +107,257 @@

caretEnsemble coverage - 100.00%

-
- +
+
-