Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and standardize predcitions #290

Merged
merged 18 commits into from
Jul 29, 2024
Merged

Refactor and standardize predcitions #290

merged 18 commits into from
Jul 29, 2024

Conversation

zachmayer
Copy link
Owner

@zachmayer zachmayer commented Jul 29, 2024

New function caretStack:

  • stacked predictions and new predictions now share much of their logic (e.g. for excluding classes)
  • reg/classification/multiclass predictions are now much more consistent with each other
  • can now ensemble models with a mix of model types!
    • list of models no longer need to all have the same type
    • list of models no longer need to all have the same resamples
    • list of models no longer need to all have the same target
    • list of models no longer need to all have the same type of predictions
    • list of models DO need to be trained on data with the same number of rows, at least for stacked predictions (they can still predict on new data if they were trained on different data)
  • stacked predictions now aggregates by rowID and sorts so strategies which repeat (e.g. boot and repeatedCV) work correctly.
  • Each model must still be trained on a common length dataset

Other changes

  • predict.caretList and predict.caretStack now always return data.tables
  • caretStack now supports transfer learning!
  • refactored variable importance too
  • updated tests and vignettes
  • Reorganized some code files
  • more linting

Copy link
Contributor

coderabbitai bot commented Jul 29, 2024

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

The recent updates to the caretEnsemble package enhance its functionality for model prediction and evaluation. Key changes include improved methods for handling class exclusions during predictions, updates to model stacking with new parameters, and adjustments to the testing framework for better validation of outputs. Documentation has been refined to clarify usage, particularly in relation to managing predictions and class probabilities, contributing to a more robust and user-friendly implementation.

Changes

File(s) Change Summary
.gitignore Added directories /doc/ and /Meta/ to be ignored by Git.
.lintr Added cyclocomp_linter for complexity checks; some linters set to NULL with TODO notes.
DESCRIPTION Updated package release date from "2024-06-25" to "2024-07-27".
Makefile Added target for building vignettes and updated clean commands for vignettes and DLLs.
NAMESPACE Removed check_binary_classification from exported functions.
R/caretEnsemble-package.R Imported .SD from data.table and .data from rlang for enhanced data handling.
R/caretEnsemble.R Introduced check_binary_classification function for validating binary classification models.
R/caretList.R Added checkCustomModel function and modified prediction logic to improve handling of classes.
R/caretPredict.R New utilities for managing class exclusions and validating model types; added caretPredict function.
R/caretStack.R Enhanced caretStack function for predictions with new parameters; refined predictions logic.
R/helper_functions.R Removed multiple functions related to model validation and prediction extraction.
inst/WORDLIST Added new terms: classProbs, coercible, and newdata.
man/*.Rd Updated documentation for various functions to reflect changes in functionality and parameters.
tests/**/* Modified tests to accommodate new functionality in predictions, handling more robust scenarios.
vignettes/caretEnsemble-intro.Rmd Revised AUC reporting and streamlined model prediction methods to exclude specific classes.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CaretEnsemble
    participant Model
    participant Predictions

    User->>CaretEnsemble: Request prediction
    CaretEnsemble->>Model: Generate predictions
    Model->>CaretEnsemble: Return predictions
    CaretEnsemble->>Predictions: Process predictions with exclusions
    Predictions->>User: Return final predictions
Loading
sequenceDiagram
    participant User
    participant CaretStack
    participant ModelStack
    participant Predictions

    User->>CaretStack: Request stacked predictions
    CaretStack->>ModelStack: Retrieve models
    ModelStack->>CaretStack: Return model predictions
    CaretStack->>Predictions: Combine predictions
    Predictions->>User: Return final stacked predictions
Loading

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

‼️ IMPORTANT
Auto-reply has been disabled for this repository in the CodeRabbit settings. The CodeRabbit bot will not respond to your replies unless it is explicitly tagged.

  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Outside diff range, codebase verification and nitpick comments (2)
R/helper_functions.R (1)

Line range hint 14-22:
Fix missing dependency.

