Skip to content

Commit

Permalink
Re-register regexp_replace, remove from spark fuzzer, and optimize
Browse files Browse the repository at this point in the history
function
  • Loading branch information
codyschierbeck committed Feb 13, 2024
1 parent b0eeef9 commit a637807
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 81 deletions.
20 changes: 15 additions & 5 deletions velox/docs/functions/spark/regexp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ See https://github.com/google/re2/wiki/Syntax for more information.
.. spark:function:: regexp_replace(string, pattern, overwrite) -> varchar
Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite``. If no match is found, the original string is returned as is.
There is a limit to the number of unique regexes to be compiled per function call, which is 20.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception.

regexp_replace will throw an exception if ``string`` contains an invalid UTF-8 character, or if ``pattern`` does not conform to RE2 syntax: https://github.com/google/re2/wiki/Syntax.

regexp_replace does not support character class union, intersection, or difference and will throw an exception if they are detected within the provided ``pattern``.


Parameters:

Expand All @@ -57,8 +62,13 @@ See https://github.com/google/re2/wiki/Syntax for more information.
.. spark:function:: regexp_replace(string, pattern, overwrite, position) -> varchar
:noindex:

Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite`` starting from the specified ``position``. If the ``position`` is less than one, the function returns an error. If ``position`` is greater than the length of ``string``, the function returns the original ``string`` without any modifications.
There is a limit to the number of unique regexes to be compiled per function call, which is 20.
Replaces all substrings in ``string`` that match the regular expression ``pattern`` with the string ``overwrite`` starting from the specified ``position``. If no match is found, the original string is returned as is. If the ``position`` is less than one, the function throws an exception. If ``position`` is greater than the length of ``string``, the function returns the original ``string`` without any modifications.
There is a limit to the number of unique regexes to be compiled per function call, which is 20. If this limit is exceeded the function will throw an exception.

regexp_replace will throw an exception if ``string`` contains an invalid UTF-8 character, if ``position`` is less than 1, or if ``pattern`` does not conform to RE2 syntax: https://github.com/google/re2/wiki/Syntax.

regexp_replace does not support character class union, intersection, or difference and will throw an exception if they are detected within the provided ``pattern``.


This function is 1-indexed, meaning the position of the first character is 1.
Parameters:
Expand All @@ -74,6 +84,6 @@ See https://github.com/google/re2/wiki/Syntax for more information.

SELECT regexp_replace('Hello, World!', 'l', 'L', 6); -- 'Hello, WorLd!'

SELECT regexp_replace('Hello, World!', 'l', 'L', -5); -- 'Hello, World!'
SELECT regexp_replace('Hello, World!', 'l', 'L', 5); -- 'Hello, World!'

