Skip to content

Commit

Permalink
feat: Support FixedSizedListArray for length kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Jul 13, 2023
1 parent c044464 commit f6acf47
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions arrow-string/src/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn length_list_fixed_size<T>(array: &dyn Array, length: i32) -> ArrayRef
where
T: ArrowPrimitiveType,
T::Native: OffsetSizeTrait,
{
let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
let null_bit_buffer = array.nulls().map(|b| b.inner().sliced());
let length_list = array.len();
let buffer = Buffer::from_vec(vec![length; length_list]);

let data = unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
length_list,
None,
null_bit_buffer,
0,
vec![buffer],
vec![],
)
};
make_array(data)
}

fn length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
O: OffsetSizeTrait,
Expand Down Expand Up @@ -172,6 +196,9 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
DataType::LargeUtf8 => Ok(length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(length_binary::<i64, Int64Type>(array)),
DataType::FixedSizeList(_, len) => {
Ok(length_list_fixed_size::<Int32Type>(array, *len))
}
other => Err(ArrowError::ComputeError(format!(
"length not supported for {other:?}"
))),
Expand Down Expand Up @@ -215,6 +242,8 @@ pub fn bit_length(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
mod tests {
use super::*;
use arrow_array::cast::AsArray;
use arrow_buffer::NullBuffer;
use arrow_schema::Field;

fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
[&v[..], &v[..]].concat()
Expand Down Expand Up @@ -696,4 +725,34 @@ mod tests {
assert_eq!(expected[i], actual[i],);
}
}

#[test]
fn test_fixed_size_list_length() {
// Construct a value array
let value_data = ArrayData::builder(DataType::Int32)
.len(9)
.add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8]))
.build()
.unwrap();
let list_data_type = DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Int32, false)),
3,
);
let nulls = NullBuffer::from(vec![true, false, true]);
let list_data = ArrayData::builder(list_data_type.clone())
.len(3)
.add_child_data(value_data.clone())
.nulls(Some(nulls))
.build()
.unwrap();
let list_array = FixedSizeListArray::from(list_data);

let lengths = length(&list_array).unwrap();
let lengths = lengths.as_any().downcast_ref::<Int32Array>().unwrap();

assert_eq!(lengths.len(), 3);
assert_eq!(lengths.value(0), 3);
assert_eq!(lengths.is_null(1), true);
assert_eq!(lengths.value(2), 3);
}
}

0 comments on commit f6acf47

Please sign in to comment.