The function extractModelName references checkCustomModel, which has been removed. This will cause an error when handling custom models.

-  if (is.list(x$method)) {
-    checkCustomModel(x$method)$method
-  } else if (x$method == "custom") {
-    checkCustomModel(x$modelInfo)$method
-  } else {
-    x$method
-  }
+  if (is.list(x$method) || x$method == "custom") {
+    stop("Custom model handling is not supported. The `checkCustomModel` function has been removed.")
+  } else {
+    x$method
+  }
man/caretPredict.Rd (1)

1-27: Consider adding an example for excluded_class_id.

The documentation is comprehensive, but an example illustrating the use of the excluded_class_id argument would enhance clarity.

+ \examples{
+   # Example usage of caretPredict with excluded_class_id
+   model <- train(Species ~ ., data = iris, method = "rf")
+   predictions <- caretPredict(model, newdata = iris, excluded_class_id = 2L)
+   print(predictions)
+ }
Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 037d994 and e2b8a1d.

Files selected for processing (32)
  • .gitignore (1 hunks)
  • .lintr (1 hunks)
  • DESCRIPTION (1 hunks)
  • Makefile (3 hunks)
  • NAMESPACE (1 hunks)
  • R/caretEnsemble-package.R (1 hunks)
  • R/caretEnsemble.R (8 hunks)
  • R/caretList.R (3 hunks)
  • R/caretPredict.R (1 hunks)
  • R/caretStack.R (4 hunks)
  • R/helper_functions.R (2 hunks)
  • inst/WORDLIST (3 hunks)
  • man/caretPredict.Rd (1 hunks)
  • man/caretStack.Rd (1 hunks)
  • man/checkCustomModel.Rd (1 hunks)
  • man/check_binary_classification.Rd (2 hunks)
  • man/dropExcludedClass.Rd (2 hunks)
  • man/extractModelType.Rd (1 hunks)
  • man/predict.caretList.Rd (1 hunks)
  • man/predict.caretStack.Rd (2 hunks)
  • man/validateExcludedClass.Rd (2 hunks)
  • man/wtd.sd.Rd (1 hunks)
  • tests/testthat/test-S3-generic-extensions.R (2 hunks)
  • tests/testthat/test-caretEnsemble.R (4 hunks)
  • tests/testthat/test-caretList.R (3 hunks)
  • tests/testthat/test-caretStack.R (9 hunks)
  • tests/testthat/test-classSelection.R (1 hunks)
  • tests/testthat/test-ensembleMethods.R (3 hunks)
  • tests/testthat/test-helper_functions.R (9 hunks)
  • tests/testthat/test-multiclass.R (10 hunks)
  • tests/testthat/test-parallel.R (1 hunks)
  • vignettes/caretEnsemble-intro.Rmd (4 hunks)
Files skipped from review due to trivial changes (7)
  • .gitignore
  • DESCRIPTION
  • man/checkCustomModel.Rd
  • man/check_binary_classification.Rd
  • man/dropExcludedClass.Rd
  • man/validateExcludedClass.Rd
  • man/wtd.sd.Rd
Additional comments not posted (101)
inst/WORDLIST (3)

4-4: Term Addition: classProbs

The term classProbs is correctly added to the list. Ensure that this term is consistently used across the codebase for clarity.


17-17: Term Addition: coercible

The term coercible is correctly added to the list. Ensure that this term is consistently used across the codebase for clarity.


33-33: Term Addition: newdata

The term newdata is correctly added to the list. Ensure that this term is consistently used across the codebase for clarity.

man/extractModelType.Rd (5)

2-2: Update Documentation Reference

The documentation reference is correctly updated to R/caretPredict.R. Ensure that this file contains the relevant documentation for the extractModelType function.


5-5: Update Title for Clarity

The title is updated to "Extract the model type from a train object," which is clear and specific. This change improves the readability of the documentation.


7-7: Update Function Usage

The function usage is updated to include the validate_for_stacking parameter. This update is clear and correctly formatted.


10-13: Update Parameter Descriptions

