Skip to content

Commit

Permalink
Support list
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan committed May 24, 2021
1 parent 5ff1e73 commit 515a9bc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 242 deletions.
259 changes: 26 additions & 233 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::physical_plan::{
use crate::scalar::ScalarValue;

use arrow::{
array::{Array, Date64Array, UInt32Builder},
array::{Array, UInt32Builder},
error::{ArrowError, Result as ArrowResult},
};
use arrow::{
Expand Down Expand Up @@ -969,80 +969,61 @@ fn create_batch_from_map(
if accumulators.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
}
// 1. for each key
// 2. create single-row ArrayRef with all group expressions
// 3. create single-row ArrayRef with all aggregate states or values
// 4. collect all in a vector per key of vec<ArrayRef>, vec[i][j]
// 5. concatenate the arrays over the second index [j] into a single vec<ArrayRef>.
let (_, (group_bys, accs, _)) = accumulators.iter().nth(0).unwrap();
let group_by_data_types: Vec<DataType> = group_bys
.iter()
.map(|x| ScalarValue::from(x).get_datatype())
.collect();
let mut acc_data_types: Vec<Vec<DataType>> = vec![];
let (_, (_, accs, _)) = accumulators.iter().nth(0).unwrap();
let mut acc_data_types: Vec<usize> = vec![];

// Calculate number/shape of state arrays
match mode {
AggregateMode::Partial => {
for acc in accs.iter() {
let state = acc
.state()
.map_err(DataFusionError::into_arrow_external_error)?;
acc_data_types
.push(state.iter().map(ScalarValue::get_datatype).collect());
acc_data_types.push(state.len());
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
for acc in accs {
acc_data_types.push(vec![acc
.evaluate()
.map_err(DataFusionError::into_arrow_external_error)?
.get_datatype()]);
for _ in accs {
acc_data_types.push(1);
}
}
}

let mut arrays = (0..num_group_expr)
let mut columns = (0..num_group_expr)
.map(|i| {
scalar_to_array(
group_by_data_types[i].clone(),
accumulators.iter().map(|(_, (group_by_values, _, _))| {
ScalarValue::from(&group_by_values[i])
}),
)
ScalarValue::iter_to_array(accumulators.into_iter().map(
|(_, (group_by_values, _, _))| ScalarValue::from(&group_by_values[i]),
))
})
.collect::<Result<Vec<_>>>()
.map_err(|x| x.into_arrow_external_error())?;

// add state / evaluated arrays
for (x, state_dt) in acc_data_types.iter().enumerate() {
for y in 0..state_dt.len() {
for (x, &state_len) in acc_data_types.iter().enumerate() {
for y in 0..state_len {
match mode {
AggregateMode::Partial => {
let res = scalar_to_array(
state_dt[y].clone(),
accumulators.iter().map(|(_, (_, accumulator, _))| {
accumulator[x].state().unwrap()[y].clone()
}),
)
let res = ScalarValue::iter_to_array(accumulators.into_iter().map(
|(_, (_, accumulator, _))| {
let x = accumulator[x].state().unwrap();
x[y].clone()
},
))
.map_err(DataFusionError::into_arrow_external_error)?;
arrays.push(res);

columns.push(res);
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let res = scalar_to_array(
state_dt[y].clone(),
accumulators.iter().map(|(_, (_, accumulator, _))| {
accumulator[x].evaluate().unwrap()
}),
)
let res = ScalarValue::iter_to_array(accumulators.into_iter().map(
|(_, (_, accumulator, _))| accumulator[x].evaluate().unwrap(),
))
.map_err(DataFusionError::into_arrow_external_error)?;
arrays.push(res);
columns.push(res);
}
}
}
}

// 5.
let columns = arrays;

// cast output if needed (e.g. for types like Dictionary where
// the intermediate GroupByScalar type was not the same as the
// output
Expand All @@ -1064,194 +1045,6 @@ fn create_accumulators(
.collect::<Result<Vec<_>>>()
}

fn scalar_to_array(
data_type: DataType,
iter: impl IntoIterator<Item = ScalarValue>,
) -> Result<ArrayRef> {
match data_type {
DataType::Int8 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Int8(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Int8Array>(),
)),
DataType::Int16 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Int16(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Int16Array>(),
)),
DataType::Int32 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Int32(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Int32Array>(),
)),
DataType::Int64 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Int64(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Int64Array>(),
)),
DataType::Boolean => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Boolean(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<BooleanArray>(),
)),
DataType::UInt8 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::UInt8(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<UInt8Array>(),
)),
DataType::UInt16 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::UInt16(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<UInt16Array>(),
)),
DataType::UInt32 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::UInt32(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<UInt32Array>(),
)),
DataType::UInt64 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::UInt64(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<UInt64Array>(),
)),
DataType::Float32 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Float32(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Float32Array>(),
)),
DataType::Float64 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Float64(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Float64Array>(),
)),
DataType::Date32 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Date32(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Date32Array>(),
)),
DataType::Date64 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Date64(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<Date64Array>(),
)),
DataType::Timestamp(TimeUnit::Millisecond, None) => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::TimestampMillisecond(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<TimestampMillisecondArray>(),
)),
// DataType::Time32(_) => {}
// DataType::Time64(_) => {}
// DataType::Duration(_) => {}
// DataType::Interval(_) => {}
// DataType::Binary => {}
// DataType::FixedSizeBinary(_) => {}
// DataType::LargeBinary => {}
DataType::Utf8 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::Utf8(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<StringArray>(),
)),
DataType::LargeUtf8 => Ok(Arc::new(
iter.into_iter()
.map(|x| match x {
ScalarValue::LargeUtf8(val) => val,
v => panic!("Unexpected type in scalar_to_array {:?}", v),
})
.collect::<LargeStringArray>(),
)),
DataType::Dictionary(key_type, val_type) => {
// Construct array first and then the dictionary array
// TODO (perf): construct dictionary array in one go
let x = scalar_to_array(*val_type, iter)?;
return match *key_type {
DataType::Int8 => Ok(Arc::new(DictionaryArray::<Int8Type>::from(
x.data().clone(),
))),
DataType::Int16 => Ok(Arc::new(DictionaryArray::<Int16Type>::from(
x.data().clone(),
))),
DataType::Int32 => Ok(Arc::new(DictionaryArray::<Int32Type>::from(
x.data().clone(),
))),
DataType::Int64 => Ok(Arc::new(DictionaryArray::<Int64Type>::from(
x.data().clone(),
))),
DataType::UInt8 => Ok(Arc::new(DictionaryArray::<UInt8Type>::from(
x.data().clone(),
))),
DataType::UInt16 => Ok(Arc::new(DictionaryArray::<UInt16Type>::from(
x.data().clone(),
))),
DataType::UInt32 => Ok(Arc::new(DictionaryArray::<UInt32Type>::from(
x.data().clone(),
))),
DataType::UInt64 => Ok(Arc::new(DictionaryArray::<UInt64Type>::from(
x.data().clone(),
))),
_ => Err(DataFusionError::NotImplemented(format!(
"Key type {:?} not supported in scalar_to_array",
key_type
))),
};
}

// DataType::FixedSizeList(_, _) => {}
// DataType::LargeList(_) => {}
// DataType::Struct(_) => {}
// DataType::Union(_) => {}
// DataType::Decimal(_, _) => {}
data_type => Err(DataFusionError::NotImplemented(format!(
"to_array not implemented for {}",
data_type
))),
}
}

/// returns a vector of ArrayRefs, where each entry corresponds to either the
/// final value (mode = Final) or states (mode = Partial)
fn finalize_aggregation(
Expand Down
Loading

0 comments on commit 515a9bc

Please sign in to comment.