Skip to content

Commit

Permalink
perf: Unset maintain_order pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 21, 2024
1 parent dbd2c5f commit 9346989
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 17 deletions.
12 changes: 2 additions & 10 deletions crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,7 @@ pub(crate) fn insert_streaming_nodes(
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
HStack { input, exprs, .. }
if exprs
.iter()
.all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) =>
{
HStack { input, exprs, .. } if all_elementwise(exprs, expr_arena) => {
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand All @@ -198,11 +194,7 @@ pub(crate) fn insert_streaming_nodes(
state.operators_sinks.push(PipelineNode::Sink(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
Select { input, expr, .. }
if expr
.iter()
.all(|e| is_elementwise_rec(expr_arena.get(e.node()), expr_arena)) =>
{
Select { input, expr, .. } if all_elementwise(expr, expr_arena) => {
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/plans/aexpr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<
true
}

pub fn all_elementwise<'a, N>(nodes: &'a [N], expr_arena: &Arena<AExpr>) -> bool
where
Node: From<&'a N>,
{
nodes
.iter()
.all(|n| is_elementwise_rec(expr_arena.get(n.into()), expr_arena))
}

/// Recursive variant of `is_elementwise`
pub fn is_elementwise_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>) -> bool {
let mut stack = unitvec![];
Expand Down
7 changes: 1 addition & 6 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,9 @@ pub fn resolve_join(
}
// Every expression must be elementwise so that we are
// guaranteed the keys for a join are all the same length.
let all_elementwise = |aexprs: &[ExprIR]| {
aexprs
.iter()
.all(|e| is_elementwise_rec(ctxt.expr_arena.get(e.node()), ctxt.expr_arena))
};

polars_ensure!(
all_elementwise(&left_on) && all_elementwise(&right_on),
all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),
InvalidOperation: "all join key expressions must be elementwise."
);

Expand Down
15 changes: 15 additions & 0 deletions crates/polars-plan/src/plans/optimizer/collect_members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub(super) struct MemberCollector {
pub(crate) has_cache: bool,
pub(crate) has_ext_context: bool,
pub(crate) has_filter_with_join_input: bool,
pub(crate) has_distinct: bool,
pub(crate) has_sort: bool,
pub(crate) has_group_by: bool,
#[cfg(feature = "cse")]
scans: UniqueScans,
}
Expand All @@ -38,6 +41,9 @@ impl MemberCollector {
has_cache: false,
has_ext_context: false,
has_filter_with_join_input: false,
has_distinct: false,
has_sort: false,
has_group_by: false,
#[cfg(feature = "cse")]
scans: UniqueScans::default(),
}
Expand All @@ -50,6 +56,15 @@ impl MemberCollector {
Filter { input, .. } => {
self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how.is_cross())
},
Distinct { .. } => {
self.has_distinct = true;
},
GroupBy { .. } => {
self.has_group_by = true;
},
Sort { .. } => {
self.has_sort = true;
},
Cache { .. } => self.has_cache = true,
ExtContext { .. } => self.has_ext_context = true,
#[cfg(feature = "cse")]
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/plans/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod fused;
mod join_utils;
mod predicate_pushdown;
mod projection_pushdown;
mod set_order;
mod simplify_expr;
mod slice_pushdown_expr;
mod slice_pushdown_lp;
Expand All @@ -34,6 +35,7 @@ use slice_pushdown_lp::SlicePushDown;
pub use stack_opt::{OptimizationRule, StackOptimizer};

use self::flatten_union::FlattenUnionRule;
use self::set_order::set_order_flags;
pub use crate::frame::{AllowedOptimizations, OptFlags};
pub use crate::plans::conversion::type_coercion::TypeCoercionRule;
use crate::plans::optimizer::count_star::CountStar;
Expand Down Expand Up @@ -208,6 +210,10 @@ pub fn optimize(
cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, expr_eval, verbose)?;
}

if members.has_group_by | members.has_sort | members.has_distinct {
set_order_flags(lp_top, lp_arena, expr_arena, scratch);
}

// This one should run (nearly) last as this modifies the projections
#[cfg(feature = "cse")]
if comm_subexpr_elim && !members.has_ext_context {
Expand Down
134 changes: 134 additions & 0 deletions crates/polars-plan/src/plans/optimizer/set_order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use polars_utils::unitvec;

use super::*;

// Can give false positives.
fn is_order_dependent_top_level(ae: &AExpr, ctx: Context) -> bool {
match ae {
AExpr::Agg(agg) => match agg {
IRAggExpr::Min { .. } => false,
IRAggExpr::Max { .. } => false,
IRAggExpr::Median(_) => false,
IRAggExpr::NUnique(_) => false,
IRAggExpr::First(_) => true,
IRAggExpr::Last(_) => true,
IRAggExpr::Mean(_) => false,
IRAggExpr::Implode(_) => false,
IRAggExpr::Quantile { .. } => false,
IRAggExpr::Sum(_) => false,
IRAggExpr::Count(_, _) => false,
IRAggExpr::Std(_, _) => false,
IRAggExpr::Var(_, _) => false,
IRAggExpr::AggGroups(_) => true,
},
AExpr::Column(_) => matches!(ctx, Context::Aggregation),
_ => true,
}
}

// Can give false positives.
fn is_order_dependent<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<AExpr>, ctx: Context) -> bool {
let mut stack = unitvec![];

loop {
if !is_order_dependent_top_level(ae, ctx) {
return false;
}

let Some(node) = stack.pop() else {
break;
};

ae = expr_arena.get(node);
}

true
}

// Can give false negatives.
pub(crate) fn all_order_independent<'a, N>(
nodes: &'a [N],
expr_arena: &Arena<AExpr>,
ctx: Context,
) -> bool
where
Node: From<&'a N>,
{
nodes
.iter()
.all(|n| !is_order_dependent(expr_arena.get(n.into()), expr_arena, ctx))
}

pub(super) fn set_order_flags(
root: Node,
ir_arena: &mut Arena<IR>,
expr_arena: &Arena<AExpr>,
scratch: &mut Vec<Node>,
) {
scratch.clear();
scratch.push(root);

let mut maintain_order_above = true;

while let Some(node) = scratch.pop() {
let ir = ir_arena.get_mut(node);
ir.copy_inputs(scratch);

match ir {
IR::Sort { .. } => {
maintain_order_above = false;
},
IR::Distinct { options, .. } => {
if !maintain_order_above {
options.maintain_order = false;
continue;
}
if !options.maintain_order {
maintain_order_above = false;
}
},
IR::Union { options, .. } => {
options.maintain_order = maintain_order_above;
},
IR::GroupBy {
keys,
aggs,
maintain_order,
options,
apply,
..
} => {
if !maintain_order_above && *maintain_order {
*maintain_order = false;
continue;
}

if apply.is_some()
|| *maintain_order
|| options.rolling.is_some()
|| options.dynamic.is_some()
{
maintain_order_above = true;
continue;
}
if all_elementwise(keys, expr_arena)
&& !all_order_independent(aggs, expr_arena, Context::Aggregation)
{
maintain_order_above = false;
continue;
}
maintain_order_above = true;
},
// Conservative now.
IR::HStack { exprs, .. } | IR::Select { expr: exprs, .. } => {
if !maintain_order_above && all_elementwise(exprs, expr_arena) {
continue;
}
maintain_order_above = true;
},
_ => {
maintain_order_above = true;
},
}
}
}
6 changes: 5 additions & 1 deletion crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ pub struct FileScanOptions {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct UnionOptions {
pub slice: Option<(i64, usize)>,
pub parallel: bool,
// known row_output, estimated row output
pub rows: (Option<usize>, usize),
pub parallel: bool,
pub from_partitioned_ds: bool,
pub flattened_by_opt: bool,
pub rechunk: bool,
pub maintain_order: bool,
}

#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)]
Expand Down Expand Up @@ -387,6 +388,7 @@ pub struct UnionArgs {
pub diagonal: bool,
// If it is a union from a scan over multiple files.
pub from_partitioned_ds: bool,
pub maintain_order: bool,
}

impl Default for UnionArgs {
Expand All @@ -397,6 +399,7 @@ impl Default for UnionArgs {
to_supertypes: false,
diagonal: false,
from_partitioned_ds: false,
maintain_order: true,
}
}
}
Expand All @@ -410,6 +413,7 @@ impl From<UnionArgs> for UnionOptions {
from_partitioned_ds: args.from_partitioned_ds,
flattened_by_opt: false,
rechunk: args.rechunk,
maintain_order: args.maintain_order,
}
}
}
Expand Down

0 comments on commit 9346989

Please sign in to comment.