Skip to content

Commit

Permalink
Limit number of different regexes in regexp_xxx Presto functions (fac…
Browse files Browse the repository at this point in the history
…ebookincubator#8687)

Summary:
Introduce a limit for the number of different regular expressions used in a
specific regexp_xxx function instance. The limit of 20 applies per function
instance and thread of execution.

The same limit has been applied earlier to LIKE expression with patterns that
require compiling regular expressions.

Pull Request resolved: facebookincubator#8687

Reviewed By: Yuhta

Differential Revision: D53479087

Pulled By: mbasmanova

fbshipit-source-id: e304b4bdb6eb0bfe1bec6957be3a25bbfb1d2157
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Feb 7, 2024
1 parent e4a2ce2 commit fe52e60
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 27 deletions.
11 changes: 8 additions & 3 deletions velox/docs/functions/presto/regexp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ supports only a subset of PCRE syntax and in particular does not support
backtracking and associated features (e.g. back references).
See https://github.com/google/re2/wiki/Syntax for more information.

Compiling regular expressions is CPU intensive. Hence, each function is
limited to 20 different expressions per instance and thread of execution.

.. function:: like(string, pattern) -> boolean
like(string, pattern, escape) -> boolean

Expand All @@ -19,9 +22,11 @@ See https://github.com/google/re2/wiki/Syntax for more information.
wildcard '_' represents exactly one character.

Note: Each function instance allow for a maximum of 20 regular expressions to
be compiled throughout the lifetime of the query. Not all Patterns requires
compilation of regular expressions; for example a pattern 'aa' does not.
Only those that require the compilation of regular expressions are counted.
be compiled per thread of execution. Not all patterns require
compilation of regular expressions. Patterns 'aaa', 'aaa%', '%aaa', where 'aaa'
contains only regular characters and '_' wildcards are evaluated without
using regular expressions. Only those patterns that require the compilation of
regular expressions are counted towards the limit.

SELECT like('abc', '%b%'); -- true
SELECT like('a_c', '%#_%', '#'); -- true
Expand Down
4 changes: 4 additions & 0 deletions velox/expression/tests/ExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ int main(int argc, char** argv) {
"width_bucket",
// Fuzzer cannot generate valid 'comparator' lambda.
"array_sort(array(T),constant function(T,T,bigint)) -> array(T)",
// https://github.com/facebookincubator/velox/issues/8438#issuecomment-1907234044
"regexp_extract",
"regexp_extract_all",
"regexp_like",
};
size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
return FuzzerRunner::run(initialSeed, skipFunctions);
Expand Down
80 changes: 56 additions & 24 deletions velox/functions/lib/Re2Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,48 @@ namespace {

static const int kMaxCompiledRegexes = 20;

void checkForBadPattern(const RE2& re) {
if (UNLIKELY(!re.ok())) {
VELOX_USER_FAIL("invalid regular expression:{}", re.error());
}
}

template <typename T>
re2::StringPiece toStringPiece(const T& s) {
return re2::StringPiece(s.data(), s.size());
}

// A cache of compiled regular expressions (RE2 instances). Allows up to
// 'kMaxCompiledRegexes' different expressions.
//
// Compiling regular expressions is expensive. It can take up to 200 times
// more CPU time to compile a regex vs. evaluate it.
class ReCache {
public:
RE2* findOrCompile(const StringView& pattern) {
const std::string key = pattern;

auto reIt = cache_.find(key);
if (reIt != cache_.end()) {
return reIt->second.get();
}

VELOX_USER_CHECK_LT(
cache_.size(), kMaxCompiledRegexes, "Max number of regex reached");

auto re = std::make_unique<RE2>(toStringPiece(pattern), RE2::Quiet);
checkForBadPattern(*re);

auto [it, inserted] = cache_.emplace(key, std::move(re));
VELOX_CHECK(inserted);

return it->second.get();
}

private:
folly::F14FastMap<std::string, std::unique_ptr<RE2>> cache_;
};

std::string printTypesCsv(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
std::string result;
Expand All @@ -34,11 +76,6 @@ std::string printTypesCsv(
return result;
}

template <typename T>
re2::StringPiece toStringPiece(const T& s) {
return re2::StringPiece(s.data(), s.size());
}

// If v is a non-null constant vector, returns the constant value. Otherwise
// returns nullopt.
template <typename T>
Expand All @@ -50,12 +87,6 @@ std::optional<T> getIfConstant(const BaseVector& v) {
return std::nullopt;
}

void checkForBadPattern(const RE2& re) {
if (UNLIKELY(!re.ok())) {
VELOX_USER_FAIL("invalid regular expression:{}", re.error());
}
}

FlatVector<bool>& ensureWritableBool(
const SelectivityVector& rows,
exec::EvalCtx& context,
Expand Down Expand Up @@ -220,11 +251,13 @@ class Re2Match final : public exec::VectorFunction {
exec::LocalDecodedVector toSearch(context, *args[0], rows);
exec::LocalDecodedVector pattern(context, *args[1], rows);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
RE2 re(toStringPiece(pattern->valueAt<StringView>(row)), RE2::Quiet);
checkForBadPattern(re);
auto& re = *cache_.findOrCompile(pattern->valueAt<StringView>(row));
result.set(row, Fn(toSearch->valueAt<StringView>(row), re));
});
}

private:
mutable ReCache cache_;
};

void checkForBadGroupId(int64_t groupId, const RE2& re) {
Expand Down Expand Up @@ -348,17 +381,15 @@ class Re2SearchAndExtract final : public exec::VectorFunction {
if (args.size() == 2) {
groups.resize(1);
context.applyToSelectedNoThrow(rows, [&](vector_size_t i) {
RE2 re(toStringPiece(pattern->valueAt<StringView>(i)), RE2::Quiet);
checkForBadPattern(re);
auto& re = *cache_.findOrCompile(pattern->valueAt<StringView>(i));
mustRefSourceStrings |=
re2Extract(result, i, re, toSearch, groups, 0, emptyNoMatch_);
});
} else {
exec::LocalDecodedVector groupIds(context, *args[2], rows);
context.applyToSelectedNoThrow(rows, [&](vector_size_t i) {
const auto groupId = groupIds->valueAt<T>(i);
RE2 re(toStringPiece(pattern->valueAt<StringView>(i)), RE2::Quiet);
checkForBadPattern(re);
auto& re = *cache_.findOrCompile(pattern->valueAt<StringView>(i));
checkForBadGroupId(groupId, re);
groups.resize(groupId + 1);
mustRefSourceStrings |=
Expand All @@ -372,6 +403,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction {

private:
const bool emptyNoMatch_;
mutable ReCache cache_;
};

namespace {
Expand Down Expand Up @@ -1126,8 +1158,7 @@ class Re2ExtractAll final : public exec::VectorFunction {
//
groups.resize(1);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
RE2 re(toStringPiece(pattern->valueAt<StringView>(row)), RE2::Quiet);
checkForBadPattern(re);
auto& re = *cache_.findOrCompile(pattern->valueAt<StringView>(row));
re2ExtractAll(resultWriter, re, inputStrs, row, groups, 0);
});
} else {
Expand All @@ -1136,8 +1167,7 @@ class Re2ExtractAll final : public exec::VectorFunction {
exec::LocalDecodedVector groupIds(context, *args[2], rows);
context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
const T groupId = groupIds->valueAt<T>(row);
RE2 re(toStringPiece(pattern->valueAt<StringView>(row)), RE2::Quiet);
checkForBadPattern(re);
auto& re = *cache_.findOrCompile(pattern->valueAt<StringView>(row));
checkForBadGroupId(groupId, re);
groups.resize(groupId + 1);
re2ExtractAll(resultWriter, re, inputStrs, row, groups, groupId);
Expand All @@ -1150,6 +1180,9 @@ class Re2ExtractAll final : public exec::VectorFunction {
->asFlatVector<StringView>()
->acquireSharedStringBuffers(inputStrs->base());
}

private:
mutable ReCache cache_;
};

template <bool (*Fn)(StringView, const RE2&)>
Expand All @@ -1170,9 +1203,8 @@ std::shared_ptr<exec::VectorFunction> makeRe2MatchImpl(
return std::make_shared<Re2MatchConstantPattern<Fn>>(
constantPattern->as<ConstantVector<StringView>>()->valueAt(0));
}
static std::shared_ptr<Re2Match<Fn>> kMatchExpr =
std::make_shared<Re2Match<Fn>>();
return kMatchExpr;

return std::make_shared<Re2Match<Fn>>();
}

} // namespace
Expand Down
40 changes: 40 additions & 0 deletions velox/functions/lib/tests/Re2FunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1431,5 +1431,45 @@ TEST_F(Re2FunctionsTest, regexExtractAllLarge) {
"No group 4611686018427387904 in regex '(\\d+)([a-z]+)")
}

// Make sure we do not compile more than kMaxCompiledRegexes.
TEST_F(Re2FunctionsTest, limit) {
auto data = makeRowVector({
makeFlatVector<std::string>(
100,
[](auto row) { return fmt::format("Apples and oranges {}", row); }),
makeFlatVector<std::string>(
100,
[](auto row) { return fmt::format("Apples (.*) oranges {}", row); }),
makeFlatVector<std::string>(
100,
[](auto row) {
return fmt::format("Apples (.*) oranges {}", row % 20);
}),
});

VELOX_ASSERT_THROW(
evaluate("regexp_extract(c0, c1)", data), "Max number of regex reached");
ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2)", data));

VELOX_ASSERT_THROW(
evaluate("regexp_extract(c0, c1, 1)", data),
"Max number of regex reached");
ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2, 1)", data));

VELOX_ASSERT_THROW(
evaluate("regexp_extract_all(c0, c1)", data),
"Max number of regex reached");
ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2)", data));

VELOX_ASSERT_THROW(
evaluate("regexp_extract_all(c0, c1, 1)", data),
"Max number of regex reached");
ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2, 1)", data));

VELOX_ASSERT_THROW(
evaluate("regexp_like(c0, c1)", data), "Max number of regex reached");
ASSERT_NO_THROW(evaluate("regexp_like(c0, c2)", data));
}

} // namespace
} // namespace facebook::velox::functions

0 comments on commit fe52e60

Please sign in to comment.