From 94f2a118ecfcb0d191c9ebc98a9028aa32715d8b Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Sat, 28 Sep 2024 00:28:55 +0530 Subject: [PATCH] Add `skewness(x)` in Aggregation function (#6) --- README.md | 1 + src/lib.rs | 3 + src/skewness.rs | 169 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/main.rs | 54 ++++++++++++++++ 4 files changed, 227 insertions(+) create mode 100644 src/skewness.rs diff --git a/README.md b/README.md index db01025..f48be50 100644 --- a/README.md +++ b/README.md @@ -84,4 +84,5 @@ SELECT min_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y); - [x] `mode(expression) -> scalar` - Returns the most frequent (mode) value from a column of data. - [x] `max_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the maximum value of `expression2`. - [x] `min_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the minimum value of `expression2`. +- [x] `skewness(expression) -> scalar` - Computes the skewness value for `expression`. - [x] `kurtois_pop(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) without bias correction. diff --git a/src/lib.rs b/src/lib.rs index d9dd246..2541a0f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,11 +29,13 @@ pub mod common; pub mod kurtosis_pop; pub mod max_min_by; pub mod mode; +pub mod skewness; pub mod expr_extra_fn { pub use super::kurtosis_pop::kurtosis_pop; pub use super::max_min_by::max_by; pub use super::max_min_by::min_by; pub use super::mode::mode; + pub use super::skewness::skewness; } pub fn all_extra_aggregate_functions() -> Vec> { @@ -41,6 +43,7 @@ pub fn all_extra_aggregate_functions() -> Vec> { mode_udaf(), max_min_by::max_by_udaf(), max_min_by::min_by_udaf(), + skewness::skewness_udaf(), kurtosis_pop::kurtosis_pop_udaf(), ] } diff --git a/src/skewness.rs b/src/skewness.rs new file mode 100644 index 0000000..484ed1a --- /dev/null +++ b/src/skewness.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{Float64Type, UInt64Type}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::{function::AccumulatorArgs, function::StateFieldsArgs}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::fmt::Debug; +use std::ops::{Div, Mul, Sub}; + +make_udaf_expr_and_func!(SkewnessFunc, skewness, x, "Computes the skewness value.", skewness_udaf); + +pub struct SkewnessFunc { + name: String, + signature: Signature, +} + +impl Debug for SkewnessFunc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SkewnessFunc") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for SkewnessFunc { + fn default() -> Self { + Self::new() + } +} + +impl SkewnessFunc { + pub fn new() -> Self { + Self { + name: "skewness".to_string(), + signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SkewnessFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> datafusion::common::Result> { + Ok(Box::new(SkewnessAccumulator::new())) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> datafusion::common::Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + ]) + } +} + +/// Accumulator for calculating the skewness +/// This implementation follows the DuckDB implementation: +/// +#[derive(Debug)] +pub struct SkewnessAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, +} + +impl SkewnessAccumulator { + fn new() -> Self { + Self { + count: 0, + sum: 0f64, + sum_sqr: 0f64, + sum_cub: 0f64, + } + } +} + +impl Accumulator for SkewnessAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::common::Result<()> { + let array = values[0].as_primitive::(); + for val in array.iter().flatten() { + self.count += 1; + self.sum += val; + self.sum_sqr += val.powi(2); + self.sum_cub += val.powi(3); + } + Ok(()) + } + fn evaluate(&mut self) -> datafusion::common::Result { + if self.count <= 2 { + return Ok(ScalarValue::Float64(None)); + } + let count = self.count as f64; + let t1 = 1f64 / count; + let p = (t1 * (self.sum_sqr - self.sum * self.sum * t1)).powi(3).max(0f64); + let div = p.sqrt(); + if div == 0f64 { + return Ok(ScalarValue::Float64(None)); + } + let t2 = count.mul(count.sub(1f64)).sqrt().div(count.sub(2f64)); + let res = + t2 * t1 * (self.sum_cub - 3f64 * self.sum_sqr * self.sum * t1 + 2f64 * self.sum.powi(3) * t1 * t1) / div; + Ok(ScalarValue::Float64(Some(res))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> datafusion::common::Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::common::Result<()> { + let counts = states[0].as_primitive::(); + let sums = states[1].as_primitive::(); + let sum_sqrs = states[2].as_primitive::(); + let sum_cubs = states[3].as_primitive::(); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + } + Ok(()) + } +} diff --git a/tests/main.rs b/tests/main.rs index 1a1bbc7..98308b0 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -310,3 +310,57 @@ async fn test_kurtosis_pop() { - +--------------------+ "###); } + +#[tokio::test] +async fn test_skewness() { + let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; + + // Test with int64 + let actual = execution + .run_and_format("SELECT skewness(int64_col) FROM test_table") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +--------------------------------+ + - "| skewness(test_table.int64_col) |" + - +--------------------------------+ + - "| -0.8573214099741201 |" + - +--------------------------------+ + "###); + + // Test with float64 + let actual = execution + .run_and_format("SELECT skewness(float64_col) FROM test_table") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +----------------------------------+ + - "| skewness(test_table.float64_col) |" + - +----------------------------------+ + - "| -0.8573214099741201 |" + - +----------------------------------+ +"###); + + // Test with single value + let actual = execution.run_and_format("SELECT skewness(1.0)").await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +----------------------+ + - "| skewness(Float64(1)) |" + - +----------------------+ + - "| |" + - +----------------------+ + "###); + + let actual = execution + .run_and_format("SELECT skewness(col) FROM VALUES (1.0), (2.0) as tab(col)") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +-------------------+ + - "| skewness(tab.col) |" + - +-------------------+ + - "| |" + - +-------------------+ + "###); +}