Skip to content

Commit

Permalink
Update REGEXP_MATCH scalar function to support Utf8View (#14449) (#14457
Browse files Browse the repository at this point in the history
)

* Update REGEXP_MATCH scalar function to support Utf8View

* Cargo fmt fix.

Co-authored-by: Bruce Ritchie <bruce.ritchie@veeva.com>
  • Loading branch information
alamb and Omega359 authored Feb 3, 2025
1 parent 755b26a commit 9d287bd
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 43 deletions.
59 changes: 57 additions & 2 deletions datafusion/functions/benches/regx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
extern crate criterion;

use arrow::array::builder::StringBuilder;
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray};
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray};
use arrow::compute::cast;
use arrow::datatypes::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
Expand Down Expand Up @@ -141,6 +141,20 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

c.bench_function("regexp_like_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();

b.iter(|| {
black_box(
regexp_like(&[Arc::clone(&data), Arc::clone(&regex), Arc::clone(&flags)])
.expect("regexp_like should work on valid values"),
)
})
});

c.bench_function("regexp_match_1000", |b| {
let mut rng = rand::thread_rng();
let data = Arc::new(data(&mut rng)) as ArrayRef;
Expand All @@ -149,7 +163,25 @@ fn criterion_benchmark(c: &mut Criterion) {

b.iter(|| {
black_box(
regexp_match::<i32>(&[
regexp_match(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&flags),
])
.expect("regexp_match should work on valid values"),
)
})
});

c.bench_function("regexp_match_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();

b.iter(|| {
black_box(
regexp_match(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&flags),
Expand Down Expand Up @@ -180,6 +212,29 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

c.bench_function("regexp_replace_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
// flags are not allowed to be utf8view according to the function
let flags = Arc::new(flags(&mut rng)) as ArrayRef;
let replacement = Arc::new(StringViewArray::from_iter_values(
iter::repeat("XX").take(1000),
));

b.iter(|| {
black_box(
regexp_replace::<i32, _, _>(
data.as_string_view(),
regex.as_string_view(),
&replacement,
Some(&flags),
)
.expect("regexp_replace should work on valid values"),
)
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
66 changes: 33 additions & 33 deletions datafusion/functions/src/regex/regexpmatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
// under the License.

//! Regex expressions
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::compute::kernels::regexp;
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use datafusion_common::exec_err;
use datafusion_common::ScalarValue;
use datafusion_common::{arrow_datafusion_err, plan_err};
use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -86,11 +84,12 @@ impl RegexpMatchFunc {
signature: Signature::one_of(
vec![
// Planner attempts coercion to the target type starting with the most preferred candidate.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`.
// If that fails, it proceeds to `(LargeUtf8, Utf8)`.
// TODO: Native support Utf8View for regexp_match.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
// If that fails, it proceeds to `(Utf8, Utf8)`.
TypeSignature::Exact(vec![Utf8View, Utf8View]),
TypeSignature::Exact(vec![Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
],
Expand Down Expand Up @@ -138,7 +137,7 @@ impl ScalarUDFImpl for RegexpMatchFunc {
.map(|arg| arg.to_array(inferred_length))
.collect::<Result<Vec<_>>>()?;

let result = regexp_match_func(&args);
let result = regexp_match(&args);
if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
Expand All @@ -153,33 +152,35 @@ impl ScalarUDFImpl for RegexpMatchFunc {
}
}

fn regexp_match_func(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => regexp_match::<i32>(args),
DataType::LargeUtf8 => regexp_match::<i64>(args),
other => {
internal_err!("Unsupported data type {other:?} for function regexp_match")
}
}
}
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
regexp::regexp_match(values, regex, None)
regexp::regexp_match(&args[0], &args[1], None)
.map_err(|e| arrow_datafusion_err!(e))
}
3 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
let flags = as_generic_string_array::<T>(&args[2])?;

if flags.iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
match args[2].data_type() {
DataType::Utf8View => {
if args[2].as_string_view().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
DataType::Utf8 => {
if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
DataType::LargeUtf8 => {
if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
e => {
return plan_err!("regexp_match was called with unexpected data type {e:?}");
}
}

regexp::regexp_match(values, regex, Some(flags))
regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
.map_err(|e| arrow_datafusion_err!(e))
}
other => exec_err!(
Expand Down Expand Up @@ -211,7 +212,7 @@ mod tests {
expected_builder.append(false);
let expected = expected_builder.finish();

let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();

assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -236,9 +237,8 @@ mod tests {
expected_builder.append(false);
let expected = expected_builder.finish();

let re =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();
let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -250,7 +250,7 @@ mod tests {
let flags = StringArray::from(vec!["g"]);

let re_err =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.expect_err("unsupported flag should have failed");

assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option");
Expand Down
60 changes: 53 additions & 7 deletions datafusion/sqllogictest/test_files/regexp.slt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,29 @@ NULL
[Köln]
[إسرائيل]

# test string view
statement ok
CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;

query ?
SELECT regexp_match(str, pattern, flags) FROM t_stringview;
----
[a]
[A]
[B]
NULL
NULL
NULL
[010]
[Düsseldorf]
[Москва]
[Köln]
[إسرائيل]

statement ok
DROP TABLE t_stringview;

query ?
SELECT regexp_match('foobarbequebaz', '');
----
Expand Down Expand Up @@ -354,6 +377,29 @@ X
X
X

# test string view
statement ok
CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;

query T
SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview;
----
Xbc
X
aXc
AbC
aBC
4000
X
X
X
X
X

statement ok
DROP TABLE t_stringview;

query T
SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi');
----
Expand Down Expand Up @@ -621,7 +667,7 @@ CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t;

query I
SELECT regexp_count(str, '\w') from t;
SELECT regexp_count(str, '\w') from t_stringview;
----
3
3
Expand All @@ -636,7 +682,7 @@ SELECT regexp_count(str, '\w') from t;
7

query I
SELECT regexp_count(str, '\w{2}', start) from t;
SELECT regexp_count(str, '\w{2}', start) from t_stringview;
----
1
1
Expand All @@ -651,7 +697,7 @@ SELECT regexp_count(str, '\w{2}', start) from t;
3

query I
SELECT regexp_count(str, 'ab', 1, 'i') from t;
SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview;
----
1
1
Expand All @@ -667,7 +713,7 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t;


query I
SELECT regexp_count(str, pattern) from t;
SELECT regexp_count(str, pattern) from t_stringview;
----
1
1
Expand All @@ -682,7 +728,7 @@ SELECT regexp_count(str, pattern) from t;
1

query I
SELECT regexp_count(str, pattern, start) from t;
SELECT regexp_count(str, pattern, start) from t_stringview;
----
1
1
Expand All @@ -697,7 +743,7 @@ SELECT regexp_count(str, pattern, start) from t;
1

query I
SELECT regexp_count(str, pattern, start, flags) from t;
SELECT regexp_count(str, pattern, start, flags) from t_stringview;
----
1
1
Expand All @@ -713,7 +759,7 @@ SELECT regexp_count(str, pattern, start, flags) from t;

# test type coercion
query I
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t;
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview;
----
1
1
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: regexp_match(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k
01)Projection: regexp_match(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for REGEXP_REPLACE
Expand Down

0 comments on commit 9d287bd

Please sign in to comment.