SELECT regexp_replace('Hello, World!', 'l', 'L', 100); -- ERROR: Position exceeds string length.
SELECT regexp_replace('Hello, World!', 'l', 'L', 100); -- 'Hello, World!'
2 changes: 2 additions & 0 deletions velox/expression/tests/SparkExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ int main(int argc, char** argv) {
// The following list are the Spark UDFs that hit issues
// For rlike you need the following combo in the only list:
// rlike, md5 and upper
// regexp_replace: https://github.com/facebookincubator/velox/issues/8438
std::unordered_set<std::string> skipFunctions = {
"regexp_extract",
"regexp_replace",
"rlike",
"chr",
"replace",
Expand Down
136 changes: 75 additions & 61 deletions velox/functions/sparksql/RegexFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
#include <folly/container/F14Map.h>
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/string/StringImpl.h"

namespace facebook::velox::functions::sparksql {
namespace {
Expand Down Expand Up @@ -110,23 +111,14 @@ template <typename T>
struct RegexpReplaceFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

static constexpr bool is_default_ascii_behavior = true;

void call(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace) {
re2::RE2* patternRegex = getCachedRegex(pattern.str());
re2::StringPiece replaceStringPiece = toStringPiece(replace);

std::string string(stringInput.data(), stringInput.size());
RE2::GlobalReplace(&string, *patternRegex, replaceStringPiece);

if (string.size()) {
result.resize(string.size());
std::memcpy(result.data(), string.data(), string.size());
} else {
result.resize(0);
}
call(result, stringInput, pattern, replace, 1);
}

void call(
Expand All @@ -135,79 +127,101 @@ struct RegexpReplaceFunction {
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace,
const arg_type<int64_t>& position) {
VELOX_USER_CHECK_GE(position, 1, "regexp_replace requires a position >= 1");
if (performChecks(result, stringInput, pattern, replace, position - 1)) {
return;
}
size_t start =
functions::stringImpl::cappedByteLength<false>(stringInput, position - 1);
if (start > stringInput.size() + 1) {
result = stringInput;
return;
}
performReplace(result, stringInput, pattern, replace, start);
}

re2::RE2* patternRegex = getCachedRegex(pattern.str());
re2::StringPiece replaceStringPiece = toStringPiece(replace);
re2::StringPiece inputStringPiece = toStringPiece(stringInput);
void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace) {
callAscii(result, stringInput, pattern, replace, 1);
}

if (position > stringInput.size() + 1) {
result.resize(inputStringPiece.size());
std::memcpy(
result.data(), inputStringPiece.data(), inputStringPiece.size());
void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace,
const arg_type<int64_t>& position) {
if (performChecks(result, stringInput, pattern, replace, position - 1)) {
return;
}
performReplace(result, stringInput, pattern, replace, position - 1);
}

// Adjust the position for UTF-8 by counting the code points.
size_t utf8Position = 0;
size_t numCodePoints = 0;
while (numCodePoints < position - 1 && utf8Position <= stringInput.size()) {
int charLength =
utf8proc_char_length(inputStringPiece.data() + utf8Position);
VELOX_USER_CHECK_GT(
charLength, 0, "regexp_replace encountered invalid UTF-8 character");
++numCodePoints;
utf8Position += charLength;
private:
bool performChecks(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace,
const arg_type<int64_t>& position) {
VELOX_USER_CHECK_GE(
position + 1, 1, "regexp_replace requires a position >= 1");
if (position > stringInput.size()) {
result = stringInput;
return true;
}
if (utf8Position > stringInput.size() + 1) {
result.resize(inputStringPiece.size());
std::memcpy(
result.data(), inputStringPiece.data(), inputStringPiece.size());
return;

if (stringInput.size() == 0) {
if (pattern.size() == 0 && position == 1) {
result = replace;
return true;
}
if (pattern.size() > 0) {
result = stringInput;
return true;
}
}
return false;
}

re2::StringPiece prefix(inputStringPiece.data(), utf8Position);
re2::StringPiece targetStringPiece(
inputStringPiece.data() + utf8Position,
inputStringPiece.size() - utf8Position);
void performReplace(
out_type<Varchar>& result,
const arg_type<Varchar>& stringInput,
const arg_type<Varchar>& pattern,
const arg_type<Varchar>& replace,
const arg_type<int64_t>& position) {
re2::RE2* patternRegex = getRegex(pattern.str());
re2::StringPiece replaceStringPiece = toStringPiece(replace);

std::string prefix(stringInput.data(), position);
std::string targetString(
targetStringPiece.data(), targetStringPiece.size());
RE2::GlobalReplace(&targetString, *patternRegex, replaceStringPiece);
stringInput.data() + position, stringInput.size() - position);

if (targetString.size() || prefix.size()) {
result.resize(prefix.size() + targetString.size());
std::memcpy(result.data(), prefix.data(), prefix.size());
std::memcpy(
result.data() + prefix.size(),
targetString.data(),
targetString.size());
} else {
result.resize(0);
}
RE2::GlobalReplace(&targetString, *patternRegex, replaceStringPiece);
result = prefix + targetString;
}

private:
re2::RE2* getCachedRegex(const std::string& pattern) const {
auto it = patternCache_.find(pattern);
if (it != patternCache_.end()) {
re2::RE2* getRegex(const std::string& pattern) {
auto it = cache_.find(pattern);
if (it != cache_.end()) {
return it->second.get();
}
VELOX_USER_CHECK_LT(
patternCache_.size(),
cache_.size(),
kMaxCompiledRegexes,
"regexp_replace hit the maximum number of unique regexes: {}",
kMaxCompiledRegexes);
checkForCompatiblePattern(pattern, "regexp_replace");
auto patternRegex = std::make_unique<re2::RE2>(pattern);
auto* rawPatternRegex = patternRegex.get();
checkForBadPattern(*rawPatternRegex);
patternCache_.emplace(pattern, std::move(patternRegex));
cache_.emplace(pattern, std::move(patternRegex));
return rawPatternRegex;
}

mutable folly::F14FastMap<std::string, std::unique_ptr<re2::RE2>>
patternCache_;
folly::F14FastMap<std::string, std::unique_ptr<re2::RE2>> cache_;
};

} // namespace
Expand Down Expand Up @@ -238,14 +252,14 @@ std::shared_ptr<exec::VectorFunction> makeRegexExtract(

void registerRegexpReplace(const std::string& prefix) {
registerFunction<RegexpReplaceFunction, Varchar, Varchar, Varchar, Varchar>(
{prefix + "REGEXP_REPLACE"});
{prefix + "regexp_replace"});
registerFunction<
RegexpReplaceFunction,
Varchar,
Varchar,
Varchar,
Varchar,
int64_t>({prefix + "REGEXP_REPLACE"});
int32_t>({prefix + "regexp_replace"});
}

} // namespace facebook::velox::functions::sparksql
2 changes: 2 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ void registerFunctions(const std::string& prefix) {
// Register size functions
registerSize(prefix + "size");

registerRegexpReplace(prefix);

registerFunction<JsonExtractScalarFunction, Varchar, Varchar, Varchar>(
{prefix + "get_json_object"});

Expand Down
40 changes: 25 additions & 15 deletions velox/functions/sparksql/tests/RegexFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
Expand All @@ -34,6 +35,10 @@ class RegexFunctionsTest : public test::SparkFunctionBaseTest {
void SetUp() override {
SparkFunctionBaseTest::SetUp();
registerRegexpReplace("");
// For parsing literal integers as INTEGER, not BIGINT,
// required by regexp_replace because its position argument
// is INTEGER.
options_.parseIntegerAsBigint = false;
}

std::optional<bool> rlike(
Expand Down Expand Up @@ -301,16 +306,17 @@ TEST_F(RegexFunctionsTest, regexpReplaceWithEmptyString) {
}

TEST_F(RegexFunctionsTest, regexBadJavaPattern) {
EXPECT_THROW(testRegexpReplace("[]", "[a[b]]", ""), VeloxUserError);
EXPECT_THROW(testRegexpReplace("[]", "[a&&[b]]", ""), VeloxUserError);
EXPECT_THROW(testRegexpReplace("[]", "[a&&[^b]]", ""), VeloxUserError);
VELOX_ASSERT_THROW(
testRegexpReplace("[]", "[a[b]]", ""),
"regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])");
VELOX_ASSERT_THROW(
testRegexpReplace("[]", "[a&&[b]]", ""),
"regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])");
VELOX_ASSERT_THROW(
testRegexpReplace("[]", "[a&&[^b]]", ""),
"regexp_replace does not support character class union, intersection, or difference ([a[b]], [a&&[b]], [a&&[^b]])");
}

TEST_F(RegexFunctionsTest, regexpReplaceInvalidUTF8) {
EXPECT_THROW(
testRegexpReplace(std::string("\xA0") + "bcacbdefg", "", "", {2}),
VeloxUserError);
}

TEST_F(RegexFunctionsTest, regexpReplacePosition) {
std::string output1 = "abc";
Expand All @@ -325,11 +331,15 @@ TEST_F(RegexFunctionsTest, regexpReplacePosition) {
}

TEST_F(RegexFunctionsTest, regexpReplaceNegativePosition) {
EXPECT_THROW(testRegexpReplace("abc", "a", "", {-1}), VeloxUserError);
VELOX_ASSERT_THROW(
testRegexpReplace("abc", "a", "", {-1}),
"regexp_replace requires a position >= 1");
}

TEST_F(RegexFunctionsTest, regexpReplaceZeroPosition) {
EXPECT_THROW(testRegexpReplace("abc", "a", "", {0}), VeloxUserError);
VELOX_ASSERT_THROW(
testRegexpReplace("abc", "a", "", {0}),
"regexp_replace requires a position >= 1");
}

TEST_F(RegexFunctionsTest, regexpReplacePositionTooLarge) {
Expand Down Expand Up @@ -543,8 +553,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheLimitTest) {
"X" + std::to_string(i) + "-Y" + std::to_string(i));
}

EXPECT_THROW(
testingRegexpReplaceRows(strings, patterns, replaces), VeloxUserError);
VELOX_ASSERT_THROW(
testingRegexpReplaceRows(strings, patterns, replaces),
"regexp_replace hit the maximum number of unique regexes: 20");
}

TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) {
Expand All @@ -564,10 +575,9 @@ TEST_F(RegexFunctionsTest, regexpReplaceCacheMissLimit) {
}

auto result =
testingRegexpReplaceRows(strings, patterns, replaces, positions, 50000);
auto output = convertOutput(expectedOutputs, 50000);
testingRegexpReplaceRows(strings, patterns, replaces, positions, 3);
auto output = convertOutput(expectedOutputs, 3);
assertEqualVectors(result, output);
}

} // namespace
} // namespace facebook::velox::functions::sparksql

0 comments on commit a637807

Please sign in to comment.