diff --git a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh b/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh index 0f57c217580..ae9dd64348a 100644 --- a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh +++ b/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh @@ -47,13 +47,11 @@ namespace { template std::tuple, std::optional>> -compute_renumber_map( - raft::handle_t const& handle, - raft::device_span edgelist_srcs, - std::optional> edgelist_hops, - raft::device_span edgelist_dsts, - std::optional> - label_offsets) +compute_renumber_map(raft::handle_t const& handle, + raft::device_span edgelist_srcs, + std::optional> edgelist_hops, + raft::device_span edgelist_dsts, + std::optional> label_offsets) { auto approx_edges_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * @@ -69,8 +67,7 @@ compute_renumber_map( thrust::make_counting_iterator(edgelist_srcs.size()), (*edgelist_label_indices).begin(), [offsets = raft::device_span( - (*label_offsets).data() + 1, - (*label_offsets).size() - 1)] __device__(size_t i) { + (*label_offsets).data() + 1, (*label_offsets).size() - 1)] __device__(size_t i) { return static_cast(thrust::distance( offsets.begin(), thrust::upper_bound(thrust::seq, offsets.begin(), offsets.end(), i))); }); @@ -309,12 +306,12 @@ compute_renumber_map( rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_edge_chunks( - handle, - (*label_offsets).data(), - static_cast((*label_offsets).size() - 1), - dsts.size(), - approx_edges_to_sort_per_iteration); + auto [h_label_offsets, h_edge_offsets] = + detail::compute_offset_aligned_edge_chunks(handle, + (*label_offsets).data(), + static_cast((*label_offsets).size() - 1), + dsts.size(), + approx_edges_to_sort_per_iteration); auto num_chunks = h_label_offsets.size() - 1; for (size_t i = 0; i < num_chunks; ++i) { @@ -517,12 +514,15 @@ renumber_sampled_edgelist( // 2. compute renumber_map - auto [renumber_map, renumber_map_label_indices] = compute_renumber_map( - handle, - raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - edgelist_hops, - raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), - label_offsets ? std::make_optional>(std::get<1>(*label_offsets)) : std::nullopt); + auto [renumber_map, renumber_map_label_indices] = + compute_renumber_map( + handle, + raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), + edgelist_hops, + raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), + label_offsets + ? std::make_optional>(std::get<1>(*label_offsets)) + : std::nullopt); // 3. compute renumber map offsets for each label @@ -640,7 +640,8 @@ renumber_sampled_edgelist( new_vertices.shrink_to_fit(handle.get_stream()); d_tmp_storage.shrink_to_fit(handle.get_stream()); - rmm::device_uvector edgelist_label_indices(edgelist_srcs.size(), handle.get_stream()); + rmm::device_uvector edgelist_label_indices(edgelist_srcs.size(), + handle.get_stream()); thrust::transform( handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), @@ -651,7 +652,7 @@ renumber_sampled_edgelist( std::get<1>(*label_offsets).size() - 1)] __device__(size_t i) { return static_cast(thrust::distance( offsets.begin(), thrust::upper_bound(thrust::seq, offsets.begin(), offsets.end(), i))); - }); + }); auto pair_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_label_indices.begin()); @@ -679,8 +680,7 @@ renumber_sampled_edgelist( return *(new_vertices.begin() + thrust::distance(old_vertices.begin(), it)); }); - pair_first = - thrust::make_zip_iterator(edgelist_dsts.begin(), edgelist_label_indices.begin()); + pair_first = thrust::make_zip_iterator(edgelist_dsts.begin(), edgelist_label_indices.begin()); thrust::transform( handle.get_thrust_policy(), pair_first, diff --git a/cpp/src/sampling/renumber_sampled_edgelist_sg.cu b/cpp/src/sampling/renumber_sampled_edgelist_sg.cu index 522440108da..9ffa3cb67ad 100644 --- a/cpp/src/sampling/renumber_sampled_edgelist_sg.cu +++ b/cpp/src/sampling/renumber_sampled_edgelist_sg.cu @@ -20,11 +20,10 @@ namespace cugraph { -template -std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + std::optional>> renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, @@ -34,11 +33,10 @@ renumber_sampled_edgelist( label_offsets, bool do_expensive_check); -template -std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + std::optional>> renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, diff --git a/cpp/src/structure/detail/structure_utils.cuh b/cpp/src/structure/detail/structure_utils.cuh index e24eb5dc81b..b6c292324fa 100644 --- a/cpp/src/structure/detail/structure_utils.cuh +++ b/cpp/src/structure/detail/structure_utils.cuh @@ -354,8 +354,8 @@ void sort_adjacency_list(raft::handle_t const& handle, if constexpr (std::is_arithmetic_v) { for (size_t i = 0; i < num_chunks; ++i) { size_t tmp_storage_bytes{0}; - auto offset_first = thrust::make_transform_iterator( - offsets.data() + h_vertex_offsets[i], shift_left_t{h_edge_offsets[i]}); + auto offset_first = thrust::make_transform_iterator(offsets.data() + h_vertex_offsets[i], + shift_left_t{h_edge_offsets[i]}); cub::DeviceSegmentedSort::SortPairs(static_cast(nullptr), tmp_storage_bytes, index_first + h_edge_offsets[i], @@ -402,8 +402,8 @@ void sort_adjacency_list(raft::handle_t const& handle, edge_t{0}); for (size_t i = 0; i < num_chunks; ++i) { size_t tmp_storage_bytes{0}; - auto offset_first = thrust::make_transform_iterator( - offsets.data() + h_vertex_offsets[i], shift_left_t{h_edge_offsets[i]}); + auto offset_first = thrust::make_transform_iterator(offsets.data() + h_vertex_offsets[i], + shift_left_t{h_edge_offsets[i]}); cub::DeviceSegmentedSort::SortPairs(static_cast(nullptr), tmp_storage_bytes, index_first + h_edge_offsets[i], diff --git a/cpp/tests/sampling/renumber_sampled_edgelist_test.cu b/cpp/tests/sampling/renumber_sampled_edgelist_test.cu index 49cb674214c..6d944314605 100644 --- a/cpp/tests/sampling/renumber_sampled_edgelist_test.cu +++ b/cpp/tests/sampling/renumber_sampled_edgelist_test.cu @@ -426,13 +426,13 @@ TEST_P(Tests_RenumberSampledEdgelist, CheckInt64) run_current_test(param); } -INSTANTIATE_TEST_SUITE_P(small_test, - Tests_RenumberSampledEdgelist, - ::testing::Values(RenumberSampledEdgelist_Usecase{1024, 4096, 1, 1, true}, - RenumberSampledEdgelist_Usecase{1024, 4096, 3, 1, true}, - RenumberSampledEdgelist_Usecase{ - 1024, 32768, 1, 256, true}, - RenumberSampledEdgelist_Usecase{1024, 32768, 3, 256, true})); +INSTANTIATE_TEST_SUITE_P( + small_test, + Tests_RenumberSampledEdgelist, + ::testing::Values(RenumberSampledEdgelist_Usecase{1024, 4096, 1, 1, true}, + RenumberSampledEdgelist_Usecase{1024, 4096, 3, 1, true}, + RenumberSampledEdgelist_Usecase{1024, 32768, 1, 256, true}, + RenumberSampledEdgelist_Usecase{1024, 32768, 3, 256, true})); INSTANTIATE_TEST_SUITE_P( benchmark_test,