Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Seeing unexpected behaviour in refit() #6583

Open
Abir-lal-Roy opened this issue Jul 31, 2024 · 1 comment
Open

[python-package] Seeing unexpected behaviour in refit() #6583

Abir-lal-Roy opened this issue Jul 31, 2024 · 1 comment

Comments

@Abir-lal-Roy
Copy link

Abir-lal-Roy commented Jul 31, 2024

Description

Hi team,
I'm currently working on a classification task that involves incremental learning .To update the model with new data I am using refit() function.Now given an instance X=[x1,x2,x3,...,xn],and tree_index, I am seeing X falling into different leaf node for initial and updated model which is not expected as refit does not change the structure of the main model.Can somebody please help me understand why it is happening ?

Reproducible example

Updating main model
updated_model = initial_model.refit(data = df_train[initial_model.feature_name()],label = df_train["target"],
                                  weight=df_train["sample_weights"],decay_rate=0.95) 

example = df_test[df_test.index==0].values 
l1 = initial_model.predict(example[fixed_model.feature_name()],num_iteration=1,pred_leaf=True) # get the leaf node of 1st tree of main model
l2 = updated_model.predict(example[updated_model.feature_name()],num_iteration=1,pred_leaf=True) # get the leaf node of 1st tree of updated model
l1,l2
(array([177], dtype=int32), array([187], dtype=int32)

I was expecting l1 and l2 to be equal

Environment info

LightGBM version : 4.4.0
python version : 3.8.3

@jameslamb jameslamb changed the title Seeing unexpected behaviour in refit () [python-package] Seeing unexpected behaviour in refit () Aug 2, 2024
@jameslamb jameslamb changed the title [python-package] Seeing unexpected behaviour in refit () [python-package] Seeing unexpected behaviour in refit() Aug 2, 2024
@jameslamb
Copy link
Collaborator

Thanks for using LightGBM.

You're right that given two Booster objects, one of which was called by running .refit() on the other, the same sample should fall into the exact same leaf nodes by default. This is because .refit() preserves the tree structure.

I tried tonight with the latest development version of lightgbm (a9df7f1) and observed exactly that behavior:

import lightgbm as lgb
import pandas as pd
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=10_000, n_features=5, random_state=708)
df_train = pd.DataFrame(X)

bst = lgb.train(
    train_set=lgb.Dataset(
        data=df_train,
        label=y
    ),
    params={
        "objective": "regression",
        "num_iterations": 10,
        "seed": 708,
        "deterministic": True,
        "force_row_wise": True,
        "num_threads": 1
    }
)

bst_new = bst.refit(
    data=df_train,
    label=y,
    decay_rate=0.95
)

Chose a sample and saw it fall into the same leaf nodes in both models' first tree:

example_case = df_train.head(1)

orig_leaf_index = bst.predict(example_case, num_iteration=1, pred_leaf=True)
new_leaf_index = bst_new.predict(example_case, num_iteration=1, pred_leaf=True)
print(f"leaf indices: old={orig_leaf_index}, new={new_leaf_index}")
# leaf indices: old=[10], new=[10]

But with different raw predicted values:

bst.predict(example_case, num_iteration=1, raw_score=True)
# array([-12.81107821]

bst_new.predict(example_case, num_iteration=1, raw_score=True)
# array([-17.67358137])
environment info (click me)
  • M2 Mac, macOS Sonoma 14.4.1
  • Python 3.11.9

conda env export:

name: lgb-dev
channels:
  - conda-forge
