diff --git a/crates/polars-compute/src/horizontal_flatten/mod.rs b/crates/polars-compute/src/horizontal_flatten/mod.rs new file mode 100644 index 000000000000..228e6bd81e4c --- /dev/null +++ b/crates/polars-compute/src/horizontal_flatten/mod.rs @@ -0,0 +1,182 @@ +use arrow::array::{ + Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, + ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, PhysicalType}; +use arrow::with_match_primitive_type_full; +use strength_reduce::StrengthReducedUsize; +mod struct_; + +/// Low-level operation used by `concat_arr`. This should be called with the inner values array of +/// every FixedSizeList array. +/// +/// # Safety +/// * `arrays` is non-empty +/// * `arrays` and `widths` have equal length +/// * All widths in `widths` are non-zero +/// * Every array `arrays[i]` has a length of either +/// * `widths[i] * output_height` +/// * `widths[i]` (this would be broadcasted) +/// * All arrays in `arrays` have the same type +pub unsafe fn horizontal_flatten_unchecked( + arrays: &[Box], + widths: &[usize], + output_height: usize, +) -> Box { + use PhysicalType::*; + + let dtype = arrays[0].dtype(); + + match dtype.to_physical_type() { + Null => Box::new(NullArray::new( + dtype.clone(), + output_height * widths.iter().copied().sum::(), + )), + Boolean => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { + Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::>().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype + )) + }), + LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::>() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + Struct => Box::new(struct_::horizontal_flatten_unchecked( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + )), + LargeList => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::>().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| { + x.as_any() + .downcast_ref::() + .unwrap() + .clone() + }) + .collect::>(), + widths, + output_height, + dtype, + )), + Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic( + &arrays + .iter() + .map(|x| x.as_any().downcast_ref::().unwrap().clone()) + .collect::>(), + widths, + output_height, + dtype, + )), + t => unimplemented!("horizontal_flatten not supported for data type {:?}", t), + } +} + +unsafe fn horizontal_flatten_unchecked_impl_generic( + arrays: &[T], + widths: &[usize], + output_height: usize, + dtype: &ArrowDataType, +) -> T +where + T: StaticArray, +{ + assert!(!arrays.is_empty()); + assert_eq!(widths.len(), arrays.len()); + + debug_assert!(widths.iter().all(|x| *x > 0)); + debug_assert!(arrays + .iter() + .zip(widths) + .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)); + + // We modulo the array length to support broadcasting. + let lengths = arrays + .iter() + .map(|x| StrengthReducedUsize::new(x.len())) + .collect::>(); + let out_row_width: usize = widths.iter().cloned().sum(); + let out_len = out_row_width.checked_mul(output_height).unwrap(); + + let mut col_idx = 0; + let mut row_idx = 0; + let mut until = widths[0]; + let mut outer_row_idx = 0; + + // We do `0..out_len` to get an `ExactSizeIterator`. + (0..out_len) + .map(|_| { + let arr = arrays.get_unchecked(col_idx); + let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx)); + + row_idx += 1; + + if row_idx == until { + // Safety: All widths are non-zero so we only need to increment once. + col_idx = if 1 + col_idx == widths.len() { + outer_row_idx += 1; + 0 + } else { + 1 + col_idx + }; + row_idx = outer_row_idx * *widths.get_unchecked(col_idx); + until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx) + } + + out + }) + .collect_arr_trusted_with_dtype(dtype.clone()) +} diff --git a/crates/polars-compute/src/horizontal_flatten/struct_.rs b/crates/polars-compute/src/horizontal_flatten/struct_.rs new file mode 100644 index 000000000000..6dc12d93a1eb --- /dev/null +++ b/crates/polars-compute/src/horizontal_flatten/struct_.rs @@ -0,0 +1,88 @@ +use super::*; + +/// # Safety +/// All preconditions in [`super::horizontal_flatten_unchecked`] +pub(super) unsafe fn horizontal_flatten_unchecked( + arrays: &[StructArray], + widths: &[usize], + output_height: usize, +) -> StructArray { + // For StructArrays, we perform the flatten operation individually for every field in the struct + // as well as on the outer validity. We then construct the result array from the individual + // result parts. + + let dtype = arrays[0].dtype(); + + let field_arrays: Vec<&[Box]> = arrays + .iter() + .inspect(|x| debug_assert_eq!(x.dtype(), dtype)) + .map(|x| x.values()) + .collect::>(); + + let n_fields = field_arrays[0].len(); + + let mut scratch = Vec::with_capacity(field_arrays.len()); + // Safety: We can take by index as all struct arrays have the same columns names in the same + // order. + // Note: `field_arrays` can be empty for 0-field structs. + let field_arrays = (0..n_fields) + .map(|i| { + scratch.clear(); + scratch.extend(field_arrays.iter().map(|v| v[i].clone())); + + super::horizontal_flatten_unchecked(&scratch, widths, output_height) + }) + .collect::>(); + + let validity = if arrays.iter().any(|x| x.validity().is_some()) { + let max_height = output_height * widths.iter().fold(0usize, |a, b| a.max(*b)); + let mut shared_validity = None; + + // We need to create BooleanArrays from the Bitmaps for dispatch. + let validities: Vec = arrays + .iter() + .map(|x| { + x.validity().cloned().unwrap_or_else(|| { + if shared_validity.is_none() { + shared_validity = Some(Bitmap::new_with_value(true, max_height)) + }; + // We have to slice to exact length to pass an assertion. + shared_validity.clone().unwrap().sliced(0, x.len()) + }) + }) + .map(|x| BooleanArray::from_inner_unchecked(ArrowDataType::Boolean, x, None)) + .collect::>(); + + Some( + super::horizontal_flatten_unchecked_impl_generic::( + &validities, + widths, + output_height, + &ArrowDataType::Boolean, + ) + .as_any() + .downcast_ref::() + .unwrap() + .values() + .clone(), + ) + } else { + None + }; + + StructArray::new( + dtype.clone(), + if n_fields == 0 { + output_height * widths.iter().copied().sum::() + } else { + debug_assert_eq!( + field_arrays[0].len(), + output_height * widths.iter().copied().sum::() + ); + + field_arrays[0].len() + }, + field_arrays, + validity, + ) +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index 952cbbd33052..73ca9e6232be 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -15,6 +15,7 @@ pub mod cardinality; pub mod comparisons; pub mod filter; pub mod float_sum; +pub mod horizontal_flatten; #[cfg(feature = "approx_unique")] pub mod hyperloglogplus; pub mod if_then_else; diff --git a/crates/polars-ops/src/series/ops/concat_arr.rs b/crates/polars-ops/src/series/ops/concat_arr.rs new file mode 100644 index 000000000000..2bb811816308 --- /dev/null +++ b/crates/polars-ops/src/series/ops/concat_arr.rs @@ -0,0 +1,115 @@ +use arrow::array::FixedSizeListArray; +use arrow::compute::utils::combine_validities_and; +use polars_compute::horizontal_flatten::horizontal_flatten_unchecked; +use polars_core::prelude::{ArrayChunked, Column, CompatLevel, DataType, IntoColumn}; +use polars_core::series::Series; +use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; + +/// Note: The caller must ensure all columns in `args` have the same type. +/// +/// # Panics +/// Panics if +/// * `args` is empty +/// * `dtype` is not a `DataType::Array` +pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult { + let DataType::Array(inner_dtype, width) = dtype else { + panic!("{}", dtype); + }; + + let inner_dtype = inner_dtype.as_ref(); + let width = *width; + + let mut output_height = args[0].len(); + let mut calculated_width = 0; + let mut mismatch_height = (&PlSmallStr::EMPTY, output_height); + // If there is a `Array` column with a single NULL, the output will be entirely NULL. + let mut return_all_null = false; + + let (arrays, widths): (Vec<_>, Vec<_>) = args + .iter() + .map(|c| { + // Handle broadcasting + if output_height == 1 { + output_height = c.len(); + mismatch_height.1 = c.len(); + } + + if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height { + mismatch_height = (c.name(), c.len()); + } + + match c.dtype() { + DataType::Array(inner, width) => { + debug_assert_eq!(inner.as_ref(), inner_dtype); + + let arr = c.array().unwrap().rechunk(); + + return_all_null |= + arr.len() == 1 && arr.rechunk_validity().map_or(false, |x| !x.get_bit(0)); + + (arr.rechunk().downcast_into_array().values().clone(), *width) + }, + dtype => { + debug_assert_eq!(dtype, inner_dtype); + ( + c.as_materialized_series().rechunk().into_chunks()[0].clone(), + 1, + ) + }, + } + }) + .filter(|x| x.1 > 0) + .inspect(|x| calculated_width += x.1) + .unzip(); + + assert_eq!(calculated_width, width); + + if mismatch_height.1 != output_height { + polars_bail!( + ShapeMismatch: + "concat_arr: length of column '{}' (len={}) did not match length of \ + first column '{}' (len={})", + mismatch_height.0, mismatch_height.1, args[0].name(), output_height, + ) + } + + if return_all_null { + let arr = + FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height); + return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()); + } + + let outer_validity = args + .iter() + // Note: We ignore the validity of non-array input columns, their outer is always valid after + // being reshaped to (-1, 1). + .filter(|x| { + // Unit length validities at this point always contain a single valid, as we would have + // returned earlier otherwise with `return_all_null`, so we filter them out. + debug_assert!(x.len() == output_height || x.len() == 1); + + x.dtype().is_array() && x.len() == output_height + }) + .map(|x| x.as_materialized_series().rechunk_validity()) + .fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref())); + + let inner_arr = if output_height == 0 || width == 0 { + Series::new_empty(PlSmallStr::EMPTY, inner_dtype) + .into_chunks() + .into_iter() + .next() + .unwrap() + } else { + unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } + }; + + let arr = FixedSizeListArray::new( + dtype.to_arrow(CompatLevel::newest()), + output_height, + inner_arr, + outer_validity, + ); + + Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()) +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index b684815238f7..6be3b85cb8e2 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -131,6 +131,8 @@ pub use unique::*; pub use various::*; mod not; +#[cfg(feature = "dtype-array")] +pub mod concat_arr; #[cfg(feature = "dtype-duration")] pub(crate) mod duration; #[cfg(feature = "dtype-duration")] diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 08333beb3893..3b11243a66d5 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -31,12 +31,22 @@ pub enum ArrayFunction { CountMatches, Shift, Explode, + Concat, } impl ArrayFunction { pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { use ArrayFunction::*; match self { + Concat => Ok(Field::new( + mapper + .args() + .first() + .map_or(PlSmallStr::EMPTY, |x| x.name.clone()), + concat_arr_output_dtype( + &mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)), + )?, + )), Min | Max => mapper.map_to_list_and_array_inner_dtype(), Sum => mapper.nested_sum_type(), ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype), @@ -74,6 +84,7 @@ impl Display for ArrayFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use ArrayFunction::*; let name = match self { + Concat => "concat", Min => "min", Max => "max", Sum => "sum", @@ -108,6 +119,7 @@ impl From for SpecialEq> { fn from(func: ArrayFunction) -> Self { use ArrayFunction::*; match func { + Concat => map_as_slice!(concat_arr), Min => map!(min), Max => map!(max), Sum => map!(sum), @@ -257,3 +269,48 @@ pub(super) fn shift(s: &[Column]) -> PolarsResult { fn explode(c: &[Column]) -> PolarsResult { c[0].explode() } + +fn concat_arr(args: &[Column]) -> PolarsResult { + let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?; + + polars_ops::series::concat_arr::concat_arr(args, &dtype) +} + +/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input +/// dtypes are compatible. +fn concat_arr_output_dtype( + inputs: &mut dyn ExactSizeIterator, +) -> PolarsResult { + #[allow(clippy::len_zero)] + if inputs.len() == 0 { + // should not be reachable - we did not set ALLOW_EMPTY_INPUTS + panic!(); + } + + let mut inputs = inputs.map(|(name, dtype)| { + let (inner_dtype, width) = match dtype { + DataType::Array(inner, width) => (inner.as_ref(), *width), + dt => (dt, 1), + }; + (name, dtype, inner_dtype, width) + }); + let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap(); + + for (col_name, dtype, inner_dtype, width) in inputs { + out_width += width; + + if inner_dtype != first_inner_dtype { + polars_bail!( + SchemaMismatch: + "concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \ + input column (name: {}, dtype: {}), got {} instead for column {}", + first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name, + ) + } + } + + Ok(DataType::Array( + Box::new(first_inner_dtype.clone()), + out_width, + )) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 17dc82e23e8c..5813fa7a72cd 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -801,9 +801,9 @@ macro_rules! wrap { }}; } -// Fn(&[Column], args) -// all expression arguments are in the slice. -// the first element is the root expression. +/// `Fn(&[Column], args)` +/// * all expression arguments are in the slice. +/// * the first element is the root expression. #[macro_export] macro_rules! map_as_slice { ($func:path) => {{ @@ -823,8 +823,8 @@ macro_rules! map_as_slice { }}; } -// FnOnce(Series) -// FnOnce(Series, args) +/// * `FnOnce(Series)` +/// * `FnOnce(Series, args)` #[macro_export] macro_rules! map_owned { ($func:path) => {{ @@ -846,7 +846,7 @@ macro_rules! map_owned { }}; } -// Fn(&Series, args) +/// `Fn(&Series, args)` #[macro_export] macro_rules! map { ($func:path) => {{ diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index d15b1769cf3a..9599af9d83f8 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -69,6 +69,23 @@ pub fn concat_list, IE: Into + Clone>(s: E) -> PolarsResult }) } +/// Horizontally concatenate columns into a single array-type column. +pub fn concat_arr(input: Vec) -> PolarsResult { + feature_gated!("dtype-array", { + polars_ensure!(!input.is_empty(), ComputeError: "`concat_arr` needs one or more expressions"); + + Ok(Expr::Function { + input, + function: FunctionExpr::ArrayExpr(ArrayFunction::Concat), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION, + ..Default::default() + }, + }) + }) +} + pub fn concat_expr, IE: Into + Clone>( s: E, rechunk: bool, diff --git a/crates/polars-python/src/functions/lazy.rs b/crates/polars-python/src/functions/lazy.rs index 1c4e738ea69c..3a437b3281d3 100644 --- a/crates/polars-python/src/functions/lazy.rs +++ b/crates/polars-python/src/functions/lazy.rs @@ -200,6 +200,13 @@ pub fn concat_list(s: Vec) -> PyResult { Ok(expr.into()) } +#[pyfunction] +pub fn concat_arr(s: Vec) -> PyResult { + let s = s.into_iter().map(|e| e.inner).collect::>(); + let expr = dsl::concat_arr(s).map_err(PyPolarsErr::from)?; + Ok(expr.into()) +} + #[pyfunction] pub fn concat_str(s: Vec, separator: &str, ignore_nulls: bool) -> PyExpr { let s = s.into_iter().map(|e| e.inner).collect::>(); diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index b27fa4e87f84..69649528a3fa 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -24,6 +24,7 @@ These functions are available from the Polars module root and can be used as exp arg_where business_day_count coalesce + concat_arr concat_list concat_str corr diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index eb33f23bf53f..d6c4ae511aa8 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -86,6 +86,7 @@ collect_all, collect_all_async, concat, + concat_arr, concat_list, concat_str, corr, @@ -329,6 +330,7 @@ "col", "collect_all", "collect_all_async", + "concat_arr", "concat_list", "concat_str", "corr", diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index 32fbe4578059..708d24e55a02 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -14,6 +14,7 @@ sum_horizontal, ) from polars.functions.as_datatype import ( + concat_arr, concat_list, concat_str, duration, @@ -124,6 +125,7 @@ "col", "collect_all", "collect_all_async", + "concat_arr", "concat_list", "concat_str", "corr", diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index 30398daa01d9..181dc7767d0e 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -502,6 +502,117 @@ def concat_list(exprs: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr) -> return wrap_expr(plr.concat_list(exprs)) +def concat_arr(exprs: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr) -> Expr: + """ + Horizontally concatenate columns into a single array column. + + Non-array columns are reshaped to a unit-width array. All columns must have + a dtype of either `pl.Array(, width)` or `pl.`. + + Parameters + ---------- + exprs + Columns to concatenate into a single array column. Accepts expression input. + Strings are parsed as column names, other non-expression inputs are parsed as + literals. + *more_exprs + Additional columns to concatenate into a single array column, specified as + positional arguments. + + Examples + -------- + Concatenate 2 array columns: + + >>> ( + ... pl.select( + ... a=pl.Series([[1], [3], None], dtype=pl.Array(pl.Int64, 1)), + ... b=pl.Series([[3], [None], [5]], dtype=pl.Array(pl.Int64, 1)), + ... ).with_columns( + ... pl.concat_arr("a", "b").alias("concat_arr(a, b)"), + ... pl.concat_arr("a", pl.first("b")).alias("concat_arr(a, first(b))"), + ... ) + ... ) + shape: (3, 4) + ┌───────────────┬───────────────┬──────────────────┬─────────────────────────┐ + │ a ┆ b ┆ concat_arr(a, b) ┆ concat_arr(a, first(b)) │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ array[i64, 1] ┆ array[i64, 1] ┆ array[i64, 2] ┆ array[i64, 2] │ + ╞═══════════════╪═══════════════╪══════════════════╪═════════════════════════╡ + │ [1] ┆ [3] ┆ [1, 3] ┆ [1, 3] │ + │ [3] ┆ [null] ┆ [3, null] ┆ [3, 3] │ + │ null ┆ [5] ┆ null ┆ null │ + └───────────────┴───────────────┴──────────────────┴─────────────────────────┘ + + Concatenate non-array columns: + + >>> ( + ... pl.select( + ... c=pl.Series([None, 5, 6], dtype=pl.Int64), + ... ) + ... .with_columns(d=pl.col("c").reverse()) + ... .with_columns( + ... pl.concat_arr("c", "d").alias("concat_arr(c, d)"), + ... ) + ... ) + shape: (3, 3) + ┌──────┬──────┬──────────────────┐ + │ c ┆ d ┆ concat_arr(c, d) │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ array[i64, 2] │ + ╞══════╪══════╪══════════════════╡ + │ null ┆ 6 ┆ [null, 6] │ + │ 5 ┆ 5 ┆ [5, 5] │ + │ 6 ┆ null ┆ [6, null] │ + └──────┴──────┴──────────────────┘ + + Concatenate mixed array and non-array columns: + + >>> ( + ... pl.select( + ... a=pl.Series([[1], [3], None], dtype=pl.Array(pl.Int64, 1)), + ... b=pl.Series([[3], [None], [5]], dtype=pl.Array(pl.Int64, 1)), + ... c=pl.Series([None, 5, 6], dtype=pl.Int64), + ... ).with_columns( + ... pl.concat_arr("a", "b", "c").alias("concat_arr(a, b, c)"), + ... ) + ... ) + shape: (3, 4) + ┌───────────────┬───────────────┬──────┬─────────────────────┐ + │ a ┆ b ┆ c ┆ concat_arr(a, b, c) │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ array[i64, 1] ┆ array[i64, 1] ┆ i64 ┆ array[i64, 3] │ + ╞═══════════════╪═══════════════╪══════╪═════════════════════╡ + │ [1] ┆ [3] ┆ null ┆ [1, 3, null] │ + │ [3] ┆ [null] ┆ 5 ┆ [3, null, 5] │ + │ null ┆ [5] ┆ 6 ┆ null │ + └───────────────┴───────────────┴──────┴─────────────────────┘ + + Unit-length columns are broadcasted: + + >>> ( + ... pl.select( + ... a=pl.Series([1, 3, None]), + ... ).with_columns( + ... pl.concat_arr("a", pl.lit(0, dtype=pl.Int64)).alias("concat_arr(a, 0)"), + ... pl.concat_arr("a", pl.sum("a")).alias("concat_arr(a, sum(a))"), + ... pl.concat_arr("a", pl.max("a")).alias("concat_arr(a, max(a))"), + ... ) + ... ) + shape: (3, 4) + ┌──────┬──────────────────┬───────────────────────┬───────────────────────┐ + │ a ┆ concat_arr(a, 0) ┆ concat_arr(a, sum(a)) ┆ concat_arr(a, max(a)) │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ array[i64, 2] ┆ array[i64, 2] ┆ array[i64, 2] │ + ╞══════╪══════════════════╪═══════════════════════╪═══════════════════════╡ + │ 1 ┆ [1, 0] ┆ [1, 4] ┆ [1, 3] │ + │ 3 ┆ [3, 0] ┆ [3, 4] ┆ [3, 3] │ + │ null ┆ [null, 0] ┆ [null, 4] ┆ [null, 3] │ + └──────┴──────────────────┴───────────────────────┴───────────────────────┘ + """ + exprs = parse_into_list_of_expressions(exprs, *more_exprs) + return wrap_expr(plr.concat_arr(exprs)) + + @overload def struct( *exprs: IntoExpr | Iterable[IntoExpr], diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index f73577319545..56bbb14a4e07 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -168,6 +168,8 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::cols)).unwrap(); m.add_wrapped(wrap_pyfunction!(functions::concat_lf)) .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::concat_arr)) + .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::concat_list)) .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::concat_str)) diff --git a/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py new file mode 100644 index 000000000000..54ec79b31873 --- /dev/null +++ b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py @@ -0,0 +1,176 @@ +import pytest + +import polars as pl +from polars.exceptions import ShapeError +from polars.testing import assert_series_equal + + +def test_concat_arr() -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([2, 4, 6]), + ) + ).to_series(), + pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([2, 4, None]), + ) + ).to_series(), + pl.Series([[1, 2], [3, 4], [5, None]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.Series([[2], [None], None], dtype=pl.Array(pl.Int64, 1)), + ) + ).to_series(), + pl.Series([[1, 2], [3, None], None], dtype=pl.Array(pl.Int64, 2)), + ) + + +def test_concat_arr_broadcast() -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit(None, dtype=pl.Int64), + ) + ).to_series(), + pl.Series([[1, None], [3, None], [5, None]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit(None, dtype=pl.Array(pl.Int64, 2)), + ) + ).to_series(), + pl.Series([None, None, None], dtype=pl.Array(pl.Int64, 3)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([1, 3, 5]), + pl.lit([0, None], dtype=pl.Array(pl.Int64, 2)), + ) + ).to_series(), + pl.Series( + [[1, 0, None], [3, 0, None], [5, 0, None]], dtype=pl.Array(pl.Int64, 3) + ), + ) + + assert_series_equal( + pl.select( + pl.concat_arr(pl.lit(1, dtype=pl.Int64).alias(""), pl.Series([1, 2, 3])) + ).to_series(), + pl.Series([[1, 1], [1, 2], [1, 3]], dtype=pl.Array(pl.Int64, 2)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr(pl.Series([1, 2, 3]), pl.lit(1, dtype=pl.Int64)) + ).to_series(), + pl.Series([[1, 1], [2, 1], [3, 1]], dtype=pl.Array(pl.Int64, 2)), + ) + + with pytest.raises(ShapeError, match="length of column.*did not match"): + assert_series_equal( + pl.select( + pl.concat_arr(pl.Series([1, 3, 5]), pl.Series([1, 1])) + ).to_series(), + pl.Series([None, None, None], dtype=pl.Array(pl.Int64, 3)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series( + [{"x": [1], "y": [2]}, {"x": [3], "y": None}], + dtype=pl.Struct({"x": pl.Array(pl.Int64, 1)}), + ), + pl.lit( + {"x": [9], "y": [11]}, dtype=pl.Struct({"x": pl.Array(pl.Int64, 1)}) + ), + ) + ).to_series(), + pl.Series( + [ + [{"x": [1], "y": [2]}, {"x": [9], "y": [11]}], + [{"x": [3], "y": [4]}, {"x": [9], "y": [11]}], + ], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 1)}), 2), + ), + ) + + +@pytest.mark.parametrize("inner_dtype", [pl.Int64(), pl.Null()]) +def test_concat_arr_validity_combination(inner_dtype: pl.DataType) -> None: + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([[], [], None, None], dtype=pl.Array(inner_dtype, 0)), + pl.Series([[], [], None, None], dtype=pl.Array(inner_dtype, 0)), + pl.Series([[None], None, [None], None], dtype=pl.Array(inner_dtype, 1)), + ), + ).to_series(), + pl.Series([[None], None, None, None], dtype=pl.Array(inner_dtype, 1)), + ) + + assert_series_equal( + pl.select( + pl.concat_arr( + pl.Series([None, None], dtype=inner_dtype), + pl.Series([[None], None], dtype=pl.Array(inner_dtype, 1)), + ), + ).to_series(), + pl.Series([[None, None], None], dtype=pl.Array(inner_dtype, 2)), + ) + + +def test_concat_arr_zero_fields() -> None: + assert_series_equal( + ( + pl.Series([[[]], [None]], dtype=pl.Array(pl.Array(pl.Int64, 0), 1)) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series([[[], []], [None, None]], dtype=pl.Array(pl.Array(pl.Int64, 0), 2)), + ) + + assert_series_equal( + ( + pl.Series([[{}], [None]], dtype=pl.Array(pl.Struct({}), 1)) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series([[{}, {}], [None, None]], dtype=pl.Array(pl.Struct({}), 2)), + ) + + assert_series_equal( + ( + pl.Series( + [[{"x": []}], [{"x": None}], [None]], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 1), + ) + .to_frame() + .select(pl.concat_arr(pl.first(), pl.first())) + .to_series() + ), + pl.Series( + [[{"x": []}, {"x": []}], [{"x": None}, {"x": None}], [None, None]], + dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 2), + ), + )