Skip to content

Commit

Permalink
Replace several optionals with conditionals
Browse files Browse the repository at this point in the history
  • Loading branch information
Md Naim committed Jul 26, 2023
1 parent cc2197d commit be7afcd
Showing 1 changed file with 73 additions and 53 deletions.
126 changes: 73 additions & 53 deletions cpp/src/prims/detail/nbr_intersection.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include <prims/kv_store.cuh>

#include <prims/detail/extract_transform_v_frontier_e.cuh>

#include <cugraph/edge_partition_device_view.cuh>
#include <cugraph/edge_partition_edge_property_device_view.cuh>
#include <cugraph/graph.hpp>
Expand Down Expand Up @@ -702,6 +704,10 @@ nbr_intersection(raft::handle_t const& handle,
typename EdgeValueInputIterator::value_type>>;

using edge_property_value_t = typename EdgeValueInputIterator::value_type;
using optional_property_buffer_value_type =
std::conditional_t<!std::is_same_v<edge_property_value_t, thrust::nullopt_t>,
edge_property_value_t,
void>;

static_assert(std::is_same_v<typename thrust::iterator_traits<VertexPairIterator>::value_type,
thrust::tuple<vertex_t, vertex_t>>);
Expand Down Expand Up @@ -859,12 +865,9 @@ nbr_intersection(raft::handle_t const& handle,
rmm::device_uvector<edge_t> local_degrees_for_rx_majors(size_t{0}, handle.get_stream());
rmm::device_uvector<vertex_t> local_nbrs_for_rx_majors(size_t{0}, handle.get_stream());

std::optional<rmm::device_uvector<edge_property_value_t>> local_nbrs_properties_for_rx_majors{
std::nullopt};
if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
local_nbrs_properties_for_rx_majors = std::make_optional(
rmm::device_uvector<edge_property_value_t>(size_t{0}, handle.get_stream()));
}
[[maybe_unused]] auto local_nbrs_properties_for_rx_majors =
cugraph::detail::allocate_optional_dataframe_buffer<optional_property_buffer_value_type>(
0, handle.get_stream());

std::vector<size_t> local_nbr_counts{};
{
Expand Down Expand Up @@ -939,9 +942,14 @@ nbr_intersection(raft::handle_t const& handle,
local_nbrs_for_rx_majors.resize(
local_nbr_offsets_for_rx_majors.back_element(handle.get_stream()), handle.get_stream());

if (local_nbrs_properties_for_rx_majors)
(*local_nbrs_properties_for_rx_majors)
.resize(local_nbrs_for_rx_majors.size(), handle.get_stream());
raft::device_span<edge_property_value_t> local_nbrs_properties_span{};

if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
local_nbrs_properties_for_rx_majors.resize(local_nbrs_for_rx_majors.size(),
handle.get_stream());
local_nbrs_properties_span = raft::device_span<edge_property_value_t>(
local_nbrs_properties_for_rx_majors.data(), local_nbrs_properties_for_rx_majors.size());
}

for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) {
auto edge_partition =
Expand Down Expand Up @@ -978,9 +986,7 @@ nbr_intersection(raft::handle_t const& handle,
local_nbr_offsets_for_rx_majors.size()),
raft::device_span<vertex_t>(local_nbrs_for_rx_majors.data(),
local_nbrs_for_rx_majors.size()),
raft::device_span<edge_property_value_t>(
(*local_nbrs_properties_for_rx_majors).data(),
(*local_nbrs_properties_for_rx_majors).size())});
local_nbrs_properties_span});
}

