Skip to content

Commit

Permalink
[BP] Check cub errors (#10721) (#10903)
Browse files Browse the repository at this point in the history
This backport cherry picks the specific fix in the row partitioner instead of the entire
patch.
  • Loading branch information
trivialfis authored Oct 17, 2024
1 parent 1c61752 commit cc6d03c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/tree/gpu_hist/row_partitioner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
});
size_t temp_bytes = 0;
if (tmp->empty()) {
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows);
dh::safe_cuda(cub::DeviceScan::InclusiveScan(
nullptr, temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp(), total_rows));
tmp->resize(temp_bytes);
}
temp_bytes = tmp->size();
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows);
dh::safe_cuda(cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows));

constexpr int kBlockSize = 256;

Expand Down

0 comments on commit cc6d03c

Please sign in to comment.