The parameter descriptions are updated to reflect the new object parameter and the validate_for_stacking flag. These updates are clear and improve the documentation's specificity.


14-20: Update Return Value and Description

The return value and description are updated to reflect the function's behavior accurately. The description now includes validation logic for classification models, ensuring they can predict probabilities and meet conditions for stacked predictions. This change enhances the documentation's clarity and usability.

man/predict.caretList.Rd (1)

7-7: Update Default Value for excluded_class_id

The default value for the excluded_class_id parameter is updated from 0L to 1L. Ensure that this change is correctly documented and that users are aware of the new default behavior.

.lintr (1)

17-24: Good addition of cyclocomp_linter.

The addition of the cyclocomp_linter with a complexity threshold of 17 is a positive step for maintainability. Ensure to prioritize the refactoring tasks indicated by the TODO comments.

tests/testthat/test-parallel.R (5)

2-8: Ensure data is loaded correctly.

Loading the data at the beginning of the test file is a good practice. Ensure that the data files models.reg, X.reg, Y.reg, models.class, X.class, and Y.class are correctly loaded and available for the tests.


13-13: Efficiently create a larger dataset.

Using data.table::rbindlist to create a larger dataset is efficient. Ensure that the X.reg data is correctly replicated 100 times.


17-20: Use expect_equivalent for flexible comparison.

Switching from expect_equal to expect_equivalent allows for more flexibility in numerical comparisons, which is useful for predictions that may have slight variations.


23-26: Consistent predictions with standard errors.

The tests for predictions with standard errors (se = TRUE) are consistent with the previous tests. Ensure that the predictions are equivalent when using a larger dataset.


29-35: Consistent predictions with weights.

The tests for predictions with weights (return_weights = TRUE) are consistent with the previous tests. Ensure that the predictions are equivalent when using a larger dataset.

R/caretEnsemble-package.R (1)

10-11: New imports for data manipulation.

The new import statements for .SD from data.table and .data from rlang enhance the package's data manipulation capabilities. Ensure that these imports are used correctly in the package functions.

NAMESPACE (1)

Line range hint 1-1:
Removal of check_binary_classification from exports.

The check_binary_classification function has been removed from the list of exported entities. Ensure that this function is not used externally or provide an alternative solution if needed.

man/caretStack.Rd (3)

7-7: LGTM! The usage section is clear and consistent.

The new parameters new_X and new_y are well-documented and enhance the function's flexibility.


10-17: LGTM! The arguments section is clear and comprehensive.

The descriptions for new_X and new_y are well-written and provide sufficient context for users.


31-40: LGTM! The details section is informative and relevant.

The expanded details provide valuable context on when to use stacking versus transfer learning.

man/predict.caretStack.Rd (4)

13-14: LGTM! The usage section is clear and consistent.

The new parameters excluded_class_id and return_class_only are well-documented and enhance the function's flexibility.


31-40: LGTM! The arguments section is clear and comprehensive.

The descriptions for excluded_class_id and return_class_only are well-written and provide sufficient context for users.


42-42: LGTM! The arguments section is clear and comprehensive.

The description for verbose is well-written and provides sufficient context for users.


47-48: LGTM! The value section is clear and relevant.

The updated value section specifies the return type as a data.table, which is consistent with the function's behavior.

Makefile (4)

3-3: LGTM! The .PHONY target is correctly updated.

The addition of build-vignettes to the .PHONY target ensures that the target is not confused with a file.


6-6: LGTM! The all target is correctly updated.

The inclusion of build-vignettes in the all target ensures that vignettes are built as part of the default build process.


92-93: LGTM! The build-vignettes target is correctly added.

The new build-vignettes target enhances the build process by integrating vignette management.


106-107: LGTM! The clean target is correctly updated.

The additional commands for cleaning vignettes and DLLs ensure that the clean process is comprehensive.

tests/testthat/test-classSelection.R (2)

70-72: Validation checks enhance robustness.

The added validation checks ensure that the predictions and probabilities conform to expected dimensions, enhancing the robustness of the function.


