|
18 | 18 | #include <cudf/column/column_factories.hpp>
|
19 | 19 | #include <cudf/detail/null_mask.hpp>
|
20 | 20 | #include <cudf/detail/nvtx/ranges.hpp>
|
| 21 | +#include <cudf/detail/utilities/cuda.cuh> |
| 22 | +#include <cudf/detail/utilities/device_atomics.cuh> |
21 | 23 | #include <cudf/scalar/scalar_factories.hpp>
|
22 | 24 | #include <cudf/strings/detail/utilities.hpp>
|
23 | 25 | #include <cudf/strings/find.hpp>
|
|
28 | 30 | #include <rmm/cuda_stream_view.hpp>
|
29 | 31 | #include <rmm/exec_policy.hpp>
|
30 | 32 |
|
| 33 | +#include <thrust/binary_search.h> |
31 | 34 | #include <thrust/iterator/counting_iterator.h>
|
32 | 35 | #include <thrust/transform.h>
|
33 | 36 |
|
@@ -162,6 +165,81 @@ std::unique_ptr<column> rfind(strings_column_view const& strings,
|
162 | 165 |
|
163 | 166 | namespace detail {
|
164 | 167 | namespace {
|
| 168 | + |
| 169 | +/** |
| 170 | + * @brief Threshold to decide on using string or warp parallel functions. |
| 171 | + * |
| 172 | + * If the average byte length of a string in a column exceeds this value then |
| 173 | + * the warp-parallel `contains_warp_fn` function is used. |
| 174 | + * Otherwise, the string-parallel function in `contains_fn` is used. |
| 175 | + * |
| 176 | + * This is only used for the scalar version of `contains()` right now. |
| 177 | + */ |
| 178 | +constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 64; |
| 179 | + |
| 180 | +/** |
| 181 | + * @brief Check if `d_target` appears in a row in `d_strings`. |
| 182 | + * |
| 183 | + * This executes as a warp per string/row. |
| 184 | + */ |
| 185 | +struct contains_warp_fn { |
| 186 | + column_device_view const d_strings; |
| 187 | + string_view const d_target; |
| 188 | + bool* d_results; |
| 189 | + |
| 190 | + __device__ void operator()(std::size_t idx) |
| 191 | + { |
| 192 | + auto const str_idx = static_cast<size_type>(idx / cudf::detail::warp_size); |
| 193 | + if (d_strings.is_null(str_idx)) { return; } |
| 194 | + // get the string for this warp |
| 195 | + auto const d_str = d_strings.element<string_view>(str_idx); |
| 196 | + // each thread of the warp will check just part of the string |
| 197 | + auto found = false; |
| 198 | + for (auto i = static_cast<size_type>(idx % cudf::detail::warp_size); |
| 199 | + !found && (i + d_target.size_bytes()) < d_str.size_bytes(); |
| 200 | + i += cudf::detail::warp_size) { |
| 201 | + // check the target matches this part of the d_str data |
| 202 | + if (d_target.compare(d_str.data() + i, d_target.size_bytes()) == 0) { found = true; } |
| 203 | + } |
| 204 | + if (found) { atomicOr(d_results + str_idx, true); } |
| 205 | + } |
| 206 | +}; |
| 207 | + |
| 208 | +std::unique_ptr<column> contains_warp_parallel(strings_column_view const& input, |
| 209 | + string_scalar const& target, |
| 210 | + rmm::cuda_stream_view stream, |
| 211 | + rmm::mr::device_memory_resource* mr) |
| 212 | +{ |
| 213 | + CUDF_EXPECTS(target.is_valid(stream), "Parameter target must be valid."); |
| 214 | + auto d_target = string_view(target.data(), target.size()); |
| 215 | + |
| 216 | + // create output column |
| 217 | + auto results = make_numeric_column(data_type{type_id::BOOL8}, |
| 218 | + input.size(), |
| 219 | + cudf::detail::copy_bitmask(input.parent(), stream, mr), |
| 220 | + input.null_count(), |
| 221 | + stream, |
| 222 | + mr); |
| 223 | + |
| 224 | + // fill the output with `false` unless the `d_target` is empty |
| 225 | + auto results_view = results->mutable_view(); |
| 226 | + thrust::fill(rmm::exec_policy(stream), |
| 227 | + results_view.begin<bool>(), |
| 228 | + results_view.end<bool>(), |
| 229 | + d_target.empty()); |
| 230 | + |
| 231 | + if (!d_target.empty()) { |
| 232 | + // launch warp per string |
| 233 | + auto d_strings = column_device_view::create(input.parent(), stream); |
| 234 | + thrust::for_each_n(rmm::exec_policy(stream), |
| 235 | + thrust::make_counting_iterator<std::size_t>(0), |
| 236 | + static_cast<std::size_t>(input.size()) * cudf::detail::warp_size, |
| 237 | + contains_warp_fn{*d_strings, d_target, results_view.data<bool>()}); |
| 238 | + } |
| 239 | + results->set_null_count(input.null_count()); |
| 240 | + return results; |
| 241 | +} |
| 242 | + |
165 | 243 | /**
|
166 | 244 | * @brief Utility to return a bool column indicating the presence of
|
167 | 245 | * a given target string in a strings column.
|
@@ -286,15 +364,21 @@ std::unique_ptr<column> contains_fn(strings_column_view const& strings,
|
286 | 364 | } // namespace
|
287 | 365 |
|
288 | 366 | std::unique_ptr<column> contains(
|
289 |
| - strings_column_view const& strings, |
| 367 | + strings_column_view const& input, |
290 | 368 | string_scalar const& target,
|
291 | 369 | rmm::cuda_stream_view stream,
|
292 | 370 | rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
|
293 | 371 | {
|
| 372 | + // use warp parallel when the average string width is greater than the threshold |
| 373 | + if (!input.is_empty() && ((input.chars_size() / input.size()) > AVG_CHAR_BYTES_THRESHOLD)) { |
| 374 | + return contains_warp_parallel(input, target, stream, mr); |
| 375 | + } |
| 376 | + |
| 377 | + // benchmark measurements showed this to be faster for smaller strings |
294 | 378 | auto pfn = [] __device__(string_view d_string, string_view d_target) {
|
295 | 379 | return d_string.find(d_target) >= 0;
|
296 | 380 | };
|
297 |
| - return contains_fn(strings, target, pfn, stream, mr); |
| 381 | + return contains_fn(input, target, pfn, stream, mr); |
298 | 382 | }
|
299 | 383 |
|
300 | 384 | std::unique_ptr<column> contains(
|
|
0 commit comments