diff --git a/README.md b/README.md index f48be50..9a4efb2 100644 --- a/README.md +++ b/README.md @@ -86,3 +86,4 @@ SELECT min_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y); - [x] `min_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the minimum value of `expression2`. - [x] `skewness(expression) -> scalar` - Computes the skewness value for `expression`. - [x] `kurtois_pop(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) without bias correction. +- [x] `kurtosis(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) with bias correction according to the sample size. diff --git a/src/kurtosis.rs b/src/kurtosis.rs new file mode 100644 index 0000000..be5a722 --- /dev/null +++ b/src/kurtosis.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Float64Array, UInt64Array}; +use arrow::datatypes::{DataType, Field}; +use datafusion::arrow; +use std::any::Any; +use std::fmt::Debug; + +use datafusion::common::cast::as_float64_array; +use datafusion::common::downcast_value; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion::scalar::ScalarValue; + +make_udaf_expr_and_func!( + KurtosisFunction, + kurtosis, + x, + "Calculates the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.", + kurtosis_udaf +); + +pub struct KurtosisFunction { + signature: Signature, +} + +impl Debug for KurtosisFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KurtosisFunction") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for KurtosisFunction { + fn default() -> Self { + Self::new() + } +} + +impl KurtosisFunction { + pub fn new() -> Self { + Self { + signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for KurtosisFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "kurtosis" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(KurtosisAccumulator::new())) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + Field::new("sum_four", DataType::Float64, true), + ]) + } +} + +/// Accumulator for calculating the excess kurtosis (Fisher’s definition) with bias correction according to the sample size. +/// This implementation follows the [DuckDB implementation]: +/// +#[derive(Debug, Default)] +pub struct KurtosisAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, + sum_four: f64, +} + +impl KurtosisAccumulator { + pub fn new() -> Self { + Self { + count: 0, + sum: 0.0, + sum_sqr: 0.0, + sum_cub: 0.0, + sum_four: 0.0, + } + } +} + +impl Accumulator for KurtosisAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = as_float64_array(&values[0])?; + for value in array.iter().flatten() { + self.count += 1; + self.sum += value; + self.sum_sqr += value.powi(2); + self.sum_cub += value.powi(3); + self.sum_four += value.powi(4); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], Float64Array); + let sum_sqrs = downcast_value!(states[2], Float64Array); + let sum_cubs = downcast_value!(states[3], Float64Array); + let sum_fours = downcast_value!(states[4], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + self.sum_four += sum_fours.value(i); + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count <= 3 { + return Ok(ScalarValue::Float64(None)); + } + + let count_64 = 1_f64 / self.count as f64; + let m4 = count_64 + * (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64 + + 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2) + - 3.0 * self.sum.powi(4) * count_64.powi(3)); + + let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64; + if m2 <= 0.0 { + return Ok(ScalarValue::Float64(None)); + } + + let count = self.count as f64; + let numerator = (count - 1.0) * ((count + 1.0) * m4 / m2.powi(2) - 3.0 * (count - 1.0)); + let denominator = (count - 2.0) * (count - 3.0); + + let target = numerator / denominator; + + Ok(ScalarValue::Float64(Some(target))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ScalarValue::from(self.sum_four), + ]) + } +} diff --git a/src/lib.rs b/src/lib.rs index 2541a0f..7f1516f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,11 +26,13 @@ use datafusion::logical_expr::AggregateUDF; #[macro_use] pub mod macros; pub mod common; +pub mod kurtosis; pub mod kurtosis_pop; pub mod max_min_by; pub mod mode; pub mod skewness; pub mod expr_extra_fn { + pub use super::kurtosis::kurtosis; pub use super::kurtosis_pop::kurtosis_pop; pub use super::max_min_by::max_by; pub use super::max_min_by::min_by; @@ -43,6 +45,7 @@ pub fn all_extra_aggregate_functions() -> Vec> { mode_udaf(), max_min_by::max_by_udaf(), max_min_by::min_by_udaf(), + kurtosis::kurtosis_udaf(), skewness::skewness_udaf(), kurtosis_pop::kurtosis_pop_udaf(), ] diff --git a/tests/main.rs b/tests/main.rs index 98308b0..3a67f56 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -364,3 +364,74 @@ async fn test_skewness() { - +-------------------+ "###); } + +#[tokio::test] +async fn test_kurtosis() { + let mut execution = TestExecution::new().await.unwrap(); + + let actual = execution + .run_and_format("SELECT kurtosis(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col);") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +-------------------+ + - "| kurtosis(tab.col) |" + - +-------------------+ + - "| 4.777292927667962 |" + - +-------------------+ + "###); + + let actual = execution + .run_and_format("SELECT kurtosis(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col);") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +-------------------+ + - "| kurtosis(tab.col) |" + - +-------------------+ + - "| 4.777292927667962 |" + - +-------------------+ + "###); + + let actual = execution + .run_and_format("SELECT kurtosis(col) FROM VALUES (1.0), (2.0), (3.0) as tab(col);") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +-------------------+ + - "| kurtosis(tab.col) |" + - +-------------------+ + - "| |" + - +-------------------+ + "###); + + let actual = execution.run_and_format("SELECT kurtosis(1);").await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +--------------------+ + - "| kurtosis(Int64(1)) |" + - +--------------------+ + - "| |" + - +--------------------+ + "###); + + let actual = execution.run_and_format("SELECT kurtosis(1.0);").await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +----------------------+ + - "| kurtosis(Float64(1)) |" + - +----------------------+ + - "| |" + - +----------------------+ + "###); + + let actual = execution.run_and_format("SELECT kurtosis(null);").await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +----------------+ + - "| kurtosis(NULL) |" + - +----------------+ + - "| |" + - +----------------+ + "###); +}