Skip to content

Commit

Permalink
[R-package] fixed sorting issues in lgb.cv when using a model with ob…
Browse files Browse the repository at this point in the history
…servation weights (fixes microsoft#2572)
  • Loading branch information
jameslamb committed Dec 2, 2019
1 parent 51ceef8 commit 251cb91
Showing 1 changed file with 50 additions and 29 deletions.
79 changes: 50 additions & 29 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ CVBooster <- R6::R6Class(
#' , learning_rate = 1.0
#' , early_stopping_rounds = 5L
#' )
#' @importFrom data.table data.table setorderv
#' @export
lgb.cv <- function(params = list()
, data
Expand All @@ -95,8 +96,7 @@ lgb.cv <- function(params = list()
) {

# Setup temporary variables
addiction_params <- list(...)
params <- append(params, addiction_params)
params <- append(params, list(...))
params$verbose <- verbose
params <- lgb.check.obj(params, obj)
params <- lgb.check.eval(params, eval)
Expand Down Expand Up @@ -267,35 +267,56 @@ lgb.cv <- function(params = list()
# Categorize callbacks
cb <- categorize.callbacks(callbacks)

# Construct booster using a list apply, check if requires group or not
if (!is.list(folds[[1L]])) {
bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- slice(data, folds[[k]])
dtrain <- slice(data, seq_len(nrow(data))[-folds[[k]]])
setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]])
setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]])
setinfo(dtest, "weight", getinfo(data, "weight")[folds[[k]]])
setinfo(dtest, "init_score", getinfo(data, "init_score")[folds[[k]]])
booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid")
list(booster = booster)
})
} else {
bst_folds <- lapply(seq_along(folds), function(k) {
dtest <- slice(data, folds[[k]]$fold)
dtrain <- slice(data, (seq_len(nrow(data)))[-folds[[k]]$fold])
setinfo(dtrain, "weight", getinfo(data, "weight")[-folds[[k]]$fold])
setinfo(dtrain, "init_score", getinfo(data, "init_score")[-folds[[k]]$fold])
setinfo(dtrain, "group", getinfo(data, "group")[-folds[[k]]$group])
setinfo(dtest, "weight", getinfo(data, "weight")[folds[[k]]$fold])
setinfo(dtest, "init_score", getinfo(data, "init_score")[folds[[k]]$fold])
setinfo(dtest, "group", getinfo(data, "group")[folds[[k]]$group])
# Construct booster for each fold. The data.table() code below is used to
# guarantee that indices are sorted while keeping init_score and weight together
# with the correct indices. Note that it takes advantage of the fact that
# someDT$some_column returns NULL is 'some_column' does not exist in the data.table
bst_folds <- lapply(
X = seq_along(folds)
, FUN = function(k) {
if (is.list(folds[[k]])){
test_indices <- folds[[k]]$fold
group <- folds[[k]]$group
} else {
test_indices <- folds[[k]]
group <- NULL
}
train_indices <- seq_len(nrow(data))[-test_indices]

# set up test set
indexDT <- data.table::data.table(
indices = train_indices
, weight = getinfo(data, "weight")[test_indices]
, init_score = getinfo(data, "init_score")[test_indices]
)
data.table::setorderv(indexDT, "indices", order = 1L)
dtest <- slice(data, indexDT$indices)
setinfo(dtest, "weight", indexDT$weight)
setinfo(dtest, "init_score", indexDT$init_score)

# set up training set
indexDT <- data.table::data.table(
indices = train_indices
, weight = getinfo(data, "weight")[train_indices]
, init_score = getinfo(data, "init_score")[train_indices]
)
data.table::setorderv(indexDT, "indices", order = 1L)
dtrain <- slice(data, indexDT$indices)
setinfo(dtrain, "weight", indexDT$weight)
setinfo(dtrain, "init_score", indexDT$init_score)

if (!is.null(group)){
setinfo(dtest, "group", group)
setinfo(dtrain, "group", group)
}

booster <- Booster$new(params, dtrain)
booster$add_valid(dtest, "valid")
list(booster = booster)
})
}

return(
list(booster = booster)
)
}
)

# Create new booster
cv_booster <- CVBooster$new(bst_folds)
Expand Down

0 comments on commit 251cb91

Please sign in to comment.