Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length kernel support for List Array #1488

Merged
merged 4 commits into from
Mar 28, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 107 additions & 28 deletions arrow/src/compute/kernels/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,23 @@ macro_rules! unary_offsets {
}};
}

fn octet_length_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn length_list<O, T>(array: &dyn Array) -> ArrayRef
where
O: OffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: OffsetSizeTrait,
{
let array = array
.as_any()
.downcast_ref::<GenericListArray<O>>()
.unwrap();
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
O: BinaryOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
Expand All @@ -69,10 +82,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn octet_length<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn length_string<O, T>(array: &dyn Array) -> ArrayRef
where
O: StringOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
Expand All @@ -82,10 +95,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn bit_length_impl_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn bit_length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
O: BinaryOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
Expand All @@ -96,10 +109,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}

fn bit_length_impl<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn bit_length_string<O, T>(array: &dyn Array) -> ArrayRef
where
O: StringOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
Expand All @@ -110,20 +123,24 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}

/// Returns an array of Int32/Int64 denoting the number of bytes in each value in the array.
/// Returns an array of Int32/Int64 denoting the length of each value in the array.
/// For list array, length is the number of elements in each list.
/// For string array and binary array, length is the number of bytes of each value.
///
/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray and LargeBinaryArray
/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray
/// * length of null is null.
/// * length is in number of bytes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// * length of null is null.
/// * length is in number of bytes
/// * length of null is null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

pub fn length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Utf8 => Ok(octet_length::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(octet_length::<i64, Int64Type>(array)),
DataType::Binary => Ok(octet_length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(octet_length_binary::<i64, Int64Type>(array)),
_ => Err(ArrowError::ComputeError(format!(
DataType::List(_) => Ok(length_list::<i32, Int32Type>(array)),
DataType::LargeList(_) => Ok(length_list::<i64, Int64Type>(array)),
DataType::Utf8 => Ok(length_string::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(length_binary::<i64, Int64Type>(array)),
other => Err(ArrowError::ComputeError(format!(
"length not supported for {:?}",
array.data_type()
other
))),
}
}
Expand All @@ -135,19 +152,21 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef> {
/// * bit_length is in number of bits
pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Utf8 => Ok(bit_length_impl::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(bit_length_impl::<i64, Int64Type>(array)),
DataType::Binary => Ok(bit_length_impl_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(bit_length_impl_binary::<i64, Int64Type>(array)),
_ => Err(ArrowError::ComputeError(format!(
DataType::Utf8 => Ok(bit_length_string::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(bit_length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(bit_length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(bit_length_binary::<i64, Int64Type>(array)),
other => Err(ArrowError::ComputeError(format!(
"bit_length not supported for {:?}",
array.data_type()
other
))),
}
}

#[cfg(test)]
mod tests {
use crate::datatypes::{Float32Type, Int8Type};

use super::*;

fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
Expand Down Expand Up @@ -182,6 +201,20 @@ mod tests {
}};
}

macro_rules! length_list_helper {
($offset_ty: ty, $result_ty: ty, $element_ty: ty, $value: expr, $expected: expr) => {{
let array =
GenericListArray::<$offset_ty>::from_iter_primitive::<$element_ty, _, _>(
$value,
);
let result = length(&array)?;
let result = result.as_any().downcast_ref::<$result_ty>().unwrap();
let expected: $result_ty = $expected.into();
assert_eq!(expected.data(), result.data());
Ok(())
}};
}

#[test]
#[cfg_attr(miri, ignore)] // running forever
fn length_test_string() -> Result<()> {
Expand Down Expand Up @@ -230,6 +263,28 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}

#[test]
fn length_test_list() -> Result<()> {
let value = vec![
Some(vec![]),
Some(vec![Some(1), Some(2), Some(4)]),
Some(vec![Some(0)]),
];
let result: Vec<i32> = vec![0, 3, 1];
length_list_helper!(i32, Int32Array, Int32Type, value, result)
}

#[test]
fn length_test_large_list() -> Result<()> {
let value = vec![
Some(vec![]),
Some(vec![Some(1.1), Some(2.2), Some(3.3)]),
Some(vec![None]),
];
let result: Vec<i64> = vec![0, 3, 1];
length_list_helper!(i64, Int64Array, Float32Type, value, result)
}

type OptionStr = Option<&'static str>;

fn length_null_cases_string() -> Vec<(Vec<OptionStr>, usize, Vec<Option<i32>>)> {
Expand Down Expand Up @@ -293,6 +348,30 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}

#[test]
fn length_null_list() -> Result<()> {
let value = vec![
Some(vec![]),
None,
Some(vec![Some(1), None, Some(2), Some(4)]),
Some(vec![Some(0)]),
];
let result: Vec<Option<i32>> = vec![Some(0), None, Some(4), Some(1)];
length_list_helper!(i32, Int32Array, Int8Type, value, result)
}

#[test]
fn length_null_large_list() -> Result<()> {
let value = vec![
Some(vec![]),
None,
Some(vec![Some(1.1), None, Some(4.0)]),
Some(vec![Some(0.1)]),
];
let result: Vec<Option<i64>> = vec![Some(0), None, Some(3), Some(1)];
length_list_helper!(i64, Int64Array, Float32Type, value, result)
}

/// Tests that length is not valid for u64.
#[test]
fn length_wrong_type() {
Expand All @@ -303,7 +382,7 @@ mod tests {

/// Tests with an offset
#[test]
fn length_offsets() -> Result<()> {
fn length_offsets_string() -> Result<()> {
let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]);
let b = a.slice(1, 3);
let result = length(b.as_ref())?;
Expand All @@ -316,7 +395,7 @@ mod tests {
}

#[test]
fn binary_length_offsets() -> Result<()> {
fn length_offsets_binary() -> Result<()> {
let value: Vec<Option<&[u8]>> =
vec![Some(b"hello"), Some(b" "), Some(&[0xff, 0xf8]), None];
let a = BinaryArray::from(value);
Expand Down