Skip to content

Commit

Permalink
feat: Add pl.concat_arr to concatenate columns into an Array column (
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Nov 20, 2024
1 parent 426cf27 commit cf6f375
Show file tree
Hide file tree
Showing 15 changed files with 769 additions and 6 deletions.
182 changes: 182 additions & 0 deletions crates/polars-compute/src/horizontal_flatten/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use arrow::array::{
Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
};
use arrow::bitmap::Bitmap;
use arrow::datatypes::{ArrowDataType, PhysicalType};
use arrow::with_match_primitive_type_full;
use strength_reduce::StrengthReducedUsize;
mod struct_;

/// Low-level operation used by `concat_arr`. This should be called with the inner values array of
/// every FixedSizeList array.
///
/// # Safety
/// * `arrays` is non-empty
/// * `arrays` and `widths` have equal length
/// * All widths in `widths` are non-zero
/// * Every array `arrays[i]` has a length of either
/// * `widths[i] * output_height`
/// * `widths[i]` (this would be broadcasted)
/// * All arrays in `arrays` have the same type
pub unsafe fn horizontal_flatten_unchecked(
arrays: &[Box<dyn Array>],
widths: &[usize],
output_height: usize,
) -> Box<dyn Array> {
use PhysicalType::*;

let dtype = arrays[0].dtype();

match dtype.to_physical_type() {
Null => Box::new(NullArray::new(
dtype.clone(),
output_height * widths.iter().copied().sum::<usize>(),
)),
Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype
))
}),
LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<BinaryArray<i64>>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Struct => Box::new(struct_::horizontal_flatten_unchecked(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
)),
LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| {
x.as_any()
.downcast_ref::<BinaryViewArray>()
.unwrap()
.clone()
})
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
&arrays
.iter()
.map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
.collect::<Vec<_>>(),
widths,
output_height,
dtype,
)),
t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
}
}

unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
arrays: &[T],
widths: &[usize],
output_height: usize,
dtype: &ArrowDataType,
) -> T
where
T: StaticArray,
{
assert!(!arrays.is_empty());
assert_eq!(widths.len(), arrays.len());

debug_assert!(widths.iter().all(|x| *x > 0));
debug_assert!(arrays
.iter()
.zip(widths)
.all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width));

// We modulo the array length to support broadcasting.
let lengths = arrays
.iter()
.map(|x| StrengthReducedUsize::new(x.len()))
.collect::<Vec<_>>();
let out_row_width: usize = widths.iter().cloned().sum();
let out_len = out_row_width.checked_mul(output_height).unwrap();

let mut col_idx = 0;
let mut row_idx = 0;
let mut until = widths[0];
let mut outer_row_idx = 0;

// We do `0..out_len` to get an `ExactSizeIterator`.
(0..out_len)
.map(|_| {
let arr = arrays.get_unchecked(col_idx);
let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));

row_idx += 1;

if row_idx == until {
// Safety: All widths are non-zero so we only need to increment once.
col_idx = if 1 + col_idx == widths.len() {
outer_row_idx += 1;
0
} else {
1 + col_idx
};
row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
}

out
})
.collect_arr_trusted_with_dtype(dtype.clone())
}
88 changes: 88 additions & 0 deletions crates/polars-compute/src/horizontal_flatten/struct_.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use super::*;

/// # Safety
/// All preconditions in [`super::horizontal_flatten_unchecked`]
pub(super) unsafe fn horizontal_flatten_unchecked(
arrays: &[StructArray],
widths: &[usize],
output_height: usize,
) -> StructArray {
// For StructArrays, we perform the flatten operation individually for every field in the struct
// as well as on the outer validity. We then construct the result array from the individual
// result parts.

let dtype = arrays[0].dtype();

let field_arrays: Vec<&[Box<dyn Array>]> = arrays
.iter()
.inspect(|x| debug_assert_eq!(x.dtype(), dtype))
.map(|x| x.values())
.collect::<Vec<_>>();

let n_fields = field_arrays[0].len();

let mut scratch = Vec::with_capacity(field_arrays.len());
// Safety: We can take by index as all struct arrays have the same columns names in the same
// order.
// Note: `field_arrays` can be empty for 0-field structs.
let field_arrays = (0..n_fields)
.map(|i| {
scratch.clear();
scratch.extend(field_arrays.iter().map(|v| v[i].clone()));

super::horizontal_flatten_unchecked(&scratch, widths, output_height)
})
.collect::<Vec<_>>();

let validity = if arrays.iter().any(|x| x.validity().is_some()) {
let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b));
let mut shared_validity = None;

