From d874c3caccbab568be2819d647d4c845a67e3c38 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Sat, 21 Sep 2024 21:48:16 -0700 Subject: [PATCH] adding init --- Cargo.toml | 2 + src/common.rs | 66 ++++++++++++++++ src/lib.rs | 2 + src/max_min_by.rs | 189 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 259 insertions(+) create mode 100644 src/common.rs create mode 100644 src/lib.rs create mode 100644 src/max_min_by.rs diff --git a/Cargo.toml b/Cargo.toml index bcc16b2..127d997 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ rust-version = "1.76.0" [dependencies] datafusion = "42" +paste = "1" + [lints.clippy] dbg_macro = "deny" diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..a4ed722 --- /dev/null +++ b/src/common.rs @@ -0,0 +1,66 @@ +#[macro_export] +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + #[doc = $DOC] + pub fn $EXPR_FN( + $($arg: datafusion::logical_expr::Expr,)* + ) -> datafusion::logical_expr::Expr { + datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + false, + None, + None, + None, + )) + } + }; +} + +#[macro_export] +macro_rules! create_func { + ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { + paste::paste! { + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDAF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + #[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion::logical_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")] + pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDAF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion::logical_expr::AggregateUDF::from($CREATE)) + }) + .clone() + } + } + } +} + +#[macro_export] +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; + ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + #[doc = $DOC] + pub fn $EXPR_FN( + args: Vec, + ) -> datafusion::logical_expr::Expr { + datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + args, + false, + None, + None, + None, + )) + } + + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c8dfb07 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod common; +pub mod max_min_by; diff --git a/src/max_min_by.rs b/src/max_min_by.rs new file mode 100644 index 0000000..8436efb --- /dev/null +++ b/src/max_min_by.rs @@ -0,0 +1,189 @@ +use crate::create_func; +use crate::make_udaf_expr; +use crate::make_udaf_expr_and_func; +use datafusion::arrow::datatypes::DataType; +use datafusion::common::{exec_err, Result}; +use datafusion::functions_aggregate::first_last::last_value_udaf; +use datafusion::logical_expr::expr; +use datafusion::logical_expr::expr::AggregateFunction; +use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::function; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::simplify::SimplifyInfo; +use datafusion::logical_expr::Expr; +use datafusion::logical_expr::Volatility; +use datafusion::logical_expr::{AggregateUDFImpl, Signature}; +use datafusion::physical_plan::Accumulator; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Deref; + +make_udaf_expr_and_func!( + MaxByFunction, + max_by, + x y, + "Returns the value of the first column corresponding to the maximum value in the second column.", + max_by_udaf +); + +pub struct MaxByFunction { + signature: Signature, +} + +impl Debug for MaxByFunction { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("MaxBy") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() + } +} +impl Default for MaxByFunction { + fn default() -> Self { + Self::new() + } +} + +impl MaxByFunction { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +fn get_min_max_by_result_type(input_types: &[DataType]) -> Result> { + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + _ => Ok(input_types.to_vec()), + } +} + +impl AggregateUDFImpl for MaxByFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "max_by" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + exec_err!("should not reach here") + } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_by_result_type(arg_types) + } + + fn simplify(&self) -> Option { + let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| { + let mut order_by = aggr_func.order_by.unwrap_or_default(); + let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0)); + + order_by.push(Sort::new(second_arg, true, false)); + + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + last_value_udaf(), + vec![first_arg], + aggr_func.distinct, + aggr_func.filter, + Some(order_by), + aggr_func.null_treatment, + ))) + }; + Some(Box::new(simplify)) + } +} + +make_udaf_expr_and_func!( + MinByFunction, + min_by, + x y, + "Returns the value of the first column corresponding to the minimum value in the second column.", + min_by_udaf +); + +pub struct MinByFunction { + signature: Signature, +} + +impl Debug for MinByFunction { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("MinBy") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() + } +} + +impl Default for MinByFunction { + fn default() -> Self { + Self::new() + } +} + +impl MinByFunction { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for MinByFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "min_by" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + exec_err!("should not reach here") + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_by_result_type(arg_types) + } + + fn simplify(&self) -> Option { + let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| { + let mut order_by = aggr_func.order_by.unwrap_or_default(); + let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0)); + + order_by.push(Sort::new(second_arg, false, false)); // false for ascending sort + + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + last_value_udaf(), + vec![first_arg], + aggr_func.distinct, + aggr_func.filter, + Some(order_by), + aggr_func.null_treatment, + ))) + }; + Some(Box::new(simplify)) + } +}