Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enforce interaction constraints with monotone_constraints_method = intermediate/advanced #4043

Merged
merged 6 commits into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_per_leaf_);
// update leave outputs if needed
for (auto leaf : leaves_need_update) {
RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]);
RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]);
}
}

Expand Down Expand Up @@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output;
}

void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
Expand All @@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {

OMP_INIT_EX();
// find splits
std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
OMP_LOOP_EX_BEGIN();
Expand All @@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) {
}
const int tid = omp_get_thread_num();
int real_fidx = train_data_->RealFeatureIndex(feature_index);
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true,
ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index],
num_data, &leaf_splits, &bests[tid], parent_output);

OMP_LOOP_EX_END();
Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner {

void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);

void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split);
void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);

/*!
* \brief Some initial works before training
Expand Down
134 changes: 134 additions & 0 deletions tests/python_package_test/test_interaction_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import numpy as np
import pandas as pd
import pytest

import lightgbm as lgb


def simple_pd(X, estimator, feature, values):
"""Calculate simple partial dependency."""
Xc = X.copy()
yps = np.zeros_like(values)
for i, x in enumerate(values):
Xc[feature] = x
yps[i] = np.mean(estimator.predict(Xc.values))
return values, yps


def simple_pds(df, estimator, features):
"""Calculate simple partial dependency for all features."""
pds = {}
for feat in features:
values, yps = simple_pd(
df[features], estimator, feat, np.sort(df[feat].unique())
)
pds[feat] = pd.DataFrame(data={feat: values, "y": yps})
return pds


@pytest.fixture
def make_data():
"""Make toy data."""
np.random.seed(1)
n = 10000
d = 3
X = np.random.normal(size=(n, d))
# round to speed things up
X = np.round(X)
eps = np.random.normal() * 0.1
y = -1 * X[:, 0] + 3 * X[:, 1] + X[:, 0] * X[:, 1] + np.cos(X[:, 2]) + eps
df = pd.DataFrame(data={"y": y, "x0": X[:, 0], "x1": X[:, 1], "x2": X[:, 2]})
features = ["x0", "x1", "x2"]
outcome = "y"

return df, features, outcome


def find_interactions(gbm, feature_sets):
"""Find interactions in tree.

Parameters
---------
gbm: booster
feature_sets: list of list
set of features across which to check for interactions

Returns
-------
tree_features: pandas.DataFrame
boolean flag for every tree of has interaction across feature sets.
"""
df_trees = gbm.trees_to_dataframe()
tree_features = (
df_trees.groupby("tree_index")
.apply(lambda x: set(x["split_feature"]) - set([None]))
.reset_index()
.rename(columns={0: "features"})
)

def has_interaction(tree):
n = 0
for fs in feature_sets:
if len(tree["features"].intersection(fs)) > 0:
n += 1
if n > 1:
return True
else:
return False

tree_features["has_interaction"] = tree_features.apply(has_interaction, axis=1)

return tree_features


@pytest.fixture
def get_boosting_params():
boosting_params = {
"boosting_type": "gbdt",
"objective": "regression",
"num_leaves": 5,
"learning_rate": 0.1,
"num_boost_round": 100,
# disallow all interactions
"interaction_constraints": [[0], [1], [2]],
"monotone_constraints": [1, 1, 0],
"monotone_constraints_method": "basic",
}
return boosting_params


@pytest.mark.parametrize(
"monotone_constraints_method", ["basic", "intermediate", "advanced"]
)
def test_interaction_constraints(
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
make_data, get_boosting_params, monotone_constraints_method
):

df, features, outcome = make_data
data = lgb.Dataset(df[features], df[outcome])

boosting_params = get_boosting_params
boosting_params.update({"monotone_constraints_method": monotone_constraints_method})
gbm = lgb.train(boosting_params, data)

feature_sets = [[0], [1], [2]]
feature_sets = [[f"x{f}" for f in fs] for fs in feature_sets]
tree_features = find_interactions(gbm, feature_sets)

# Should not find any co-occurances in a given tree,
# since above we're disallowing all interactions.
assert not tree_features["has_interaction"].any()

# Check monotonicity
pds = simple_pds(df, gbm, features)
cnt = 0
for feat, df_pd in pds.items():
df_pd = df_pd.sort_values(by=feat, ascending=True)
y_pred_diff = df_pd["y"].diff().values[1:]
if boosting_params["monotone_constraints"][cnt] == 1:
assert (y_pred_diff >= 0).all()
elif boosting_params["monotone_constraints"][cnt] == -1:
assert (y_pred_diff <= 0).all()
else:
pass
cnt += 1