From 0f1b64ca75898cb0f179b70e1b55f953b306a858 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sun, 25 Aug 2024 09:03:18 +0200 Subject: [PATCH] Improve documenataion and comments --- .../groups_accumulator/accumulate.rs | 69 +++++++++++-------- .../functions-aggregate/src/variance.rs | 4 +- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 5206ab48fcfb..a0475fe8e446 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -91,36 +91,9 @@ impl NullState { /// * `opt_filter`: if present, only rows for which is Some(true) are included /// * `value_fn`: function invoked for (group_index, value) where value is non null /// - /// # Example + /// See [`accumulate`], for more details on how value_fn is called /// - /// ```text - /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ - /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ - /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ - /// │ └─────┘ │ │ └─────┘ │ └─────┘ - /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ - /// - /// group_indices values opt_filter - /// ``` - /// - /// In the example above, `value_fn` is invoked for each (group_index, - /// value) pair where `opt_filter[i]` is true and values is non null - /// - /// ```text - /// value_fn(2, 200) - /// value_fn(0, 200) - /// value_fn(0, 300) - /// ``` - /// - /// It also sets + /// When value_fn is called it also sets /// /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale pub fn accumulate( @@ -260,6 +233,44 @@ impl NullState { } } +/// Invokes `value_fn(group_index, value)` for each non null, non +/// filtered value of `value`, +/// +/// # Arguments: +/// +/// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) +/// * `values`: the input arguments to the accumulator +/// * `opt_filter`: if present, only rows for which is Some(true) are included +/// * `value_fn`: function invoked for (group_index, value) where value is non null +/// +/// # Example +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ +/// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ +/// │ └─────┘ │ │ └─────┘ │ └─────┘ +/// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ +/// +/// group_indices values opt_filter +/// ``` +/// +/// In the example above, `value_fn` is invoked for each (group_index, +/// value) pair where `opt_filter[i]` is true and values is non null +/// +/// ```text +/// value_fn(2, 200) +/// value_fn(0, 200) +/// value_fn(0, 300) +/// ``` pub fn accumulate( group_indices: &[usize], values: &PrimitiveArray, diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 32e54d92e6c3..f5f2d06e3837 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -464,6 +464,8 @@ impl VarianceGroupsAccumulator { emit_to: datafusion_expr::EmitTo, ) -> (Vec, NullBuffer) { let mut counts = emit_to.take_needed(&mut self.counts); + // means are only needed for updating m2s and are not needed for the final result. + // But we still need to take them to ensure the internal state is consistent. let _ = emit_to.take_needed(&mut self.means); let m2s = emit_to.take_needed(&mut self.m2s); @@ -517,7 +519,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); - // first batch is counts, second is partial sums + // first batch is counts, second is partial means, third is partial m2s let partial_counts = downcast_value!(values[0], UInt64Array); let partial_means = downcast_value!(values[1], Float64Array); let partial_m2s = downcast_value!(values[2], Float64Array);