Skip to content

Commit

Permalink
Minor: Move group accumulator for aggregate function to physical-expr…
Browse files Browse the repository at this point in the history
…-common, and add ahash physical-expr-common (apache#10574)

* ahash workspace

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move other utils

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move NullState

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move PrimitiveGroupsAccumulator

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move boolop

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move deciamlavg

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add comment

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored May 22, 2024
1 parent 65e281a commit 2eb38bd
Show file tree
Hide file tree
Showing 17 changed files with 222 additions and 220 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ version = "38.0.0"
# for the inherited dependency but cannot do the reverse (override from true to false).
#
# See for more detaiils: https://github.com/rust-lang/cargo/issues/11329
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
arrow = { version = "51.0.0", features = ["prettyprint"] }
arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] }
arrow-buffer = { version = "51.0.0", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ backtrace = []
pyarrow = ["pyo3", "arrow/pyarrow", "parquet"]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
apache-avro = { version = "0.16", default-features = false, features = [
"bzip",
"snappy",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ unicode_expressions = [
]

[dependencies]
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
ahash = { workspace = true }
apache-avro = { version = "0.16", optional = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
Expand Down
4 changes: 1 addition & 3 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ path = "src/lib.rs"
[features]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
chrono = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ path = "src/lib.rs"
arrow = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
rand = { workspace = true }
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
//!
//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::ArrowPrimitiveType;
use arrow_array::{Array, BooleanArray, PrimitiveArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};

use datafusion_expr::EmitTo;
/// Track the accumulator null state per row: if any values for that
Expand Down Expand Up @@ -462,9 +462,9 @@ fn initialize_builder(
mod test {
use super::*;

use arrow_array::UInt32Array;
use hashbrown::HashSet;
use arrow::array::UInt32Array;
use rand::{rngs::ThreadRng, Rng};
use std::collections::HashSet;

#[test]
fn accumulate() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

use std::sync::Arc;

use arrow::array::AsArray;
use arrow_array::{ArrayRef, BooleanArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
use arrow::buffer::BooleanBuffer;
use datafusion_common::Result;
use datafusion_expr::{EmitTo, GroupsAccumulator};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Utilities for implementing GroupsAccumulator

pub mod accumulate;
pub mod bool_op;
pub mod prim_op;
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

use std::sync::Arc;

use arrow::{array::AsArray, datatypes::ArrowPrimitiveType};
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
use arrow_schema::DataType;
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr::{EmitTo, GroupsAccumulator};

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

pub mod groups_accumulator;
pub mod stats;
pub mod utils;

Expand Down
162 changes: 161 additions & 1 deletion datafusion/physical-expr-common/src/aggregate/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@

use std::{any::Any, sync::Arc};

use arrow::datatypes::ArrowNativeType;
use arrow::{
array::{ArrayRef, ArrowNativeTypeOp, AsArray},
compute::SortOptions,
datatypes::{DataType, Field},
datatypes::{
DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
ToByteSlice,
},
};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;

use crate::sort_expr::PhysicalSortExpr;

Expand All @@ -43,6 +51,60 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
}
}

/// Convert scalar values from an accumulator into arrays.
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}

/// Adjust array type metadata if needed
///
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
/// default precision and scale, this function adjusts the output to
/// match `data_type`, if necessary
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
// no adjustment needed for other arrays
_ => array,
};
Ok(array)
}

/// Construct corresponding fields for lexicographical ordering requirement expression
pub fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
Expand All @@ -67,3 +129,101 @@ pub fn ordering_fields(
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
ordering_req.iter().map(|item| item.options).collect()
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone, Debug)]
pub struct Hashable<T>(pub T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}

/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
///
/// This is needed because different precisions for Decimal128/Decimal256 can
/// store different ranges of values and thus sum/count may not fit in
/// the target type.
///
/// For example, the precision is 3, the max of value is `999` and the min
/// value is `-999`
pub struct DecimalAverager<T: DecimalType> {
/// scale factor for sum values (10^sum_scale)
sum_mul: T::Native,
/// scale factor for target (10^target_scale)
target_mul: T::Native,
/// the output precision
target_precision: u8,
}

impl<T: DecimalType> DecimalAverager<T> {
/// Create a new `DecimalAverager`:
///
/// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
/// * target_precision: the output precision
/// * target_scale: the output scale
///
/// Errors if the resulting data can not be stored
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(sum_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute sum_mul in DecimalAverager".to_string(),
))?;

let target_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(target_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute target_mul in DecimalAverager".to_string(),
))?;

if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_precision,
})
} else {
// can't convert the lit decimal to the returned data type
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}

/// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
/// target_scale and target_precision and reporting overflow.
///
/// * sum: The total sum value stored as Decimal128 with sum_scale
/// (passed to `Self::try_new`)
/// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
#[inline(always)]
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
let new_value = value.div_wrapping(count);

let validate =
T::validate_decimal_precision(new_value, self.target_precision);

if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
// can't convert the lit decimal to the returned data type
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}
4 changes: 1 addition & 3 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ encoding_expressions = ["base64", "hex"]
regex_expressions = ["regex"]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
Expand Down
17 changes: 13 additions & 4 deletions datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
// specific language governing permissions and limitations
// under the License.

pub(crate) mod accumulate;
mod adapter;
pub use accumulate::NullState;
pub use adapter::GroupsAccumulatorAdapter;

pub(crate) mod bool_op;
pub(crate) mod prim_op;
// Backward compatibility
pub(crate) mod accumulate {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState};
}

pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;

pub(crate) mod bool_op {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
}
pub(crate) mod prim_op {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
}
7 changes: 6 additions & 1 deletion datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ pub(crate) mod variance;

pub mod build_in;
pub mod moving_min_max;
pub mod utils;
pub mod utils {
pub use datafusion_physical_expr_common::aggregate::utils::{
adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays,
get_sort_options, ordering_fields, DecimalAverager, Hashable,
};
}

/// Checks whether the given aggregate expression is order-sensitive.
/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs.
Expand Down
Loading

0 comments on commit 2eb38bd

Please sign in to comment.