Skip to content

Commit

Permalink
Merge d0f1020 into ced3b27
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical authored Sep 26, 2022
2 parents ced3b27 + d0f1020 commit bbb8c8b
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 20 deletions.
14 changes: 8 additions & 6 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,22 @@ pub fn create_physical_fun(
BuiltinScalarFunction::RegexpReplace => {
Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_replace,
let specializer_func = invoke_if_regex_expressions_feature_flag!(
specialize_regexp_replace,
i32,
"regexp_replace"
);
make_scalar_function(func)(args)
let func = specializer_func(args)?;
func(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_replace,
let specializer_func = invoke_if_regex_expressions_feature_flag!(
specialize_regexp_replace,
i64,
"regexp_replace"
);
make_scalar_function(func)(args)
let func = specializer_func(args)?;
func(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function regexp_replace",
Expand Down
288 changes: 274 additions & 14 deletions datafusion/physical-expr/src/regex_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,32 @@

//! Regex expressions

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::array::{
new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait,
};
use arrow::compute;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use regex::Regex;
use std::any::type_name;
use std::sync::Arc;

macro_rules! downcast_string_arg {
use crate::functions::make_scalar_function;

macro_rules! fetch_string_arg {
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
let array = downcast_string_array_arg!($ARG, $NAME, $T);
if array.is_null(0) {
return $EARLY_ABORT(array);
} else {
array.value(0)
}
}};
}

macro_rules! downcast_string_array_arg {
($ARG:expr, $NAME:expr, $T:ident) => {{
$ARG.as_any()
.downcast_ref::<GenericStringArray<T>>()
Expand All @@ -48,14 +64,14 @@ macro_rules! downcast_string_arg {
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let values = downcast_string_array_arg!(args[0], "string", T);
let regex = downcast_string_array_arg!(args[1], "pattern", T);
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
let values = downcast_string_arg!(args[0], "string", T);
let regex = downcast_string_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_arg!(args[2], "flags", T));
let values = downcast_string_array_arg!(args[0], "string", T);
let regex = downcast_string_array_arg!(args[1], "pattern", T);
let flags = Some(downcast_string_array_arg!(args[2], "flags", T));
compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError)
}
other => Err(DataFusionError::Internal(format!(
Expand All @@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
///
/// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'`
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
// Default implementation for regexp_replace, assumes all args are arrays
// and args is a sequence of 3 or 4 elements.

// creating Regex is expensive so create hashmap for memoization
let mut patterns: HashMap<String, Regex> = HashMap::new();

match args.len() {
3 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);

let result = string_array
.iter()
Expand Down Expand Up @@ -120,10 +139,10 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = downcast_string_arg!(args[0], "string", T);
let pattern_array = downcast_string_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_arg!(args[2], "replacement", T);
let flags_array = downcast_string_arg!(args[3], "flags", T);
let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern_array = downcast_string_array_arg!(args[1], "pattern", T);
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
let flags_array = downcast_string_array_arg!(args[3], "flags", T);

let result = string_array
.iter()
Expand Down Expand Up @@ -178,10 +197,125 @@ pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef>
}
}

fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
input_array: &GenericStringArray<T>,
) -> Result<ArrayRef> {
// Mimicing the existing behavior of regexp_replace, if any of the scalar arguments
// are actuall null, then the result will be an array of the same size but with nulls.
Ok(new_null_array(input_array.data_type(), input_array.len()))
}

fn _regexp_replace_static_pattern<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
// Special cased regex_replace implementation for the scenerio where
// both the pattern itself and the flags are scalars. This means we can
// skip regex caching system and basically hold a single Regex object
// for the replace operation.

let string_array = downcast_string_array_arg!(args[0], "string", T);
let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort);
let replacement_array = downcast_string_array_arg!(args[2], "replacement", T);
let flags = match args.len() {
3 => None,
4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)),
other => {
return Err(DataFusionError::Internal(format!(
"regexp_replace was called with {} arguments. It requires at least 3 and at most 4.",
other
)))
}
};

// Embed the flag (if it exists) into the pattern
let (pattern, replace_all) = match flags {
Some("g") => (pattern.to_string(), true),
Some(flags) => (
format!("(?{}){}", flags.to_string().replace('g', ""), pattern),
flags.contains('g'),
),
None => (pattern.to_string(), false),
};

let re = Regex::new(&pattern)
.map_err(|err| DataFusionError::Execution(err.to_string()))?;

let result = string_array
.iter()
.zip(replacement_array.iter())
.map(|(string, replacement)| match (string, replacement) {
(Some(string), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);

if replace_all {
Some(re.replace_all(string, replacement.as_str()))
} else {
Some(re.replace(string, replacement.as_str()))
}
}
_ => None,
})
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
}

