Skip to content

Commit

Permalink
adding init
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Sep 22, 2024
1 parent 06ace1b commit d874c3c
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ rust-version = "1.76.0"

[dependencies]
datafusion = "42"
paste = "1"


[lints.clippy]
dbg_macro = "deny"
Expand Down
66 changes: 66 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#[macro_export]
macro_rules! make_udaf_expr {
($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
#[doc = $DOC]
pub fn $EXPR_FN(
$($arg: datafusion::logical_expr::Expr,)*
) -> datafusion::logical_expr::Expr {
datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
false,
None,
None,
None,
))
}
};
}

#[macro_export]
macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default());
};
($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => {
paste::paste! {
#[allow(non_upper_case_globals)]
static [< STATIC_ $UDAF >]: std::sync::OnceLock<std::sync::Arc<datafusion::logical_expr::AggregateUDF>> =
std::sync::OnceLock::new();

#[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion::logical_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")]
pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<datafusion::logical_expr::AggregateUDF> {
[< STATIC_ $UDAF >]
.get_or_init(|| {
std::sync::Arc::new(datafusion::logical_expr::AggregateUDF::from($CREATE))
})
.clone()
}
}
}
}

#[macro_export]
macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN);
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<datafusion::logical_expr::Expr>,
) -> datafusion::logical_expr::Expr {
datafusion::logical_expr::Expr::AggregateFunction(AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
false,
None,
None,
None,
))
}

create_func!($UDAF, $AGGREGATE_UDF_FN);
};
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod common;
pub mod max_min_by;
189 changes: 189 additions & 0 deletions src/max_min_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use crate::create_func;
use crate::make_udaf_expr;
use crate::make_udaf_expr_and_func;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result};
use datafusion::functions_aggregate::first_last::last_value_udaf;
use datafusion::logical_expr::expr;
use datafusion::logical_expr::expr::AggregateFunction;
use datafusion::logical_expr::expr::Sort;
use datafusion::logical_expr::function;
use datafusion::logical_expr::function::AccumulatorArgs;
use datafusion::logical_expr::simplify::SimplifyInfo;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::Volatility;
use datafusion::logical_expr::{AggregateUDFImpl, Signature};
use datafusion::physical_plan::Accumulator;
use std::any::Any;
use std::fmt::Debug;
use std::ops::Deref;

make_udaf_expr_and_func!(
MaxByFunction,
max_by,
x y,
"Returns the value of the first column corresponding to the maximum value in the second column.",
max_by_udaf
);

pub struct MaxByFunction {
signature: Signature,
}

impl Debug for MaxByFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("MaxBy")
.field("name", &self.name())
.field("signature", &self.signature)
.field("accumulator", &"<FUNC>")
.finish()
}
}
impl Default for MaxByFunction {
fn default() -> Self {
Self::new()
}
}

impl MaxByFunction {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}

fn get_min_max_by_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
_ => Ok(input_types.to_vec()),
}
}

impl AggregateUDFImpl for MaxByFunction {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"max_by"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].to_owned())
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
exec_err!("should not reach here")
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
get_min_max_by_result_type(arg_types)
}

fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
let mut order_by = aggr_func.order_by.unwrap_or_default();
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));

order_by.push(Sort::new(second_arg, true, false));

Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
last_value_udaf(),
vec![first_arg],
aggr_func.distinct,
aggr_func.filter,
Some(order_by),
aggr_func.null_treatment,
)))
};
Some(Box::new(simplify))
}
}

make_udaf_expr_and_func!(
MinByFunction,
min_by,
x y,
"Returns the value of the first column corresponding to the minimum value in the second column.",
min_by_udaf
);

pub struct MinByFunction {
signature: Signature,
}

impl Debug for MinByFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("MinBy")
.field("name", &self.name())
.field("signature", &self.signature)
.field("accumulator", &"<FUNC>")
.finish()
}
}

impl Default for MinByFunction {
fn default() -> Self {
Self::new()
}
}

impl MinByFunction {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for MinByFunction {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"min_by"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].to_owned())
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
exec_err!("should not reach here")
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
get_min_max_by_result_type(arg_types)
}

fn simplify(&self) -> Option<function::AggregateFunctionSimplification> {
let simplify = |mut aggr_func: expr::AggregateFunction, _: &dyn SimplifyInfo| {
let mut order_by = aggr_func.order_by.unwrap_or_default();
let (second_arg, first_arg) = (aggr_func.args.remove(1), aggr_func.args.remove(0));

order_by.push(Sort::new(second_arg, false, false)); // false for ascending sort

Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
last_value_udaf(),
vec![first_arg],
aggr_func.distinct,
aggr_func.filter,
Some(order_by),
aggr_func.null_treatment,
)))
};
Some(Box::new(simplify))
}
}

0 comments on commit d874c3c

Please sign in to comment.