diff --git a/cpp/src/prims/detail/nbr_intersection.cuh b/cpp/src/prims/detail/nbr_intersection.cuh index 6c856cc0d09..613cba83f17 100644 --- a/cpp/src/prims/detail/nbr_intersection.cuh +++ b/cpp/src/prims/detail/nbr_intersection.cuh @@ -17,6 +17,8 @@ #include +#include + #include #include #include @@ -702,6 +704,10 @@ nbr_intersection(raft::handle_t const& handle, typename EdgeValueInputIterator::value_type>>; using edge_property_value_t = typename EdgeValueInputIterator::value_type; + using optional_property_buffer_value_type = + std::conditional_t, + edge_property_value_t, + void>; static_assert(std::is_same_v::value_type, thrust::tuple>); @@ -859,12 +865,9 @@ nbr_intersection(raft::handle_t const& handle, rmm::device_uvector local_degrees_for_rx_majors(size_t{0}, handle.get_stream()); rmm::device_uvector local_nbrs_for_rx_majors(size_t{0}, handle.get_stream()); - std::optional> local_nbrs_properties_for_rx_majors{ - std::nullopt}; - if constexpr (!std::is_same_v) { - local_nbrs_properties_for_rx_majors = std::make_optional( - rmm::device_uvector(size_t{0}, handle.get_stream())); - } + [[maybe_unused]] auto local_nbrs_properties_for_rx_majors = + cugraph::detail::allocate_optional_dataframe_buffer( + 0, handle.get_stream()); std::vector local_nbr_counts{}; { @@ -939,9 +942,14 @@ nbr_intersection(raft::handle_t const& handle, local_nbrs_for_rx_majors.resize( local_nbr_offsets_for_rx_majors.back_element(handle.get_stream()), handle.get_stream()); - if (local_nbrs_properties_for_rx_majors) - (*local_nbrs_properties_for_rx_majors) - .resize(local_nbrs_for_rx_majors.size(), handle.get_stream()); + raft::device_span local_nbrs_properties_span{}; + + if constexpr (!std::is_same_v) { + local_nbrs_properties_for_rx_majors.resize(local_nbrs_for_rx_majors.size(), + handle.get_stream()); + local_nbrs_properties_span = raft::device_span( + local_nbrs_properties_for_rx_majors.data(), local_nbrs_properties_for_rx_majors.size()); + } for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { auto edge_partition = @@ -978,9 +986,7 @@ nbr_intersection(raft::handle_t const& handle, local_nbr_offsets_for_rx_majors.size()), raft::device_span(local_nbrs_for_rx_majors.data(), local_nbrs_for_rx_majors.size()), - raft::device_span( - (*local_nbrs_properties_for_rx_majors).data(), - (*local_nbrs_properties_for_rx_majors).size())}); + local_nbrs_properties_span}); } std::vector h_rx_offsets(rx_major_counts.size() + size_t{1}, size_t{0}); @@ -1032,7 +1038,7 @@ nbr_intersection(raft::handle_t const& handle, std::tie(*major_nbr_properties, std::ignore) = shuffle_values(major_comm, - (*local_nbrs_properties_for_rx_majors).begin(), + local_nbrs_properties_for_rx_majors.begin(), local_nbr_counts, handle.get_stream()); } @@ -1065,20 +1071,16 @@ nbr_intersection(raft::handle_t const& handle, rmm::device_uvector nbr_intersection_offsets(size_t{0}, handle.get_stream()); rmm::device_uvector nbr_intersection_indices(size_t{0}, handle.get_stream()); - std::optional> nbr_intersection_properties0{ - std::nullopt}; - std::optional> nbr_intersection_properties1{ - std::nullopt}; - std::optional> nbr_intersection_idx_buffer{std::nullopt}; - - if constexpr (!std::is_same_v) { - nbr_intersection_properties0 = std::make_optional( - rmm::device_uvector(size_t{0}, handle.get_stream())); - nbr_intersection_properties1 = std::make_optional( - rmm::device_uvector(size_t{0}, handle.get_stream())); - nbr_intersection_idx_buffer = - std::make_optional(rmm::device_uvector(size_t{0}, handle.get_stream())); - } + [[maybe_unused]] auto nbr_intersection_properties0 = + cugraph::detail::allocate_optional_dataframe_buffer( + 0, handle.get_stream()); + + [[maybe_unused]] auto nbr_intersection_properties1 = + cugraph::detail::allocate_optional_dataframe_buffer( + 0, handle.get_stream()); + + [[maybe_unused]] auto nbr_intersection_idx_buffer = + cugraph::detail::allocate_optional_dataframe_buffer(0, handle.get_stream()); if constexpr (GraphViewType::is_multi_gpu) { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); @@ -1612,8 +1614,8 @@ nbr_intersection(raft::handle_t const& handle, } nbr_intersection_indices.resize(num_nbr_intersection_indices, handle.get_stream()); if constexpr (!std::is_same_v) { - (*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream()); - (*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream()); } size_t size_offset{0}; size_t index_offset{0}; @@ -1632,12 +1634,12 @@ nbr_intersection(raft::handle_t const& handle, thrust::copy(handle.get_thrust_policy(), edge_partition_nbr_intersection_property0[i].begin(), edge_partition_nbr_intersection_property0[i].end(), - (*nbr_intersection_properties0).begin() + index_offset); + nbr_intersection_properties0.begin() + index_offset); thrust::copy(handle.get_thrust_policy(), edge_partition_nbr_intersection_property1[i].begin(), edge_partition_nbr_intersection_property1[i].end(), - (*nbr_intersection_properties1).begin() + index_offset); + nbr_intersection_properties1.begin() + index_offset); } index_offset += edge_partition_nbr_intersection_indices[i].size(); @@ -1687,10 +1689,23 @@ nbr_intersection(raft::handle_t const& handle, nbr_intersection_indices.resize(nbr_intersection_offsets.back_element(handle.get_stream()), handle.get_stream()); + raft::device_span nbr_intersection_properties0_span{}; + raft::device_span nbr_intersection_properties1_span{}; + raft::device_span nbr_intersection_idx_buffer_span{}; + if constexpr (!std::is_same_v) { - (*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream()); - (*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream()); - (*nbr_intersection_idx_buffer).resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_idx_buffer.resize(nbr_intersection_indices.size(), handle.get_stream()); + + nbr_intersection_properties0_span = raft::device_span( + nbr_intersection_properties0.data(), nbr_intersection_properties0.size()); + + nbr_intersection_properties1_span = raft::device_span( + nbr_intersection_properties1.data(), nbr_intersection_properties1.size()); + + nbr_intersection_idx_buffer_span = raft::device_span( + nbr_intersection_idx_buffer.data(), nbr_intersection_idx_buffer.size()); } if (intersect_minor_nbr[0] && intersect_minor_nbr[1]) { @@ -1721,12 +1736,9 @@ nbr_intersection(raft::handle_t const& handle, nbr_intersection_offsets.size()), raft::device_span(nbr_intersection_indices.data(), nbr_intersection_indices.size()), - raft::device_span((*nbr_intersection_properties0).data(), - (*nbr_intersection_properties0).size()), - raft::device_span((*nbr_intersection_properties1).data(), - (*nbr_intersection_properties1).size()), - raft::device_span((*nbr_intersection_idx_buffer).data(), - (*nbr_intersection_idx_buffer).size()), + nbr_intersection_properties0_span, + nbr_intersection_properties1_span, + nbr_intersection_idx_buffer_span, invalid_vertex_id::value}); } else { CUGRAPH_FAIL("unimplemented."); @@ -1741,23 +1753,31 @@ nbr_intersection(raft::handle_t const& handle, detail::not_equal_t{invalid_vertex_id::value}), handle.get_stream()); - std::optional> tmp_properties0{std::nullopt}; - std::optional> tmp_properties1{std::nullopt}; + [[maybe_unused]] auto tmp_properties0 = + cugraph::detail::allocate_optional_dataframe_buffer( + tmp_indices.size(), handle.get_stream()); + + [[maybe_unused]] auto tmp_properties1 = + cugraph::detail::allocate_optional_dataframe_buffer( + tmp_indices.size(), handle.get_stream()); + + raft::device_span tmp_properties0_span{}; + raft::device_span tmp_properties1_span{}; if constexpr (!std::is_same_v) { - tmp_properties0 = std::make_optional( - rmm::device_uvector(tmp_indices.size(), handle.get_stream())); - tmp_properties1 = std::make_optional( - rmm::device_uvector(tmp_indices.size(), handle.get_stream())); + tmp_properties0_span = + raft::device_span(tmp_properties0.data(), tmp_properties0.size()); + tmp_properties1_span = + raft::device_span(tmp_properties1.data(), tmp_properties1.size()); } auto zipped_itr_to_indices_and_properties_begin = thrust::make_zip_iterator(thrust::make_tuple(nbr_intersection_indices.begin(), - (*nbr_intersection_properties0).begin(), - (*nbr_intersection_properties1).begin())); + nbr_intersection_properties0_span.begin(), + nbr_intersection_properties1_span.begin())); auto zipped_itr_to_tmps_begin = thrust::make_zip_iterator(thrust::make_tuple( - tmp_indices.begin(), (*tmp_properties0).begin(), (*tmp_properties1).begin())); + tmp_indices.begin(), tmp_properties0_span.begin(), tmp_properties1_span.begin())); size_t num_copied{0}; size_t num_scanned{0}; @@ -1820,8 +1840,8 @@ nbr_intersection(raft::handle_t const& handle, })), handle.get_stream()); - (*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream()); - (*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream()); + nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream()); } #endif @@ -1840,8 +1860,8 @@ nbr_intersection(raft::handle_t const& handle, } else { return std::make_tuple(std::move(nbr_intersection_offsets), std::move(nbr_intersection_indices), - std::move((*nbr_intersection_properties0)), - std::move((*nbr_intersection_properties1))); + std::move(nbr_intersection_properties0), + std::move(nbr_intersection_properties1)); } }