diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index a34c1ab90c5d..6997adc46d10 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -500,20 +500,22 @@ pub fn create_physical_fun( BuiltinScalarFunction::RegexpReplace => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, + let specializer_func = invoke_if_regex_expressions_feature_flag!( + specialize_regexp_replace, i32, "regexp_replace" ); - make_scalar_function(func)(args) + let func = specializer_func(args)?; + func(args) } DataType::LargeUtf8 => { - let func = invoke_if_regex_expressions_feature_flag!( - regexp_replace, + let specializer_func = invoke_if_regex_expressions_feature_flag!( + specialize_regexp_replace, i64, "regexp_replace" ); - make_scalar_function(func)(args) + let func = specializer_func(args)?; + func(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function regexp_replace", diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index c09c3c265e7d..26f106db6a78 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -21,16 +21,32 @@ //! Regex expressions -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait, +}; use arrow::compute; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use hashbrown::HashMap; use lazy_static::lazy_static; use regex::Regex; use std::any::type_name; use std::sync::Arc; -macro_rules! downcast_string_arg { +use crate::functions::make_scalar_function; + +macro_rules! fetch_string_arg { + ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{ + let array = downcast_string_array_arg!($ARG, $NAME, $T); + if array.is_null(0) { + return $EARLY_ABORT(array); + } else { + array.value(0) + } + }}; +} + +macro_rules! downcast_string_array_arg { ($ARG:expr, $NAME:expr, $T:ident) => {{ $ARG.as_any() .downcast_ref::>() @@ -48,14 +64,14 @@ macro_rules! downcast_string_arg { pub fn regexp_match(args: &[ArrayRef]) -> Result { match args.len() { 2 => { - let values = downcast_string_arg!(args[0], "string", T); - let regex = downcast_string_arg!(args[1], "pattern", T); + let values = downcast_string_array_arg!(args[0], "string", T); + let regex = downcast_string_array_arg!(args[1], "pattern", T); compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError) } 3 => { - let values = downcast_string_arg!(args[0], "string", T); - let regex = downcast_string_arg!(args[1], "pattern", T); - let flags = Some(downcast_string_arg!(args[2], "flags", T)); + let values = downcast_string_array_arg!(args[0], "string", T); + let regex = downcast_string_array_arg!(args[1], "pattern", T); + let flags = Some(downcast_string_array_arg!(args[2], "flags", T)); compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError) } other => Err(DataFusionError::Internal(format!( @@ -80,14 +96,17 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// /// example: `regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM'` pub fn regexp_replace(args: &[ArrayRef]) -> Result { + // Default implementation for regexp_replace, assumes all args are arrays + // and args is a sequence of 3 or 4 elements. + // creating Regex is expensive so create hashmap for memoization let mut patterns: HashMap = HashMap::new(); match args.len() { 3 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let pattern_array = downcast_string_arg!(args[1], "pattern", T); - let replacement_array = downcast_string_arg!(args[2], "replacement", T); + let string_array = downcast_string_array_arg!(args[0], "string", T); + let pattern_array = downcast_string_array_arg!(args[1], "pattern", T); + let replacement_array = downcast_string_array_arg!(args[2], "replacement", T); let result = string_array .iter() @@ -120,10 +139,10 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } 4 => { - let string_array = downcast_string_arg!(args[0], "string", T); - let pattern_array = downcast_string_arg!(args[1], "pattern", T); - let replacement_array = downcast_string_arg!(args[2], "replacement", T); - let flags_array = downcast_string_arg!(args[3], "flags", T); + let string_array = downcast_string_array_arg!(args[0], "string", T); + let pattern_array = downcast_string_array_arg!(args[1], "pattern", T); + let replacement_array = downcast_string_array_arg!(args[2], "replacement", T); + let flags_array = downcast_string_array_arg!(args[3], "flags", T); let result = string_array .iter() @@ -178,10 +197,125 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result } } +fn _regexp_replace_early_abort( + input_array: &GenericStringArray, +) -> Result { + // Mimicing the existing behavior of regexp_replace, if any of the scalar arguments + // are actuall null, then the result will be an array of the same size but with nulls. + Ok(new_null_array(input_array.data_type(), input_array.len())) +} + +fn _regexp_replace_static_pattern( + args: &[ArrayRef], +) -> Result { + // Special cased regex_replace implementation for the scenerio where + // both the pattern itself and the flags are scalars. This means we can + // skip regex caching system and basically hold a single Regex object + // for the replace operation. + + let string_array = downcast_string_array_arg!(args[0], "string", T); + let pattern = fetch_string_arg!(args[1], "pattern", T, _regexp_replace_early_abort); + let replacement_array = downcast_string_array_arg!(args[2], "replacement", T); + let flags = match args.len() { + 3 => None, + 4 => Some(fetch_string_arg!(args[3], "flags", T, _regexp_replace_early_abort)), + other => { + return Err(DataFusionError::Internal(format!( + "regexp_replace was called with {} arguments. It requires at least 3 and at most 4.", + other + ))) + } + }; + + // Embed the flag (if it exists) into the pattern + let (pattern, replace_all) = match flags { + Some("g") => (pattern.to_string(), true), + Some(flags) => ( + format!("(?{}){}", flags.to_string().replace('g', ""), pattern), + flags.contains('g'), + ), + None => (pattern.to_string(), false), + }; + + let re = Regex::new(&pattern) + .map_err(|err| DataFusionError::Execution(err.to_string()))?; + + let result = string_array + .iter() + .zip(replacement_array.iter()) + .map(|(string, replacement)| match (string, replacement) { + (Some(string), Some(replacement)) => { + let replacement = regex_replace_posix_groups(replacement); + + if replace_all { + Some(re.replace_all(string, replacement.as_str())) + } else { + Some(re.replace(string, replacement.as_str())) + } + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} + +/// Determine which implementation of the regexp_replace to use based +/// on the given set of arguments. +pub fn specialize_regexp_replace( + args: &[ColumnarValue], +) -> Result { + // This will serve as a dispatch table where we can + // leverage it in order to determine whether the scalarity + // of the given set of arguments fits a better specialized + // function. + let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = ( + matches!(args[0], ColumnarValue::Scalar(_)), + matches!(args[1], ColumnarValue::Scalar(_)), + matches!(args[2], ColumnarValue::Scalar(_)), + // The forth argument (flags) is optional; so in the event that + // it is not available, we'll claim that it is scalar. + matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None), + ); + + match ( + is_source_scalar, + is_pattern_scalar, + is_replacement_scalar, + is_flags_scalar, + ) { + // This represents a very hot path for the case where the there is + // a single pattern that is being matched against. This is extremely + // important to specialize on since it removes the overhead of DF's + // in-house regex pattern cache (since there will be at most a single + // pattern). + // + // The flags needs to be a scalar as well since each pattern is actually + // constructed with the flags embedded into the pattern itself. This means + // even if the pattern itself is scalar, if the flags are an array then + // we will create many regexes and it is best to use the implementation + // that caches it. If there are no flags, we can simply ignore it here, + // and let the specialized function handle it. + (_, true, _, true) => { + // We still don't know the scalarity of source/replacement, so we + // need the adapter even if it will do some extra work for the pattern + // and the flags. + // + // TODO: maybe we need a way of telling the adapter on which arguments + // it can skip filling (so that we won't create N - 1 redundant cols). + Ok(make_scalar_function(_regexp_replace_static_pattern::)) + } + + // If there are no specialized implementations, we'll fall back to the + // generic implementation. + (_, _, _, _) => Ok(make_scalar_function(regexp_replace::)), + } +} + #[cfg(test)] mod tests { use super::*; use arrow::array::*; + use datafusion_common::ScalarValue; #[test] fn test_case_sensitive_regexp_match() { @@ -231,4 +365,130 @@ mod tests { assert_eq!(re.as_ref(), &expected); } + + #[test] + fn test_static_pattern_regexp_replace() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = StringArray::from(vec!["b"; 5]); + let replacements = StringArray::from(vec!["foo"; 5]); + let expected = StringArray::from(vec!["afooc"; 5]); + + let re = _regexp_replace_static_pattern::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_static_pattern_regexp_replace_with_flags() { + let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]); + let patterns = StringArray::from(vec!["b"; 5]); + let replacements = StringArray::from(vec!["foo"; 5]); + let flags = StringArray::from(vec!["i"; 5]); + let expected = + StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]); + + let re = _regexp_replace_static_pattern::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + Arc::new(flags), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_static_pattern_regexp_replace_early_abort() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = StringArray::from(vec![None; 5]); + let replacements = StringArray::from(vec!["foo"; 5]); + let expected = StringArray::from(vec![None; 5]); + + let re = _regexp_replace_static_pattern::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_static_pattern_regexp_replace_early_abort_flags() { + let values = StringArray::from(vec!["abc"; 5]); + let patterns = StringArray::from(vec!["a"; 5]); + let replacements = StringArray::from(vec!["foo"; 5]); + let flags = StringArray::from(vec![None; 5]); + let expected = StringArray::from(vec![None; 5]); + + let re = _regexp_replace_static_pattern::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + Arc::new(flags), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_static_pattern_regexp_replace_pattern_error() { + let values = StringArray::from(vec!["abc"; 5]); + // Delibaretely using an invalid pattern to see how the single pattern + // error is propagated on regexp_replace. + let patterns = StringArray::from(vec!["["; 5]); + let replacements = StringArray::from(vec!["foo"; 5]); + + let re = _regexp_replace_static_pattern::(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + ]); + let pattern_err = re.expect_err("broken pattern should have failed"); + assert_eq!( + pattern_err.to_string(), + "Execution error: regex parse error:\n [\n ^\nerror: unclosed character class" + ); + } + + #[test] + fn test_regexp_can_specialize_all_cases() { + macro_rules! make_scalar { + () => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some("foo".to_string()))) + }; + } + + macro_rules! make_array { + () => { + ColumnarValue::Array( + Arc::new(StringArray::from(vec!["bar"; 2])) as ArrayRef + ) + }; + } + + for source in [make_scalar!(), make_array!()] { + for pattern in [make_scalar!(), make_array!()] { + for replacement in [make_scalar!(), make_array!()] { + for flags in [Some(make_scalar!()), Some(make_array!()), None] { + let mut args = + vec![source.clone(), pattern.clone(), replacement.clone()]; + if let Some(flags) = flags { + args.push(flags.clone()); + } + let regex_func = specialize_regexp_replace::(&args); + assert!(regex_func.is_ok()); + } + } + } + } + } }