-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add
pl.concat_arr
to concatenate columns into an Array column (…
- Loading branch information
1 parent
426cf27
commit cf6f375
Showing
15 changed files
with
769 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
use arrow::array::{ | ||
Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, | ||
ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray, | ||
}; | ||
use arrow::bitmap::Bitmap; | ||
use arrow::datatypes::{ArrowDataType, PhysicalType}; | ||
use arrow::with_match_primitive_type_full; | ||
use strength_reduce::StrengthReducedUsize; | ||
mod struct_; | ||
|
||
/// Low-level operation used by `concat_arr`. This should be called with the inner values array of | ||
/// every FixedSizeList array. | ||
/// | ||
/// # Safety | ||
/// * `arrays` is non-empty | ||
/// * `arrays` and `widths` have equal length | ||
/// * All widths in `widths` are non-zero | ||
/// * Every array `arrays[i]` has a length of either | ||
/// * `widths[i] * output_height` | ||
/// * `widths[i]` (this would be broadcasted) | ||
/// * All arrays in `arrays` have the same type | ||
pub unsafe fn horizontal_flatten_unchecked( | ||
arrays: &[Box<dyn Array>], | ||
widths: &[usize], | ||
output_height: usize, | ||
) -> Box<dyn Array> { | ||
use PhysicalType::*; | ||
|
||
let dtype = arrays[0].dtype(); | ||
|
||
match dtype.to_physical_type() { | ||
Null => Box::new(NullArray::new( | ||
dtype.clone(), | ||
output_height * widths.iter().copied().sum::<usize>(), | ||
)), | ||
Boolean => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { | ||
Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype | ||
)) | ||
}), | ||
LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<BinaryArray<i64>>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Struct => Box::new(struct_::horizontal_flatten_unchecked( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
)), | ||
LargeList => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<FixedSizeListArray>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<BinaryViewArray>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
t => unimplemented!("horizontal_flatten not supported for data type {:?}", t), | ||
} | ||
} | ||
|
||
unsafe fn horizontal_flatten_unchecked_impl_generic<T>( | ||
arrays: &[T], | ||
widths: &[usize], | ||
output_height: usize, | ||
dtype: &ArrowDataType, | ||
) -> T | ||
where | ||
T: StaticArray, | ||
{ | ||
assert!(!arrays.is_empty()); | ||
assert_eq!(widths.len(), arrays.len()); | ||
|
||
debug_assert!(widths.iter().all(|x| *x > 0)); | ||
debug_assert!(arrays | ||
.iter() | ||
.zip(widths) | ||
.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)); | ||
|
||
// We modulo the array length to support broadcasting. | ||
let lengths = arrays | ||
.iter() | ||
.map(|x| StrengthReducedUsize::new(x.len())) | ||
.collect::<Vec<_>>(); | ||
let out_row_width: usize = widths.iter().cloned().sum(); | ||
let out_len = out_row_width.checked_mul(output_height).unwrap(); | ||
|
||
let mut col_idx = 0; | ||
let mut row_idx = 0; | ||
let mut until = widths[0]; | ||
let mut outer_row_idx = 0; | ||
|
||
// We do `0..out_len` to get an `ExactSizeIterator`. | ||
(0..out_len) | ||
.map(|_| { | ||
let arr = arrays.get_unchecked(col_idx); | ||
let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx)); | ||
|
||
row_idx += 1; | ||
|
||
if row_idx == until { | ||
// Safety: All widths are non-zero so we only need to increment once. | ||
col_idx = if 1 + col_idx == widths.len() { | ||
outer_row_idx += 1; | ||
0 | ||
} else { | ||
1 + col_idx | ||
}; | ||
row_idx = outer_row_idx * *widths.get_unchecked(col_idx); | ||
until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx) | ||
} | ||
|
||
out | ||
}) | ||
.collect_arr_trusted_with_dtype(dtype.clone()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use super::*; | ||
|
||
/// # Safety | ||
/// All preconditions in [`super::horizontal_flatten_unchecked`] | ||
pub(super) unsafe fn horizontal_flatten_unchecked( | ||
arrays: &[StructArray], | ||
widths: &[usize], | ||
output_height: usize, | ||
) -> StructArray { | ||
// For StructArrays, we perform the flatten operation individually for every field in the struct | ||
// as well as on the outer validity. We then construct the result array from the individual | ||
// result parts. | ||
|
||
let dtype = arrays[0].dtype(); | ||
|
||
let field_arrays: Vec<&[Box<dyn Array>]> = arrays | ||
.iter() | ||
.inspect(|x| debug_assert_eq!(x.dtype(), dtype)) | ||
.map(|x| x.values()) | ||
.collect::<Vec<_>>(); | ||
|
||
let n_fields = field_arrays[0].len(); | ||
|
||
let mut scratch = Vec::with_capacity(field_arrays.len()); | ||
// Safety: We can take by index as all struct arrays have the same columns names in the same | ||
// order. | ||
// Note: `field_arrays` can be empty for 0-field structs. | ||
let field_arrays = (0..n_fields) | ||
.map(|i| { | ||
scratch.clear(); | ||
scratch.extend(field_arrays.iter().map(|v| v[i].clone())); | ||
|
||
super::horizontal_flatten_unchecked(&scratch, widths, output_height) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
let validity = if arrays.iter().any(|x| x.validity().is_some()) { | ||
let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b)); | ||
let mut shared_validity = None; | ||
|
||
// We need to create BooleanArrays from the Bitmaps for dispatch. | ||
let validities: Vec<BooleanArray> = arrays | ||
.iter() | ||
.map(|x| { | ||
x.validity().cloned().unwrap_or_else(|| { | ||
if shared_validity.is_none() { | ||
shared_validity = Some(Bitmap::new_with_value(true, max_height)) | ||
}; | ||
// We have to slice to exact length to pass an assertion. | ||
shared_validity.clone().unwrap().sliced(0, x.len()) | ||
}) | ||
}) | ||
.map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None)) | ||
.collect::<Vec<_>>(); | ||
|
||
Some( | ||
super::horizontal_flatten_unchecked_impl_generic::<BooleanArray>( | ||
&validities, | ||
widths, | ||
output_height, | ||
&ArrowDataType::Boolean, | ||
) | ||
.as_any() | ||
.downcast_ref::<BooleanArray>() | ||
.unwrap() | ||
.values() | ||
.clone(), | ||
) | ||
} else { | ||
None | ||
}; | ||
|
||
StructArray::new( | ||
dtype.clone(), | ||
if n_fields == 0 { | ||
output_height * widths.iter().copied().sum::<usize>() | ||
} else { | ||
debug_assert_eq!( | ||
field_arrays[0].len(), | ||
output_height * widths.iter().copied().sum::<usize>() | ||
); | ||
|
||
field_arrays[0].len() | ||
}, | ||
field_arrays, | ||
validity, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
use arrow::array::FixedSizeListArray; | ||
use arrow::compute::utils::combine_validities_and; | ||
use polars_compute::horizontal_flatten::horizontal_flatten_unchecked; | ||
use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn}; | ||
use polars_core::series::Series; | ||
use polars_error::{polars_bail, PolarsResult}; | ||
use polars_utils::pl_str::PlSmallStr; | ||
|
||
/// Note: The caller must ensure all columns in `args` have the same type. | ||
/// | ||
/// # Panics | ||
/// Panics if | ||
/// * `args` is empty | ||
/// * `dtype` is not a `DataType::Array` | ||
pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> { | ||
let DataType::Array(inner_dtype, width) = dtype else { | ||
panic!("{}", dtype); | ||
}; | ||
|
||
let inner_dtype = inner_dtype.as_ref(); | ||
let width = *width; | ||
|
||
let mut output_height = args[0].len(); | ||
let mut calculated_width = 0; | ||
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height); | ||
// If there is a `Array` column with a single NULL, the output will be entirely NULL. | ||
let mut return_all_null = false; | ||
|
||
let (arrays, widths): (Vec<_>, Vec<_>) = args | ||
.iter() | ||
.map(|c| { | ||
// Handle broadcasting | ||
if output_height == 1 { | ||
output_height = c.len(); | ||
mismatch_height.1 = c.len(); | ||
} | ||
|
||
if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height { | ||
mismatch_height = (c.name(), c.len()); | ||
} | ||
|
||
match c.dtype() { | ||
DataType::Array(inner, width) => { | ||
debug_assert_eq!(inner.as_ref(), inner_dtype); | ||
|
||
let arr = c.array().unwrap().rechunk(); | ||
|
||
return_all_null |= | ||
arr.len() == 1 && arr.rechunk_validity().map_or(false, |x| !x.get_bit(0)); | ||
|
||
(arr.rechunk().downcast_into_array().values().clone(), *width) | ||
}, | ||
dtype => { | ||
debug_assert_eq!(dtype, inner_dtype); | ||
( | ||
c.as_materialized_series().rechunk().into_chunks()[0].clone(), | ||
1, | ||
) | ||
}, | ||
} | ||
}) | ||
.filter(|x| x.1 > 0) | ||
.inspect(|x| calculated_width += x.1) | ||
.unzip(); | ||
|
||
assert_eq!(calculated_width, width); | ||
|
||
if mismatch_height.1 != output_height { | ||
polars_bail!( | ||
ShapeMismatch: | ||
"concat_arr: length of column '{}' (len={}) did not match length of \ | ||
first column '{}' (len={})", | ||
mismatch_height.0, mismatch_height.1, args[0].name(), output_height, | ||
) | ||
} | ||
|
||
if return_all_null { | ||
let arr = | ||
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height); | ||
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()); | ||
} | ||
|
||
let outer_validity = args | ||
.iter() | ||
// Note: We ignore the validity of non-array input columns, their outer is always valid after | ||
// being reshaped to (-1, 1). | ||
.filter(|x| { | ||
// Unit length validities at this point always contain a single valid, as we would have | ||
// returned earlier otherwise with `return_all_null`, so we filter them out. | ||
debug_assert!(x.len() == output_height || x.len() == 1); | ||
|
||
x.dtype().is_array() && x.len() == output_height | ||
}) | ||
.map(|x| x.as_materialized_series().rechunk_validity()) | ||
.fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref())); | ||
|
||
let inner_arr = if output_height == 0 || width == 0 { | ||
Series::new_empty(PlSmallStr::EMPTY, inner_dtype) | ||
.into_chunks() | ||
.into_iter() | ||
.next() | ||
.unwrap() | ||
} else { | ||
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } | ||
}; | ||
|
||
let arr = FixedSizeListArray::new( | ||
dtype.to_arrow(CompatLevel::newest()), | ||
output_height, | ||
inner_arr, | ||
outer_validity, | ||
); | ||
|
||
Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()) | ||
} |
Oops, something went wrong.