Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Accumulator::update and Accumulator::merge #1582

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ the convenience of an SQL interface or a DataFrame API.

## Known Uses

Projects that adapt to or service as plugins to DataFusion:
Projects that adapt to or serve as plugins to DataFusion:

- [datafusion-python](https://github.com/datafusion-contrib/datafusion-python)
- [datafusion-java](https://github.com/datafusion-contrib/datafusion-java)
Expand Down
58 changes: 43 additions & 15 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
/// In this example we will declare a single-type, single return type UDAF that computes the geometric mean.
/// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean
use datafusion::arrow::{
array::Float32Array, array::Float64Array, datatypes::DataType,
array::ArrayRef, array::Float32Array, array::Float64Array, datatypes::DataType,
record_batch::RecordBatch,
};

Expand Down Expand Up @@ -66,20 +66,6 @@ impl GeometricMean {
pub fn new() -> Self {
GeometricMean { n: 0, prod: 1.0 }
}
}

// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl Accumulator for GeometricMean {
// this function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
])
}

// this function receives one entry per argument of this accumulator.
// DataFusion calls this function on every row, and expects this function to update the accumulator's state.
Expand Down Expand Up @@ -114,6 +100,20 @@ impl Accumulator for GeometricMean {
};
Ok(())
}
}

// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
])
}

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
Expand All @@ -122,9 +122,37 @@ impl Accumulator for GeometricMean {
Ok(ScalarValue::from(value))
}

// DataFusion calls this function to update the accumulator's state for a batch
// of inputs rows. In this case the product is updated with values from the first column
// and the count is updated based on the row count
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
};
(0..values[0].len()).try_for_each(|index| {
let v = values
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.update(&v)
})
}

// Optimization hint: this trait also supports `update_batch` and `merge_batch`,
// that can be used to perform these operations on arrays instead of single values.
// By default, these methods call `update` and `merge` row by row
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
};
(0..states[0].len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.merge(&v)
})
}
}

#[tokio::main]
Expand Down
48 changes: 37 additions & 11 deletions datafusion/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ use parquet::file::statistics::Statistics as ParquetStatistics;

use super::FileFormat;
use super::FileScanConfig;
use crate::arrow::array::{
BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array,
};
use crate::arrow::datatypes::{DataType, Field};
use crate::datasource::object_store::{ObjectReader, ObjectReaderStream};
use crate::datasource::{create_max_min_accs, get_col_stats};
Expand All @@ -47,7 +50,6 @@ use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator};
use crate::physical_plan::file_format::ParquetExec;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::{Accumulator, Statistics};
use crate::scalar::ScalarValue;

