Skip to content

Commit

Permalink
Row major.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Dec 17, 2022
1 parent c319262 commit 3502fce
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
8 changes: 3 additions & 5 deletions src/tree/fit_stump.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets);
// column-major
auto n_samples = gpair.Size() / n_targets;
std::size_t shape[2]{n_samples, n_targets};
std::size_t strides[2];
linalg::detail::CalcStride<2, true>(shape, strides);

gpair.SetDevice(ctx->gpu_id);
linalg::TensorView<GradientPair const, 2> gpair_t{
ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(), shape, strides, ctx->gpu_id};
ctx->IsCPU() ? gpair.ConstHostSpan() : gpair.ConstDeviceSpan(),
{n_samples, n_targets},
ctx->gpu_id};
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
}
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/tree/test_fit_stump.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ void TestFitStump(Context const *ctx) {
HostDeviceVector<GradientPair> gpair;
auto &h_gpair = gpair.HostVector();
h_gpair.resize(kRows * kTargets);
for (std::size_t t = 0; t < kTargets; ++t) {
for (std::size_t i = 0; i < kRows; ++i) {
h_gpair.at(t * kRows + i) = GradientPair{static_cast<float>(i), 1};
for (std::size_t i = 0; i < kRows; ++i) {
for (std::size_t t = 0; t < kTargets; ++t) {
h_gpair.at(i * kTargets + t) = GradientPair{static_cast<float>(i), 1};
}
}
linalg::Vector<float> out;
Expand Down

0 comments on commit 3502fce

Please sign in to comment.