Skip to content

Commit

Permalink
[R-package] prefer params to keyword argument in lgb.train() (#5007)
Browse files Browse the repository at this point in the history
* [R-package] prefer params to keyword argument in lgb.train()

* make test stricter

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS committed Feb 18, 2022
1 parent cb8c61e commit 9e73cee
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 14 deletions.
4 changes: 2 additions & 2 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ lgb.cv <- function(params = list()
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
Expand All @@ -137,7 +137,7 @@ lgb.cv <- function(params = list()
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
4 changes: 2 additions & 2 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ lgb.train <- function(params = list(),
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
, alternative_kwarg_value = obj
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
Expand All @@ -105,7 +105,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds <- params[["early_stopping_round"]]

# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.obj(params = params)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
Expand Down
13 changes: 3 additions & 10 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ lgb.check_interaction_constraints <- function(interaction_constraints, column_na

}

lgb.check.obj <- function(params, obj) {
lgb.check.obj <- function(params) {

# List known objectives in a vector
OBJECTIVES <- c(
Expand Down Expand Up @@ -158,25 +158,18 @@ lgb.check.obj <- function(params, obj) {
, "xendcg_mart"
)

# Check whether the objective is empty or not, and take it from params if needed
if (!is.null(obj)) {
params$objective <- obj
if (is.null(params$objective)) {
stop("lgb.check.obj: objective should be a character or a function")
}

# Check whether the objective is a character
if (is.character(params$objective)) {

# If the objective is a character, check if it is a known objective
if (!(params$objective %in% OBJECTIVES)) {

stop("lgb.check.obj: objective name error should be one of (", paste0(OBJECTIVES, collapse = ", "), ")")

}

} else if (!is.function(params$objective)) {

stop("lgb.check.obj: objective should be a character or a function")

}

return(params)
Expand Down
53 changes: 53 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,34 @@ test_that("lgb.cv() respects parameter aliases for objective", {
expect_length(cv_bst$boosters, nfold)
})

test_that("lgb.cv() prefers objective in params to keyword argument", {
data("EuStockMarkets")
cv_bst <- lgb.cv(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
application = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
for (bst_list in cv_bst$boosters) {
bst <- bst_list[["booster"]]
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
}
})

test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L
nfold <- 4L
Expand Down Expand Up @@ -657,6 +685,31 @@ test_that("lgb.train() respects parameter aliases for objective", {
expect_equal(bst$params[["objective"]], "binary")
})

test_that("lgb.train() prefers objective in params to keyword argument", {
data("EuStockMarkets")
bst <- lgb.train(
data = lgb.Dataset(
data = EuStockMarkets[, c("SMI", "CAC", "FTSE")]
, label = EuStockMarkets[, "DAX"]
)
, params = list(
loss = "regression_l1"
, verbosity = VERBOSITY
)
, nrounds = 5L
, obj = "regression_l2"
)
expect_equal(bst$params$objective, "regression_l1")
# NOTE: using save_model_to_string() since that is the simplest public API in the R package
# allowing access to the "objective" attribute of the Booster object on the C++ side
model_txt_lines <- strsplit(
x = bst$save_model_to_string()
, split = "\n"
)[[1L]]
expect_true(any(model_txt_lines == "objective=regression_l1"))
expect_false(any(model_txt_lines == "objective=regression_l2"))
})

test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L
dtrain <- lgb.Dataset(
Expand Down

0 comments on commit 9e73cee

Please sign in to comment.