Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,33 +589,16 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(self.plan)
.with_options(options)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
// For grouping sets we do a project to not expose the internal grouping id
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: !is_grouping_set,
projection_requires_validation: true,
})
}

Expand Down
181 changes: 156 additions & 25 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,16 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::utils::grouping_set_to_exprlist;
use datafusion_expr::{
Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType,
Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame,
WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::aggregate::{
AggregateExpr, AggregateExprBuilder, GroupingExpr,
};
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
create_physical_sort_exprs, LexOrdering, PhysicalSortExpr,
};
Expand Down Expand Up @@ -722,6 +725,12 @@ impl DefaultPhysicalPlanner {
session_state,
)?;

let group_by_expr = if groups.is_single() {
None
} else {
Some(group_expr_to_bitmap_index(group_expr)?)
};

let agg_filter = aggr_expr
.iter()
.map(|e| {
Expand All @@ -730,12 +739,32 @@ impl DefaultPhysicalPlanner {
logical_input_schema,
&physical_input_schema,
session_state.execution_props(),
group_by_expr.as_ref(),
)
})
.collect::<Result<Vec<_>>>()?;

let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(agg_filter);
let no_grouping_agg = agg_filter
.iter()
.filter_map(|(e, filters, order_bys)| {
if matches!(e, AggregateExpr::AggregateFunctionExpr(_)) {
Some((e.clone(), filters.clone(), order_bys.clone()))
} else {
None
}
})
.collect::<Vec<_>>();

let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(no_grouping_agg);

let mut aggregates = aggregates
.into_iter()
.map(|e| match e {
AggregateExpr::AggregateFunctionExpr(e) => e,
_ => unreachable!(),
})
.collect::<Vec<_>>();

let mut async_exprs = Vec::new();
let num_input_columns = physical_input_schema.fields().len();
Expand Down Expand Up @@ -813,22 +842,72 @@ impl DefaultPhysicalPlanner {

let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
let final_agg = Arc::new(AggregateExec::try_new(
next_partition_mode,
final_grouping_set,
updated_aggregates,
filters,
initial_aggr,
Arc::clone(&physical_input_schema),
)?)
)?);

