diff --git a/cpp/src/strings/search/find.cu b/cpp/src/strings/search/find.cu index 15d89069ba3..1390b304e43 100644 --- a/cpp/src/strings/search/find.cu +++ b/cpp/src/strings/search/find.cu @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -28,6 +30,7 @@ #include #include +#include #include #include @@ -162,6 +165,81 @@ std::unique_ptr rfind(strings_column_view const& strings, namespace detail { namespace { + +/** + * @brief Threshold to decide on using string or warp parallel functions. + * + * If the average byte length of a string in a column exceeds this value then + * the warp-parallel `contains_warp_fn` function is used. + * Otherwise, the string-parallel function in `contains_fn` is used. + * + * This is only used for the scalar version of `contains()` right now. + */ +constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64; + +/** + * @brief Check if `d_target` appears in a row in `d_strings`. + * + * This executes as a warp per string/row. + */ +struct contains_warp_fn { + column_device_view const d_strings; + string_view const d_target; + bool* d_results; + + __device__ void operator()(std::size_t idx) + { + auto const str_idx = static_cast(idx / cudf::detail::warp_size); + if (d_strings.is_null(str_idx)) { return; } + // get the string for this warp + auto const d_str = d_strings.element(str_idx); + // each thread of the warp will check just part of the string + auto found = false; + for (auto i = static_cast(idx % cudf::detail::warp_size); + !found && (i + d_target.size_bytes()) < d_str.size_bytes(); + i += cudf::detail::warp_size) { + // check the target matches this part of the d_str data + if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; } + } + if (found) { atomicOr(d_results + str_idx, true); } + } +}; + +std::unique_ptr contains_warp_parallel(strings_column_view const& input, + string_scalar const& target, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); + auto d_target = string_view(target.data(), target.size()); + + // create output column + auto results = make_numeric_column(data_type{type_id::BOOL8}, + input.size(), + cudf::detail::copy_bitmask(input.parent(), stream, mr), + input.null_count(), + stream, + mr); + + // fill the output with `false` unless the `d_target` is empty + auto results_view = results->mutable_view(); + thrust::fill(rmm::exec_policy(stream), + results_view.begin(), + results_view.end(), + d_target.empty()); + + if (!d_target.empty()) { + // launch warp per string + auto d_strings = column_device_view::create(input.parent(), stream); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + static_cast(input.size()) * cudf::detail::warp_size, + contains_warp_fn{*d_strings, d_target, results_view.data()}); + } + results->set_null_count(input.null_count()); + return results; +} + /** * @brief Utility to return a bool column indicating the presence of * a given target string in a strings column. @@ -286,15 +364,21 @@ std::unique_ptr contains_fn(strings_column_view const& strings, } // namespace std::unique_ptr contains( - strings_column_view const& strings, + strings_column_view const& input, string_scalar const& target, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { + // use warp parallel when the average string width is greater than the threshold + if (!input.is_empty() && ((input.chars_size() / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) { + return contains_warp_parallel(input, target, stream, mr); + } + + // benchmark measurements showed this to be faster for smaller strings auto pfn = [] __device__(string_view d_string, string_view d_target) { return d_string.find(d_target) >= 0; }; - return contains_fn(strings, target, pfn, stream, mr); + return contains_fn(input, target, pfn, stream, mr); } std::unique_ptr contains( diff --git a/cpp/tests/strings/find_tests.cpp b/cpp/tests/strings/find_tests.cpp index 177e6d97f7f..208063adcb0 100644 --- a/cpp/tests/strings/find_tests.cpp +++ b/cpp/tests/strings/find_tests.cpp @@ -82,6 +82,26 @@ TEST_F(StringsFindTest, Contains) } } +TEST_F(StringsFindTest, ContainsLongStrings) +{ + cudf::test::strings_column_wrapper strings( + {"Héllo, there world and goodbye", + "quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving", + "the following code snippet demonstrates how to use search for values in an ordered range", + "it returns the last position where value could be inserted without violating the ordering", + "algorithms execution is parallelized as determined by an execution policy. t", + "he this is a continuation of previous row to make sure string boundaries are honored", + ""}); + auto strings_view = cudf::strings_column_view(strings); + auto results = cudf::strings::contains(strings_view, cudf::string_scalar("e")); + cudf::test::fixed_width_column_wrapper expected({1, 1, 1, 1, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected); + + results = cudf::strings::contains(strings_view, cudf::string_scalar(" the ")); + cudf::test::fixed_width_column_wrapper expected2({0, 1, 0, 1, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected2); +} + TEST_F(StringsFindTest, StartsWith) { cudf::test::strings_column_wrapper strings({"Héllo", "thesé", "", "lease", "tést strings", ""},