diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 2fd7fcba5859..70b9cdee596b 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -677,7 +677,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, best_split_per_leaf_); // update leave outputs if needed for (auto leaf : leaves_need_update) { - RecomputeBestSplitForLeaf(leaf, &best_split_per_leaf_[leaf]); + RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]); } } @@ -768,7 +768,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le return parent_output; } -void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { +void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) { FeatureHistogram* histogram_array_; if (!histogram_pool_.Get(leaf, &histogram_array_)) { Log::Warning( @@ -795,6 +795,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { OMP_INIT_EX(); // find splits +std::vector node_used_features = col_sampler_.GetByNode(tree, leaf); #pragma omp parallel for schedule(static) num_threads(share_state_->num_threads) for (int feature_index = 0; feature_index < num_features_; ++feature_index) { OMP_LOOP_EX_BEGIN(); @@ -804,7 +805,7 @@ void SerialTreeLearner::RecomputeBestSplitForLeaf(int leaf, SplitInfo* split) { } const int tid = omp_get_thread_num(); int real_fidx = train_data_->RealFeatureIndex(feature_index); - ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, true, + ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index], num_data, &leaf_splits, &bests[tid], parent_output); OMP_LOOP_EX_END(); diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index 1a0bda53b5db..6a903542bef9 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -128,7 +128,7 @@ class SerialTreeLearner: public TreeLearner { void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time); - void RecomputeBestSplitForLeaf(int leaf, SplitInfo* split); + void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split); /*! * \brief Some initial works before training diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b22651ee246d..c83b2699ba9d 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1252,7 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True): return trainset -def test_monotone_constraints(): +@pytest.mark.parametrize("test_with_interaction_constraints", [True, False]) +def test_monotone_constraints(test_with_interaction_constraints): def is_increasing(y): return (np.diff(y) >= 0.0).all() @@ -1273,28 +1274,69 @@ def is_correctly_constrained(learner, x3_to_category=True): monotonically_increasing_y = learner.predict(monotonically_increasing_x) monotonically_decreasing_x = np.column_stack((fixed_x, variable_x, fixed_x)) monotonically_decreasing_y = learner.predict(monotonically_decreasing_x) - non_monotone_x = np.column_stack((fixed_x, - fixed_x, - categorize(variable_x) if x3_to_category else variable_x)) + non_monotone_x = np.column_stack( + ( + fixed_x, + fixed_x, + categorize(variable_x) if x3_to_category else variable_x, + ) + ) non_monotone_y = learner.predict(non_monotone_x) - if not (is_increasing(monotonically_increasing_y) - and is_decreasing(monotonically_decreasing_y) - and is_non_monotone(non_monotone_y)): + if not ( + is_increasing(monotonically_increasing_y) + and is_decreasing(monotonically_decreasing_y) + and is_non_monotone(non_monotone_y) + ): return False return True + def are_interactions_enforced(gbm, feature_sets): + def parse_tree_features(gbm): + # trees start at position 1. + tree_str = gbm.model_to_string().split("Tree")[1:] + feature_sets = [] + for tree in tree_str: + # split_features are in 4th line. + features = tree.splitlines()[3].split("=")[1].split(" ") + features = set(f"Column_{f}" for f in features) + feature_sets.append(features) + return np.array(feature_sets) + + def has_interaction(treef): + n = 0 + for fs in feature_sets: + if len(treef.intersection(fs)) > 0: + n += 1 + return n > 1 + + tree_features = parse_tree_features(gbm) + has_interaction_flag = np.array( + [has_interaction(treef) for treef in tree_features] + ) + + return not has_interaction_flag.any() + for test_with_categorical_variable in [True, False]: - trainset = generate_trainset_for_monotone_constraints_tests(test_with_categorical_variable) + trainset = generate_trainset_for_monotone_constraints_tests( + test_with_categorical_variable + ) for monotone_constraints_method in ["basic", "intermediate", "advanced"]: params = { - 'min_data': 20, - 'num_leaves': 20, - 'monotone_constraints': [1, -1, 0], + "min_data": 20, + "num_leaves": 20, + "monotone_constraints": [1, -1, 0], "monotone_constraints_method": monotone_constraints_method, "use_missing": False, } + if test_with_interaction_constraints: + params["interaction_constraints"] = [[0], [1], [2]] constrained_model = lgb.train(params, trainset) - assert is_correctly_constrained(constrained_model, test_with_categorical_variable) + assert is_correctly_constrained( + constrained_model, test_with_categorical_variable + ) + if test_with_interaction_constraints: + feature_sets = [["Column_0"], ["Column_1"], "Column_2"] + assert are_interactions_enforced(constrained_model, feature_sets) def test_monotone_penalty():