-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add pl.concat_arr
to concatenate columns into an Array column
#19881
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
use arrow::array::{ | ||
Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, | ||
ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray, | ||
}; | ||
use arrow::bitmap::Bitmap; | ||
use arrow::datatypes::{ArrowDataType, PhysicalType}; | ||
use arrow::with_match_primitive_type_full; | ||
use strength_reduce::StrengthReducedUsize; | ||
mod struct_; | ||
|
||
/// Low-level operation used by `concat_arr`. This should be called with the inner values array of | ||
/// every FixedSizeList array. | ||
/// | ||
/// # Safety | ||
/// * `arrays` is non-empty | ||
/// * `arrays` and `widths` have equal length | ||
/// * All widths in `widths` are non-zero | ||
/// * Every array `arrays[i]` has a length of either | ||
/// * `widths[i] * output_height` | ||
/// * `widths[i]` (this would be broadcasted) | ||
/// * All arrays in `arrays` have the same type | ||
pub unsafe fn horizontal_flatten_unchecked( | ||
arrays: &[Box<dyn Array>], | ||
widths: &[usize], | ||
output_height: usize, | ||
) -> Box<dyn Array> { | ||
use PhysicalType::*; | ||
|
||
let dtype = arrays[0].dtype(); | ||
|
||
match dtype.to_physical_type() { | ||
Null => Box::new(NullArray::new( | ||
dtype.clone(), | ||
output_height * widths.iter().copied().sum::<usize>(), | ||
)), | ||
Boolean => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { | ||
Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype | ||
)) | ||
}), | ||
LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<BinaryArray<i64>>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Struct => Box::new(struct_::horizontal_flatten_unchecked( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
)), | ||
LargeList => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<FixedSizeListArray>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| { | ||
x.as_any() | ||
.downcast_ref::<BinaryViewArray>() | ||
.unwrap() | ||
.clone() | ||
}) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic( | ||
&arrays | ||
.iter() | ||
.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone()) | ||
.collect::<Vec<_>>(), | ||
widths, | ||
output_height, | ||
dtype, | ||
)), | ||
t => unimplemented!("horizontal_flatten not supported for data type {:?}", t), | ||
} | ||
} | ||
|
||
unsafe fn horizontal_flatten_unchecked_impl_generic<T>( | ||
arrays: &[T], | ||
widths: &[usize], | ||
output_height: usize, | ||
dtype: &ArrowDataType, | ||
) -> T | ||
where | ||
T: StaticArray, | ||
{ | ||
assert!(!arrays.is_empty()); | ||
assert_eq!(widths.len(), arrays.len()); | ||
|
||
debug_assert!(widths.iter().all(|x| *x > 0)); | ||
debug_assert!(arrays | ||
.iter() | ||
.zip(widths) | ||
.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)); | ||
|
||
// We modulo the array length to support broadcasting. | ||
let lengths = arrays | ||
.iter() | ||
.map(|x| StrengthReducedUsize::new(x.len())) | ||
.collect::<Vec<_>>(); | ||
let out_row_width: usize = widths.iter().cloned().sum(); | ||
let out_len = out_row_width.checked_mul(output_height).unwrap(); | ||
|
||
let mut col_idx = 0; | ||
let mut row_idx = 0; | ||
let mut until = widths[0]; | ||
let mut outer_row_idx = 0; | ||
|
||
// We do `0..out_len` to get an `ExactSizeIterator`. | ||
(0..out_len) | ||
.map(|_| { | ||
let arr = arrays.get_unchecked(col_idx); | ||
let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx)); | ||
|
||
row_idx += 1; | ||
|
||
if row_idx == until { | ||
// Safety: All widths are non-zero so we only need to increment once. | ||
col_idx = if 1 + col_idx == widths.len() { | ||
outer_row_idx += 1; | ||
0 | ||
} else { | ||
1 + col_idx | ||
}; | ||
row_idx = outer_row_idx * *widths.get_unchecked(col_idx); | ||
until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx) | ||
} | ||
|
||
out | ||
}) | ||
.collect_arr_trusted_with_dtype(dtype.clone()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use super::*; | ||
|
||
/// # Safety | ||
/// All preconditions in [`super::horizontal_flatten_unchecked`] | ||
pub(super) unsafe fn horizontal_flatten_unchecked( | ||
arrays: &[StructArray], | ||
widths: &[usize], | ||
output_height: usize, | ||
) -> StructArray { | ||
// For StructArrays, we perform the flatten operation individually for every field in the struct | ||
// as well as on the outer validity. We then construct the result array from the individual | ||
// result parts. | ||
|
||
let dtype = arrays[0].dtype(); | ||
|
||
let field_arrays: Vec<&[Box<dyn Array>]> = arrays | ||
.iter() | ||
.inspect(|x| debug_assert_eq!(x.dtype(), dtype)) | ||
.map(|x| x.values()) | ||
.collect::<Vec<_>>(); | ||
|
||
let n_fields = field_arrays[0].len(); | ||
|
||
let mut scratch = Vec::with_capacity(field_arrays.len()); | ||
// Safety: We can take by index as all struct arrays have the same columns names in the same | ||
// order. | ||
// Note: `field_arrays` can be empty for 0-field structs. | ||
let field_arrays = (0..n_fields) | ||
.map(|i| { | ||
scratch.clear(); | ||
scratch.extend(field_arrays.iter().map(|v| v[i].clone())); | ||
|
||
super::horizontal_flatten_unchecked(&scratch, widths, output_height) | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
let validity = if arrays.iter().any(|x| x.validity().is_some()) { | ||
let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b)); | ||
let mut shared_validity = None; | ||
|
||
// We need to create BooleanArrays from the Bitmaps for dispatch. | ||
let validities: Vec<BooleanArray> = arrays | ||
.iter() | ||
.map(|x| { | ||
x.validity().cloned().unwrap_or_else(|| { | ||
if shared_validity.is_none() { | ||
shared_validity = Some(Bitmap::new_with_value(true, max_height)) | ||
}; | ||
// We have to slice to exact length to pass an assertion. | ||
shared_validity.clone().unwrap().sliced(0, x.len()) | ||
}) | ||
}) | ||
.map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None)) | ||
.collect::<Vec<_>>(); | ||
|
||
Some( | ||
super::horizontal_flatten_unchecked_impl_generic::<BooleanArray>( | ||
&validities, | ||
widths, | ||
output_height, | ||
&ArrowDataType::Boolean, | ||
) | ||
.as_any() | ||
.downcast_ref::<BooleanArray>() | ||
.unwrap() | ||
.values() | ||
.clone(), | ||
) | ||
} else { | ||
None | ||
}; | ||
|
||
StructArray::new( | ||
dtype.clone(), | ||
if n_fields == 0 { | ||
output_height * widths.iter().copied().sum::<usize>() | ||
} else { | ||
debug_assert_eq!( | ||
field_arrays[0].len(), | ||
output_height * widths.iter().copied().sum::<usize>() | ||
); | ||
|
||
field_arrays[0].len() | ||
}, | ||
field_arrays, | ||
validity, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
use arrow::array::FixedSizeListArray; | ||
use arrow::compute::utils::combine_validities_and; | ||
use polars_compute::horizontal_flatten::horizontal_flatten_unchecked; | ||
use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn}; | ||
use polars_core::series::Series; | ||
use polars_error::{polars_bail, PolarsResult}; | ||
use polars_utils::pl_str::PlSmallStr; | ||
|
||
/// Note: The caller must ensure all columns in `args` have the same type. | ||
/// | ||
/// # Panics | ||
/// Panics if | ||
/// * `args` is empty | ||
/// * `dtype` is not a `DataType::Array` | ||
pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> { | ||
let DataType::Array(inner_dtype, width) = dtype else { | ||
panic!("{}", dtype); | ||
}; | ||
|
||
let inner_dtype = inner_dtype.as_ref(); | ||
let width = *width; | ||
|
||
let mut output_height = args[0].len(); | ||
let mut calculated_width = 0; | ||
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height); | ||
// If there is a `Array` column with a single NULL, the output will be entirely NULL. | ||
let mut return_all_null = false; | ||
|
||
let (arrays, widths): (Vec<_>, Vec<_>) = args | ||
.iter() | ||
.map(|c| { | ||
// Handle broadcasting | ||
if output_height == 1 { | ||
output_height = c.len(); | ||
mismatch_height.1 = c.len(); | ||
} | ||
|
||
if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height { | ||
mismatch_height = (c.name(), c.len()); | ||
} | ||
|
||
match c.dtype() { | ||
DataType::Array(inner, width) => { | ||
debug_assert_eq!(inner.as_ref(), inner_dtype); | ||
|
||
let arr = c.array().unwrap().rechunk(); | ||
|
||
return_all_null |= | ||
arr.len() == 1 && arr.rechunk_validity().map_or(false, |x| !x.get_bit(0)); | ||
|
||
(arr.rechunk().downcast_into_array().values().clone(), *width) | ||
}, | ||
dtype => { | ||
debug_assert_eq!(dtype, inner_dtype); | ||
( | ||
c.as_materialized_series().rechunk().into_chunks()[0].clone(), | ||
1, | ||
) | ||
}, | ||
} | ||
}) | ||
.filter(|x| x.1 > 0) | ||
.inspect(|x| calculated_width += x.1) | ||
.unzip(); | ||
|
||
assert_eq!(calculated_width, width); | ||
|
||
if mismatch_height.1 != output_height { | ||
polars_bail!( | ||
ShapeMismatch: | ||
"concat_arr: length of column '{}' (len={}) did not match length of \ | ||
first column '{}' (len={})", | ||
mismatch_height.0, mismatch_height.1, args[0].name(), output_height, | ||
) | ||
} | ||
|
||
if return_all_null { | ||
let arr = | ||
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height); | ||
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()); | ||
} | ||
|
||
let outer_validity = args | ||
.iter() | ||
// Note: We ignore the validity of non-array input columns, their outer is always valid after | ||
// being reshaped to (-1, 1). | ||
.filter(|x| { | ||
// Unit length validities at this point always contain a single valid, as we would have | ||
// returned earlier otherwise with `return_all_null`, so we filter them out. | ||
debug_assert!(x.len() == output_height || x.len() == 1); | ||
|
||
x.dtype().is_array() && x.len() == output_height | ||
}) | ||
.map(|x| x.as_materialized_series().rechunk_validity()) | ||
.fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref())); | ||
|
||
let inner_arr = if output_height == 0 || width == 0 { | ||
Series::new_empty(PlSmallStr::EMPTY, inner_dtype) | ||
.into_chunks() | ||
.into_iter() | ||
.next() | ||
.unwrap() | ||
} else { | ||
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } | ||
}; | ||
|
||
let arr = FixedSizeListArray::new( | ||
dtype.to_arrow(CompatLevel::newest()), | ||
output_height, | ||
inner_arr, | ||
outer_validity, | ||
); | ||
|
||
Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()) | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
fn horizontal_flatten_unchecked
underpolars-compute
, which performs the flatten operation that creates the values array of the result. This is called fromfn concat_arr
underpolars-ops
after it checks the safety conditions.