// We need to create BooleanArrays from the Bitmaps for dispatch.
let validities: Vec<BooleanArray> = arrays
.iter()
.map(|x| {
x.validity().cloned().unwrap_or_else(|| {
if shared_validity.is_none() {
shared_validity = Some(Bitmap::new_with_value(true, max_height))
};
// We have to slice to exact length to pass an assertion.
shared_validity.clone().unwrap().sliced(0, x.len())
})
})
.map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None))
.collect::<Vec<_>>();

Some(
super::horizontal_flatten_unchecked_impl_generic::<BooleanArray>(
&validities,
widths,
output_height,
&ArrowDataType::Boolean,
)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values()
.clone(),
)
} else {
None
};

StructArray::new(
dtype.clone(),
if n_fields == 0 {
output_height * widths.iter().copied().sum::<usize>()
} else {
debug_assert_eq!(
field_arrays[0].len(),
output_height * widths.iter().copied().sum::<usize>()
);

field_arrays[0].len()
},
field_arrays,
validity,
)
}
1 change: 1 addition & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod cardinality;
pub mod comparisons;
pub mod filter;
pub mod float_sum;
pub mod horizontal_flatten;
#[cfg(feature = "approx_unique")]
pub mod hyperloglogplus;
pub mod if_then_else;
Expand Down
115 changes: 115 additions & 0 deletions crates/polars-ops/src/series/ops/concat_arr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use arrow::array::FixedSizeListArray;
use arrow::compute::utils::combine_validities_and;
use polars_compute::horizontal_flatten::horizontal_flatten_unchecked;
use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn};
use polars_core::series::Series;
use polars_error::{polars_bail, PolarsResult};
use polars_utils::pl_str::PlSmallStr;

/// Note: The caller must ensure all columns in `args` have the same type.
///
/// # Panics
/// Panics if
/// * `args` is empty
/// * `dtype` is not a `DataType::Array`
pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {
let DataType::Array(inner_dtype, width) = dtype else {
panic!("{}", dtype);
};

let inner_dtype = inner_dtype.as_ref();
let width = *width;

let mut output_height = args[0].len();
let mut calculated_width = 0;
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height);
// If there is a `Array` column with a single NULL, the output will be entirely NULL.
let mut return_all_null = false;

let (arrays, widths): (Vec<_>, Vec<_>) = args
.iter()
.map(|c| {
// Handle broadcasting
if output_height == 1 {
output_height = c.len();
mismatch_height.1 = c.len();
}

if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height {
mismatch_height = (c.name(), c.len());
}

match c.dtype() {
DataType::Array(inner, width) => {
debug_assert_eq!(inner.as_ref(), inner_dtype);

let arr = c.array().unwrap().rechunk();

return_all_null |=
arr.len() == 1 && arr.rechunk_validity().map_or(false, |x| !x.get_bit(0));

(arr.rechunk().downcast_into_array().values().clone(), *width)
},
dtype => {
debug_assert_eq!(dtype, inner_dtype);
(
c.as_materialized_series().rechunk().into_chunks()[0].clone(),
1,
)
},
}
})
.filter(|x| x.1 > 0)
.inspect(|x| calculated_width += x.1)
.unzip();

assert_eq!(calculated_width, width);

if mismatch_height.1 != output_height {
polars_bail!(
ShapeMismatch:
"concat_arr: length of column '{}' (len={}) did not match length of \
first column '{}' (len={})",
mismatch_height.0, mismatch_height.1, args[0].name(), output_height,
)
}

if return_all_null {
let arr =
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height);
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column());
}

let outer_validity = args
.iter()
// Note: We ignore the validity of non-array input columns, their outer is always valid after
// being reshaped to (-1, 1).
.filter(|x| {
// Unit length validities at this point always contain a single valid, as we would have
// returned earlier otherwise with `return_all_null`, so we filter them out.
debug_assert!(x.len() == output_height || x.len() == 1);

x.dtype().is_array() && x.len() == output_height
})
.map(|x| x.as_materialized_series().rechunk_validity())
.fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref()));

let inner_arr = if output_height == 0 || width == 0 {
Series::new_empty(PlSmallStr::EMPTY, inner_dtype)
.into_chunks()
.into_iter()
.next()
.unwrap()
} else {
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }
};

let arr = FixedSizeListArray::new(
dtype.to_arrow(CompatLevel::newest()),
output_height,
inner_arr,
outer_validity,
);

Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column())
}
Loading

0 comments on commit cf6f375

Please sign in to comment.