68-69: Ensure return_class_only parameter is supported.

The predict function now uses the return_class_only parameter. Ensure that this parameter is supported by the predict function in all contexts where it is used.

Verification successful

Ensure return_class_only parameter is supported.

The predict function supports the return_class_only parameter as confirmed by its presence in the function definitions and documentation.

  • R/caretStack.R
  • R/caretList.R
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the `return_class_only` parameter is supported by the `predict` function.

# Test: Search for the `predict` function definition. Expect: The function should support `return_class_only` parameter.
rg --type r --pattern $'predict'

Length of output: 443


Script:

#!/bin/bash
# Description: Verify that the `return_class_only` parameter is supported by the `predict` function.

# Test: Search for the `predict` function definition and check if it supports the `return_class_only` parameter.
rg 'predict' -A 10

Length of output: 171730

tests/testthat/test-S3-generic-extensions.R (2)

220-220: Ensure excluded_class_id parameter is supported.

The predict function now uses the excluded_class_id parameter. Ensure that this parameter is supported by the predict function in all contexts where it is used.


208-208: Ensure excluded_class_id parameter is supported.

The predict function now uses the excluded_class_id parameter. Ensure that this parameter is supported by the predict function in all contexts where it is used.

Verification successful

The excluded_class_id parameter is supported.

The predict function and its related methods support the excluded_class_id parameter as evidenced by the various instances found in the codebase.

  • R/caretStack.R: preds <- predict.caretList(all.models, newdata = new_X, excluded_class_id = excluded_class_id)
  • R/caretList.R: predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_class_id = 1L, ...)
  • R/caretPredict.R: caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...)
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the `excluded_class_id` parameter is supported by the `predict` function.

# Test: Search for the `predict` function definition. Expect: The function should support `excluded_class_id` parameter.
rg --type r --pattern $'predict'

Length of output: 443


Script:

#!/bin/bash
# Description: Verify that the `excluded_class_id` parameter is supported by the `predict` function.

# Test: Search for the `predict` function definition. Expect: The function should support `excluded_class_id` parameter.
rg 'predict' --type r -A 5 | grep -A 5 'predict' | grep 'excluded_class_id'

Length of output: 4053

R/caretPredict.R (3)

118-185: Ensure consistent column names and order.

The function ensures that the column names and order are consistent for both stacked predictions and new predictions. This is important for maintaining the integrity of the predictions.


76-105: Ensure trainControl settings are appropriate.

The function validates the trainControl settings for stacked predictions. Ensure that the trainControl settings are appropriate for all use cases.


57-66: Ensure validateExcludedClass handles all cases.

The function relies on validateExcludedClass to ensure the excluded class is valid. Ensure that validateExcludedClass handles all possible cases correctly.

Verification successful

Ensure validateExcludedClass handles all cases.

The validateExcludedClass function handles all possible cases correctly, as evidenced by the extensive test coverage in tests/testthat/test-helper_functions.R and tests/testthat/test-classSelection.R.

  • R/caretPredict.R (function definition)
  • tests/testthat/test-helper_functions.R (test cases)
  • tests/testthat/test-classSelection.R (test cases)
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that `validateExcludedClass` handles all possible cases correctly.

# Test: Search for the `validateExcludedClass` function definition. Expect: The function should handle all possible cases correctly.
rg --type r --pattern $'validateExcludedClass'

Length of output: 457


Script:

#!/bin/bash
# Description: Verify that `validateExcludedClass` handles all possible cases correctly.

# Test: Search for the `validateExcludedClass` function definition. Expect: The function should handle all possible cases correctly.
rg 'validateExcludedClass' --type r -A 10

Length of output: 10976

tests/testthat/test-ensembleMethods.R (10)

Line range hint 13-20: LGTM!

The test correctly verifies the variable importance functionality for ensemble models.


Line range hint 22-38: LGTM!

The test correctly verifies the varImp functionality for caretEnsemble objects with and without scaling.


Line range hint 40-47: LGTM!

The test correctly verifies the dimensions of the variable importance output for ensemble models.


