diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index d730fbf89b72..f18db2f31b14 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2039,7 +2039,17 @@ impl ScalarValue { } } - /// Retrieve ScalarValue for each row in `array` + /// Retrieve `ScalarValue` for each row in `array` + /// + /// Convert `ListArray` into a 2 dimensional to `Vec>`, first `Vec` is for rows, + /// second `Vec` is for elements in the list. + /// + /// See [`Self::convert_non_list_array_to_scalars`] for converting non Lists + /// + /// This method is an optimization to unwrap nested ListArrays to nested Rust structures without + /// converting them twice + /// + /// Return `Err` if `array` is not `ListArray` /// /// Example /// ``` @@ -2053,7 +2063,7 @@ impl ScalarValue { /// Some(vec![Some(4), Some(5)]) /// ]); /// - /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::(&list_arr).unwrap(); /// /// let expected = vec![ /// vec![ @@ -2067,30 +2077,78 @@ impl ScalarValue { /// /// assert_eq!(scalar_vec, expected); /// ``` - pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { - let mut scalars = Vec::with_capacity(array.len()); + pub fn convert_list_array_to_scalar_vec( + array: &dyn Array, + ) -> Result>> { + if array.as_list_opt::().is_some() { + Self::convert_list_array_to_scalar_vec_internal::(array) + } else { + _internal_err!("Expected GenericListArray but found: {array:?}") + } + } - for index in 0..array.len() { - let scalar_values = match array.data_type() { - DataType::List(_) => { - let list_array = as_list_array(array); - match list_array.is_null(index) { - true => Vec::new(), - false => { - let nested_array = list_array.value(index); - ScalarValue::convert_array_to_scalar_vec(&nested_array)? - .into_iter() - .flatten() - .collect() - } + fn convert_list_array_to_scalar_vec_internal( + array: &dyn Array, + ) -> Result>> { + let mut scalars_vec = Vec::with_capacity(array.len()); + + if let Some(list_arr) = array.as_list_opt::() { + for index in 0..list_arr.len() { + let scalars = match list_arr.is_null(index) { + true => Vec::new(), + false => { + let nested_array = list_arr.value(index); + Self::convert_list_array_to_scalar_vec_internal::( + &nested_array, + )? + .into_iter() + .flatten() + .collect() } - } - _ => { - let scalar = ScalarValue::try_from_array(array, index)?; - vec![scalar] - } - }; - scalars.push(scalar_values); + }; + scalars_vec.push(scalars); + } + } else { + let scalars = ScalarValue::convert_non_list_array_to_scalars(array)?; + scalars_vec.push(scalars); + } + + Ok(scalars_vec) + } + + /// Convert non-ListArray to `Vec` + /// + /// Return Err if `array` is ListArray + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::Int32Array; + /// + /// let list_arr = Int32Array::from(vec![Some(1), Some(2), Some(3), None, Some(4), Some(5)]); + /// + /// let scalar_vec = ScalarValue::convert_non_list_array_to_scalars(&list_arr).unwrap(); + /// + /// let expected = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(4)), + /// ScalarValue::Int32(Some(5)), + /// ]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + pub fn convert_non_list_array_to_scalars(array: &dyn Array) -> Result> { + if array.as_list_opt::().is_some() || array.as_list_opt::().is_some() { + return _internal_err!("Expected non ListArray but found: {array:?}"); + } + + let mut scalars = Vec::with_capacity(array.len()); + for index in 0..array.len() { + let scalar = ScalarValue::try_from_array(array, index)?; + scalars.push(scalar); } Ok(scalars) } @@ -3129,6 +3187,44 @@ mod tests { use arrow_array::ArrowNumericType; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + use crate::utils::arrays_into_list_array; + + #[test] + fn convert_list_array_to_scalar_vec_nested() { + let l1 = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ]); + + let l2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(6), + Some(7), + Some(8), + ])]); + + let l1 = Arc::new(l1) as ArrayRef; + let l2 = Arc::new(l2) as ArrayRef; + + let l12 = arrays_into_list_array([l1, l2]).unwrap(); + let arr = Arc::new(l12) as ArrayRef; + + let actual = ScalarValue::convert_list_array_to_scalar_vec::(&arr).unwrap(); + let expected = vec![ + vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(3)), + ScalarValue::Int32(Some(4)), + ScalarValue::Int32(Some(5)), + ], + vec![ + ScalarValue::Int32(Some(6)), + ScalarValue::Int32(Some(7)), + ScalarValue::Int32(Some(8)), + ], + ]; + assert_eq!(actual, expected); + } #[test] fn test_to_array_of_size_for_list() { @@ -3188,6 +3284,7 @@ mod tests { "arrow", "data-fusion", ]))); + let result = as_list_array(&array); assert_eq!(result, &expected); } diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index af6d0d5f4e24..b30dffb50f2c 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -45,13 +45,17 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); - let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; - let mut scalars = scalar_vec[0].clone(); + // 1 row + let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::(&column)?; + let mut scalars = scalar_vec.first().unwrap().to_owned(); + // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { a.partial_cmp(b).expect("Can compare ScalarValues") }; + scalars.sort_by(cmp); + assert_eq!( scalars, vec![ diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 1efae424cc69..cbf2b713074f 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -137,9 +137,25 @@ impl Accumulator for DistinctArrayAggAccumulator { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); let array = &values[0]; - let scalars = ScalarValue::convert_array_to_scalar_vec(array)?; - for scalar in scalars { - self.values.extend(scalar) + match array.data_type() { + DataType::List(_) => { + let scalar_vec = + ScalarValue::convert_list_array_to_scalar_vec::(array)?; + for scalars in scalar_vec { + self.values.extend(scalars); + } + } + DataType::LargeList(_) => { + let scalar_vec = + ScalarValue::convert_list_array_to_scalar_vec::(array)?; + for scalars in scalar_vec { + self.values.extend(scalars); + } + } + _ => { + let scalars = ScalarValue::convert_non_list_array_to_scalars(array)?; + self.values.extend(scalars); + } } Ok(()) } @@ -149,18 +165,7 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - assert_eq!( - states.len(), - 1, - "array_agg_distinct states must contain single array" - ); - - let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; - for scalars in scalar_vec { - self.values.extend(scalars) - } - - Ok(()) + self.update_batch(states) } fn evaluate(&self) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index eb5ae8b0b0c3..4956c4901da6 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -225,13 +225,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { partition_ordering_values.push(self.ordering_values.clone()); let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + ScalarValue::convert_list_array_to_scalar_vec::(array_agg_values)?; for v in array_agg_res.into_iter() { partition_values.push(v); } - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + let orderings = + ScalarValue::convert_list_array_to_scalar_vec::(agg_orderings)?; for partition_ordering_rows in orderings.into_iter() { // Extract value from struct to ordering_rows for each group/partition diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f5242d983d4c..d2a18869ae99 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -167,7 +167,8 @@ impl Accumulator for DistinctCountAccumulator { return Ok(()); } assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + let scalar_vec = + ScalarValue::convert_list_array_to_scalar_vec::(&states[0])?; for scalars in scalar_vec.into_iter() { self.values.extend(scalars) }