/// The default file exetension of parquet files
pub const DEFAULT_PARQUET_EXTENSION: &str = ".parquet";
Expand Down Expand Up @@ -132,15 +134,19 @@ fn summarize_min_max(
if let DataType::Boolean = fields[i].data_type() {
if s.has_min_max_set() {
if let Some(max_value) = &mut max_values[i] {
match max_value.update(&[ScalarValue::Boolean(Some(*s.max()))]) {
match max_value.update_batch(&[Arc::new(BooleanArray::from(
vec![Some(*s.max())],
))]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
}
}
}
if let Some(min_value) = &mut min_values[i] {
match min_value.update(&[ScalarValue::Boolean(Some(*s.min()))]) {
match min_value.update_batch(&[Arc::new(BooleanArray::from(
vec![Some(*s.min())],
))]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand All @@ -154,15 +160,21 @@ fn summarize_min_max(
if let DataType::Int32 = fields[i].data_type() {
if s.has_min_max_set() {
if let Some(max_value) = &mut max_values[i] {
match max_value.update(&[ScalarValue::Int32(Some(*s.max()))]) {
match max_value.update_batch(&[Arc::new(Int32Array::from_value(
*s.max(),
1,
))]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
}
}
}
if let Some(min_value) = &mut min_values[i] {
match min_value.update(&[ScalarValue::Int32(Some(*s.min()))]) {
match min_value.update_batch(&[Arc::new(Int32Array::from_value(
*s.min(),
1,
))]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand All @@ -176,15 +188,21 @@ fn summarize_min_max(
if let DataType::Int64 = fields[i].data_type() {
if s.has_min_max_set() {
if let Some(max_value) = &mut max_values[i] {
match max_value.update(&[ScalarValue::Int64(Some(*s.max()))]) {
match max_value.update_batch(&[Arc::new(Int64Array::from_value(
*s.max(),
1,
))]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
}
}
}
if let Some(min_value) = &mut min_values[i] {
match min_value.update(&[ScalarValue::Int64(Some(*s.min()))]) {
match min_value.update_batch(&[Arc::new(Int64Array::from_value(
*s.min(),
1,
))]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand All @@ -198,15 +216,19 @@ fn summarize_min_max(
if let DataType::Float32 = fields[i].data_type() {
if s.has_min_max_set() {
if let Some(max_value) = &mut max_values[i] {
match max_value.update(&[ScalarValue::Float32(Some(*s.max()))]) {
match max_value.update_batch(&[Arc::new(Float32Array::from(
vec![Some(*s.max())],
))]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
}
}
}
if let Some(min_value) = &mut min_values[i] {
match min_value.update(&[ScalarValue::Float32(Some(*s.min()))]) {
match min_value.update_batch(&[Arc::new(Float32Array::from(
vec![Some(*s.min())],
))]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand All @@ -220,15 +242,19 @@ fn summarize_min_max(
if let DataType::Float64 = fields[i].data_type() {
if s.has_min_max_set() {
if let Some(max_value) = &mut max_values[i] {
match max_value.update(&[ScalarValue::Float64(Some(*s.max()))]) {
match max_value.update_batch(&[Arc::new(Float64Array::from(
vec![Some(*s.max())],
))]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
}
}
}
if let Some(min_value) = &mut min_values[i] {
match min_value.update(&[ScalarValue::Float64(Some(*s.min()))]) {
match min_value.update_batch(&[Arc::new(Float64Array::from(
vec![Some(*s.min())],
))]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub async fn get_statistics_with_limit(

if let Some(max_value) = &mut max_values[i] {
if let Some(file_max) = cs.max_value.clone() {
match max_value.update(&[file_max]) {
match max_value.update_batch(&[file_max.to_array()]) {
Ok(_) => {}
Err(_) => {
max_values[i] = None;
Expand All @@ -82,7 +82,7 @@ pub async fn get_statistics_with_limit(

if let Some(min_value) = &mut min_values[i] {
if let Some(file_min) = cs.min_value.clone() {
match min_value.update(&[file_min]) {
match min_value.update_batch(&[file_min.to_array()]) {
Ok(_) => {}
Err(_) => {
min_values[i] = None;
Expand Down
50 changes: 43 additions & 7 deletions datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`

use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use arrow::datatypes::{DataType, Field};

use ahash::RandomState;
use std::collections::HashSet;

Expand Down Expand Up @@ -130,8 +130,7 @@ struct DistinctCountAccumulator {
state_data_types: Vec<DataType>,
count_data_type: DataType,
}

impl Accumulator for DistinctCountAccumulator {
impl DistinctCountAccumulator {
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
// If a row has a NULL, it is not included in the final count.
if !values.iter().any(|v| v.is_null()) {
Expand Down Expand Up @@ -165,7 +164,33 @@ impl Accumulator for DistinctCountAccumulator {
self.update(&row_values)
})
}
}

impl Accumulator for DistinctCountAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
};
(0..values[0].len()).try_for_each(|index| {
let v = values
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.update(&v)
})
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
};
(0..states[0].len()).try_for_each(|index| {
let v = states
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.merge(&v)
})
}
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut cols_out = self
.state_data_types
Expand Down Expand Up @@ -317,9 +342,20 @@ mod tests {

let mut accum = agg.create_accumulator()?;

for row in rows.iter() {
accum.update(row)?
}
let cols = (0..rows[0].len())
.map(|i| {
rows.iter()
.map(|inner| inner[i].clone())
.collect::<Vec<ScalarValue>>()
})
.collect::<Vec<_>>();

let arrays: Vec<ArrayRef> = cols
.iter()
.map(|c| ScalarValue::iter_to_array(c.clone()))
.collect::<Result<Vec<ArrayRef>>>()?;

accum.update_batch(&arrays)?;

Ok((accum.state()?, accum.evaluate()?))
}
Expand Down
17 changes: 0 additions & 17 deletions datafusion/src/physical_plan/expressions/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,23 +217,6 @@ impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> {

macro_rules! default_accumulator_impl {
() => {
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
self.update_batch(
values
.iter()
.map(|s| s.to_array() as ArrayRef)
.collect::<Vec<_>>()
.as_slice(),
)
}

fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
assert_eq!(1, states.len(), "expect only 1 element in the states");
let other = HyperLogLog::try_from(&states[0])?;
self.hll.merge(&other);
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
assert_eq!(1, states.len(), "expect only 1 element in the states");
let binary_array = states[0].as_any().downcast_ref::<BinaryArray>().unwrap();
Expand Down
Loading