Skip to content

Commit

Permalink
[BP] Fix potential race in feature constraint. (#10719) (#10900)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 17, 2024
1 parent a4c6cde commit 6b4d703
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
16 changes: 11 additions & 5 deletions src/common/bitfield.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ struct BitFieldContainer {
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(NumValues(), rhs.NumValues());
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
if (tid < min_size) {
Data()[tid] |= rhs.Data()[tid];
if (this->Check(tid) || rhs.Check(tid)) {
this->Set(tid);
}
}
return *this;
}
Expand All @@ -126,16 +128,20 @@ struct BitFieldContainer {

#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = min(NumValues(), rhs.NumValues());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
if (tid < min_size) {
Data()[tid] &= rhs.Data()[tid];
if (this->Check(tid) && rhs.Check(tid)) {
this->Set(tid);
} else {
this->Clear(tid);
}
}
return *this;
}
#else
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = std::min(NumValues(), rhs.NumValues());
std::size_t min_size = std::min(NumValues(), rhs.NumValues());
for (size_t i = 0; i < min_size; ++i) {
Data()[i] &= rhs.Data()[i];
}
Expand Down
11 changes: 4 additions & 7 deletions src/tree/constraints.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>

#include <algorithm>
#include <string>
#include <set>

Expand Down Expand Up @@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
}
// enable constraints from feature
node |= feature;
// clear the buffer after use
if (tid < feature.Capacity()) {
feature.Clear(tid);
}

// enable constraints from parent
left |= node;
Expand All @@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split(
<< " Split node: " << node_id << " and its left child: "
<< left_id << " cannot be the same.";
CHECK_NE(node_id, right_id)
<< " Split node: " << node_id << " and its left child: "
<< " Split node: " << node_id << " and its right child: "
<< right_id << " cannot be the same.";
CHECK_LT(right_id, s_node_constraints_.size());
CHECK_NE(s_node_constraints_.size(), 0);
Expand All @@ -330,6 +325,8 @@ void FeatureInteractionConstraintDevice::Split(
feature_buffer_,
feature_id,
node, left, right);
}

// clear the buffer after use
thrust::fill_n(thrust::device, feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
}
} // namespace xgboost
15 changes: 7 additions & 8 deletions tests/cpp/tree/test_constraints.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
/**
* Copyright 2019-2023, XGBoost contributors
* Copyright 2019-2024, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <cinttypes>
#include <string>
#include <bitset>

#include <cstdint>
#include <set>
#include <string>

#include "../../../src/common/device_helpers.cuh"
#include "../../../src/tree/constraints.cuh"
#include "../../../src/tree/param.h"
#include "../../../src/common/device_helpers.cuh"

namespace xgboost {
namespace {
Expand All @@ -36,9 +37,7 @@ std::string GetConstraintsStr() {
}

tree::TrainParam GetParameter() {
std::vector<std::pair<std::string, std::string>> args{
{"interaction_constraints", GetConstraintsStr()}
};
Args args{{"interaction_constraints", GetConstraintsStr()}};
tree::TrainParam param;
param.Init(args);
return param;
Expand Down

0 comments on commit 6b4d703

Please sign in to comment.