Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted Custom Loss Function Different Training Loss #2834

Closed
amin-nejad opened this issue Feb 27, 2020 · 2 comments
Closed

Weighted Custom Loss Function Different Training Loss #2834

amin-nejad opened this issue Feb 27, 2020 · 2 comments

Comments

@amin-nejad
Copy link

amin-nejad commented Feb 27, 2020

I am experimenting with weighted custom loss functions for a binary classification problem but I am having trouble replicating the exact behaviour when using the default loss function in conjunction with sample weights. If there are no sample weights, then the custom loss function (same as in the advanced example notebook) works as expected and the training/validation loss as well as final predictions are exactly the same.

However, if I weight the samples (doesn't matter what weights - any weights), the training loss in the custom function becomes slightly different to that of the default loss function. The difference is very minor but I still do not understand why there is any difference at all. This looks like it might be a bug to me. It's not clear which is right and which is wrong. Intriguingly, the final predictions however are still exactly the same.

I have looked at the C++ source code and I believe I have followed the same logic. When sample weights are provided, the gradient and the hessian are simply multiplied by the sample weights before being returned - that is the only difference between when sample weights are provided and when they are not. This is exactly what I'm doing.

Please let me know if my loss function is wrong or if this is a bug. Thanks

Environment info

Operating System: Ubuntu 18.04.4 LTS

CPU/GPU model: Intel Core i7-8565U CPU

C++/Python/R version: Python 3.7.6

LightGBM version or commit hash: 2.3.0

Error message

No error message.

Reproducible examples

import lightgbm as lgb
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# CREATE DATASET
np.random.seed(42)
X, y = make_classification(n_samples=10000, n_features=10)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
train = lgb.Dataset(X_train, y_train)
valid = lgb.Dataset(X_test, y_test)

# ASSIGN WEIGHTS
weights = np.concatenate((np.full((4500), 1, dtype=float), np.full((4500), 2, dtype=float)), axis=None)
train.set_weight(weight=weights)

# DEFINE PARAMS
params = {
    'objective': 'binary',
    'n_estimators': 10,
    'boost_from_average': False,
    'random_state': 42
}

# DEFINE CUSTOM LOSS FUNCTION
def logloss(preds, data):
    y_true = data.get_label()
    preds = 1. / (1. + np.exp(-preds))
    weight = data.get_weight() if data.get_weight() is not None else 1
    grad = (preds - y_true) * weight
    hess = preds * (1.0 - preds) * weight
    return grad, hess

# DEFINE CUSTOM EVAL LOSS FUNCTION
def logloss_eval(preds, data):
    y_true = data.get_label()
    preds = 1. / (1. + np.exp(-preds))
    loss = -(y_true * np.log(preds)) - ((1 - y_true) * np.log(1 - preds))
    return "binary_logloss", np.mean(loss), False

# TRAIN MODEL_1
model_1 = lgb.train(params=params, train_set=train, valid_sets=[train, valid], valid_names=['train','valid'])
scores_1 = model_1.predict(X_test)

# TRAIN MODEL_2
model_2 = lgb.train(params=params, train_set=train, valid_sets=[train, valid], valid_names=['train','valid'], fobj=logloss, feval=logloss_eval)
scores_2 = model_2.predict(X_test)

# PRINT SCORES
print(scores_1[:20])
scores_2_trans = 1. / (1. + np.exp(-scores_2))
print(scores_2_trans[:20])

Steps to reproduce

  1. Simply run the above code in python. This will output the training and validation losses as well as the first 20 predictions for the model which uses the default loss function and the model which uses the custom loss function. The printed training loss numbers will be different but the final predictions will nevertheless be the same

  2. However if you now comment out the line which sets the weights (train.set_weight(weight=weights)), the training losses for the two models will now be exactly the same.

@guolinke
Copy link
Collaborator

@amin-nejad
Copy link
Author

Thanks very much Guolin, I thought since I am not passing weights to the validation set, the logloss_eval function doesn't matter but I now realise it's used to compute the metrics for the training set as well.

In case anyone else comes across this, the following function for training metrics on training and validation sets works:

def logloss_eval(preds, data):
    y_true = data.get_label()
    weight = data.get_weight() if data.get_weight() is not None else np.ones(len(y_true))
    preds = 1. / (1. + np.exp(-preds))
    sum_loss = sum((-(y_true * np.log(preds)) - ((1 - y_true) * np.log(1 - preds))) * weight)
    
    return 'binary_logloss', sum_loss / sum(weight), False

@lock lock bot locked as resolved and limited conversation to collaborators May 5, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants