Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add pl.concat_arr to concatenate columns into an Array column #19881

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions crates/polars-compute/src/horizontal_flatten/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use arrow::array::{
Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
};
use arrow::bitmap::Bitmap;
use arrow::datatypes::{ArrowDataType, PhysicalType};
use arrow::with_match_primitive_type_full;
use strength_reduce::StrengthReducedUsize;
mod struct_;

/// Low-level operation used by `concat_arr`. This should be called with the inner values array of
/// every FixedSizeList array.
///
/// # Safety
/// * `arrays` is non-empty
/// * `arrays` and `widths` have equal length
/// * All widths in `widths` are non-zero
/// * Every array `arrays[i]` has a length of either
/// * `widths[i] * output_height`
/// * `widths[i]` (this would be broadcasted)
/// * All arrays in `arrays` have the same type
pub unsafe fn horizontal_flatten_unchecked(
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added fn horizontal_flatten_unchecked under polars-compute, which performs the flatten operation that creates the values array of the result. This is called from fn concat_arr under polars-ops after it checks the safety conditions.

arrays: &[Box<dyn Array>],
widths: &[usize],
output_height: usize,
) -> Box<dyn Array> {
use PhysicalType::*;

let dtype = arrays[0].dtype();

match dtype.to_physical_type() {
Null => Box::new(NullArray::new(
dtype.clone(),
output_height * widths.iter().copied().sum::<usize>(),
)),
Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype
))
}),
LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<BinaryArray<i64>>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Struct => Box::new(struct_::horizontal_flatten_unchecked(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
)),
LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
}
}

unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
arrays: &[T],
widths: &[usize],
output_height: usize,
dtype: &ArrowDataType,
) -> T
where
T: StaticArray,
{
assert!(!arrays.is_empty());
assert_eq!(widths.len(), arrays.len());

debug_assert!(widths.iter().all(|x| *x > 0));
debug_assert!(arrays
.iter()
.zip(widths)
.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width));

// We modulo the array length to support broadcasting.
let lengths = arrays
.iter()
.map(|x| StrengthReducedUsize::new(x.len()))
.collect::<Vec<_>>();
let out_row_width: usize = widths.iter().cloned().sum();
let out_len = out_row_width.checked_mul(output_height).unwrap();

let mut col_idx = 0;
let mut row_idx = 0;
let mut until = widths[0];
let mut outer_row_idx = 0;

// We do `0..out_len` to get an `ExactSizeIterator`.
(0..out_len)
.map(|_| {
let arr = arrays.get_unchecked(col_idx);
let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));

row_idx += 1;

if row_idx == until {
// Safety: All widths are non-zero so we only need to increment once.
col_idx = if 1 + col_idx == widths.len() {
outer_row_idx += 1;
0
} else {
1 + col_idx
};
row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
}

out
})
.collect_arr_trusted_with_dtype(dtype.clone())
}
88 changes: 88 additions & 0 deletions crates/polars-compute/src/horizontal_flatten/struct_.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use super::*;

/// # Safety
/// All preconditions in [`super::horizontal_flatten_unchecked`]
pub(super) unsafe fn horizontal_flatten_unchecked(
arrays: &[StructArray],
widths: &[usize],
output_height: usize,
) -> StructArray {
// For StructArrays, we perform the flatten operation individually for every field in the struct
// as well as on the outer validity. We then construct the result array from the individual
// result parts.

let dtype = arrays[0].dtype();

let field_arrays: Vec<&[Box<dyn Array>]> = arrays
.iter()
.inspect(|x| debug_assert_eq!(x.dtype(), dtype))
.map(|x| x.values())
.collect::<Vec<_>>();

let n_fields = field_arrays[0].len();

let mut scratch = Vec::with_capacity(field_arrays.len());
// Safety: We can take by index as all struct arrays have the same columns names in the same
// order.
// Note: `field_arrays` can be empty for 0-field structs.
let field_arrays = (0..n_fields)
.map(|i| {
scratch.clear();
scratch.extend(field_arrays.iter().map(|v| v[i].clone()));

super::horizontal_flatten_unchecked(&scratch, widths, output_height)
})
.collect::<Vec<_>>();

let validity = if arrays.iter().any(|x| x.validity().is_some()) {
let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b));
let mut shared_validity = None;

// We need to create BooleanArrays from the Bitmaps for dispatch.
let validities: Vec<BooleanArray> = arrays
.iter()
.map(|x| {
x.validity().cloned().unwrap_or_else(|| {
if shared_validity.is_none() {
shared_validity = Some(Bitmap::new_with_value(true, max_height))
};
// We have to slice to exact length to pass an assertion.
shared_validity.clone().unwrap().sliced(0, x.len())
})
})
.map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None))
.collect::<Vec<_>>();

Some(
super::horizontal_flatten_unchecked_impl_generic::<BooleanArray>(
&validities,
widths,
output_height,
&ArrowDataType::Boolean,
)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values()
.clone(),
)
} else {
None
};

StructArray::new(
dtype.clone(),
if n_fields == 0 {
output_height * widths.iter().copied().sum::<usize>()
} else {
debug_assert_eq!(
field_arrays[0].len(),
output_height * widths.iter().copied().sum::<usize>()
);

field_arrays[0].len()
},
field_arrays,
validity,
)
}
1 change: 1 addition & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod cardinality;
pub mod comparisons;
pub mod filter;
pub mod float_sum;
pub mod horizontal_flatten;
#[cfg(feature = "approx_unique")]
pub mod hyperloglogplus;
pub mod if_then_else;
Expand Down
115 changes: 115 additions & 0 deletions crates/polars-ops/src/series/ops/concat_arr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use arrow::array::FixedSizeListArray;
use arrow::compute::utils::combine_validities_and;
use polars_compute::horizontal_flatten::horizontal_flatten_unchecked;
use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn};
use polars_core::series::Series;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::pl_str::PlSmallStr;

/// Note: The caller must ensure all columns in `args` have the same type.
///
/// # Panics
/// Panics if
/// * `args` is empty
/// * `dtype` is not a `DataType::Array`
pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {
let DataType::Array(inner_dtype, width) = dtype else {
panic!("{}", dtype);
};

let inner_dtype = inner_dtype.as_ref();
let width = *width;

let mut output_height = args[0].len();
let mut calculated_width = 0;
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height);
// If there is a `Array` column with a single NULL, the output will be entirely NULL.
let mut return_all_null = false;

let (arrays, widths): (Vec<_>, Vec<_>) = args
.iter()
.map(|c| {
// Handle broadcasting
if output_height == 1 {
output_height = c.len();
mismatch_height.1 = c.len();
}

if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height {
mismatch_height = (c.name(), c.len());
}

match c.dtype() {
DataType::Array(inner, width) => {
debug_assert_eq!(inner.as_ref(), inner_dtype);

let arr = c.array().unwrap().rechunk();

return_all_null |=
arr.len() == 1 && arr.rechunk_validity().map_or(false, |x| !x.get_bit(0));

(arr.rechunk().downcast_into_array().values().clone(), *width)
},
dtype => {
debug_assert_eq!(dtype, inner_dtype);
(
c.as_materialized_series().rechunk().into_chunks()[0].clone(),
1,
)
},
}
})
.filter(|x| x.1 > 0)
.inspect(|x| calculated_width += x.1)
.unzip();

assert_eq!(calculated_width, width);

if mismatch_height.1 != output_height {
polars_bail!(
ShapeMismatch:
"concat_arr: length of column '{}' (len={}) did not match length of \
first column '{}' (len={})",
mismatch_height.0, mismatch_height.1, args[0].name(), output_height,
)
}

if return_all_null {
let arr =
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height);
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column());
}

let outer_validity = args
.iter()
// Note: We ignore the validity of non-array input columns, their outer is always valid after
// being reshaped to (-1, 1).
.filter(|x| {
// Unit length validities at this point always contain a single valid, as we would have
// returned earlier otherwise with `return_all_null`, so we filter them out.
debug_assert!(x.len() == output_height || x.len() == 1);

x.dtype().is_array() && x.len() == output_height
})
.map(|x| x.as_materialized_series().rechunk_validity())
.fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref()));

let inner_arr = if output_height == 0 || width == 0 {
Series::new_empty(PlSmallStr::EMPTY, inner_dtype)
.into_chunks()
.into_iter()
.next()
.unwrap()
} else {
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }
};

let arr = FixedSizeListArray::new(
dtype.to_arrow(CompatLevel::newest()),
output_height,
inner_arr,
outer_validity,
);

Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column())
}
Loading