Skip to content

Commit

Permalink
Allow variable pattern and replacement in regexp_replace Presto funct…
Browse files Browse the repository at this point in the history
…ion (facebookincubator#10108)

Summary:
Pull Request resolved: facebookincubator#10108

Presto allows non-constant pattern and replacement.

Reviewed By: kagamiori

Differential Revision: D58285738
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Jun 7, 2024
1 parent 1c0df40 commit 90f1399
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 97 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ jobs:
--retry_with_try \
--enable_dereference \
--logtostderr=1 \
--minloglevel=1 \
--minloglevel=0 \
--repro_persist_path=/tmp/presto_fuzzer_repro \
&& echo -e "\n\nPresto Fuzzer run finished successfully."
Expand Down
53 changes: 24 additions & 29 deletions velox/functions/lib/Re2Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
namespace facebook::velox::functions {
namespace {

static const int kMaxCompiledRegexes = 20;

void checkForBadPattern(const RE2& re) {
if (UNLIKELY(!re.ok())) {
VELOX_USER_FAIL("invalid regular expression:{}", re.error());
Expand All @@ -34,36 +32,33 @@ re2::StringPiece toStringPiece(const T& s) {
return re2::StringPiece(s.data(), s.size());
}

// A cache of compiled regular expressions (RE2 instances). Allows up to
// 'kMaxCompiledRegexes' different expressions.
//
// Compiling regular expressions is expensive. It can take up to 200 times
// more CPU time to compile a regex vs. evaluate it.
class ReCache {
public:
RE2* findOrCompile(const StringView& pattern) {
const std::string key = pattern;
} // namespace

auto reIt = cache_.find(key);
if (reIt != cache_.end()) {
return reIt->second.get();
}
namespace detail {

VELOX_USER_CHECK_LT(
cache_.size(), kMaxCompiledRegexes, "Max number of regex reached");
RE2* ReCache::findOrCompile(const StringView& pattern) {
const std::string key = pattern;

auto re = std::make_unique<RE2>(toStringPiece(pattern), RE2::Quiet);
checkForBadPattern(*re);
auto reIt = cache_.find(key);
if (reIt != cache_.end()) {
return reIt->second.get();
}

auto [it, inserted] = cache_.emplace(key, std::move(re));
VELOX_CHECK(inserted);
VELOX_USER_CHECK_LT(
cache_.size(), kMaxCompiledRegexes, "Max number of regex reached");

return it->second.get();
}
auto re = std::make_unique<RE2>(toStringPiece(pattern), RE2::Quiet);
checkForBadPattern(*re);

private:
folly::F14FastMap<std::string, std::unique_ptr<RE2>> cache_;
};
auto [it, inserted] = cache_.emplace(key, std::move(re));
VELOX_CHECK(inserted);

return it->second.get();
}

} // namespace detail

namespace {

std::string printTypesCsv(
const std::vector<exec::VectorFunctionArg>& inputArgs) {
Expand Down Expand Up @@ -257,7 +252,7 @@ class Re2Match final : public exec::VectorFunction {
}

private:
mutable ReCache cache_;
mutable detail::ReCache cache_;
};

void checkForBadGroupId(int64_t groupId, const RE2& re) {
Expand Down Expand Up @@ -403,7 +398,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction {

private:
const bool emptyNoMatch_;
mutable ReCache cache_;
mutable detail::ReCache cache_;
};

namespace {
Expand Down Expand Up @@ -1173,7 +1168,7 @@ class Re2ExtractAll final : public exec::VectorFunction {
}

private:
mutable ReCache cache_;
mutable detail::ReCache cache_;
};

template <bool (*Fn)(StringView, const RE2&)>
Expand Down
102 changes: 77 additions & 25 deletions velox/functions/lib/Re2Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ std::shared_ptr<exec::VectorFunction> makeRe2ExtractAll(

std::vector<std::shared_ptr<exec::FunctionSignature>> re2ExtractAllSignatures();

namespace detail {

// A cache of compiled regular expressions (RE2 instances). Allows up to
// 'kMaxCompiledRegexes' different expressions.
//
// Compiling regular expressions is expensive. It can take up to 200 times
// more CPU time to compile a regex vs. evaluate it.
class ReCache {
public:
RE2* findOrCompile(const StringView& pattern);

private:
folly::F14FastMap<std::string, std::unique_ptr<RE2>> cache_;
};

} // namespace detail

/// regexp_replace(string, pattern, replacement) -> string
/// regexp_replace(string, pattern) -> string
///
Expand All @@ -255,55 +272,90 @@ template <
struct Re2RegexpReplace {
VELOX_DEFINE_FUNCTION_TYPES(T);

std::string processedReplacement_;
std::string result_;
std::optional<RE2> re_;

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& /*inputTypes*/,
const core::QueryConfig& config,
const arg_type<Varchar>* /*string*/,
const arg_type<Varchar>* pattern,
const arg_type<Varchar>* replacement) {
VELOX_USER_CHECK(
pattern != nullptr, "Pattern of regexp_replace must be constant.");
VELOX_USER_CHECK(
replacement != nullptr,
"Replacement sequence of regexp_replace must be constant.");

auto processedPattern = prepareRegexpPattern(*pattern);

re_.emplace(processedPattern, RE2::Quiet);
if (UNLIKELY(!re_->ok())) {
VELOX_USER_FAIL(
"Invalid regular expression {}: {}.", processedPattern, re_->error());
if (pattern != nullptr) {
const auto processedPattern = prepareRegexpPattern(*pattern);
re_.emplace(processedPattern, RE2::Quiet);
VELOX_USER_CHECK(
re_->ok(),
"Invalid regular expression {}: {}.",
processedPattern,
re_->error());
}

processedReplacement_ = prepareRegexpReplacement(*re_, *replacement);
if (replacement != nullptr) {
// Constant 'replacement' with non-constant 'pattern' needs to be
// processed separately for each row.
if (pattern != nullptr) {
ensureProcessedReplacement(re_.value(), *replacement);
constantReplacement_ = true;
}
}
}

FOLLY_ALWAYS_INLINE void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& config,
const arg_type<Varchar>* string,
const arg_type<Varchar>* pattern) {
StringView emptyReplacement;

initialize(inputTypes, config, string, pattern, &emptyReplacement);
initialize(inputTypes, config, string, pattern, nullptr);
}

FOLLY_ALWAYS_INLINE bool call(
FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& out,
const arg_type<Varchar>& string,
const arg_type<Varchar>& /*pattern*/,
const arg_type<Varchar>& /*replacement*/ = StringView{}) {
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replacement = StringView{}) {
auto& re = ensurePattern(pattern);
const auto& processedReplacement =
ensureProcessedReplacement(re, replacement);

result_.assign(string.data(), string.size());
RE2::GlobalReplace(&result_, *re_, processedReplacement_);
RE2::GlobalReplace(&result_, re, processedReplacement);

UDFOutputString::assign(out, result_);
}

return true;
private:
RE2& ensurePattern(const arg_type<Varchar>& pattern) {
if (!re_.has_value()) {
auto processedPattern = prepareRegexpPattern(pattern);
return *cache_.findOrCompile(StringView(processedPattern));
} else {
return re_.value();
}
}

const std::string& ensureProcessedReplacement(
RE2& re,
const arg_type<Varchar>& replacement) {
if (!constantReplacement_) {
processedReplacement_ = prepareRegexpReplacement(re, replacement);
}

return processedReplacement_;
}

// Used when pattern is constant.
std::optional<RE2> re_;

// True if replacement is constant.
bool constantReplacement_{false};

// Constant replacement if 'constantReplacement_' is true, or 'current'
// replacement.
std::string processedReplacement_;

// Used when pattern is not constant.
detail::ReCache cache_;

// Scratch memory to store result of replacement.
std::string result_;
};

} // namespace facebook::velox::functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,10 @@ void registerSimpleFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "like", likeSignatures(), makeLike);

registerFunction<Re2RegexpReplacePresto, Varchar, Varchar, Constant<Varchar>>(
registerFunction<Re2RegexpReplacePresto, Varchar, Varchar, Varchar>(
{prefix + "regexp_replace"});
registerFunction<Re2RegexpReplacePresto, Varchar, Varchar, Varchar, Varchar>(
{prefix + "regexp_replace"});
registerFunction<
Re2RegexpReplacePresto,
Varchar,
Varchar,
Constant<Varchar>,
Constant<Varchar>>({prefix + "regexp_replace"});
}
} // namespace

Expand Down
100 changes: 65 additions & 35 deletions velox/functions/prestosql/tests/RegexpReplaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
namespace facebook::velox {
namespace {

class RegexFunctionsTest : public functions::test::FunctionBaseTest {
class RegexpReplaceTest : public functions::test::FunctionBaseTest {
protected:
std::optional<std::string> regexp_replace(
std::optional<std::string> regexpReplace(
const std::optional<std::string>& string,
const std::string& pattern) {
return evaluateOnce<std::string>(
fmt::format("regexp_replace(c0, '{}')", pattern), string);
}

std::optional<std::string> regexp_replace(
std::optional<std::string> regexpReplace(
const std::optional<std::string>& string,
const std::string& pattern,
const std::string& replacement) {
Expand All @@ -40,47 +40,77 @@ class RegexFunctionsTest : public functions::test::FunctionBaseTest {
}
};

TEST_F(RegexFunctionsTest, RegexpReplaceNoReplacement) {
EXPECT_EQ(regexp_replace("abcd", "cd"), "ab");
EXPECT_EQ(regexp_replace("a12b34c5", "\\d"), "abc");
EXPECT_EQ(regexp_replace("abc", "\\w"), "");
EXPECT_EQ(regexp_replace("", "\\d"), "");
EXPECT_EQ(regexp_replace("abc", ""), "abc");
EXPECT_EQ(regexp_replace("$$$.", "\\$"), ".");
EXPECT_EQ(regexp_replace("???.", "\\?"), ".");
EXPECT_EQ(regexp_replace(std::nullopt, "abc"), std::nullopt);
TEST_F(RegexpReplaceTest, noReplacement) {
EXPECT_EQ(regexpReplace("abcd", "cd"), "ab");
EXPECT_EQ(regexpReplace("a12b34c5", "\\d"), "abc");
EXPECT_EQ(regexpReplace("abc", "\\w"), "");
EXPECT_EQ(regexpReplace("", "\\d"), "");
EXPECT_EQ(regexpReplace("abc", ""), "abc");
EXPECT_EQ(regexpReplace("$$$.", "\\$"), ".");
EXPECT_EQ(regexpReplace("???.", "\\?"), ".");
EXPECT_EQ(regexpReplace(std::nullopt, "abc"), std::nullopt);

auto input = makeRowVector({
makeFlatVector<std::string>(
{"apple123", "1 banana", "orange 23 ...", "12 34 56"}),
makeFlatVector<std::string>({"[0-9]+", "\\d+", "ge\\s", "[4-9 ]"}),
});

auto result = evaluate("regexp_replace(c0, c1)", input);

auto expected =
makeFlatVector<std::string>({"apple", " banana", "oran23 ...", "123"});
test::assertEqualVectors(expected, result);
}

TEST_F(RegexFunctionsTest, RegexpReplaceWithReplacement) {
EXPECT_EQ(regexp_replace("abcd", "cd", "ef"), "abef");
EXPECT_EQ(regexp_replace("abc", "\\w", ""), "");
EXPECT_EQ(regexp_replace("a12b34c5", "\\d", "."), "a..b..c.");
EXPECT_EQ(regexp_replace("", "\\d", "."), "");
EXPECT_EQ(regexp_replace("abc", "", "."), ".a.b.c.");
TEST_F(RegexpReplaceTest, withReplacement) {
EXPECT_EQ(regexpReplace("abcd", "cd", "ef"), "abef");
EXPECT_EQ(regexpReplace("abc", "\\w", ""), "");
EXPECT_EQ(regexpReplace("a12b34c5", "\\d", "."), "a..b..c.");
EXPECT_EQ(regexpReplace("", "\\d", "."), "");
EXPECT_EQ(regexpReplace("abc", "", "."), ".a.b.c.");
EXPECT_EQ(
regexp_replace("1a 2b 14m", "(\\d+)([ab]) ", "3c$2 "), "3ca 3cb 14m");
EXPECT_EQ(regexp_replace("1a 2b 14m", "(\\d+)([ab])", "3c$2"), "3ca 3cb 14m");
EXPECT_EQ(regexp_replace("abc", "(?P<alpha>\\w)", "1${alpha}"), "1a1b1c");
regexpReplace("1a 2b 14m", "(\\d+)([ab]) ", "3c$2 "), "3ca 3cb 14m");
EXPECT_EQ(regexpReplace("1a 2b 14m", "(\\d+)([ab])", "3c$2"), "3ca 3cb 14m");
EXPECT_EQ(regexpReplace("abc", "(?P<alpha>\\w)", "1${alpha}"), "1a1b1c");
EXPECT_EQ(
regexp_replace("1a1b1c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}\\$"),
regexpReplace("1a1b1c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}\\$"),
"a$b$c$");
EXPECT_EQ(
regexp_replace(
"1a2b3c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}${digit}"),
regexpReplace("1a2b3c", "(?<digit>\\d)(?<alpha>\\w)", "${alpha}${digit}"),
"a1b2c3");
EXPECT_EQ(regexp_replace("123", "(\\d)", "\\$"), "$$$");
EXPECT_EQ(regexpReplace("123", "(\\d)", "\\$"), "$$$");
EXPECT_EQ(
regexp_replace("123", "(?<digit>(?<nest>\\d))", ".${digit}"), ".1.2.3");
regexpReplace("123", "(?<digit>(?<nest>\\d))", ".${digit}"), ".1.2.3");
EXPECT_EQ(
regexp_replace("123", "(?<digit>(?<nest>\\d))", ".${nest}"), ".1.2.3");
EXPECT_EQ(regexp_replace(std::nullopt, "abc", "def"), std::nullopt);

EXPECT_THROW(regexp_replace("123", "(?<d", "."), VeloxUserError);
EXPECT_THROW(regexp_replace("123", R"((?''digit''\d))", "."), VeloxUserError);
EXPECT_THROW(regexp_replace("123", "(?P<>\\d)", "."), VeloxUserError);
EXPECT_THROW(
regexp_replace("123", "(?P<digit>\\d)", "${dd}"), VeloxUserError);
EXPECT_THROW(regexp_replace("123", "(?P<digit>\\d)", "${}"), VeloxUserError);
regexpReplace("123", "(?<digit>(?<nest>\\d))", ".${nest}"), ".1.2.3");
EXPECT_EQ(regexpReplace(std::nullopt, "abc", "def"), std::nullopt);

EXPECT_THROW(regexpReplace("123", "(?<d", "."), VeloxUserError);
EXPECT_THROW(regexpReplace("123", R"((?''digit''\d))", "."), VeloxUserError);
EXPECT_THROW(regexpReplace("123", "(?P<>\\d)", "."), VeloxUserError);
EXPECT_THROW(regexpReplace("123", "(?P<digit>\\d)", "${dd}"), VeloxUserError);
EXPECT_THROW(regexpReplace("123", "(?P<digit>\\d)", "${}"), VeloxUserError);

auto input = makeRowVector({
makeFlatVector<std::string>(
{"apple123", "1 banana", "orange 23 ...", "12 34 56"}),
makeFlatVector<std::string>({"[0-9]+", "\\d+", "ge\\s", "[4-9]"}),
makeFlatVector<std::string>({"_", ".", "", "[===]"}),
});

auto result = evaluate("regexp_replace(c0, c1, c2)", input);

auto expected = makeFlatVector<std::string>(
{"apple_", ". banana", "oran23 ...", "12 3[===] [===][===]"});
test::assertEqualVectors(expected, result);

// Constant 'replacement' with non-constant 'pattern'.
result = evaluate("regexp_replace(c0, c1, '||')", input);

expected = makeFlatVector<std::string>(
{"apple||", "|| banana", "oran||23 ...", "12 3|| ||||"});
test::assertEqualVectors(expected, result);
}

} // namespace
Expand Down

0 comments on commit 90f1399

Please sign in to comment.