diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 4852dfe9ed68b..b039cb651d09f 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -35,7 +35,7 @@ use crate::physical_plan::{ use crate::scalar::ScalarValue; use arrow::{ - array::{Array, Date64Array, UInt32Builder}, + array::{Array, UInt32Builder}, error::{ArrowError, Result as ArrowResult}, }; use arrow::{ @@ -969,80 +969,61 @@ fn create_batch_from_map( if accumulators.is_empty() { return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned()))); } - // 1. for each key - // 2. create single-row ArrayRef with all group expressions - // 3. create single-row ArrayRef with all aggregate states or values - // 4. collect all in a vector per key of vec, vec[i][j] - // 5. concatenate the arrays over the second index [j] into a single vec. - let (_, (group_bys, accs, _)) = accumulators.iter().nth(0).unwrap(); - let group_by_data_types: Vec = group_bys - .iter() - .map(|x| ScalarValue::from(x).get_datatype()) - .collect(); - let mut acc_data_types: Vec> = vec![]; + let (_, (_, accs, _)) = accumulators.iter().nth(0).unwrap(); + let mut acc_data_types: Vec = vec![]; + + // Calculate number/shape of state arrays match mode { AggregateMode::Partial => { for acc in accs.iter() { let state = acc .state() .map_err(DataFusionError::into_arrow_external_error)?; - acc_data_types - .push(state.iter().map(ScalarValue::get_datatype).collect()); + acc_data_types.push(state.len()); } } AggregateMode::Final | AggregateMode::FinalPartitioned => { - for acc in accs { - acc_data_types.push(vec![acc - .evaluate() - .map_err(DataFusionError::into_arrow_external_error)? - .get_datatype()]); + for _ in accs { + acc_data_types.push(1); } } } - let mut arrays = (0..num_group_expr) + let mut columns = (0..num_group_expr) .map(|i| { - scalar_to_array( - group_by_data_types[i].clone(), - accumulators.iter().map(|(_, (group_by_values, _, _))| { - ScalarValue::from(&group_by_values[i]) - }), - ) + ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (group_by_values, _, _))| ScalarValue::from(&group_by_values[i]), + )) }) .collect::>>() .map_err(|x| x.into_arrow_external_error())?; // add state / evaluated arrays - for (x, state_dt) in acc_data_types.iter().enumerate() { - for y in 0..state_dt.len() { + for (x, &state_len) in acc_data_types.iter().enumerate() { + for y in 0..state_len { match mode { AggregateMode::Partial => { - let res = scalar_to_array( - state_dt[y].clone(), - accumulators.iter().map(|(_, (_, accumulator, _))| { - accumulator[x].state().unwrap()[y].clone() - }), - ) + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| { + let x = accumulator[x].state().unwrap(); + x[y].clone() + }, + )) .map_err(DataFusionError::into_arrow_external_error)?; - arrays.push(res); + + columns.push(res); } AggregateMode::Final | AggregateMode::FinalPartitioned => { - let res = scalar_to_array( - state_dt[y].clone(), - accumulators.iter().map(|(_, (_, accumulator, _))| { - accumulator[x].evaluate().unwrap() - }), - ) + let res = ScalarValue::iter_to_array(accumulators.into_iter().map( + |(_, (_, accumulator, _))| accumulator[x].evaluate().unwrap(), + )) .map_err(DataFusionError::into_arrow_external_error)?; - arrays.push(res); + columns.push(res); } } } } - // 5. - let columns = arrays; - // cast output if needed (e.g. for types like Dictionary where // the intermediate GroupByScalar type was not the same as the // output @@ -1064,194 +1045,6 @@ fn create_accumulators( .collect::>>() } -fn scalar_to_array( - data_type: DataType, - iter: impl IntoIterator, -) -> Result { - match data_type { - DataType::Int8 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Int8(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Int16 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Int16(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Int32 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Int32(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Int64 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Int64(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Boolean => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Boolean(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::UInt8 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::UInt8(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::UInt16 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::UInt16(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::UInt32 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::UInt32(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::UInt64 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::UInt64(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Float32 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Float32(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Float64 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Float64(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Date32 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Date32(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Date64 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Date64(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Timestamp(TimeUnit::Millisecond, None) => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::TimestampMillisecond(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - // DataType::Time32(_) => {} - // DataType::Time64(_) => {} - // DataType::Duration(_) => {} - // DataType::Interval(_) => {} - // DataType::Binary => {} - // DataType::FixedSizeBinary(_) => {} - // DataType::LargeBinary => {} - DataType::Utf8 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::Utf8(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::LargeUtf8 => Ok(Arc::new( - iter.into_iter() - .map(|x| match x { - ScalarValue::LargeUtf8(val) => val, - v => panic!("Unexpected type in scalar_to_array {:?}", v), - }) - .collect::(), - )), - DataType::Dictionary(key_type, val_type) => { - // Construct array first and then the dictionary array - // TODO (perf): construct dictionary array in one go - let x = scalar_to_array(*val_type, iter)?; - return match *key_type { - DataType::Int8 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::Int16 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::Int32 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::Int64 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::UInt8 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::UInt16 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::UInt32 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - DataType::UInt64 => Ok(Arc::new(DictionaryArray::::from( - x.data().clone(), - ))), - _ => Err(DataFusionError::NotImplemented(format!( - "Key type {:?} not supported in scalar_to_array", - key_type - ))), - }; - } - - // DataType::FixedSizeList(_, _) => {} - // DataType::LargeList(_) => {} - // DataType::Struct(_) => {} - // DataType::Union(_) => {} - // DataType::Decimal(_, _) => {} - data_type => Err(DataFusionError::NotImplemented(format!( - "to_array not implemented for {}", - data_type - ))), - } -} - /// returns a vector of ArrayRefs, where each entry corresponds to either the /// final value (mode = Final) or states (mode = Partial) fn finalize_aggregation( diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index f3fa5b2c5de5c..1374b83766c90 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -311,7 +311,7 @@ impl ScalarValue { /// ]; /// /// // Build an Array from the list of ScalarValues - /// let array = ScalarValue::iter_to_array(scalars.iter()) + /// let array = ScalarValue::iter_to_array(scalars.into_iter()) /// .unwrap(); /// /// let expected: ArrayRef = std::sync::Arc::new( @@ -324,8 +324,8 @@ impl ScalarValue { /// /// assert_eq!(&array, &expected); /// ``` - pub fn iter_to_array<'a>( - scalars: impl IntoIterator, + pub fn iter_to_array( + scalars: impl IntoIterator, ) -> Result { let mut scalars = scalars.into_iter().peekable(); @@ -347,7 +347,7 @@ impl ScalarValue { let values = scalars .map(|sv| { if let ScalarValue::$SCALAR_TY(v) = sv { - Ok(*v) + Ok(v) } else { Err(DataFusionError::Internal(format!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -394,6 +394,24 @@ impl ScalarValue { }}; } + macro_rules! build_array_list { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ + Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x { + ScalarValue::List(xs, _) => xs.map(|x| { + x.iter() + .map(|x| match x { + ScalarValue::$SCALAR_TY(i) => *i, + _ => panic!("xxx"), + }) + .collect::>>() + }), + _ => panic!("xxx"), + }), + )) + }}; + } + let array: ArrayRef = match &data_type { DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), @@ -430,6 +448,30 @@ impl ScalarValue { DataType::Interval(IntervalUnit::YearMonth) => { build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth) } + DataType::List(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list!(Int8Type, Int8, i8) + } + DataType::List(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list!(Int16Type, Int16, i16) + } + DataType::List(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list!(Int32Type, Int32, i32) + } + DataType::List(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list!(Int64Type, Int64, i64) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list!(UInt8Type, UInt8, u8) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list!(UInt16Type, UInt16, u16) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list!(UInt32Type, UInt32, u32) + } + DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list!(UInt64Type, UInt64, u64) + } _ => { return Err(DataFusionError::Internal(format!( "Unsupported creation of {:?} array from ScalarValue {:?}", @@ -1102,7 +1144,7 @@ mod tests { let scalars: Vec<_> = $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1119,7 +1161,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT)); @@ -1136,7 +1178,7 @@ mod tests { .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec()))) .collect(); - let array = ScalarValue::iter_to_array(scalars.iter()).unwrap(); + let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap(); let expected: $ARRAYTYPE = $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect(); @@ -1210,7 +1252,7 @@ mod tests { fn scalar_iter_to_array_empty() { let scalars = vec![] as Vec; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!( result .to_string() @@ -1226,7 +1268,7 @@ mod tests { // If the scalar values are not all the correct type, error here let scalars: Vec = vec![Boolean(Some(true)), Int32(Some(5))]; - let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err(); + let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err(); assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"), "{}", result); }