Skip to content

Commit

Permalink
fix: Fix join literal behavior (#20477)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 27, 2024
1 parent 4539173 commit 8e27477
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 97 deletions.
6 changes: 5 additions & 1 deletion crates/polars-core/src/frame/column/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,11 @@ impl ScalarColumn {
pub fn from_single_value_series(series: Series, length: usize) -> Self {
debug_assert!(series.len() <= 1);

let value = series.get(0).map_or(AnyValue::Null, |av| av.into_static());
let value = if series.is_empty() {
AnyValue::Null
} else {
unsafe { series.get_unchecked(0) }.into_static()
};
let value = Scalar::new(series.dtype().clone(), value);
ScalarColumn::new(series.name().clone(), value, length)
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use polars_utils::pl_str::PlSmallStr;

pub static MAP_LIST_NAME: &str = "map_list";
pub static CSE_REPLACED: &str = "__POLARS_CSER_";
pub static POLARS_TMP_PREFIX: &str = "_POLARS_";
pub const LEN: &str = "len";
const LITERAL_NAME: &str = "literal";
pub const UNLIMITED_CACHE: u32 = u32::MAX;
Expand Down
14 changes: 3 additions & 11 deletions crates/polars-plan/src/plans/builder_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,13 @@ impl<'a> IRBuilder<'a> {
let schema_left = self.schema();
let schema_right = self.lp_arena.get(other).schema(self.lp_arena);

let left_on_exprs = left_on
.iter()
.map(|e| e.to_expr(self.expr_arena))
.collect::<Vec<_>>();
let right_on_exprs = right_on
.iter()
.map(|e| e.to_expr(self.expr_arena))
.collect::<Vec<_>>();

let schema = det_join_schema(
&schema_left,
&schema_right,
&left_on_exprs,
&right_on_exprs,
&left_on,
&right_on,
&options,
self.expr_arena,
)
.unwrap();

Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
ctxt,
)
.map_err(|e| e.context(failed_here!(join)))
.map(|t| t.0)
},
DslPlan::HStack {
input,
Expand Down
173 changes: 144 additions & 29 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ use either::Either;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::error::feature_gated;
use polars_core::utils::get_numeric_upcast_supertype_lossless;
use polars_utils::format_pl_smallstr;
use polars_utils::itertools::Itertools;

use super::*;
use crate::constants::POLARS_TMP_PREFIX;
use crate::dsl::Expr;
#[cfg(feature = "iejoin")]
use crate::plans::AExpr;
Expand All @@ -20,6 +23,8 @@ fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
}
Ok(())
}

/// Returns: left: join_node, right: last_node (often both the same)
pub fn resolve_join(
input_left: Either<Arc<DslPlan>, Node>,
input_right: Either<Arc<DslPlan>, Node>,
Expand All @@ -28,7 +33,7 @@ pub fn resolve_join(
predicates: Vec<Expr>,
mut options: Arc<JoinOptions>,
ctxt: &mut DslConversionContext,
) -> PolarsResult<Node> {
) -> PolarsResult<(Node, Node)> {
if !predicates.is_empty() {
feature_gated!("iejoin", {
debug_assert!(left_on.is_empty() && right_on.is_empty());
Expand Down Expand Up @@ -101,9 +106,6 @@ pub fn resolve_join(
);
}

let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options)
.map_err(|e| e.context(failed_here!(join schema resolving)))?;

let mut left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?;
let mut right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?;
let mut joined_on = PlHashSet::new();
Expand Down Expand Up @@ -147,15 +149,102 @@ pub fn resolve_join(
.get_type($schema, Context::Default, ctxt.expr_arena)
};
}
// # Resolve scalars
//
// Scalars need to be expanded. We translate them to temporary columns added with
// `with_columns` and remove them later with `project`
// This way the backends don't have to expand the literals in the join implementation

let has_scalars = left_on
.iter()
.chain(right_on.iter())
.any(|e| e.is_scalar(ctxt.expr_arena));

let (schema_left, schema_right) = if has_scalars {
let mut as_with_columns_l = vec![];
let mut as_with_columns_r = vec![];
for (i, e) in left_on.iter().enumerate() {
if e.is_scalar(ctxt.expr_arena) {
as_with_columns_l.push((i, e.clone()));
}
}
for (i, e) in right_on.iter().enumerate() {
if e.is_scalar(ctxt.expr_arena) {
as_with_columns_r.push((i, e.clone()));
}
}

let mut count = 0;
let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}");

// Early clone because of bck.
let mut schema_right_new = if !as_with_columns_r.is_empty() {
(**schema_right).clone()
} else {
Default::default()
};
if !as_with_columns_l.is_empty() {
let mut schema_left_new = (**schema_left).clone();

let mut exprs = Vec::with_capacity(as_with_columns_l.len());
for (i, mut e) in as_with_columns_l {
let tmp_name = get_tmp_name(count);
count += 1;
e.set_alias(tmp_name.clone());
let dtype = e.dtype(&schema_left_new, Context::Default, ctxt.expr_arena)?;
schema_left_new.with_column(tmp_name.clone(), dtype.clone());

let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
left_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
exprs.push(e);
}
input_left = ctxt.lp_arena.add(IR::HStack {
input: input_left,
exprs,
schema: Arc::new(schema_left_new),
options: ProjectionOptions::default(),
})
}
if !as_with_columns_r.is_empty() {
let mut exprs = Vec::with_capacity(as_with_columns_r.len());
for (i, mut e) in as_with_columns_r {
let tmp_name = get_tmp_name(count);
count += 1;
e.set_alias(tmp_name.clone());
let dtype = e.dtype(&schema_right_new, Context::Default, ctxt.expr_arena)?;
schema_right_new.with_column(tmp_name.clone(), dtype.clone());

let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));
right_on[i] = ExprIR::from_node(col, ctxt.expr_arena);
exprs.push(e);
}
input_right = ctxt.lp_arena.add(IR::HStack {
input: input_right,
exprs,
schema: Arc::new(schema_right_new),
options: ProjectionOptions::default(),
})
}

(
ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena),
ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena),
)
} else {
(schema_left, schema_right)
};