Line range hint 50-84: LGTM!

The test correctly verifies the usage of the correct metric in ensemble models.


Line range hint 110-143: LGTM!

The test correctly verifies that no errors are thrown by generic functions for ensemble models.


Line range hint 146-163: LGTM!

The test correctly verifies that the model results in caretEnsemble match the component models for classification.


Line range hint 165-180: LGTM!

The test correctly verifies that the model results in caretEnsemble match the component models for regression.


186-208: LGTM!

The test correctly verifies that prediction options are respected in regression.


210-231: LGTM!

The test correctly verifies that prediction options are respected in classification.


Line range hint 233-237: LGTM!

The test correctly verifies that caretList and caretStack work for multiclass problems.

vignettes/caretEnsemble-intro.Rmd (3)

Line range hint 1-74: LGTM!

The section correctly introduces caretList and demonstrates its usage with comprehensive examples.


Line range hint 75-129: LGTM!

The section correctly introduces caretEnsemble and demonstrates its usage with comprehensive examples.


Line range hint 130-182: LGTM!

The section correctly introduces caretStack and demonstrates its usage with comprehensive examples.

tests/testthat/test-multiclass.R (10)

Line range hint 3-30: LGTM!

The test correctly verifies that caretList and caretStack can predict multiclass problems.


Line range hint 32-69: LGTM!

The test correctly verifies that the columns for caretList predictions are correct and ordered.


Line range hint 71-105: LGTM!

The test correctly verifies that the columns for caretStack predictions are correct.


Line range hint 107-155: LGTM!

The test correctly verifies that periods are supported in method and class names in caretList and caretStack.


Line range hint 157-193: LGTM!

The test correctly verifies that a confusion matrix can be made from ensemble models.


Line range hint 195-212: LGTM!

The test correctly verifies that multiclass is not supported for caretEnsemble.


Line range hint 214-252: LGTM!

The test correctly verifies that caretList and caretStack can handle imbalanced multiclass data.


Line range hint 254-282: LGTM!

The test correctly verifies that caretList and caretStack can handle a large number of classes.


Line range hint 284-314: LGTM!

The test correctly verifies that caretList and caretStack can handle ordinal multiclass data.


Line range hint 316-342: LGTM!

The test correctly verifies that caretList and caretStack produce consistent probability predictions.

tests/testthat/test-caretEnsemble.R (13)

84-85: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


157-159: Ensure proper handling of savePredictions.

The trainControl function now includes the savePredictions parameter. Ensure that this parameter is correctly handled in the caretEnsemble function.


168-170: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


186-193: Ensure proper handling of classProbs.

The trainControl function now includes the classProbs parameter. Ensure that this parameter is correctly handled in the caretEnsemble function.


203-205: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


249-253: Ensure proper handling of classProbs and savePredictions.

The trainControl function now includes the classProbs and savePredictions parameters. Ensure that these parameters are correctly handled in the caretEnsemble function.


260-262: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


Line range hint 314-317: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


Line range hint 368-370: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


74-75: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.

Verification successful

The predict function returns a data.table as expected.

The search results confirm that the predict function within the caretEnsemble package is designed to return a data.table. This is evident from the various test cases and documentation snippets that explicitly mention the return type.

  • File: tests/testthat/test-caretEnsemble.R

    • Multiple instances confirm that the predict function returns a data.table.
  • File: tests/testthat/test-caretStack.R

    • Several test cases validate that predict returns a data.table.
  • File: tests/testthat/test-caretList.R

    • The predict function is confirmed to return a data.table.

Thus, the expected output type for predictions has been correctly updated to data.table, and the predict function adheres to this requirement.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the return type of the `predict` function in the `caretEnsemble` function.

# Test: Search for the return type of the `predict` function. Expect: `data.table`.
rg --type r 'predict' R/caretEnsemble.R -A 5

Length of output: 2301


Script:

#!/bin/bash
# Searching for the definition of the predict function and its return type
ast-grep --lang r --pattern 'predict($_, $$$) {
  $$$
}'

