Skip to content

Commit

Permalink
[R] Fix multiclass demo (#1940)
Browse files Browse the repository at this point in the history
* Fix multiclass custom objective demo

* Use option not to boost from average instead of setting init score explicitly

* Reference #1846 when turning off boost_from_average

* Add trailing whitespace
  • Loading branch information
maximilianeber authored and Laurae2 committed Jan 19, 2019
1 parent e6a32c8 commit ace9c99
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions R-package/demo/multiclass_custom_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ data(iris)
# For instance: 0, 1, 2, 3, 4, 5...
iris$Species <- as.numeric(as.factor(iris$Species)) - 1

# We cut the data set into 80% train and 20% validation
# Create imbalanced training data (20, 30, 40 examples for classes 0, 1, 2)
train <- as.matrix(iris[c(1:20, 51:80, 101:140), ])
# The 10 last samples of each class are for validation

train <- as.matrix(iris[c(1:40, 51:90, 101:140), ])
test <- as.matrix(iris[c(41:50, 91:100, 141:150), ])

dtrain <- lgb.Dataset(data = train[, 1:4], label = train[, 5])
dtest <- lgb.Dataset.create.valid(dtrain, data = test[, 1:4], label = test[, 5])
valids <- list(test = dtest)
valids <- list(train = dtrain, test = dtest)

# Method 1 of training with built-in multiclass objective
# Note: need to turn off boost from average to match custom objective
# (https://github.com/Microsoft/LightGBM/issues/1846)
model_builtin <- lgb.train(list(),
dtrain,
boost_from_average = FALSE,
100,
valids,
min_data = 1,
Expand All @@ -29,7 +32,8 @@ model_builtin <- lgb.train(list(),
metric = "multi_logloss",
num_class = 3)

preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE)
preds_builtin <- predict(model_builtin, test[, 1:4], rawscore = TRUE, reshape = TRUE)
probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin))

# Method 2 of training with custom objective function

Expand Down Expand Up @@ -64,7 +68,6 @@ custom_multiclass_metric = function(preds, dtrain) {
return(list(name = "error",
value = -mean(log(prob[cbind(1:length(labels), labels + 1)])),
higher_better = FALSE))

}

model_custom <- lgb.train(list(),
Expand All @@ -78,8 +81,10 @@ model_custom <- lgb.train(list(),
eval = custom_multiclass_metric,
num_class = 3)

preds_custom <- predict(model_custom, test[, 1:4], rawscore = TRUE)
preds_custom <- predict(model_custom, test[, 1:4], rawscore = TRUE, reshape = TRUE)
probs_custom <- exp(preds_custom) / rowSums(exp(preds_custom))

# compare predictions
identical(preds_builtin, preds_custom)
stopifnot(identical(probs_builtin, probs_custom))
stopifnot(identical(preds_builtin, preds_custom))

0 comments on commit ace9c99

Please sign in to comment.