if groups.is_single()
&& !agg_filter
.iter()
.any(|(e, _, _)| matches!(e, AggregateExpr::GroupingExpr(_)))
{
final_agg
} else {
// Need to project out __grouping_id column and compute GROUPING expressions
let mut proj_exprs = Vec::new();
let num_group_exprs = groups.expr().len();

let schema = final_agg.schema();

// Add group columns
for i in 0..num_group_exprs {
let field = schema.field(i);
proj_exprs.push(ProjectionExpr {
expr: Arc::new(Column::new(field.name(), i)),
alias: field.name().to_string(),
});
}

// Skip __grouping_id at position num_group_exprs
// Add aggregate expressions (either computed GROUPING or column references)
let mut agg_col_idx = num_group_exprs + 1; // Start after __grouping_id

for (agg_expr, _, _) in &agg_filter {
match agg_expr {
AggregateExpr::GroupingExpr(grouping_expr) => {
// Use the GroupingExpr directly as a physical expression
proj_exprs.push(ProjectionExpr {
expr: Arc::clone(grouping_expr)
as Arc<dyn PhysicalExpr>,
alias: agg_expr.name().to_string(),
});
}
AggregateExpr::AggregateFunctionExpr(_) => {
// Reference the aggregate function column
let field = schema.field(agg_col_idx);
proj_exprs.push(ProjectionExpr {
expr: Arc::new(Column::new(
field.name(),
agg_col_idx,
)),
alias: field.name().to_string(),
});
agg_col_idx += 1;
}
}
}
Arc::new(ProjectionExec::try_new(proj_exprs, final_agg)?)
}
}
LogicalPlan::Projection(Projection { input, expr, .. }) => {
let child = children.one()?;
self.create_project_physical_exec(session_state, child, input, expr)?
}
LogicalPlan::Projection(Projection { input, expr, .. }) => self
.create_project_physical_exec(
session_state,
children.one()?,
input,
expr,
)?,
LogicalPlan::Filter(Filter {
predicate, input, ..
}) => {
Expand Down Expand Up @@ -1749,23 +1828,68 @@ pub fn create_window_expr(
}

type AggregateExprWithOptionalArgs = (
Arc<AggregateFunctionExpr>,
AggregateExpr,
// The filter clause, if any
Option<Arc<dyn PhysicalExpr>>,
// Expressions in the ORDER BY clause
Vec<PhysicalSortExpr>,
);

/// Create a map from grouping expr to index in the internal grouping id.
///
/// For more details on how the grouping id bitmap works the documentation for
/// [[datafusion_physical_expr::aggregate::INTERNAL_GROUPING_ID]]
fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, usize>> {
Ok(grouping_set_to_exprlist(group_expr)?
.into_iter()
.rev()
.enumerate()
.map(|(idx, v)| (v, idx))
.collect::<HashMap<_, _>>())
}

/// Create an aggregate expression with a name from a logical expression
pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: Option<String>,
human_displan: String,
human_display: String,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
group_by_expr: Option<&HashMap<&Expr, usize>>,
) -> Result<AggregateExprWithOptionalArgs> {
let name = if let Some(name) = name {
name
} else {
physical_name(e)?
};
match e {
Expr::AggregateFunction(AggregateFunction { func, params })
if func.name() == "grouping" =>
{
match group_by_expr {
Some(group_by_expr) => {
let indices = params
.args
.iter()
.map(|expr| match group_by_expr.get(expr) {
Some(idx) => Ok(*idx as i32),
None => plan_err!(
"Grouping function argument {} not in grouping columns",
expr
),
})
.collect::<Result<Vec<i32>>>()?;
let grouping_expr =
GroupingExpr::new(name, human_display, Some(indices));
Ok((Arc::new(grouping_expr).into(), None, Vec::new()))
}
None => {
let grouping_expr = GroupingExpr::new(name, human_display, None);
Ok((Arc::new(grouping_expr).into(), None, Vec::new()))
}
}
}
Expr::AggregateFunction(AggregateFunction {
func,
params:
Expand All @@ -1777,12 +1901,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
null_treatment,
},
}) => {
let name = if let Some(name) = name {
name
} else {
physical_name(e)?
};

let physical_args =
create_physical_exprs(args, logical_input_schema, execution_props)?;
let filter = match filter {
Expand All @@ -1809,7 +1927,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
.order_by(order_bys.clone())
.schema(Arc::new(physical_input_schema.to_owned()))
.alias(name)
.human_display(human_displan)
.human_display(human_display)
.with_ignore_nulls(ignore_nulls)
.with_distinct(*distinct)
.build()
Expand All @@ -1818,7 +1936,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
(agg_expr, filter, order_bys)
};

Ok((agg_expr, filter, order_bys))
Ok((
AggregateExpr::AggregateFunctionExpr(agg_expr),
filter,
order_bys,
))
}
other => internal_err!("Invalid aggregate expression '{other:?}'"),
}
Expand All @@ -1830,6 +1952,7 @@ pub fn create_aggregate_expr_and_maybe_filter(
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
group_by_expr: Option<&HashMap<&Expr, usize>>,
) -> Result<AggregateExprWithOptionalArgs> {
// Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total"
// Some functions like `count_all()` create internal aliases,
Expand All @@ -1854,6 +1977,7 @@ pub fn create_aggregate_expr_and_maybe_filter(
logical_input_schema,
physical_input_schema,
execution_props,
group_by_expr,
)
}

Expand Down Expand Up @@ -2796,13 +2920,20 @@ mod tests {
.build()?;

let execution_plan = plan(&logical_plan).await?;
let final_hash_agg = execution_plan
let projection = execution_plan
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("projection");

let final_hash_agg = projection
.input()
.as_any()
.downcast_ref::<AggregateExec>()
.expect("hash aggregate");

assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(3).name()
projection.schema().field(2).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
Loading