// # Cast lossless
//
// If we do a full join and keys are coalesced, the casted keys must be added up front.
let key_cols_coalesced =
options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);
let mut to_cast_left = vec![];
let mut to_cast_right = vec![];
let mut to_cast_indices = vec![];
let mut as_with_columns_l = vec![];
let mut as_with_columns_r = vec![];
for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {
//polars_ensure!(!lnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");
//polars_ensure!(!rnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'");

for (i, (lnode, rnode)) in left_on.iter_mut().zip(right_on.iter_mut()).enumerate() {
let ltype = get_dtype!(lnode, &schema_left)?;
let rtype = get_dtype!(rnode, &schema_right)?;

Expand Down Expand Up @@ -186,9 +275,8 @@ pub fn resolve_join(
lnode.set_node(casted_l);
rnode.set_node(casted_r);

to_cast_indices.push(i);
to_cast_right.push(rnode);
to_cast_left.push(lnode);
as_with_columns_r.push(rnode);
as_with_columns_l.push(lnode);
} else {
lnode.set_node(casted_l);
rnode.set_node(casted_r);
Expand All @@ -214,42 +302,70 @@ pub fn resolve_join(
let schema_left = schema_left.into_owned();
let schema_right = schema_right.into_owned();

let key_cols_coalesced =
options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);
let join_schema = det_join_schema(
&schema_left,
&schema_right,
&left_on,
&right_on,
&options,
ctxt.expr_arena,
)
.map_err(|e| e.context(failed_here!(join schema resolving)))?;

if key_cols_coalesced {
input_left = if to_cast_left.is_empty() {
input_left = if as_with_columns_l.is_empty() {
input_left
} else {
ctxt.lp_arena.add(IR::HStack {
input: input_left,
exprs: to_cast_left,
exprs: as_with_columns_l,
schema: schema_left,
options: ProjectionOptions::default(),
})
};

input_right = if to_cast_right.is_empty() {
input_right = if as_with_columns_r.is_empty() {
input_right
} else {
ctxt.lp_arena.add(IR::HStack {
input: input_right,
exprs: to_cast_right,
exprs: as_with_columns_r,
schema: schema_right,
options: ProjectionOptions::default(),
})
};
}

let lp = IR::Join {
let ir = IR::Join {
input_left,
input_right,
schema,
schema: join_schema.clone(),
left_on,
right_on,
options,
};
Ok(ctxt.lp_arena.add(lp))
let join_node = ctxt.lp_arena.add(ir);

if has_scalars {
let names = join_schema
.iter_names()
.filter_map(|n| {
if n.starts_with(POLARS_TMP_PREFIX) {
None
} else {
Some(n.clone())
}
})
.collect_vec();

let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena);
let ir = builder.project_simple(names).map(|b| b.build())?;
let select_node = ctxt.lp_arena.add(ir);

Ok((select_node, join_node))
} else {
Ok((join_node, join_node))
}
}

#[cfg(feature = "iejoin")]
Expand All @@ -265,13 +381,14 @@ impl From<InequalityOperator> for Operator {
}

#[cfg(feature = "iejoin")]
/// Returns: left: join_node, right: last_node (often both the same)
fn resolve_join_where(
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
predicates: Vec<Expr>,
mut options: Arc<JoinOptions>,
ctxt: &mut DslConversionContext,
) -> PolarsResult<Node> {
) -> PolarsResult<(Node, Node)> {
check_join_keys(&predicates)?;
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
.map_err(|e| e.context(failed_here!(join left)))?;
Expand Down Expand Up @@ -499,10 +616,10 @@ fn resolve_join_where(
}
}

let join_node = if !eq_left_on.is_empty() {
let (mut last_node, join_node) = if !eq_left_on.is_empty() {
// We found one or more equality predicates. Go into a default equi join
// as those are cheapest on avg.
let join_node = resolve_join(
let (last_node, join_node) = resolve_join(
Either::Right(input_left),
Either::Right(input_right),
eq_left_on,
Expand All @@ -520,7 +637,7 @@ fn resolve_join_where(
&schema_right,
&suffix,
);
join_node
(last_node, join_node)
} else if ie_right_on.len() >= 2 {
// Do an IEjoin.
let opts = Arc::make_mut(&mut options);
Expand All @@ -529,7 +646,7 @@ fn resolve_join_where(
operator2: Some(ie_op[1]),
});

let join_node = resolve_join(
let (last_node, join_node) = resolve_join(
Either::Right(input_left),
Either::Right(input_right),
ie_left_on[..2].to_vec(),
Expand All @@ -550,7 +667,7 @@ fn resolve_join_where(

remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
join_node
(last_node, join_node)
} else if ie_right_on.len() == 1 {
// For a single inequality comparison, we use the piecewise merge join algorithm
let opts = Arc::make_mut(&mut options);
Expand Down Expand Up @@ -605,8 +722,6 @@ fn resolve_join_where(
.schema(ctxt.lp_arena)
.into_owned();

let mut last_node = join_node;

// Ensure that the predicates use the proper suffix
for e in remaining_preds {
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;
Expand Down Expand Up @@ -637,5 +752,5 @@ fn resolve_join_where(
};
last_node = ctxt.lp_arena.add(ir);
}
Ok(last_node)
Ok((last_node, join_node))
}
1 change: 0 additions & 1 deletion crates/polars-plan/src/plans/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ impl ExprIR {
self.output_dtype = OnceLock::new();
}

#[cfg(feature = "cse")]
pub(crate) fn set_alias(&mut self, name: PlSmallStr) {
self.output_name = OutputName::Alias(name)
}
Expand Down
Loading

0 comments on commit 8e27477

Please sign in to comment.