Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add approx_percentile_cont() aggregation function #1539

Merged
merged 10 commits into from
Jan 31, 2022
3 changes: 2 additions & 1 deletion ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,12 @@ enum AggregateFunction {
STDDEV=11;
STDDEV_POP=12;
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
}

message AggregateExprNode {
AggregateFunction aggr_function = 1;
LogicalExprNode expr = 2;
repeated LogicalExprNode expr = 2;
}

enum BuiltInWindowFunction {
Expand Down
6 changes: 5 additions & 1 deletion ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {

Ok(Expr::AggregateFunction {
fun,
args: vec![parse_required_expr(&expr.expr)?],
args: expr
.expr
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<_>, _>>()?,
distinct: false, //TODO
})
}
Expand Down
21 changes: 16 additions & 5 deletions ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ mod roundtrip_tests {
use super::super::{super::error::Result, protobuf};
use crate::error::BallistaError;
use core::panic;
use datafusion::arrow::datatypes::UnionMode;
use datafusion::logical_plan::Repartition;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode},
datasource::object_store::local::LocalFileSystem,
logical_plan::{
col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder,
Partitioning, ToDFSchema,
Partitioning, Repartition, ToDFSchema,
},
physical_plan::functions::BuiltinScalarFunction::Sqrt,
physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt},
prelude::*,
scalar::ScalarValue,
sql::parser::FileType,
Expand Down Expand Up @@ -1001,4 +999,17 @@ mod roundtrip_tests {

Ok(())
}

#[test]
fn roundtrip_approx_percentile_cont() -> Result<()> {
let test_expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxPercentileCont,
args: vec![col("bananas"), lit(0.42)],
distinct: false,
};

roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);

Ok(())
}
}
14 changes: 10 additions & 4 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
AggregateFunction::ApproxPercentileCont => {
protobuf::AggregateFunction::ApproxPercentileCont
}
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
Expand All @@ -1099,11 +1102,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};

let arg = &args[0];
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
let aggregate_expr = protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
expr: Some(Box::new(arg.try_into()?)),
});
expr: args
.iter()
.map(|v| v.try_into())
.collect::<Result<Vec<_>, _>>()?,
};
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
})
Expand Down Expand Up @@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
protobuf::AggregateFunction::ApproxPercentileCont => {
AggregateFunction::ApproxPercentileCont
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
}
}

/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxPercentileCont,
distinct: false,
args: vec![expr, percentile],
}
}

// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
Expand Down
14 changes: 7 additions & 7 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ pub use builder::{
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr,
bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf,
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col,
columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct,
create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields,
floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2,
lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length,
or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
Expand Down
87 changes: 83 additions & 4 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.

use super::{
functions::{Signature, Volatility},
functions::{Signature, TypeSignature, Volatility},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
Expand Down Expand Up @@ -80,6 +80,8 @@ pub enum AggregateFunction {
CovariancePop,
/// Correlation
Correlation,
/// Approximate continuous percentile function
ApproxPercentileCont,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction {
"covar_samp" => AggregateFunction::Covariance,
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -157,6 +160,7 @@ pub fn return_type(
coerced_data_types[0].clone(),
true,
)))),
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
}
}

Expand Down Expand Up @@ -331,6 +335,20 @@ pub fn create_aggregate_expr(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::ApproxPercentileCont, false) => {
Arc::new(expressions::ApproxPercentileCont::new(
// Pass in the desired percentile expr
coerced_phy_exprs,
name,
return_type,
)?)
}
(AggregateFunction::ApproxPercentileCont, true) => {
return Err(DataFusionError::NotImplemented(
"approx_percentile_cont(DISTINCT) aggregations are not available"
.to_string(),
));
}
})
}

Expand Down Expand Up @@ -389,17 +407,25 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => Signature::one_of(
// Accept any numeric value paired with a float64 percentile
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
domodwyer marked this conversation as resolved.
Show resolved Hide resolved
.collect(),
Volatility::Immutable,
),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::{
ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg,
DistinctCount, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count,
Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
};
use crate::{error::Result, scalar::ScalarValue};

#[test]
fn test_count_arragg_approx_expr() -> Result<()> {
Expand Down Expand Up @@ -513,6 +539,59 @@ mod tests {
Ok(())
}

#[test]
fn test_agg_approx_percentile_phy_expr() {
for data_type in NUMERICS {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
];
let result_agg_phy_exprs = create_aggregate_expr(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
&input_schema,
"c1",
)
.expect("failed to create aggregate expr");

assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), false),
result_agg_phy_exprs.field().unwrap()
);
}
}

#[test]
fn test_agg_approx_percentile_invalid_phy_expr() {
for data_type in NUMERICS {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
];
let err = create_aggregate_expr(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
&input_schema,
"c1",
)
.expect_err("should fail due to invalid percentile");

assert!(matches!(err, DataFusionError::Plan(_)));
}
}

#[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
Expand Down
Loading