Skip to content

Commit

Permalink
support LargeList in array_remove (#8595)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H authored Dec 22, 2023
1 parent 39e9f41 commit ef34af8
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 18 deletions.
114 changes: 96 additions & 18 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ fn compare_element_to_list(
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
if list_array_row.data_type() != element_array.data_type() {
return exec_err!(
"compare_element_to_list received incompatible types: '{:?}' and '{:?}'.",
list_array_row.data_type(),
element_array.data_type()
);
}

let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;

Expand All @@ -126,6 +134,26 @@ fn compare_element_to_list(
})
.collect::<BooleanArray>()
}
DataType::LargeList(_) => {
// compare each element of the from array
let element_array_row_inner =
as_large_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_large_list_array(list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
_ => {
let element_arr = Scalar::new(element_array_row);
// use not_distinct so we can compare NULL
Expand Down Expand Up @@ -1511,14 +1539,14 @@ pub fn array_remove_n(args: &[ArrayRef]) -> Result<ArrayRef> {
/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced)
/// )
/// ```
fn general_replace(
list_array: &ListArray,
fn general_replace<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
from_array: &ArrayRef,
to_array: &ArrayRef,
arr_n: Vec<i64>,
) -> Result<ArrayRef> {
// Build up the offsets for the final output array
let mut offsets: Vec<i32> = vec![0];
let mut offsets: Vec<O> = vec![O::usize_as(0)];
let values = list_array.values();
let original_data = values.to_data();
let to_data = to_array.to_data();
Expand All @@ -1540,8 +1568,8 @@ fn general_replace(
continue;
}

let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let start = offset_window[0];
let end = offset_window[1];

let list_array_row = list_array.value(row_index);

Expand All @@ -1550,43 +1578,56 @@ fn general_replace(
let eq_array =
compare_element_to_list(&list_array_row, &from_array, row_index, true)?;

let original_idx = 0;
let replace_idx = 1;
let original_idx = O::usize_as(0);
let replace_idx = O::usize_as(1);
let n = arr_n[row_index];
let mut counter = 0;

// All elements are false, no need to replace, just copy original data
if eq_array.false_count() == eq_array.len() {
mutable.extend(original_idx, start, end);
offsets.push(offsets[row_index] + (end - start) as i32);
mutable.extend(
original_idx.to_usize().unwrap(),
start.to_usize().unwrap(),
end.to_usize().unwrap(),
);
offsets.push(offsets[row_index] + (end - start));
valid.append(true);
continue;
}

for (i, to_replace) in eq_array.iter().enumerate() {
let i = O::usize_as(i);
if let Some(true) = to_replace {
mutable.extend(replace_idx, row_index, row_index + 1);
mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
counter += 1;
if counter == n {
// copy original data for any matches past n
mutable.extend(original_idx, start + i + 1, end);
mutable.extend(
original_idx.to_usize().unwrap(),
(start + i).to_usize().unwrap() + 1,
end.to_usize().unwrap(),
);
break;
}
} else {
// copy original data for false / null matches
mutable.extend(original_idx, start + i, start + i + 1);
mutable.extend(
original_idx.to_usize().unwrap(),
(start + i).to_usize().unwrap(),
(start + i).to_usize().unwrap() + 1,
);
}
}

offsets.push(offsets[row_index] + (end - start) as i32);
offsets.push(offsets[row_index] + (end - start));
valid.append(true);
}

let data = mutable.freeze();

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", list_array.value_type(), true)),
OffsetBuffer::new(offsets.into()),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
Some(NullBuffer::new(valid.finish())),
)?))
Expand All @@ -1595,19 +1636,56 @@ fn general_replace(
pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace at most one occurence for each element
let arr_n = vec![1; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => exec_err!("array_replace does not support type '{array_type:?}'."),
}
}

pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace the specified number of occurences
let arr_n = as_int64_array(&args[3])?.values().to_vec();
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => {
exec_err!("array_replace_n does not support type '{array_type:?}'.")
}
}
}

pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
// replace all occurrences (up to "i64::MAX")
let arr_n = vec![i64::MAX; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
let array = &args[0];
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_replace::<i32>(list_array, &args[1], &args[2], arr_n)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_replace::<i64>(list_array, &args[1], &args[2], arr_n)
}
array_type => {
exec_err!("array_replace_all does not support type '{array_type:?}'.")
}
}
}

macro_rules! to_string {
Expand Down
Loading

0 comments on commit ef34af8

Please sign in to comment.