Skip to content

Commit 6128e0d

Browse files
authored
Use warp per string for long strings in cudf::strings::contains() (#10739)
Improves the performance on `cudf::strings::contains()` for long strings. This executes a warp per string to match a target over sections of a single string in parallel. The benchmark showed this to be faster than the current implementation only for longer strings (greater than 64 bytes). It also proved somewhat faster and more consistent than a pure character-parallel approach. This change may also help improve the performance of the regex `contains_re()` function in the future. Authors: - David Wendt (https://github.com/davidwendt) Approvers: - Mike Wilson (https://github.com/hyperbolic2346) - Bradley Dice (https://github.com/bdice) URL: #10739
1 parent 027c34a commit 6128e0d

File tree

2 files changed

+106
-2
lines changed

2 files changed

+106
-2
lines changed

cpp/src/strings/search/find.cu

+86-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <cudf/column/column_factories.hpp>
1919
#include <cudf/detail/null_mask.hpp>
2020
#include <cudf/detail/nvtx/ranges.hpp>
21+
#include <cudf/detail/utilities/cuda.cuh>
22+
#include <cudf/detail/utilities/device_atomics.cuh>
2123
#include <cudf/scalar/scalar_factories.hpp>
2224
#include <cudf/strings/detail/utilities.hpp>
2325
#include <cudf/strings/find.hpp>
@@ -28,6 +30,7 @@
2830
#include <rmm/cuda_stream_view.hpp>
2931
#include <rmm/exec_policy.hpp>
3032

33+
#include <thrust/binary_search.h>
3134
#include <thrust/iterator/counting_iterator.h>
3235
#include <thrust/transform.h>
3336

@@ -162,6 +165,81 @@ std::unique_ptr<column> rfind(strings_column_view const& strings,
162165

163166
namespace detail {
164167
namespace {
168+
169+
/**
170+
* @brief Threshold to decide on using string or warp parallel functions.
171+
*
172+
* If the average byte length of a string in a column exceeds this value then
173+
* the warp-parallel `contains_warp_fn` function is used.
174+
* Otherwise, the string-parallel function in `contains_fn` is used.
175+
*
176+
* This is only used for the scalar version of `contains()` right now.
177+
*/
178+
constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64;
179+
180+
/**
181+
* @brief Check if `d_target` appears in a row in `d_strings`.
182+
*
183+
* This executes as a warp per string/row.
184+
*/
185+
struct contains_warp_fn {
186+
column_device_view const d_strings;
187+
string_view const d_target;
188+
bool* d_results;
189+
190+
__device__ void operator()(std::size_t idx)
191+
{
192+
auto const str_idx = static_cast<size_type>(idx / cudf::detail::warp_size);
193+
if (d_strings.is_null(str_idx)) { return; }
194+
// get the string for this warp
195+
auto const d_str = d_strings.element<string_view>(str_idx);
196+
// each thread of the warp will check just part of the string
197+
auto found = false;
198+
for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size);
199+
!found && (i + d_target.size_bytes()) < d_str.size_bytes();
200+
i += cudf::detail::warp_size) {
201+
// check the target matches this part of the d_str data
202+
if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; }
203+
}
204+
if (found) { atomicOr(d_results + str_idx, true); }
205+
}
206+
};
207+
208+
std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input,
209+
string_scalar const& target,
210+
rmm::cuda_stream_view stream,
211+
rmm::mr::device_memory_resource* mr)
212+
{
213+
CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid.");
214+
auto d_target = string_view(target.data(), target.size());
215+
216+
// create output column
217+
auto results = make_numeric_column(data_type{type_id::BOOL8},
218+
input.size(),
219+
cudf::detail::copy_bitmask(input.parent(), stream, mr),
220+
input.null_count(),
221+
stream,
222+
mr);
223+
224+
// fill the output with `false` unless the `d_target` is empty
225+
auto results_view = results->mutable_view();
226+
thrust::fill(rmm::exec_policy(stream),
227+
results_view.begin<bool>(),
228+
results_view.end<bool>(),
229+
d_target.empty());
230+
231+
if (!d_target.empty()) {
232+
// launch warp per string
233+
auto d_strings = column_device_view::create(input.parent(), stream);
234+
thrust::for_each_n(rmm::exec_policy(stream),
235+
thrust::make_counting_iterator<std::size_t>(0),
236+
static_cast<std::size_t>(input.size()) * cudf::detail::warp_size,
237+
contains_warp_fn{*d_strings, d_target, results_view.data<bool>()});
238+
}
239+
results->set_null_count(input.null_count());
240+
return results;
241+
}
242+
165243
/**
166244
* @brief Utility to return a bool column indicating the presence of
167245
* a given target string in a strings column.
@@ -286,15 +364,21 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
286364
} // namespace
287365

288366
std::unique_ptr<column> contains(
289-
strings_column_view const& strings,
367+
strings_column_view const& input,
290368
string_scalar const& target,
291369
rmm::cuda_stream_view stream,
292370
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
293371
{
372+
// use warp parallel when the average string width is greater than the threshold
373+
if (!input.is_empty() && ((input.chars_size() / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) {
374+
return contains_warp_parallel(input, target, stream, mr);
375+
}
376+
377+
// benchmark measurements showed this to be faster for smaller strings
294378
auto pfn = [] __device__(string_view d_string, string_view d_target) {
295379
return d_string.find(d_target) >= 0;
296380
};
297-
return contains_fn(strings, target, pfn, stream, mr);
381+
return contains_fn(input, target, pfn, stream, mr);
298382
}
299383

300384
std::unique_ptr<column> contains(

cpp/tests/strings/find_tests.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,26 @@ TEST_F(StringsFindTest, Contains)
8282
}
8383
}
8484

85+
TEST_F(StringsFindTest, ContainsLongStrings)
86+
{
87+
cudf::test::strings_column_wrapper strings(
88+
{"Héllo, there world and goodbye",
89+
"quick brown fox jumped over the lazy brown dog; the fat cats jump in place without moving",
90+
"the following code snippet demonstrates how to use search for values in an ordered range",
91+
"it returns the last position where value could be inserted without violating the ordering",
92+
"algorithms execution is parallelized as determined by an execution policy. t",
93+
"he this is a continuation of previous row to make sure string boundaries are honored",
94+
""});
95+
auto strings_view = cudf::strings_column_view(strings);
96+
auto results = cudf::strings::contains(strings_view, cudf::string_scalar("e"));
97+
cudf::test::fixed_width_column_wrapper<bool> expected({1, 1, 1, 1, 1, 1, 0});
98+
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected);
99+
100+
results = cudf::strings::contains(strings_view, cudf::string_scalar(" the "));
101+
cudf::test::fixed_width_column_wrapper<bool> expected2({0, 1, 0, 1, 0, 0, 0});
102+
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected2);
103+
}
104+
85105
TEST_F(StringsFindTest, StartsWith)
86106
{
87107
cudf::test::strings_column_wrapper strings({"Héllo", "thesé", "", "lease", "tést strings", ""},

0 commit comments

Comments
 (0)