/// Determine which implementation of the regexp_replace to use based
/// on the given set of arguments.
pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
args: &[ColumnarValue],
) -> Result<ScalarFunctionImplementation> {
// This will serve as a dispatch table where we can
// leverage it in order to determine whether the scalarity
// of the given set of arguments fits a better specialized
// function.
let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
matches!(args[0], ColumnarValue::Scalar(_)),
matches!(args[1], ColumnarValue::Scalar(_)),
matches!(args[2], ColumnarValue::Scalar(_)),
// The forth argument (flags) is optional; so in the event that
// it is not available, we'll claim that it is scalar.
matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None),
);

match (
is_source_scalar,
is_pattern_scalar,
is_replacement_scalar,
is_flags_scalar,
) {
// This represents a very hot path for the case where the there is
// a single pattern that is being matched against. This is extremely
// important to specialize on since it removes the overhead of DF's
// in-house regex pattern cache (since there will be at most a single
// pattern).
//
// The flags needs to be a scalar as well since each pattern is actually
// constructed with the flags embedded into the pattern itself. This means
// even if the pattern itself is scalar, if the flags are an array then
// we will create many regexes and it is best to use the implementation
// that caches it. If there are no flags, we can simply ignore it here,
// and let the specialized function handle it.
(_, true, _, true) => {
// We still don't know the scalarity of source/replacement, so we
// need the adapter even if it will do some extra work for the pattern
// and the flags.
//
// TODO: maybe we need a way of telling the adapter on which arguments
// it can skip filling (so that we won't create N - 1 redundant cols).
Ok(make_scalar_function(_regexp_replace_static_pattern::<T>))
}

// If there are no specialized implementations, we'll fall back to the
// generic implementation.
(_, _, _, _) => Ok(make_scalar_function(regexp_replace::<T>)),
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::*;
use datafusion_common::ScalarValue;

#[test]
fn test_case_sensitive_regexp_match() {
Expand Down Expand Up @@ -231,4 +365,130 @@ mod tests {

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec!["afooc"; 5]);

let re = _regexp_replace_static_pattern::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_with_flags() {
let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec!["i"; 5]);
let expected =
StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]);

let re = _regexp_replace_static_pattern::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_early_abort() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec![None; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec![None; 5]);

let re = _regexp_replace_static_pattern::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_early_abort_flags() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["a"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec![None; 5]);
let expected = StringArray::from(vec![None; 5]);

let re = _regexp_replace_static_pattern::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}

#[test]
fn test_static_pattern_regexp_replace_pattern_error() {
let values = StringArray::from(vec!["abc"; 5]);
// Delibaretely using an invalid pattern to see how the single pattern
// error is propagated on regexp_replace.
let patterns = StringArray::from(vec!["["; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);

let re = _regexp_replace_static_pattern::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
]);
let pattern_err = re.expect_err("broken pattern should have failed");
assert_eq!(
pattern_err.to_string(),
"Execution error: regex parse error:\n [\n ^\nerror: unclosed character class"
);
}

#[test]
fn test_regexp_can_specialize_all_cases() {
macro_rules! make_scalar {
() => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some("foo".to_string())))
};
}

macro_rules! make_array {
() => {
ColumnarValue::Array(
Arc::new(StringArray::from(vec!["bar"; 2])) as ArrayRef
)
};
}

for source in [make_scalar!(), make_array!()] {
for pattern in [make_scalar!(), make_array!()] {
for replacement in [make_scalar!(), make_array!()] {
for flags in [Some(make_scalar!()), Some(make_array!()), None] {
let mut args =
vec![source.clone(), pattern.clone(), replacement.clone()];
if let Some(flags) = flags {
args.push(flags.clone());
}
let regex_func = specialize_regexp_replace::<i32>(&args);
assert!(regex_func.is_ok());
}
}
}
}
}
}

0 comments on commit bbb8c8b

Please sign in to comment.