std::vector<size_t> h_rx_offsets(rx_major_counts.size() + size_t{1}, size_t{0});
Expand Down Expand Up @@ -1032,7 +1038,7 @@ nbr_intersection(raft::handle_t const& handle,

std::tie(*major_nbr_properties, std::ignore) =
shuffle_values(major_comm,
(*local_nbrs_properties_for_rx_majors).begin(),
local_nbrs_properties_for_rx_majors.begin(),
local_nbr_counts,
handle.get_stream());
}
Expand Down Expand Up @@ -1065,20 +1071,16 @@ nbr_intersection(raft::handle_t const& handle,
rmm::device_uvector<size_t> nbr_intersection_offsets(size_t{0}, handle.get_stream());
rmm::device_uvector<vertex_t> nbr_intersection_indices(size_t{0}, handle.get_stream());

std::optional<rmm::device_uvector<edge_property_value_t>> nbr_intersection_properties0{
std::nullopt};
std::optional<rmm::device_uvector<edge_property_value_t>> nbr_intersection_properties1{
std::nullopt};
std::optional<rmm::device_uvector<vertex_t>> nbr_intersection_idx_buffer{std::nullopt};

if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
nbr_intersection_properties0 = std::make_optional(
rmm::device_uvector<edge_property_value_t>(size_t{0}, handle.get_stream()));
nbr_intersection_properties1 = std::make_optional(
rmm::device_uvector<edge_property_value_t>(size_t{0}, handle.get_stream()));
nbr_intersection_idx_buffer =
std::make_optional(rmm::device_uvector<vertex_t>(size_t{0}, handle.get_stream()));
}
[[maybe_unused]] auto nbr_intersection_properties0 =
cugraph::detail::allocate_optional_dataframe_buffer<optional_property_buffer_value_type>(
0, handle.get_stream());

[[maybe_unused]] auto nbr_intersection_properties1 =
cugraph::detail::allocate_optional_dataframe_buffer<optional_property_buffer_value_type>(
0, handle.get_stream());

[[maybe_unused]] auto nbr_intersection_idx_buffer =
cugraph::detail::allocate_optional_dataframe_buffer<vertex_t>(0, handle.get_stream());

if constexpr (GraphViewType::is_multi_gpu) {
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
Expand Down Expand Up @@ -1612,8 +1614,8 @@ nbr_intersection(raft::handle_t const& handle,
}
nbr_intersection_indices.resize(num_nbr_intersection_indices, handle.get_stream());
if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
(*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream());
(*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream());
}
size_t size_offset{0};
size_t index_offset{0};
Expand All @@ -1632,12 +1634,12 @@ nbr_intersection(raft::handle_t const& handle,
thrust::copy(handle.get_thrust_policy(),
edge_partition_nbr_intersection_property0[i].begin(),
edge_partition_nbr_intersection_property0[i].end(),
(*nbr_intersection_properties0).begin() + index_offset);
nbr_intersection_properties0.begin() + index_offset);

thrust::copy(handle.get_thrust_policy(),
edge_partition_nbr_intersection_property1[i].begin(),
edge_partition_nbr_intersection_property1[i].end(),
(*nbr_intersection_properties1).begin() + index_offset);
nbr_intersection_properties1.begin() + index_offset);
}

index_offset += edge_partition_nbr_intersection_indices[i].size();
Expand Down Expand Up @@ -1687,10 +1689,23 @@ nbr_intersection(raft::handle_t const& handle,
nbr_intersection_indices.resize(nbr_intersection_offsets.back_element(handle.get_stream()),
handle.get_stream());

raft::device_span<edge_property_value_t> nbr_intersection_properties0_span{};
raft::device_span<edge_property_value_t> nbr_intersection_properties1_span{};
raft::device_span<vertex_t> nbr_intersection_idx_buffer_span{};

if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
(*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream());
(*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream());
(*nbr_intersection_idx_buffer).resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_idx_buffer.resize(nbr_intersection_indices.size(), handle.get_stream());

nbr_intersection_properties0_span = raft::device_span<edge_property_value_t>(
nbr_intersection_properties0.data(), nbr_intersection_properties0.size());

nbr_intersection_properties1_span = raft::device_span<edge_property_value_t>(
nbr_intersection_properties1.data(), nbr_intersection_properties1.size());

nbr_intersection_idx_buffer_span = raft::device_span<vertex_t>(
nbr_intersection_idx_buffer.data(), nbr_intersection_idx_buffer.size());
}

