Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LargeList in array_prepend and array_append #8679

Merged
merged 6 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,11 @@ pub fn arrays_into_list_array(
/// assert_eq!(base_type(&data_type), DataType::Int32);
/// ```
pub fn base_type(data_type: &DataType) -> DataType {
if let DataType::List(field) = data_type {
base_type(field.data_type())
} else {
data_type.to_owned()
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
base_type(field.data_type())
}
_ => data_type.to_owned(),
}
}

Expand Down Expand Up @@ -462,17 +463,32 @@ pub fn coerced_type_with_base_type_only(
field.is_nullable(),
)))
}
DataType::LargeList(field) => {
let data_type = match field.data_type() {
DataType::LargeList(_) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
};

DataType::LargeList(Arc::new(Field::new(
field.name(),
data_type,
field.is_nullable(),
)))
}

_ => base_type.clone(),
}
}

/// Compute the number of dimensions in a list data type.
pub fn list_ndims(data_type: &DataType) -> u64 {
if let DataType::List(field) = data_type {
1 + list_ndims(field.data_type())
} else {
0
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
1 + list_ndims(field.data_type())
}
_ => 0,
}
}

Expand Down
24 changes: 12 additions & 12 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ fn get_valid_types(
&new_base_type,
);

if let DataType::List(ref field) = array_type {
let elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.to_owned()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
match array_type {
DataType::List(ref field) | DataType::LargeList(ref field) => {
let elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.to_owned()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
}
}
} else {
Ok(vec![vec![]])
_ => Ok(vec![vec![]]),
}
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
Expand Down Expand Up @@ -311,9 +311,9 @@ fn coerced_from<'a>(
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),

// Only accept list with the same number of dimensions unless the type is Null.
// List with different dimensions should be handled in TypeSignature or other places before this.
List(_)
// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
List(_) | LargeList(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Expand Down
144 changes: 71 additions & 73 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,6 @@ macro_rules! downcast_arg {
}};
}

/// Downcasts multiple arguments into a single concrete type
/// $ARGS: &[ArrayRef]
/// $ARRAY_TYPE: type to downcast to
///
/// $returns a Vec<$ARRAY_TYPE>
macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
$ARGS
.iter()
.map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
Some(array) => Ok(array),
_ => internal_err!("failed to downcast"),
})
}};
}

Comment on lines -55 to -70
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused function

/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
///
/// # Arguments
Expand Down Expand Up @@ -832,17 +816,20 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result<ArrayRef> {
///
/// # Examples
///
/// general_append_and_prepend(
/// generic_append_and_prepend(
/// [1, 2, 3], 4, append => [1, 2, 3, 4]
/// 5, [6, 7, 8], prepend => [5, 6, 7, 8]
/// )
fn general_append_and_prepend(
list_array: &ListArray,
fn generic_append_and_prepend<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
element_array: &ArrayRef,
data_type: &DataType,
is_append: bool,
) -> Result<ArrayRef> {
let mut offsets = vec![0];
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let mut offsets = vec![O::usize_as(0)];
let values = list_array.values();
let original_data = values.to_data();
let element_data = element_array.to_data();
Expand All @@ -858,21 +845,21 @@ fn general_append_and_prepend(
let element_index = 1;

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
if is_append {
mutable.extend(values_index, start, end);
mutable.extend(element_index, row_index, row_index + 1);
} else {
mutable.extend(element_index, row_index, row_index + 1);
mutable.extend(values_index, start, end);
}
offsets.push(offsets[row_index] + (end - start + 1) as i32);
offsets.push(offsets[row_index] + O::usize_as(end - start + 1));
}

let data = mutable.freeze();

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::new(offsets.into()),
arrow_array::make_array(data),
Expand Down Expand Up @@ -938,36 +925,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(arr)
}

/// Array_append SQL function
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}

let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
list_array,
element_array,
&data_type,
true,
);
}
};

Ok(res)
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
Expand Down Expand Up @@ -1051,32 +1008,71 @@ fn order_nulls_first(modifier: &str) -> Result<bool> {
}
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}

let list_array = as_list_array(&args[1])?;
let element_array = &args[0];
fn general_append_and_prepend<O: OffsetSizeTrait>(
args: &[ArrayRef],
is_append: bool,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let (list_array, element_array) = if is_append {
let list_array = as_generic_list_array::<O>(&args[0])?;
let element_array = &args[1];
check_datatypes("array_append", &[element_array, list_array.values()])?;
(list_array, element_array)
} else {
let list_array = as_generic_list_array::<O>(&args[1])?;
let element_array = &args[0];
check_datatypes("array_prepend", &[list_array.values(), element_array])?;
(list_array, element_array)
};

check_datatypes("array_prepend", &[element_array, list_array.values()])?;
let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => return make_array(&[element_array.to_owned()]),
DataType::List(_) => concat_internal::<i32>(args)?,
DataType::LargeList(_) => concat_internal::<i64>(args)?,
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
return generic_append_and_prepend::<O>(
list_array,
element_array,
&data_type,
false,
is_append,
);
}
};

Ok(res)
}

/// Array_append SQL function
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}

match args[0].data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, true),
_ => general_append_and_prepend::<i32>(args, true),
}
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}

match args[1].data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, false),
_ => general_append_and_prepend::<i32>(args, false),
}
}

fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
let args_ndim = args
.iter()
Expand Down Expand Up @@ -1114,11 +1110,13 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
}

// Concatenate arrays on the same row.
fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let args = align_array_dimensions(args.to_vec())?;

let list_arrays =
downcast_vec!(args, ListArray).collect::<Result<Vec<&ListArray>>>()?;
let list_arrays = args
.iter()
.map(|arg| as_generic_list_array::<O>(arg))
.collect::<Result<Vec<_>>>()?;

// Assume number of rows is the same for all arrays
let row_count = list_arrays[0].len();
Expand Down Expand Up @@ -1165,7 +1163,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
let list_arr = GenericListArray::<O>::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
Expand All @@ -1192,7 +1190,7 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

concat_internal(new_args.as_slice())
concat_internal::<i32>(new_args.as_slice())
}

/// Array_empty SQL function
Expand Down
Loading