From 5fa615bc48f5cbad741aaade68761a4f98c60f75 Mon Sep 17 00:00:00 2001 From: Serkan Korkmaz <77464572+serkor1@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:34:15 +0200 Subject: [PATCH] [R-package] only warn about early stopping and DART boosting being incompatible if early stopping was requested (#6619) --- R-package/R/lgb.cv.R | 4 +- R-package/R/lgb.train.R | 4 +- R-package/tests/testthat/test_parameters.R | 61 ++++++++++++++++++++-- 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index c22d0ea848bb..638d1c628e12 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -295,7 +295,9 @@ lgb.cv <- function(params = list() # Cannot use early stopping with 'dart' boosting if (using_dart) { - warning("Early stopping is not available in 'dart' mode.") + if (using_early_stopping) { + warning("Early stopping is not available in 'dart' mode.") + } using_early_stopping <- FALSE # Remove the cb_early_stop() function if it was passed in to callbacks diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index dafb4d83b66b..4d994cfc6f04 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -258,7 +258,9 @@ lgb.train <- function(params = list(), # Cannot use early stopping with 'dart' boosting if (using_dart) { - warning("Early stopping is not available in 'dart' mode.") + if (using_early_stopping) { + warning("Early stopping is not available in 'dart' mode.") + } using_early_stopping <- FALSE # Remove the cb_early_stop() function if it was passed in to callbacks diff --git a/R-package/tests/testthat/test_parameters.R b/R-package/tests/testthat/test_parameters.R index 367f01af817c..9949ffe646b9 100644 --- a/R-package/tests/testthat/test_parameters.R +++ b/R-package/tests/testthat/test_parameters.R @@ -91,7 +91,7 @@ test_that(".PARAMETER_ALIASES() uses the internal session cache", { expect_false(exists(cache_key, where = .lgb_session_cache_env)) }) -test_that("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", { +test_that("training should warn if you use 'dart' boosting with early stopping", { for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]) { params <- list( num_leaves = 5L @@ -101,14 +101,69 @@ test_that("training should warn if you use 'dart' boosting, specified with 'boos , num_threads = .LGB_MAX_THREADS ) params[[boosting_param]] <- "dart" + + # warning: early stopping requested expect_warning({ result <- lightgbm( data = train$data , label = train$label , params = params - , nrounds = 5L - , verbose = -1L + , nrounds = 2L + , verbose = .LGB_VERBOSITY + , early_stopping_rounds = 1L + ) + }, regexp = "Early stopping is not available in 'dart' mode") + + # no warning: early stopping not requested + expect_silent({ + result <- lightgbm( + data = train$data + , label = train$label + , params = params + , nrounds = 2L + , verbose = .LGB_VERBOSITY + , early_stopping_rounds = NULL + ) + }) + } +}) + +test_that("lgb.cv() should warn if you use 'dart' boosting with early stopping", { + for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]) { + params <- list( + num_leaves = 5L + , objective = "binary" + , metric = "binary_error" + , num_threads = .LGB_MAX_THREADS + ) + params[[boosting_param]] <- "dart" + + # warning: early stopping requested + expect_warning({ + result <- lgb.cv( + data = lgb.Dataset( + data = train$data + , label = train$label + ) + , params = params + , nrounds = 2L + , verbose = .LGB_VERBOSITY + , early_stopping_rounds = 1L ) }, regexp = "Early stopping is not available in 'dart' mode") + + # no warning: early stopping not requested + expect_silent({ + result <- lgb.cv( + data = lgb.Dataset( + data = train$data + , label = train$label + ) + , params = params + , nrounds = 2L + , verbose = .LGB_VERBOSITY + , early_stopping_rounds = NULL + ) + }) } })