Skip to content

Commit

Permalink
Implement native support StringView for character length (#11676)
Browse files Browse the repository at this point in the history
* native support for character length

* Update datafusion/functions/src/unicode/character_length.rs

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
XiangpengHao and alamb authored Jul 27, 2024
1 parent 322c3d2 commit ab8005d
Showing 1 changed file with 68 additions and 63 deletions.
131 changes: 68 additions & 63 deletions datafusion/functions/src/unicode/character_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

use crate::utils::{make_scalar_function, utf8_to_int_type};
use arrow::array::{
ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
Array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
OffsetSizeTrait, PrimitiveArray,
};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::exec_err;
use datafusion_common::Result;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
Expand Down Expand Up @@ -71,17 +70,7 @@ impl ScalarUDFImpl for CharacterLengthFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(character_length::<Int32Type>, vec![])(args)
}
DataType::LargeUtf8 => {
make_scalar_function(character_length::<Int64Type>, vec![])(args)
}
other => {
exec_err!("Unsupported data type {other:?} for function character_length")
}
}
make_scalar_function(character_length, vec![])(args)
}

fn aliases(&self) -> &[String] {
Expand All @@ -92,15 +81,32 @@ impl ScalarUDFImpl for CharacterLengthFunc {
/// Returns number of characters in the string.
/// character_length('josé') = 4
/// The implementation counts UTF-8 code points to count the number of characters
fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
fn character_length(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
character_length_general::<Int32Type, _>(string_array)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
character_length_general::<Int64Type, _>(string_array)
}
DataType::Utf8View => {
let string_array = args[0].as_string_view();
character_length_general::<Int32Type, _>(string_array)
}
_ => unreachable!(),
}
}

fn character_length_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item = &'a str>>(
array: V,
) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
{
let string_array: &GenericStringArray<T::Native> =
as_generic_string_array::<T::Native>(&args[0])?;

let result = string_array
.iter()
let iter = ArrayIter::new(array);
let result = iter
.map(|string| {
string.map(|string: &str| {
T::Native::from_usize(string.chars().count())
Expand All @@ -116,55 +122,54 @@ where
mod tests {
use crate::unicode::character_length::CharacterLengthFunc;
use crate::utils::test::test_function;
use arrow::array::{Array, Int32Array};
use arrow::datatypes::DataType::Int32;
use arrow::array::{Array, Int32Array, Int64Array};
use arrow::datatypes::DataType::{Int32, Int64};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

macro_rules! test_character_length {
($INPUT:expr, $EXPECTED:expr) => {
test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);

test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
$EXPECTED,
i64,
Int64,
Int64Array
);

test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
$EXPECTED,
i32,
Int32,
Int32Array
);
};
}

#[test]
fn test_functions() -> Result<()> {
#[cfg(feature = "unicode_expressions")]
test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8(Some(
String::from("chars")
)))],
Ok(Some(5)),
i32,
Int32,
Int32Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8(Some(
String::from("josé")
)))],
Ok(Some(4)),
i32,
Int32,
Int32Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8(Some(
String::from("")
)))],
Ok(Some(0)),
i32,
Int32,
Int32Array
);
#[cfg(feature = "unicode_expressions")]
test_function!(
CharacterLengthFunc::new(),
&[ColumnarValue::Scalar(ScalarValue::Utf8(None))],
Ok(None),
i32,
Int32,
Int32Array
);
{
test_character_length!(Some(String::from("chars")), Ok(Some(5)));
test_character_length!(Some(String::from("josé")), Ok(Some(4)));
// test long strings (more than 12 bytes for StringView)
test_character_length!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
test_character_length!(Some(String::from("")), Ok(Some(0)));
test_character_length!(None, Ok(None));
}

#[cfg(not(feature = "unicode_expressions"))]
test_function!(
CharacterLengthFunc::new(),
Expand Down

0 comments on commit ab8005d

Please sign in to comment.