Skip to content

Commit

Permalink
udpate tests for 'mg_homogeneous_uniform_neighbor_sampling'
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Oct 22, 2024
1 parent d34d85c commit 2aa0903
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions cpp/tests/sampling/mg_homogeneous_uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ class Tests_MGHomogeneous_Uniform_Neighbor_Sampling

auto batch_number = cugraph::test::modulo_sequence<int32_t>(
*handle_, random_sources.size(), num_batches, seed_offsets[handle_->get_comms().get_rank()]);

// Get unique batch_number -> label_list
rmm::device_uvector<int32_t> label_list(batch_number.size(), handle_->get_stream());

raft::copy(label_list.data(),
batch_number.data(),
batch_number.size(),
handle_->get_stream());

label_list = cugraph::test::sort<int32_t>(*handle_, std::move(label_list));
label_list = cugraph::test::unique<int32_t>(*handle_, std::move(label_list));

auto num_unique_labels = label_list.size();

//

/*
rmm::device_uvector<int32_t> unique_batches(num_batches, handle_->get_stream());
Expand All @@ -156,7 +171,12 @@ class Tests_MGHomogeneous_Uniform_Neighbor_Sampling
*/

auto comm_ranks = cugraph::test::scalar_fill<int32_t>(
*handle_, random_sources.size(), int32_t{handle_->get_comms().get_rank()});
*handle_, num_unique_labels, int32_t{handle_->get_comms().get_rank()});

// perform allgatherv
comm_ranks =
cugraph::test::device_allgatherv(*handle_, comm_ranks.data(), comm_ranks.size());


//raft::print_device_vector("random_sources", random_sources.data(), random_sources.size(), std::cout);
//raft::print_device_vector("comm_ranks", comm_ranks.data(), comm_ranks.size(), std::cout);
Expand Down Expand Up @@ -325,28 +345,9 @@ std::unique_ptr<raft::handle_t> Tests_MGHomogeneous_Uniform_Neighbor_Sampling<in
using Tests_MGHomogeneous_Uniform_Neighbor_Sampling_File =
Tests_MGHomogeneous_Uniform_Neighbor_Sampling<cugraph::test::File_Usecase>;

//#if 0
using Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat =
Tests_MGHomogeneous_Uniform_Neighbor_Sampling<cugraph::test::Rmat_Usecase>;

//#endif

#if 0
TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float)
{
run_current_test<int32_t, int32_t, float>(
override_File_Usecase_with_cmd_line_arguments(GetParam()));
}

INSTANTIATE_TEST_SUITE_P(
file_test,
Tests_MGHomogeneous_Uniform_Neighbor_Sampling_File,
::testing::Combine(
::testing::Values(Homogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, false, true}),
::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))));
#endif

//#if 0
TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float)
{
run_current_test<int32_t, int32_t, float>(
Expand Down Expand Up @@ -418,6 +419,5 @@ INSTANTIATE_TEST_SUITE_P(
Homogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, false, false},
Homogeneous_Uniform_Neighbor_Sampling_Usecase{{4, 10}, 128, true, true, false}),
::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false))));
//#endif

CUGRAPH_MG_TEST_PROGRAM_MAIN()

0 comments on commit 2aa0903

Please sign in to comment.