Length of output: 161


Script:

#!/bin/bash
# Searching for the definition of the predict function and its return type
rg 'predict' -A 10

Length of output: 171730


Line range hint 346-356: Ensure proper handling of savePredictions and classProbs.

The trainControl function now includes the savePredictions and classProbs parameters. Ensure that these parameters are correctly handled in the caretStack function.


65-65: Ensure proper handling of savePredictions.

The trainControl function now includes the savePredictions parameter. Ensure that this parameter is correctly handled in the caretEnsemble function.


76-80: Ensure proper handling of classProbs.

The trainControl function now includes the classProbs parameter. Ensure that this parameter is correctly handled in the caretEnsemble function.

tests/testthat/test-caretStack.R (8)

29-30: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


43-48: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


77-79: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


88-89: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


346-356: Ensure proper handling of savePredictions and classProbs.

The trainControl function now includes the savePredictions and classProbs parameters. Ensure that these parameters are correctly handled in the caretStack function.


368-370: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


222-222: Check the expected output type.

The expected output type for predictions has been updated to data.table. Ensure that the predict function returns the correct type.


19-23: Ensure proper handling of savePredictions.

The trainControl function now includes the savePredictions parameter. Ensure that this parameter is correctly handled in the caretStack function.

Verification successful

Ensure proper handling of savePredictions.

The trainControl function now includes the savePredictions parameter. This parameter is indirectly handled by the train function within the caretStack function. Ensure that the train function appropriately manages this parameter.

  • caretStack function in R/caretStack.R
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the handling of `savePredictions` parameter in the `caretStack` function.

# Test: Search for the usage of `savePredictions` in `caretStack` function. Expect: Proper handling of the parameter.
rg --type r 'savePredictions' R/caretStack.R

Length of output: 44


Script:

#!/bin/bash
# Description: Trace the usage of `trainControl` and its parameters within the `caretStack` implementation.

# Step 1: Search for the `trainControl` function usage in the R/ directory.
rg 'trainControl' R/

# Step 2: Search for the `caretStack` function implementation to understand how it handles the control parameters.
rg 'caretStack' R/

Length of output: 4126


Script:

#!/bin/bash
# Description: Extract the `caretStack` function implementation from `R/caretStack.R` to verify the handling of `savePredictions` parameter.

# Extract the `caretStack` function implementation.
ast-grep --lang r --pattern $'caretStack <- function($_, $$$) {
  $$$
}' R/caretStack.R

Length of output: 191


Script:

#!/bin/bash
# Description: Extract the `caretStack` function implementation from `R/caretStack.R` to verify the handling of `savePredictions` parameter.

# Extract the `caretStack` function implementation.
rg -A 50 'caretStack <- function' R/caretStack.R

Length of output: 6004

R/caretStack.R (5)

47-50: Ensure proper validation of new_X and new_y.

The function includes validation checks to ensure that both new_X and new_y are either both provided or both NULL. Ensure that these validation checks are correctly implemented.


52-57: Validate the dimensions of new_X and new_y.

The function validates the dimensions of new_X and new_y to ensure they align correctly for predictions. Ensure that these validation checks are correctly implemented.


101-111: Review the wtd.sd function.

The function calculates a weighted standard deviation. Ensure


41-45: Ensure proper handling of new_X and new_y.

The function now includes the new_X and new_y parameters. Ensure that these parameters are correctly handled and validated.

Verification successful

The handling of new_X and new_y parameters in the caretStack function is correct.

The function includes appropriate checks to ensure that both new_X and new_y are either provided together or not at all. It also validates the data types and dimensions of these parameters to ensure they are suitable for the prediction logic.

  • The function checks that both new_X and new_y are either NULL or non-NULL together.
  • It validates that new_X is either a data frame or matrix, and new_y is numeric, factor, or character.
  • It ensures that the number of rows in new_X matches the length of new_y.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the handling of `new_X` and `new_y` parameters in the `caretStack` function.

