Skip to content

Commit

Permalink
adopts device-side test data gen
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 13, 2022
1 parent 694a365 commit f656f49
Showing 1 changed file with 9 additions and 25 deletions.
34 changes: 9 additions & 25 deletions cpp/tests/io/fst/fst_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/cuda_stream.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -91,35 +90,26 @@ static std::pair<OutputItT, IndexOutputItT> fst_baseline(InputItT begin,
auto const& symbol = *it;

std::size_t symbol_group = 0;
bool found = false;

// Iterate over symbol groups and search for the first symbol group containing the current
// symbol
for (auto const& sg : symbol_group_lut) {
for (auto const& s : sg)
if (s == symbol) found = true;
if (found) break;
if (std::find(std::cbegin(sg), std::cend(sg), symbol) != std::cend(sg)) { break; }
symbol_group++;
}

// Output the translated symbols to the output tape
size_t inserted = 0;
for (auto out : translation_table[state][symbol_group]) {
// std::cout << in_offset << ": " << out << "\n";
*out_tape = out;
++out_tape;
inserted++;
}

// Output the index of the current symbol, iff it caused some output to be written
if (inserted > 0) {
*out_index_tape = in_offset;
out_index_tape++;
}

// Transition the state of the finite-state machine
state = transition_table[state][symbol_group];

// Continue with next symbol from input tape
in_offset++;
}
return {out_tape, out_index_tape};
Expand Down Expand Up @@ -195,10 +185,11 @@ TEST_F(FstTest, GroundTruth)

// Prepare cuda stream for data transfers & kernels
rmm::cuda_stream stream{};
rmm::cuda_stream_view stream_view(stream);

// Test input
std::string input = R"( {)"
R"(category": "reference",)"
R"("category": "reference",)"
R"("index:" [4,12,42],)"
R"("author": "Nigel Rees",)"
R"("title": "Sayings of the Century",)"
Expand All @@ -212,25 +203,19 @@ TEST_F(FstTest, GroundTruth)
R"("price": 8.95)"
R"(} {} [] [ ])";

// Repeat input sample 1024x
size_t string_size = 1 << 10;
size_t string_size = input.size() * (1 << 10);
auto d_input_scalar = cudf::make_string_scalar(input);
auto& d_string_scalar = static_cast<cudf::string_scalar&>(*d_input_scalar);
const cudf::size_type repeat_times = string_size / input.size();
auto d_input_string = cudf::strings::repeat_string(d_string_scalar, repeat_times);
auto& d_input = static_cast<cudf::scalar_type_t<std::string>&>(*d_input_string);
input = d_input.to_string(stream);



// Prepare input & output buffers
constexpr std::size_t single_item = 1;
rmm::device_uvector<SymbolT> d_input(input.size(), stream.view());
hostdevice_vector<SymbolT> output_gpu(input.size(), stream.view());
hostdevice_vector<SymbolOffsetT> output_gpu_size(single_item, stream.view());
hostdevice_vector<SymbolOffsetT> out_indexes_gpu(input.size(), stream.view());
ASSERT_CUDA_SUCCEEDED(cudaMemcpyAsync(
d_input.data(), input.data(), input.size() * sizeof(SymbolT), cudaMemcpyHostToDevice, stream.value()));
hostdevice_vector<SymbolT> output_gpu(input.size(), stream_view);
hostdevice_vector<SymbolOffsetT> output_gpu_size(single_item, stream_view);
hostdevice_vector<SymbolOffsetT> out_indexes_gpu(input.size(), stream_view);

// Run algorithm
DfaFstT parser{pda_sgs, pda_state_tt, pda_out_tt, stream.value()};
Expand Down Expand Up @@ -270,11 +255,10 @@ TEST_F(FstTest, GroundTruth)

// Verify results
ASSERT_EQ(output_gpu_size[0], output_cpu.size());
ASSERT_EQ(out_indexes_gpu.size(), out_index_cpu.size());
for (std::size_t i = 0; i < output_cpu.size(); i++) {
ASSERT_EQ(output_gpu[i], output_cpu[i]) << "Mismatch at index #" << i;
}
for (std::size_t i = 0; i < out_indexes_gpu.size(); i++) {
for (std::size_t i = 0; i < output_cpu.size(); i++) {
ASSERT_EQ(out_indexes_gpu[i], out_index_cpu[i]) << "Mismatch at index #" << i;
}
}
Expand Down

0 comments on commit f656f49

Please sign in to comment.