From 9ca0d44ab3d56bd8ed16e2baba93c14da2d2fcd2 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 17 Jul 2024 13:34:07 +0800 Subject: [PATCH] Remove element's nullability of array_agg function (#11447) * rm null Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/tests/sql/aggregates.rs | 2 +- .../physical-expr/src/aggregate/array_agg.rs | 23 +++--------- .../src/aggregate/array_agg_distinct.rs | 23 +++--------- .../src/aggregate/array_agg_ordered.rs | 37 +++++-------------- .../physical-expr/src/aggregate/build_in.rs | 12 +----- .../physical-plan/src/aggregates/mod.rs | 1 - 6 files changed, 23 insertions(+), 75 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 86032dc9bc963..1f4f9e77d5dc5 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { *actual[0].schema(), Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", - Field::new("item", DataType::UInt32, false), + Field::new("item", DataType::UInt32, true), true ),]) ); diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 38a9738029335..0d5ed730e2834 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -24,7 +24,7 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use arrow_array::Array; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -40,8 +40,6 @@ pub struct ArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs - nullable: bool, } impl ArrayAgg { @@ -50,13 +48,11 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, - nullable: bool, ) -> Self { Self { name: name.into(), input_data_type: data_type, expr, - nullable, } } } @@ -70,7 +66,7 @@ impl AggregateExpr for ArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -78,14 +74,13 @@ impl AggregateExpr for ArrayAgg { fn create_accumulator(&self) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new( &self.input_data_type, - self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )]) } @@ -116,16 +111,14 @@ impl PartialEq for ArrayAgg { pub(crate) struct ArrayAggAccumulator { values: Vec, datatype: DataType, - nullable: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType, nullable: bool) -> Result { + pub fn try_new(datatype: &DataType) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), - nullable, }) } } @@ -169,15 +162,11 @@ impl Accumulator for ArrayAggAccumulator { self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - return Ok(ScalarValue::new_null_list( - self.datatype.clone(), - self.nullable, - 1, - )); + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array, self.nullable); + let list_array = array_into_list_array_nullable(concated_array); Ok(ScalarValue::List(Arc::new(list_array))) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 368d11d7421ab..eca6e4ce4f656 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -42,8 +42,6 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs - nullable: bool, } impl DistinctArrayAgg { @@ -52,14 +50,12 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - nullable: bool, ) -> Self { let name = name.into(); Self { name, input_data_type, expr, - nullable, } } } @@ -74,7 +70,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -82,14 +78,13 @@ impl AggregateExpr for DistinctArrayAgg { fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctArrayAggAccumulator::try_new( &self.input_data_type, - self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, )]) } @@ -120,15 +115,13 @@ impl PartialEq for DistinctArrayAgg { struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, - nullable: bool, } impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType, nullable: bool) -> Result { + pub fn try_new(datatype: &DataType) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), - nullable, }) } } @@ -166,13 +159,9 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); if values.is_empty() { - return Ok(ScalarValue::new_null_list( - self.datatype.clone(), - self.nullable, - 1, - )); + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } - let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); + let arr = ScalarValue::new_list(&values, &self.datatype, true); Ok(ScalarValue::List(arr)) } @@ -255,7 +244,6 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, - true, )); let actual = aggregate(&batch, agg)?; compare_list_contents(expected, actual) @@ -272,7 +260,6 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, - true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index d44811192f667..992c06f5bf628 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; use arrow_schema::Fields; -use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; @@ -50,8 +50,6 @@ pub struct OrderSensitiveArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have `NULL`s - nullable: bool, /// Ordering data types order_by_data_types: Vec, /// Ordering requirement @@ -66,7 +64,6 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { @@ -74,7 +71,6 @@ impl OrderSensitiveArrayAgg { name: name.into(), input_data_type, expr, - nullable, order_by_data_types, ordering_req, reverse: false, @@ -90,8 +86,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), self.nullable), + // This should be the same as return type of AggregateFunction::OrderSensitiveArrayAgg + Field::new("item", self.input_data_type.clone(), true), true, )) } @@ -102,7 +98,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.order_by_data_types, self.ordering_req.clone(), self.reverse, - self.nullable, ) .map(|acc| Box::new(acc) as _) } @@ -110,17 +105,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.nullable), + Field::new("item", self.input_data_type.clone(), true), true, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new( - "item", - DataType::Struct(Fields::from(orderings)), - self.nullable, - ), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), false, )); Ok(fields) @@ -147,7 +138,6 @@ impl AggregateExpr for OrderSensitiveArrayAgg { name: self.name.to_string(), input_data_type: self.input_data_type.clone(), expr: Arc::clone(&self.expr), - nullable: self.nullable, order_by_data_types: self.order_by_data_types.clone(), // Reverse requirement: ordering_req: reverse_order_bys(&self.ordering_req), @@ -186,8 +176,6 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, - /// Whether the input expr is nullable - nullable: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -198,7 +186,6 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, - nullable: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -208,7 +195,6 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, - nullable, }) } } @@ -312,7 +298,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if self.values.is_empty() { return Ok(ScalarValue::new_null_list( self.datatypes[0].clone(), - self.nullable, + true, 1, )); } @@ -322,14 +308,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { ScalarValue::new_list_from_iter( values.into_iter().rev(), &self.datatypes[0], - self.nullable, + true, ) } else { - ScalarValue::new_list_from_iter( - values.into_iter(), - &self.datatypes[0], - self.nullable, - ) + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) }; Ok(ScalarValue::List(array)) } @@ -385,9 +367,8 @@ impl OrderSensitiveArrayAggAccumulator { column_wise_ordering_values, None, )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), - self.nullable, )))) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 68c9b4859f1f8..ef21b3d0f7883 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -62,16 +62,14 @@ pub fn create_aggregate_expr( Ok(match (fun, distinct) { (AggregateFunction::ArrayAgg, false) => { let expr = Arc::clone(&input_phy_exprs[0]); - let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, - nullable, ordering_types, ordering_req.to_vec(), )) @@ -84,13 +82,7 @@ pub fn create_aggregate_expr( ); } let expr = Arc::clone(&input_phy_exprs[0]); - let is_expr_nullable = expr.nullable(input_schema)?; - Arc::new(expressions::DistinctArrayAgg::new( - expr, - name, - data_type, - is_expr_nullable, - )) + Arc::new(expressions::DistinctArrayAgg::new(expr, name, data_type)) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( Arc::clone(&input_phy_exprs[0]), diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8bf808af3b5b8..5f780f1ff8019 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2231,7 +2231,6 @@ mod tests { Arc::clone(col_a), "array_agg", DataType::Int32, - false, vec![], order_by_expr.unwrap_or_default(), )) as _