# Test: Search for the usage of `new_X` and `new_y` in `caretStack` function. Expect: Proper handling and validation of the parameters.
rg --type r 'new_X|new_y' R/caretStack.R

Length of output: 1042


62-71: Update the prediction logic.

The function updates the prediction logic to handle the new parameters. Ensure that the prediction logic is correctly implemented and that predictions are made as expected.

Verification successful

Verification successful.

The prediction logic in the caretStack function correctly handles the new parameters new_X and excluded_class_id and ensures that predictions are made as expected.

  • The predict.caretList function handles the newdata and excluded_class_id parameters properly.
  • The caretPredict function further processes these parameters to ensure accurate predictions.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the updated prediction logic in the `caretStack` function.

# Test: Search for the prediction logic in the `caretStack` function. Expect: Proper handling of the new parameters and correct predictions.
rg --type r 'predict' R/caretStack.R -A 10

Length of output: 6572


Script:

#!/bin/bash
# Search for the implementation of the predict.caretList function to verify the handling of new parameters
ast-grep --lang r --pattern 'predict.caretList($_$)' 

Length of output: 157


Script:

#!/bin/bash
# Search for the implementation of the predict.caretList function to verify the handling of new parameters
rg 'predict.caretList' --type r -A 10

Length of output: 4504


Script:

#!/bin/bash
# Search for the implementation of the caretPredict function to verify the handling of new parameters
rg 'caretPredict' --type r -A 10

Length of output: 3496

R/caretList.R (3)

44-58: LGTM! The function is well-structured.

The checkCustomModel function effectively validates the presence of the "method" attribute in custom caret model info lists.


295-295: LGTM! The method signature update is appropriate.

The as.caretList function now correctly accepts an object argument, ensuring proper method dispatch.


364-402: LGTM! The function logic is improved.

The updated predict.caretList function simplifies the prediction process and ensures consistency in the format and row counts of predictions.

However, ensure that all function calls to predict.caretList match the new signature.

R/caretEnsemble.R (5)

1-16: LGTM! The function is well-structured.

The check_binary_classification function effectively validates that a list of caret models is suitable for binary classification.


150-155: LGTM! The normalization by class is a valuable enhancement.

The updated varImpDataTable function provides more detailed insights into the importance of variables for each class.


177-185: LGTM! The aggregation by class is a valuable enhancement.

The updated varImp.caretEnsemble function ensures correct aggregation of importance values by class, enhancing the granularity of the importance metrics.


251-269: LGTM! The additional checks and processing enhance robustness.

The updated extractPredObsResid function includes additional checks and more explicit processing for classification tasks, enhancing the robustness of the function.


Line range hint 341-384:
LGTM! The adjustments ensure correct data merging and warning suppression.

The updated autoplot.caretEnsemble function ensures that data is merged correctly based on rowIndex and that warnings are appropriately suppressed during plotting.

tests/testthat/test-helper_functions.R (6)

Line range hint 32-52:
LGTM! The test case for extractModelType is comprehensive.

The test case effectively validates model types across different scenarios, ensuring that the function correctly identifies model types.


55-59: LGTM! The test case for the stacked predictions matrix is robust.

The test case ensures that the prediction results are of class data.table and contain the expected number of columns corresponding to the models being tested.


73-89: LGTM! The test case for predict is comprehensive.

The test case effectively validates the behavior of the predict function across different scenarios, ensuring consistent results regardless of the verbose option.


Line range hint 135-190:
LGTM! The test case for validateExcludedClass is thorough.

The test case ensures that the validateExcludedClass function properly handles various input scenarios, including warnings and errors for incorrect types and values.


172-183: LGTM! The test case for extractModelType edge cases is comprehensive.

The test case ensures that the extractModelType function correctly validates model types, raising appropriate errors for invalid models.


342-364: LGTM! The test case for stacked predictions is robust.

The test case ensures that the function correctly handles different resampling strategies and row indexes, producing consistent results.

tests/testthat/test-caretList.R (3)

Line range hint 141-158:
Verify the new argument excluded_class_id.

