Skip to content

Commit

Permalink
[R-package] raise an informative error when custom objective produces…
Browse files Browse the repository at this point in the history
… incorrect output (fixes #5323) (#5329)
  • Loading branch information
jmoralez committed Jul 12, 2022
1 parent 273c9b0 commit 44fe591
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
16 changes: 14 additions & 2 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,33 @@ Booster <- R6::R6Class(
private$set_objective_to_none <- TRUE
}
# Perform objective calculation
gpair <- fobj(private$inner_predict(1L), private$train_set)
preds <- private$inner_predict(1L)
gpair <- fobj(preds, private$train_set)

# Check for gradient and hessian as list
if (is.null(gpair$grad) || is.null(gpair$hess)) {
stop("lgb.Booster.update: custom objective should
return a list with attributes (hess, grad)")
}

# Check grad and hess have the right shape
n_grad <- length(gpair$grad)
n_hess <- length(gpair$hess)
n_preds <- length(preds)
if (n_grad != n_preds) {
stop(sprintf("Expected custom objective function to return grad with length %d, got %d.", n_preds, n_grad))
}
if (n_hess != n_preds) {
stop(sprintf("Expected custom objective function to return hess with length %d, got %d.", n_preds, n_hess))
}

# Return custom boosting gradient/hessian
.Call(
LGBM_BoosterUpdateOneIterCustom_R
, private$handle
, gpair$grad
, gpair$hess
, length(gpair$grad)
, n_preds
)

}
Expand Down
16 changes: 16 additions & 0 deletions R-package/tests/testthat/test_custom_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,19 @@ test_that("using a custom objective, custom eval, and no other metrics works", {
expect_true(eval_results[["name"]] == "error")
expect_false(eval_results[["higher_better"]])
})

test_that("using a custom objective that returns wrong shape grad or hess raises an informative error", {
bad_grad <- function(preds, dtrain) {
return(list(grad = numeric(0L), hess = rep(1.0, length(preds))))
}
bad_hess <- function(preds, dtrain) {
return(list(grad = rep(1.0, length(preds)), hess = numeric(0L)))
}
params <- list(num_leaves = 3L, verbose = VERBOSITY)
expect_error({
lgb.train(params = params, data = dtrain, obj = bad_grad)
}, sprintf("Expected custom objective function to return grad with length %d, got 0.", nrow(dtrain)))
expect_error({
lgb.train(params = params, data = dtrain, obj = bad_hess)
}, sprintf("Expected custom objective function to return hess with length %d, got 0.", nrow(dtrain)))
})

0 comments on commit 44fe591

Please sign in to comment.