if (intersect_minor_nbr[0] && intersect_minor_nbr[1]) {
Expand Down Expand Up @@ -1721,12 +1736,9 @@ nbr_intersection(raft::handle_t const& handle,
nbr_intersection_offsets.size()),
raft::device_span<vertex_t>(nbr_intersection_indices.data(),
nbr_intersection_indices.size()),
raft::device_span<edge_property_value_t>((*nbr_intersection_properties0).data(),
(*nbr_intersection_properties0).size()),
raft::device_span<edge_property_value_t>((*nbr_intersection_properties1).data(),
(*nbr_intersection_properties1).size()),
raft::device_span<vertex_t>((*nbr_intersection_idx_buffer).data(),
(*nbr_intersection_idx_buffer).size()),
nbr_intersection_properties0_span,
nbr_intersection_properties1_span,
nbr_intersection_idx_buffer_span,
invalid_vertex_id<vertex_t>::value});
} else {
CUGRAPH_FAIL("unimplemented.");
Expand All @@ -1741,23 +1753,31 @@ nbr_intersection(raft::handle_t const& handle,
detail::not_equal_t<vertex_t>{invalid_vertex_id<vertex_t>::value}),
handle.get_stream());

std::optional<rmm::device_uvector<edge_property_value_t>> tmp_properties0{std::nullopt};
std::optional<rmm::device_uvector<edge_property_value_t>> tmp_properties1{std::nullopt};
[[maybe_unused]] auto tmp_properties0 =
cugraph::detail::allocate_optional_dataframe_buffer<optional_property_buffer_value_type>(
tmp_indices.size(), handle.get_stream());

[[maybe_unused]] auto tmp_properties1 =
cugraph::detail::allocate_optional_dataframe_buffer<optional_property_buffer_value_type>(
tmp_indices.size(), handle.get_stream());

raft::device_span<edge_property_value_t> tmp_properties0_span{};
raft::device_span<edge_property_value_t> tmp_properties1_span{};

if constexpr (!std::is_same_v<edge_property_value_t, thrust::nullopt_t>) {
tmp_properties0 = std::make_optional(
rmm::device_uvector<edge_property_value_t>(tmp_indices.size(), handle.get_stream()));
tmp_properties1 = std::make_optional(
rmm::device_uvector<edge_property_value_t>(tmp_indices.size(), handle.get_stream()));
tmp_properties0_span =
raft::device_span<edge_property_value_t>(tmp_properties0.data(), tmp_properties0.size());
tmp_properties1_span =
raft::device_span<edge_property_value_t>(tmp_properties1.data(), tmp_properties1.size());
}

auto zipped_itr_to_indices_and_properties_begin =
thrust::make_zip_iterator(thrust::make_tuple(nbr_intersection_indices.begin(),
(*nbr_intersection_properties0).begin(),
(*nbr_intersection_properties1).begin()));
nbr_intersection_properties0_span.begin(),
nbr_intersection_properties1_span.begin()));

auto zipped_itr_to_tmps_begin = thrust::make_zip_iterator(thrust::make_tuple(
tmp_indices.begin(), (*tmp_properties0).begin(), (*tmp_properties1).begin()));
tmp_indices.begin(), tmp_properties0_span.begin(), tmp_properties1_span.begin()));

size_t num_copied{0};
size_t num_scanned{0};
Expand Down Expand Up @@ -1820,8 +1840,8 @@ nbr_intersection(raft::handle_t const& handle,
})),
handle.get_stream());
(*nbr_intersection_properties0).resize(nbr_intersection_indices.size(), handle.get_stream());
(*nbr_intersection_properties1).resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties0.resize(nbr_intersection_indices.size(), handle.get_stream());
nbr_intersection_properties1.resize(nbr_intersection_indices.size(), handle.get_stream());
}
#endif

Expand All @@ -1840,8 +1860,8 @@ nbr_intersection(raft::handle_t const& handle,
} else {
return std::make_tuple(std::move(nbr_intersection_offsets),
std::move(nbr_intersection_indices),
std::move((*nbr_intersection_properties0)),
std::move((*nbr_intersection_properties1)));
std::move(nbr_intersection_properties0),
std::move(nbr_intersection_properties1));
}
}

Expand Down

0 comments on commit be7afcd

Please sign in to comment.