dependencies:
  - atk-1.0=2.38.0=hd03087b_2
  - aws-c-auth=0.7.22=hec39e38_2
  - aws-c-cal=0.6.14=h5db4892_1
  - aws-c-common=0.9.19=h99b78c6_0
  - aws-c-compression=0.2.18=h5db4892_6
  - aws-c-event-stream=0.4.2=h5eab607_12
  - aws-c-http=0.8.1=had10953_17
  - aws-c-io=0.14.8=hb5a7b21_5
  - aws-c-mqtt=0.10.4=h78534b8_4
  - aws-c-s3=0.5.9=h1755d02_3
  - aws-c-sdkutils=0.1.16=h5db4892_2
  - aws-checksums=0.1.18=h5db4892_6
  - aws-crt-cpp=0.26.9=h03bff2b_0
  - aws-sdk-cpp=1.11.329=hb37a6d0_3
  - brotli=1.1.0=hb547adb_1
  - brotli-bin=1.1.0=hb547adb_1
  - bzip2=1.0.8=h93a5062_5
  - c-ares=1.28.1=h93a5062_0
  - ca-certificates=2024.7.4=hf0a4a13_0
  - cairo=1.18.0=hd1e100b_0
  - certifi=2024.6.2=pyhd8ed1ab_0
  - cffi=1.16.0=py311h4a08483_0
  - cfgv=3.3.1=pyhd8ed1ab_0
  - click=8.1.7=unix_pyh707e725_0
  - cloudpickle=3.0.0=pyhd8ed1ab_0
  - cmakelint=1.4.3=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - coverage=7.5.0=py311hd23d018_0
  - cpplint=1.6.0=pyhd8ed1ab_0
  - cycler=0.12.1=pyhd8ed1ab_0
  - distlib=0.3.8=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - expat=2.6.2=hebf3989_0
  - filelock=3.14.0=pyhd8ed1ab_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=h77eed37_1
  - fontconfig=2.14.2=h82840c6_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.53.0=py311hd3f4193_0
  - freetype=2.12.1=hadb7bae_2
  - fribidi=1.0.10=h27ca646_0
  - gdk-pixbuf=2.42.11=h13c029f_0
  - gettext=0.22.5=h8fbad5d_2
  - gettext-tools=0.22.5=h8fbad5d_2
  - gflags=2.2.2=hc88da5d_1004
  - giflib=5.2.2=h93a5062_0
  - glog=0.7.1=heb240a5_0
  - gmp=6.3.0=hebf3989_1
  - graphite2=1.3.13=hebf3989_1003
  - graphviz=9.0.0=h3face73_1
  - gtk2=2.24.33=h7895bb2_4
  - gts=0.7.6=he42f4ea_4
  - harfbuzz=8.4.0=hbe0f7c0_0
  - icu=73.2=hc8870d7_0
  - identify=2.5.36=pyhd8ed1ab_0
  - iniconfig=2.0.0=pyhd8ed1ab_0
  - joblib=1.4.2=pyhd8ed1ab_0
  - kiwisolver=1.4.5=py311he4fd1f5_1
  - krb5=1.21.3=h237132a_0
  - lcms2=2.16=ha0e7c42_0
  - lerc=4.0.0=h9a09cb3_0
  - libabseil=20240116.2=cxx17_hebf3989_0
  - libarrow=16.1.0=h28dd788_6_cpu
  - libarrow-acero=16.1.0=h00cdb27_6_cpu
  - libarrow-dataset=16.1.0=h00cdb27_6_cpu
  - libarrow-substrait=16.1.0=hc68f6b8_6_cpu
  - libasprintf=0.22.5=h8fbad5d_2
  - libasprintf-devel=0.22.5=h8fbad5d_2
  - libblas=3.9.0=22_osxarm64_openblas
  - libbrotlicommon=1.1.0=hb547adb_1
  - libbrotlidec=1.1.0=hb547adb_1
  - libbrotlienc=1.1.0=hb547adb_1
  - libcblas=3.9.0=22_osxarm64_openblas
  - libcrc32c=1.1.2=hbdafb3b_0
  - libcurl=8.8.0=h7b6f9a7_1
  - libcxx=16.0.6=h4653b0c_0
  - libdeflate=1.20=h93a5062_0
  - libedit=3.1.20191231=hc8eb9b7_2
  - libev=4.33=h93a5062_2
  - libevent=2.1.12=h2757513_1
  - libexpat=2.6.2=hebf3989_0
  - libffi=3.4.2=h3422bc3_5
  - libgd=2.3.3=hfdf3952_9
  - libgettextpo=0.22.5=h8fbad5d_2
  - libgettextpo-devel=0.22.5=h8fbad5d_2
  - libgfortran=5.0.0=13_2_0_hd922786_3
  - libgfortran5=13.2.0=hf226fd6_3
  - libglib=2.80.0=hfc324ee_6
  - libgoogle-cloud=2.24.0=hfe08963_0
  - libgoogle-cloud-storage=2.24.0=h3fa5b87_0
  - libgrpc=1.62.2=h9c18a4f_0
  - libiconv=1.17=h0d3ecfb_2
  - libintl=0.22.5=h8fbad5d_2
  - libintl-devel=0.22.5=h8fbad5d_2
  - libjpeg-turbo=3.0.0=hb547adb_1
  - liblapack=3.9.0=22_osxarm64_openblas
  - libnghttp2=1.58.0=ha4dd798_1
  - libopenblas=0.3.27=openmp_h6c19121_0
  - libparquet=16.1.0=hcf52c46_6_cpu
  - libpng=1.6.43=h091b4b1_0
  - libprotobuf=4.25.3=hbfab5d5_0
  - libre2-11=2023.09.01=h7b2c953_2
  - librsvg=2.58.0=hb3d354b_1
  - libsqlite=3.45.3=h091b4b1_0
  - libssh2=1.11.0=h7a5bd25_0
  - libthrift=0.19.0=h026a170_1
  - libtiff=4.6.0=h07db509_3
  - libutf8proc=2.8.0=h1a8c8d9_0
  - libwebp=1.3.2=hf30222e_1
  - libwebp-base=1.3.2=h93a5062_1
  - libxcb=1.15=hf346824_0
  - libxml2=2.12.6=h0d0cfa8_2
  - libzlib=1.2.13=h53f4e23_5
  - llvm-openmp=18.1.8=hde57baf_0
  - lz4-c=1.9.4=hb7217d7_0
  - matplotlib=3.8.4=py311ha1ab1f8_2
  - matplotlib-base=3.8.4=py311h000fb6e_2
  - munkres=1.1.4=pyh9f0ad1d_0
  - mypy=1.10.0=py311hd23d018_0
  - mypy_extensions=1.0.0=pyha770c72_0
  - ncurses=6.4.20240210=h078ce10_0
  - nodeenv=1.8.0=pyhd8ed1ab_0
  - numpy=2.0.0=py311h4268184_0
  - openjpeg=2.5.2=h9f1df11_0
  - openssl=3.3.1=hfb2fe0b_1
  - orc=2.0.1=h47ade37_1
  - packaging=24.0=pyhd8ed1ab_0
  - pandas=2.2.2=py311h4b4568b_1
  - pandoc=3.1.13=hce30654_0
  - pango=1.52.2=hb067d4f_0
  - pcre2=10.43=h26f9a81_0
  - pillow=10.3.0=py311h0b5d0a1_0
  - pip=24.0=pyhd8ed1ab_0
  - pixman=0.43.4=hebf3989_0
  - platformdirs=4.2.1=pyhd8ed1ab_0
  - pluggy=1.5.0=pyhd8ed1ab_0
  - pre-commit=3.7.0=pyha770c72_0
  - psutil=5.9.8=py311h05b510d_0
  - pthread-stubs=0.4=h27ca646_1001
  - pyarrow=16.1.0=py311h35c05fe_2
  - pyarrow-core=16.1.0=py311hb5ba6a5_2_cpu
  - pycparser=2.22=pyhd8ed1ab_0
  - pyparsing=3.1.2=pyhd8ed1ab_0
  - pytest=8.2.2=pyhd8ed1ab_0
  - pytest-cov=5.0.0=pyhd8ed1ab_0
  - python=3.11.9=h932a869_0_cpython
  - python-dateutil=2.9.0=pyhd8ed1ab_0
  - python-graphviz=0.20.3=pyh717bed2_0
  - python-tzdata=2024.1=pyhd8ed1ab_0
  - python_abi=3.11=4_cp311
  - pytz=2024.1=pyhd8ed1ab_0
  - pyyaml=6.0.1=py311heffc1b2_1
  - re2=2023.09.01=h4cba328_2
  - readline=8.2=h92ec313_1
  - ruff=0.4.7=py311hd374d79_0
  - scikit-learn=1.5.1=py311hbfb48bc_0
  - scipy=1.14.0=py311hceeca8c_0
  - setuptools=69.5.1=pyhd8ed1ab_0
  - shellcheck=0.10.0=hecfb573_0
  - six=1.16.0=pyh6c4a22f_0
  - snappy=1.2.1=hd02b534_0
  - threadpoolctl=3.5.0=pyhc1e730c_0
  - tk=8.6.13=h5083fa2_1
  - toml=0.10.2=pyhd8ed1ab_0
  - tomli=2.0.1=pyhd8ed1ab_0
  - tornado=6.4.1=py311hd3f4193_0
  - typing_extensions=4.11.0=pyha770c72_0
  - tzdata=2024a=h0c530f3_0
  - ukkonen=1.0.1=py311he4fd1f5_4
  - virtualenv=20.26.0=pyhd8ed1ab_0
  - wheel=0.43.0=pyhd8ed1ab_1
  - xorg-libxau=1.0.11=hb547adb_0
  - xorg-libxdmcp=1.1.3=h27ca646_0
  - xz=5.2.6=h57fd34a_0
  - yaml=0.2.5=h3422bc3_2
  - zlib=1.2.13=h53f4e23_5
  - zstd=1.5.6=hb46c0d2_0
  - pip:
      - altgraph==0.17.4
      - annotated-types==0.7.0
      - attrs==24.2.0
      - backports-tarfile==1.2.0
      - build==1.2.1
      - charset-normalizer==3.3.2
      - check-wheel-contents==0.6.0
      - contourpy==1.3.0.dev1
      - delocate==0.11.0
      - docutils==0.21.2
      - idna==3.8
      - importlib-metadata==8.4.0
      - jaraco-classes==3.4.0
      - jaraco-context==6.0.1
      - jaraco-functools==4.0.2
      - keyring==25.3.0
      - lightgbm==4.5.0.99
      - macholib==1.16.3
      - markdown-it-py==3.0.0
      - mdurl==0.1.2
      - more-itertools==10.4.0
      - nh3==0.2.18
      - pkginfo==1.10.0
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pydistcheck==0.7.1
      - pygments==2.18.0
      - pyproject-hooks==1.0.0
      - readme-renderer==44.0
      - requests==2.32.3
      - requests-toolbelt==1.0.0
      - rfc3986==2.0.0
      - rich==13.8.0
      - twine==5.1.1
      - urllib3==2.2.2
      - wheel-filename==1.4.1
      - zipp==3.20.1
prefix: /Users/jlamb/miniforge3/envs/lgb-dev

So to investigate this further, we're going to need your help.

Can you share a minimal, reproducible example? That'd be code that any of us can copy, paste, and run which reproduces the behavior you're seeing. You could use my example above as a starting point.

Your report leaves out a lot of important details, including the content of the training data. You may want to review "How to create a minimal, reproducible example" in the Stack Overflow docs (link) for some tips on how to report software questions like this.

If you're unable to reproduce this behavior, then could you provide the content of both models in text files? You could called .save_model() to produce those files.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants