Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to train with dart and early_stopping_rounds #1893

Closed
RickoClausen opened this issue Dec 6, 2018 · 16 comments
Closed

Not able to train with dart and early_stopping_rounds #1893

RickoClausen opened this issue Dec 6, 2018 · 16 comments

Comments

@RickoClausen
Copy link

RickoClausen commented Dec 6, 2018

Environment info

Operating System: MacOS Mojave and Ubuntu

CPU/GPU model: CPU

C++/Python/R version: Python

Error message

When using the dart booting type the model is not trained when applying early_stopping_rounds.
The rmse after training is not the same as it was at the stopping point in the training.

When I use gbdt the model trains fine, and I am able to reproduce the rmse from the in-training.

Reproducible examples

import lightgbm
import numpy as np

np.random.seed(1234)

params = {
    "early_stopping_rounds": 100,
    "metric": "root_mean_squared_error",
    "objective": "regression",
    "num_boost_round": 1000,
    "boosting_type": "dart",
}

size = (245688, 470)
x = np.random.exponential(scale=10, size=size)
y = 2 * x[:, 0] + np.random.exponential(scale=2, size=(size[0],))

x_val = np.random.exponential(scale=10, size=(int(size[0] / 13), size[1]))
y_val = 2 * x_val[:, 0] + np.random.exponential(scale=2, size=(int(size[0] / 13),))

model = lightgbm.LGBMModel(**params)

model.fit(x, y, eval_set=[(x, y), (x_val, y_val)], verbose=50)
train_pred = model.predict(x)
rmse = np.sqrt(np.mean((y - train_pred) ** 2))
print(f"Train rmse: {rmse}")

Output:

UserWarning: Starting from version 2.2.1, the library file in distribution wheels for macOS is built by the Apple Clang (Xcode_8.3.1) compiler.
This means that in case of installing LightGBM from PyPI via the ``pip install lightgbm`` command, you don't need to install the gcc compiler anymore.
Instead of that, you need to install the OpenMP library, which is required for running LightGBM on the system with the Apple Clang compiler.
You can install the OpenMP library by the following command: ``brew install libomp``.
  "You can install the OpenMP library by the following command: ``brew install libomp``.", UserWarning)
UserWarning: Found `num_boost_round` in params. Will use it instead ofargument
  warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
UserWarning: Found `early_stopping_rounds` in params. Will use it instead of argument
  warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
Training until validation scores don't improve for 100 rounds.
[50]    valid_0's rmse: 3.86284 valid_1's rmse: 4.04269
[100]   valid_0's rmse: 3.64912 valid_1's rmse: 3.88063
Early stopping, best iteration is:
[34]    valid_0's rmse: 2.60659 valid_1's rmse: 2.88739
Train rmse: 16.60181744687661

Thanks for an amazing product! 👍

@guolinke
Copy link
Collaborator

guolinke commented Dec 6, 2018

I think early stopping and dart cannot be used together.
The reason is when using dart, the previous trees will be updated.
For example, in your case, although iteration 34 is best, these trees are changed in the later iterations, as dart will update the previous trees.

To support this, a simple solution is to clone the model of best_iteration at that time, to avoid the updates on it.
And you can write a callback function to achieve this.

@StrikerRUS
Is that easy for us to support this in python-package side? at least a warning is needed.

@bbennett36
Copy link

The issue here is that he's trying to use the sklearn version of LightGBM that doesn't support early stopping (from my understanding).

I have used early stopping and dart with no issues for the past couple months on multiple models. However, I do have to set the early stopping rounds higher than normal because there is cases where the validation score will rise, then drop then start rising again. I have to use a higher learning rate as well so it doesn't take forever to run.

@RickoClausen
Copy link
Author

I get the same "problem" then using the non-sklearn syntax.

@bbennett36
Copy link

@RickoClausen do you have "boost_from_average" = False?

@StrikerRUS
Copy link
Collaborator

@guolinke Should the same (#1895) be done for R-package and then this issue can be closed?

@guolinke
Copy link
Collaborator

guolinke commented Feb 3, 2019

yeah, this fix should be in R package too.

@StrikerRUS
Copy link
Collaborator

ping @Laurae2 and @jameslamb for R-fix

@Laurae2
Copy link
Contributor

Laurae2 commented Mar 20, 2019

For R we can add a simple check here: https://github.com/Microsoft/LightGBM/blob/master/R-package/R/lgb.train.R#L205-L209

Will try to do it by the end of this week.

@StrikerRUS
Copy link
Collaborator

@Laurae2 It seems that this check will not solve the issue when users create early stopping callback by themselves.

@StrikerRUS
Copy link
Collaborator

@Laurae2

@guolinke
Copy link
Collaborator

guolinke commented Aug 1, 2019

ping @Laurae2 @jameslamb for R's fix

@jameslamb
Copy link
Collaborator

Thank you for the ping, will pick it up soon.

@StrikerRUS
Copy link
Collaborator

Any updates?

@jameslamb
Copy link
Collaborator

Any updates?

Sorry for the delay. Attempted a fix in #2443. There are parts of this section of the code that I'm not very familiar with, so let's see if @Laurae2 agrees with my proposal in the PR review.

@StrikerRUS
Copy link
Collaborator

R-package should be fixed in #2443.

@jameslamb
Copy link
Collaborator

Thanks @StrikerRUS . Sorry, I should have come and closed this

@lock lock bot locked as resolved and limited conversation to collaborators Mar 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants