From fe52e60a09e6468d2886c195fef70d99cf5cdce6 Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Tue, 6 Feb 2024 19:55:34 -0800 Subject: [PATCH] Limit number of different regexes in regexp_xxx Presto functions (#8687) Summary: Introduce a limit for the number of different regular expressions used in a specific regexp_xxx function instance. The limit of 20 applies per function instance and thread of execution. The same limit has been applied earlier to LIKE expression with patterns that require compiling regular expressions. Pull Request resolved: https://github.com/facebookincubator/velox/pull/8687 Reviewed By: Yuhta Differential Revision: D53479087 Pulled By: mbasmanova fbshipit-source-id: e304b4bdb6eb0bfe1bec6957be3a25bbfb1d2157 --- velox/docs/functions/presto/regexp.rst | 11 ++- .../expression/tests/ExpressionFuzzerTest.cpp | 4 + velox/functions/lib/Re2Functions.cpp | 80 +++++++++++++------ .../functions/lib/tests/Re2FunctionsTest.cpp | 40 ++++++++++ 4 files changed, 108 insertions(+), 27 deletions(-) diff --git a/velox/docs/functions/presto/regexp.rst b/velox/docs/functions/presto/regexp.rst index c8a8f4b81ac6..6171ad7a90ab 100644 --- a/velox/docs/functions/presto/regexp.rst +++ b/velox/docs/functions/presto/regexp.rst @@ -7,6 +7,9 @@ supports only a subset of PCRE syntax and in particular does not support backtracking and associated features (e.g. back references). See https://github.com/google/re2/wiki/Syntax for more information. +Compiling regular expressions is CPU intensive. Hence, each function is +limited to 20 different expressions per instance and thread of execution. + .. function:: like(string, pattern) -> boolean like(string, pattern, escape) -> boolean @@ -19,9 +22,11 @@ See https://github.com/google/re2/wiki/Syntax for more information. wildcard '_' represents exactly one character. Note: Each function instance allow for a maximum of 20 regular expressions to - be compiled throughout the lifetime of the query. Not all Patterns requires - compilation of regular expressions; for example a pattern 'aa' does not. - Only those that require the compilation of regular expressions are counted. + be compiled per thread of execution. Not all patterns require + compilation of regular expressions. Patterns 'aaa', 'aaa%', '%aaa', where 'aaa' + contains only regular characters and '_' wildcards are evaluated without + using regular expressions. Only those patterns that require the compilation of + regular expressions are counted towards the limit. SELECT like('abc', '%b%'); -- true SELECT like('a_c', '%#_%', '#'); -- true diff --git a/velox/expression/tests/ExpressionFuzzerTest.cpp b/velox/expression/tests/ExpressionFuzzerTest.cpp index d206add7c3bf..007e8a428e1d 100644 --- a/velox/expression/tests/ExpressionFuzzerTest.cpp +++ b/velox/expression/tests/ExpressionFuzzerTest.cpp @@ -54,6 +54,10 @@ int main(int argc, char** argv) { "width_bucket", // Fuzzer cannot generate valid 'comparator' lambda. "array_sort(array(T),constant function(T,T,bigint)) -> array(T)", + // https://github.com/facebookincubator/velox/issues/8438#issuecomment-1907234044 + "regexp_extract", + "regexp_extract_all", + "regexp_like", }; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; return FuzzerRunner::run(initialSeed, skipFunctions); diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index ff7053526a6d..8c2509f03adc 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -23,6 +23,48 @@ namespace { static const int kMaxCompiledRegexes = 20; +void checkForBadPattern(const RE2& re) { + if (UNLIKELY(!re.ok())) { + VELOX_USER_FAIL("invalid regular expression:{}", re.error()); + } +} + +template +re2::StringPiece toStringPiece(const T& s) { + return re2::StringPiece(s.data(), s.size()); +} + +// A cache of compiled regular expressions (RE2 instances). Allows up to +// 'kMaxCompiledRegexes' different expressions. +// +// Compiling regular expressions is expensive. It can take up to 200 times +// more CPU time to compile a regex vs. evaluate it. +class ReCache { + public: + RE2* findOrCompile(const StringView& pattern) { + const std::string key = pattern; + + auto reIt = cache_.find(key); + if (reIt != cache_.end()) { + return reIt->second.get(); + } + + VELOX_USER_CHECK_LT( + cache_.size(), kMaxCompiledRegexes, "Max number of regex reached"); + + auto re = std::make_unique(toStringPiece(pattern), RE2::Quiet); + checkForBadPattern(*re); + + auto [it, inserted] = cache_.emplace(key, std::move(re)); + VELOX_CHECK(inserted); + + return it->second.get(); + } + + private: + folly::F14FastMap> cache_; +}; + std::string printTypesCsv( const std::vector& inputArgs) { std::string result; @@ -34,11 +76,6 @@ std::string printTypesCsv( return result; } -template -re2::StringPiece toStringPiece(const T& s) { - return re2::StringPiece(s.data(), s.size()); -} - // If v is a non-null constant vector, returns the constant value. Otherwise // returns nullopt. template @@ -50,12 +87,6 @@ std::optional getIfConstant(const BaseVector& v) { return std::nullopt; } -void checkForBadPattern(const RE2& re) { - if (UNLIKELY(!re.ok())) { - VELOX_USER_FAIL("invalid regular expression:{}", re.error()); - } -} - FlatVector& ensureWritableBool( const SelectivityVector& rows, exec::EvalCtx& context, @@ -220,11 +251,13 @@ class Re2Match final : public exec::VectorFunction { exec::LocalDecodedVector toSearch(context, *args[0], rows); exec::LocalDecodedVector pattern(context, *args[1], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); result.set(row, Fn(toSearch->valueAt(row), re)); }); } + + private: + mutable ReCache cache_; }; void checkForBadGroupId(int64_t groupId, const RE2& re) { @@ -348,8 +381,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { if (args.size() == 2) { groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); mustRefSourceStrings |= re2Extract(result, i, re, toSearch, groups, 0, emptyNoMatch_); }); @@ -357,8 +389,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { const auto groupId = groupIds->valueAt(i); - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); mustRefSourceStrings |= @@ -372,6 +403,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { private: const bool emptyNoMatch_; + mutable ReCache cache_; }; namespace { @@ -1126,8 +1158,7 @@ class Re2ExtractAll final : public exec::VectorFunction { // groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); re2ExtractAll(resultWriter, re, inputStrs, row, groups, 0); }); } else { @@ -1136,8 +1167,7 @@ class Re2ExtractAll final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { const T groupId = groupIds->valueAt(row); - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); re2ExtractAll(resultWriter, re, inputStrs, row, groups, groupId); @@ -1150,6 +1180,9 @@ class Re2ExtractAll final : public exec::VectorFunction { ->asFlatVector() ->acquireSharedStringBuffers(inputStrs->base()); } + + private: + mutable ReCache cache_; }; template @@ -1170,9 +1203,8 @@ std::shared_ptr makeRe2MatchImpl( return std::make_shared>( constantPattern->as>()->valueAt(0)); } - static std::shared_ptr> kMatchExpr = - std::make_shared>(); - return kMatchExpr; + + return std::make_shared>(); } } // namespace diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 1d0bdfb505b2..bb7503d74b5f 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -1431,5 +1431,45 @@ TEST_F(Re2FunctionsTest, regexExtractAllLarge) { "No group 4611686018427387904 in regex '(\\d+)([a-z]+)") } +// Make sure we do not compile more than kMaxCompiledRegexes. +TEST_F(Re2FunctionsTest, limit) { + auto data = makeRowVector({ + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples and oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples (.*) oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { + return fmt::format("Apples (.*) oranges {}", row % 20); + }), + }); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_like(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_like(c0, c2)", data)); +} + } // namespace } // namespace facebook::velox::functions