The predict function calls now include an additional argument excluded_class_id = 0L. Ensure that this argument is correctly handled within the predict function and that it does not introduce any unintended side effects.

Verification successful

The excluded_class_id argument is correctly handled within the predict function.

The predict function and its associated helper functions correctly process the excluded_class_id argument, ensuring that the specified class is excluded from predictions as intended.

  • R/caretStack.R: The predict.caretStack function handles the excluded_class_id argument.
  • R/caretPredict.R: The caretPredict function processes the excluded_class_id argument to exclude the specified class from predictions.
Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the handling of the new argument `excluded_class_id` in the `predict` function.

# Test: Search for the function definition and usage. Expect: Proper handling of `excluded_class_id`.
rg --type r -A 5 $'predict'

Length of output: 73154


Line range hint 141-158:
Ensure comprehensive test coverage for the new functionality.

The tests have been updated to verify the dimensions and names of the prediction outputs. Ensure that the tests cover all edge cases and scenarios for the new functionality introduced by the excluded_class_id argument.


694-703: LGTM! Verify the handling of NULL training data.

The test case for handling missing training data has been updated to verify that the predict.caretList function can handle models with NULL training data without throwing an error. This change improves robustness.

However, ensure that the predict.caretList function handles NULL training data correctly.

Verification successful

LGTM! Verify the handling of NULL training data.

The test case for handling missing training data has been updated to verify that the predict.caretList function can handle models with NULL training data without throwing an error. This change improves robustness.

The predict.caretList function and the underlying caretPredict function do not directly reference object$trainingData, ensuring that they handle NULL training data correctly.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the handling of `NULL` training data in the `predict.caretList` function.

# Test: Search for the function definition and usage. Expect: Proper handling of `NULL` training data.
rg --type r -A 5 $'predict.caretList'

Length of output: 2943


Script:

#!/bin/bash
# Extract the full implementation of the predict.caretList function from R/caretList.R
ast-grep --lang r --pattern $'predict.caretList <- function($_, $_, $_, $_, $$$) {
  $$$
}' 

Length of output: 195


Script:

#!/bin/bash
# Extract the full implementation of the predict.caretList function from R/caretList.R using rg
rg --type r -A 50 'predict.caretList <- function' R/caretList.R

Length of output: 1633


Script:

#!/bin/bash
# Extract the implementation of the caretPredict function to verify handling of NULL training data
rg --type r -A 50 'caretPredict' R/

Length of output: 6882

Comment on lines +7 to +49
validateExcludedClass <- function(arg) {
# Handle the null case (usually old object where the missing level was not defined)
if (is.null(arg)) {
arg <- 1L
warning("No excluded_class_id set. Setting to 1L.")
}
# Check the input
if (!is.numeric(arg)) {
stop(paste0(
"classification excluded level must be numeric: ", arg
))
}
if (length(arg) != 1L) {
stop(paste0(
"classification excluded level must have a length of 1: length=", length(arg)
))
}

# Convert to integer if possible
if (is.integer(arg)) {
out <- arg
} else {
warning(paste0("classification excluded level is not an integer: ", arg))
if (is.numeric(arg)) {
out <- floor(arg)
}
suppressWarnings(out <- as.integer(out))
}

# Check the output
if (!is.finite(out)) {
stop(paste0(
"classification excluded level must be finite: ", arg
))
}
if (out < 0L) {
stop(paste0(
"classification excluded level must be >= 0: ", arg
))
}

out
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle non-integer numeric values correctly.

The function converts non-integer numeric values to integers using floor. This may lead to unexpected behavior if the user provides a non-integer value. Consider explicitly checking for integer values and providing a clear error message if the value is not an integer.

-    if (is.numeric(arg)) {
-      out <- floor(arg)
-    }
+    if (!is.integer(arg)) {
+      stop("classification excluded level must be an integer.")
+    }

Committable suggestion was skipped due to low confidence.

@zachmayer zachmayer merged commit 6ff11e0 into main Jul 29, 2024
9 checks passed
@zachmayer zachmayer deleted the pred branch July 29, 2024 16:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant