Skip to content

Commit

Permalink
regen test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 24, 2023
1 parent 47137d1 commit c81bf63
Showing 1 changed file with 46 additions and 13 deletions.
59 changes: 46 additions & 13 deletions tests/cpp/tree/test_regen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ class RegenTest : public ::testing::Test {
auto constexpr Iter() const { return 4; }

template <typename Page>
size_t TestTreeMethod(std::string tree_method, std::string obj, bool reset = true) const {
size_t TestTreeMethod(Context const* ctx, std::string tree_method, std::string obj,
bool reset = true) const {
auto learner = std::unique_ptr<Learner>{Learner::Create({p_fmat_})};
learner->SetParam("device", ctx->DeviceName());
learner->SetParam("tree_method", tree_method);
learner->SetParam("objective", obj);
learner->Configure();
Expand All @@ -87,40 +89,71 @@ class RegenTest : public ::testing::Test {
} // anonymous namespace

TEST_F(RegenTest, Approx) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic");
ASSERT_EQ(n, this->Iter());
}

TEST_F(RegenTest, Hist) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror");
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror");
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:logistic");
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 1);
}

TEST_F(RegenTest, Mixed) {
auto n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", false);
Context ctx;
auto n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);

n = this->TestTreeMethod<GHistIndexMatrix>("approx", "reg:logistic", false);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<GHistIndexMatrix>("hist", "reg:squarederror", true);
n = this->TestTreeMethod<GHistIndexMatrix>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}

#if defined(XGBOOST_USE_CUDA)
TEST_F(RegenTest, GpuApprox) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());

n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() * 2);
}

TEST_F(RegenTest, GpuHist) {
auto n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:squarederror");
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>("gpu_hist", "reg:logistic", false);
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic", false);
ASSERT_EQ(n, 1);

n = this->TestTreeMethod<EllpackPage>("hist", "reg:logistic");
ASSERT_EQ(n, 2);
{
Context ctx;
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:logistic");
ASSERT_EQ(n, 2);
}
}

TEST_F(RegenTest, GpuMixed) {
auto ctx = MakeCUDACtx(0);
auto n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", false);
ASSERT_EQ(n, 1);
n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", true);
ASSERT_EQ(n, this->Iter() + 1);

n = this->TestTreeMethod<EllpackPage>(&ctx, "approx", "reg:logistic", false);
ASSERT_EQ(n, this->Iter());
n = this->TestTreeMethod<EllpackPage>(&ctx, "hist", "reg:squarederror", true);
ASSERT_EQ(n, this->Iter() + 1);
}
#endif // defined(XGBOOST_USE_CUDA)
} // namespace xgboost

0 comments on commit c81bf63

Please sign in to comment.