From f6acf4752bd91ef2a61058d53bfa22e9c90756fd Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Fri, 14 Jul 2023 00:15:25 +0200 Subject: [PATCH] feat: Support FixedSizedListArray for length kernel --- arrow-string/src/length.rs | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/arrow-string/src/length.rs b/arrow-string/src/length.rs index 90efdd7b67cc..c688e549b973 100644 --- a/arrow-string/src/length.rs +++ b/arrow-string/src/length.rs @@ -88,6 +88,30 @@ where unary_offsets!(array, T::DATA_TYPE, |x| x) } +fn length_list_fixed_size(array: &dyn Array, length: i32) -> ArrayRef +where + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, +{ + let array = array.as_any().downcast_ref::().unwrap(); + let null_bit_buffer = array.nulls().map(|b| b.inner().sliced()); + let length_list = array.len(); + let buffer = Buffer::from_vec(vec![length; length_list]); + + let data = unsafe { + ArrayData::new_unchecked( + T::DATA_TYPE, + length_list, + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }; + make_array(data) +} + fn length_binary(array: &dyn Array) -> ArrayRef where O: OffsetSizeTrait, @@ -172,6 +196,9 @@ pub fn length(array: &dyn Array) -> Result { DataType::LargeUtf8 => Ok(length_string::(array)), DataType::Binary => Ok(length_binary::(array)), DataType::LargeBinary => Ok(length_binary::(array)), + DataType::FixedSizeList(_, len) => { + Ok(length_list_fixed_size::(array, *len)) + } other => Err(ArrowError::ComputeError(format!( "length not supported for {other:?}" ))), @@ -215,6 +242,8 @@ pub fn bit_length(array: &dyn Array) -> Result { mod tests { use super::*; use arrow_array::cast::AsArray; + use arrow_buffer::NullBuffer; + use arrow_schema::Field; fn double_vec(v: Vec) -> Vec { [&v[..], &v[..]].concat() @@ -696,4 +725,34 @@ mod tests { assert_eq!(expected[i], actual[i],); } } + + #[test] + fn test_fixed_size_list_length() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + let list_data_type = DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Int32, false)), + 3, + ); + let nulls = NullBuffer::from(vec![true, false, true]); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_child_data(value_data.clone()) + .nulls(Some(nulls)) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let lengths = length(&list_array).unwrap(); + let lengths = lengths.as_any().downcast_ref::().unwrap(); + + assert_eq!(lengths.len(), 3); + assert_eq!(lengths.value(0), 3); + assert_eq!(lengths.is_null(1), true); + assert_eq!(lengths.value(2), 3); + } }