From 773e61d66e06bc0f00bf0d7da1ece13d6d91fadd Mon Sep 17 00:00:00 2001 From: ritchie Date: Fri, 6 Sep 2024 16:56:16 +0200 Subject: [PATCH] wrap-up and add test --- crates/polars-lazy/src/frame/mod.rs | 13 +- .../polars-ops/src/frame/join/iejoin/mod.rs | 103 +++++++----- crates/polars-ops/src/frame/join/mod.rs | 34 +--- .../src/plans/conversion/dsl_to_ir.rs | 21 ++- .../polars-plan/src/plans/conversion/join.rs | 158 ++++++++++++------ .../polars-plan/src/plans/conversion/mod.rs | 2 +- crates/polars-python/src/lazyframe/general.rs | 1 - py-polars/polars/dataframe/frame.py | 1 + py-polars/polars/lazyframe/frame.py | 3 +- .../unit/operations/test_inequality_join.py | 60 +++++++ 10 files changed, 250 insertions(+), 146 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 49ea9014ca02..5aeabaa96cae 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -35,7 +35,6 @@ use polars_expr::{create_physical_expr, ExpressionConversionState}; use polars_io::RowIndex; use polars_mem_engine::{create_physical_plan, Executor}; use polars_ops::frame::JoinCoalesce; -use polars_ops::prelude::InequalityOperator; pub use polars_plan::frame::{AllowedOptimizations, OptFlags}; use polars_plan::global::FETCH_ROWS; use polars_utils::pl_str::PlSmallStr; @@ -2106,7 +2105,6 @@ impl JoinBuilder { LazyFrame::from_logical_plan(lp, opt_state) } - // Finish with join predicates pub fn join_where(self, predicates: Vec) -> LazyFrame { let mut opt_state = self.lf.opt_state; @@ -2126,11 +2124,11 @@ impl JoinBuilder { coalesce: self.coalesce, }; let options = JoinOptions { - allow_parallel: self.allow_parallel, - force_parallel: self.force_parallel, - args, - ..Default::default() - }; + allow_parallel: self.allow_parallel, + force_parallel: self.force_parallel, + args, + ..Default::default() + }; let lp = DslPlan::Join { input_left: Arc::new(self.lf.logical_plan), @@ -2142,6 +2140,5 @@ impl JoinBuilder { }; LazyFrame::from_logical_plan(lp, opt_state) - } } diff --git a/crates/polars-ops/src/frame/join/iejoin/mod.rs b/crates/polars-ops/src/frame/join/iejoin/mod.rs index 531de4bdbe4c..d0698018c5bb 100644 --- a/crates/polars-ops/src/frame/join/iejoin/mod.rs +++ b/crates/polars-ops/src/frame/join/iejoin/mod.rs @@ -7,18 +7,19 @@ use polars_core::chunked_array::ChunkedArray; use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType}; use polars_core::frame::DataFrame; use polars_core::prelude::*; +use polars_core::utils::{_set_partition_size, split}; use polars_core::{with_match_physical_numeric_polars_type, POOL}; use polars_error::{polars_err, PolarsResult}; use polars_utils::binary_search::ExponentialSearch; +use polars_utils::itertools::Itertools; use polars_utils::slice::GetSaferUnchecked; -use polars_utils::total_ord::{TotalEq}; +use polars_utils::total_ord::TotalEq; use polars_utils::IdxSize; +use rayon::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use polars_core::utils::{_set_partition_size, split}; + use crate::frame::_finish_join; -use rayon::prelude::*; -use polars_utils::itertools::Itertools; #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -138,7 +139,10 @@ pub(super) fn iejoin_par( suffix: Option, slice: Option<(i64, usize)>, ) -> PolarsResult { - let l1_descending = matches!(options.operator1, InequalityOperator::Gt | InequalityOperator::GtEq); + let l1_descending = matches!( + options.operator1, + InequalityOperator::Gt | InequalityOperator::GtEq + ); let l1_sort_options = SortOptions::default() .with_maintain_order(true) @@ -146,34 +150,35 @@ pub(super) fn iejoin_par( .with_order_descending(l1_descending); let sl = &selected_left[0]; - let l1_s_l = sl.arg_sort(l1_sort_options) - .slice( - sl.null_count() as i64, - sl.len() - sl.null_count(), - ); + let l1_s_l = sl + .arg_sort(l1_sort_options) + .slice(sl.null_count() as i64, sl.len() - sl.null_count()); let sr = &selected_right[0]; - let l1_s_r = sr.arg_sort(l1_sort_options) - .slice( - sr.null_count() as i64, - sr.len() - sr.null_count(), - ); + let l1_s_r = sr + .arg_sort(l1_sort_options) + .slice(sr.null_count() as i64, sr.len() - sr.null_count()); // Because we do a cartesian product, the number of partitions is squared. // We take the sqrt, but we don't expect every partition to produce results and work can be // imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4 - let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; + let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; let splitted_a = split(&l1_s_l, n_partitions); let splitted_b = split(&l1_s_r, n_partitions); - let cartesian_prod = splitted_a.iter() - .flat_map(|l| splitted_b.iter().map(move |r| (l, r))).collect::>(); + let cartesian_prod = splitted_a + .iter() + .flat_map(|l| splitted_b.iter().map(move |r| (l, r))) + .collect::>(); let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| { if l_l1_idx.is_empty() || r_l1_idx.is_empty() { - return Ok(None) + return Ok(None); } - fn get_extrema<'a>(l1_idx: &'a IdxCa, s: &'a Series) -> Option<(AnyValue<'a>, AnyValue<'a>)> { + fn get_extrema<'a>( + l1_idx: &'a IdxCa, + s: &'a Series, + ) -> Option<(AnyValue<'a>, AnyValue<'a>)> { let first = l1_idx.first()?; let last = l1_idx.last()?; @@ -186,28 +191,39 @@ pub(super) fn iejoin_par( (end, start) }) } - let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else {return Ok(None)}; - let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else {return Ok(None)}; + let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else { + return Ok(None); + }; + let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else { + return Ok(None); + }; let include_block = match options.operator1 { InequalityOperator::Lt => min_l < max_r, InequalityOperator::LtEq => min_l <= max_r, InequalityOperator::Gt => max_l > min_r, - InequalityOperator::GtEq => max_l >= min_r + InequalityOperator::GtEq => max_l >= min_r, }; if include_block { let (l, r) = unsafe { - (selected_left.iter().map(|s| s.take_unchecked(l_l1_idx)).collect_vec(), - selected_right.iter().map(|s| s.take_unchecked(r_l1_idx)).collect_vec()) + ( + selected_left + .iter() + .map(|s| s.take_unchecked(l_l1_idx)) + .collect_vec(), + selected_right + .iter() + .map(|s| s.take_unchecked(r_l1_idx)) + .collect_vec(), + ) }; - // Compute the row indexes let (idx_l, idx_r) = iejoin_tuples(l, r, options, None)?; if idx_l.is_empty() { - return Ok(None) + return Ok(None); } // These are row indexes in the slices we have given, so we use those to gather in the @@ -215,7 +231,7 @@ pub(super) fn iejoin_par( unsafe { Ok(Some(( l_l1_idx.take_unchecked(&idx_l), - r_l1_idx.take_unchecked(&idx_r) + r_l1_idx.take_unchecked(&idx_r), ))) } } else { @@ -223,17 +239,13 @@ pub(super) fn iejoin_par( } }); - let row_indices = POOL.install(|| { - iter.collect::>>() - })?; + let row_indices = POOL.install(|| iter.collect::>>())?; let mut left_idx = IdxCa::default(); let mut right_idx = IdxCa::default(); - for opt in row_indices { - if let Some((l, r)) = opt { - left_idx.append(&l)?; - right_idx.append(&r)?; - } + for (l, r) in row_indices.into_iter().flatten() { + left_idx.append(&l)?; + right_idx.append(&r)?; } if let Some((offset, end)) = slice { left_idx = left_idx.slice(offset, end); @@ -241,7 +253,6 @@ pub(super) fn iejoin_par( } unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) } - } pub(super) fn iejoin( @@ -253,24 +264,28 @@ pub(super) fn iejoin( suffix: Option, slice: Option<(i64, usize)>, ) -> PolarsResult { - - let (left_row_idx, right_row_idx)= iejoin_tuples(selected_left, selected_right, options, slice)?; + let (left_row_idx, right_row_idx) = + iejoin_tuples(selected_left, selected_right, options, slice)?; unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) } } -unsafe fn materialize_join(left: &DataFrame, right: &DataFrame, left_row_idx: &IdxCa, right_row_idx: &IdxCa, suffix: Option) -> PolarsResult { +unsafe fn materialize_join( + left: &DataFrame, + right: &DataFrame, + left_row_idx: &IdxCa, + right_row_idx: &IdxCa, + suffix: Option, +) -> PolarsResult { let (join_left, join_right) = { POOL.join( - || left.take_unchecked(&left_row_idx), - || right.take_unchecked(&right_row_idx), + || left.take_unchecked(left_row_idx), + || right.take_unchecked(right_row_idx), ) }; _finish_join(join_left, join_right, suffix) - } - /// Inequality join. Matches rows between two DataFrames using two inequality operators /// (one of [<, <=, >, >=]). /// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins" diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 71d9e903eea5..433bffd232dd 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -200,7 +200,8 @@ pub trait DataFrameJoinOps: IntoDf { } if let JoinType::IEJoin(options) = args.how { - let func = if POOL.current_num_threads() > 1 && !left_df.is_empty() && !other.is_empty() { + let func = if POOL.current_num_threads() > 1 && !left_df.is_empty() && !other.is_empty() + { iejoin::iejoin_par } else { iejoin::iejoin @@ -535,34 +536,3 @@ pub fn private_left_join_multiple_keys( let b = prepare_keys_multiple(b.get_columns(), join_nulls)?.into_series(); sort_or_hash_left(&a, &b, false, JoinValidation::ManyToMany, join_nulls) } - -#[test] -fn test_foo() { - let west = df![ - "t_id" => [0, 1, 2, 3, 4, 5], - "time" => [100, 140, 100, 80, 90, 90], - "cost" => [6, 11, 11, 10, 5, 5], - ] - .unwrap(); - - let time = west.column("time").unwrap(); - let cost = west.column("cost").unwrap(); - - let selected = vec![time.clone(), cost.clone()]; - - let out = west - ._join_impl( - &west.clone(), - selected.clone(), - selected, - JoinArgs::new(JoinType::IEJoin(IEJoinOptions { - operator1: InequalityOperator::Gt, - operator2: InequalityOperator::LtEq, - })), - false, - false, - ) - .unwrap(); - - dbg!(out); -} diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 86692ee33619..a902b2da1e5d 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -52,9 +52,7 @@ macro_rules! failed_here { format!("'{}' failed", stringify!($($t)*)).into() } } -pub(super) use failed_input; -pub(super) use failed_here; -pub(super) use failed_input_args; +pub(super) use {failed_here, failed_input, failed_input_args}; pub fn to_alp( lp: DslPlan, @@ -85,7 +83,11 @@ pub(super) struct DslConversionContext<'a> { pub(super) opt_flags: &'a mut OptFlags, } -pub(super) fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult { +pub(super) fn run_conversion( + lp: IR, + ctxt: &mut DslConversionContext, + name: &str, +) -> PolarsResult { let lp_node = ctxt.lp_arena.add(lp); ctxt.conversion_optimizer .coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node) @@ -101,7 +103,6 @@ pub(super) fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult { let owned = Arc::unwrap_or_clone; - let v = match lp { DslPlan::Scan { paths, @@ -544,7 +545,15 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult predicates, options, } => { - return join::resolve_join(input_left, input_right, left_on, right_on, predicates, options, ctxt) + return join::resolve_join( + input_left, + input_right, + left_on, + right_on, + predicates, + options, + ctxt, + ) }, DslPlan::HStack { input, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index b77803f1f6b2..0433b2b14e69 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -1,21 +1,19 @@ use arrow::legacy::error::PolarsResult; -use polars_utils::arena::Arena; -use crate::dsl::{Expr, FunctionExpr}; -use crate::plans::AExpr; -use crate::prelude::FunctionOptions; + use super::*; +use crate::dsl::Expr; +use crate::plans::AExpr; fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { for e in keys { if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { polars_bail!( - InvalidOperation: - "'alias' is not allowed in a join key, use 'with_columns' first", - ) + InvalidOperation: + "'alias' is not allowed in a join key, use 'with_columns' first", + ) } } Ok(()) - } pub fn resolve_join( input_left: Arc, @@ -24,11 +22,11 @@ pub fn resolve_join( right_on: Vec, predicates: Vec, mut options: Arc, - ctxt: &mut DslConversionContext + ctxt: &mut DslConversionContext, ) -> PolarsResult { if !predicates.is_empty() { debug_assert!(left_on.is_empty() && right_on.is_empty()); - return resolve_join_where(input_left, input_right, predicates, options, ctxt) + return resolve_join_where(input_left, input_right, predicates, options, ctxt); } let owned = Arc::unwrap_or_clone; @@ -54,38 +52,37 @@ pub fn resolve_join( options.args.validation.is_valid_join(&options.args.how)?; polars_ensure!( - left_on.len() == right_on.len(), - InvalidOperation: - format!( - "the number of columns given as join key (left: {}, right:{}) should be equal", - left_on.len(), - right_on.len() - ) - ); + left_on.len() == right_on.len(), + InvalidOperation: + format!( + "the number of columns given as join key (left: {}, right:{}) should be equal", + left_on.len(), + right_on.len() + ) + ); } - let input_left = to_alp_impl(owned(input_left), ctxt) - .map_err(|e| e.context(failed_input!(join left)))?; - let input_right = to_alp_impl(owned(input_right), ctxt) - .map_err(|e| e.context(failed_input!(join, right)))?; + let input_left = + to_alp_impl(owned(input_left), ctxt).map_err(|e| e.context(failed_input!(join left)))?; + let input_right = + to_alp_impl(owned(input_right), ctxt).map_err(|e| e.context(failed_input!(join, right)))?; let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); - let schema = - det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) - .map_err(|e| e.context(failed_here!(join schema resolving)))?; + let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) + .map_err(|e| e.context(failed_here!(join schema resolving)))?; let left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?; let right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?; let mut joined_on = PlHashSet::new(); for (l, r) in left_on.iter().zip(right_on.iter()) { polars_ensure!( - joined_on.insert((l.output_name(), r.output_name())), - InvalidOperation: "joining with repeated key names; already joined on {} and {}", - l.output_name(), - r.output_name() - ) + joined_on.insert((l.output_name(), r.output_name())), + InvalidOperation: "joining with repeated key names; already joined on {} and {}", + l.output_name(), + r.output_name() + ) } drop(joined_on); @@ -99,9 +96,9 @@ pub fn resolve_join( let all_elementwise = |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default); polars_ensure!( - all_elementwise(&left_on) && all_elementwise(&right_on), - InvalidOperation: "All join key expressions must be elementwise." - ); + all_elementwise(&left_on) && all_elementwise(&right_on), + InvalidOperation: "All join key expressions must be elementwise." + ); let lp = IR::Join { input_left, input_right, @@ -129,7 +126,7 @@ fn resolve_join_where( input_right: Arc, predicates: Vec, mut options: Arc, - ctxt: &mut DslConversionContext + ctxt: &mut DslConversionContext, ) -> PolarsResult { check_join_keys(&predicates)?; @@ -155,13 +152,15 @@ fn resolve_join_where( } for pred in predicates.into_iter() { - let Expr::BinaryExpr {left, op, right} = pred.clone() else { polars_bail!(InvalidOperation: "can only join on binary expressions") }; + let Expr::BinaryExpr { left, op, right } = pred.clone() else { + polars_bail!(InvalidOperation: "can only join on binary expressions") + }; polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); if let Some(ie_op_) = to_inequality_operator(&op) { // We already have an IEjoin or an Inner join, push to remaining if ie_op.len() >= 2 || !eq_right_on.is_empty() { - remaining_preds.push(Expr::BinaryExpr {left, op, right}) + remaining_preds.push(Expr::BinaryExpr { left, op, right }) } else { ie_left_on.push(owned(left)); ie_right_on.push(owned(right)); @@ -176,13 +175,28 @@ fn resolve_join_where( } let join_node = if !eq_left_on.is_empty() { - let join_node = resolve_join(input_left, input_right, eq_left_on, eq_right_on, vec![], options.clone(), ctxt)?; - - for ((l, op), r) in ie_left_on.into_iter().zip(ie_op.into_iter()).zip(ie_right_on.into_iter()) { - remaining_preds.push(Expr::BinaryExpr {left: Arc::from(l), op: op.into(), right: Arc::from(r)}) + let join_node = resolve_join( + input_left, + input_right, + eq_left_on, + eq_right_on, + vec![], + options.clone(), + ctxt, + )?; + + for ((l, op), r) in ie_left_on + .into_iter() + .zip(ie_op.into_iter()) + .zip(ie_right_on.into_iter()) + { + remaining_preds.push(Expr::BinaryExpr { + left: Arc::from(l), + op: op.into(), + right: Arc::from(r), + }) } join_node - } else if ie_right_on.len() == 2 { let opts = Arc::make_mut(&mut options); opts.args.how = JoinType::IEJoin(IEJoinOptions { @@ -190,17 +204,49 @@ fn resolve_join_where( operator2: ie_op[1], }); - resolve_join(input_left, input_right, ie_left_on, ie_right_on, vec![], options.clone(), ctxt)? + resolve_join( + input_left, + input_right, + ie_left_on, + ie_right_on, + vec![], + options.clone(), + ctxt, + )? } else { let opts = Arc::make_mut(&mut options); opts.args.how = JoinType::Cross; - resolve_join(input_left, input_right, vec![], vec![], vec![], options.clone(), ctxt)? + resolve_join( + input_left, + input_right, + vec![], + vec![], + vec![], + options.clone(), + ctxt, + )? }; - let IR::Join {input_right, ..} = ctxt.lp_arena.get(join_node) else { unreachable!()}; - let schema_right = ctxt.lp_arena.get(*input_right).schema(ctxt.lp_arena).into_owned(); - + let IR::Join { + input_left, + input_right, + .. + } = ctxt.lp_arena.get(join_node) + else { + unreachable!() + }; + let schema_right = ctxt + .lp_arena + .get(*input_right) + .schema(ctxt.lp_arena) + .into_owned(); + + let schema_left = ctxt + .lp_arena + .get(*input_left) + .schema(ctxt.lp_arena) + .into_owned(); let suffix = options.args.suffix(); @@ -209,26 +255,32 @@ fn resolve_join_where( // Ensure that the predicates use the proper suffix for e in remaining_preds { let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; - let AExpr::BinaryExpr {left, op, mut right} = *ctxt.expr_arena.get(predicate.node()) else { unreachable!() }; + let AExpr::BinaryExpr { mut right, .. } = *ctxt.expr_arena.get(predicate.node()) else { + unreachable!() + }; let original_right = right; for name in aexpr_to_leaf_names(right, ctxt.expr_arena) { - if !schema_right.contains(name.as_str()) { + if schema_left.contains(name.as_str()) { let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); - polars_ensure!(schema_right.contains(new_name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation"); + polars_ensure!(schema_right.contains(name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation"); - right = rename_matching_aexpr_leaf_names(right, ctxt.expr_arena, name.as_str(), new_name); + right = rename_matching_aexpr_leaf_names( + right, + ctxt.expr_arena, + name.as_str(), + new_name, + ); } } ctxt.expr_arena.swap(right, original_right); let ir = IR::Filter { input: last_node, - predicate + predicate, }; last_node = ctxt.lp_arena.add(ir); - } Ok(last_node) -} \ No newline at end of file +} diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 2bf0138a2e59..89167a124534 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -21,8 +21,8 @@ use polars_core::prelude::*; use polars_utils::vec::ConvertVec; use recursive::recursive; mod functions; -pub(crate) mod type_coercion; mod join; +pub(crate) mod type_coercion; pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection}; diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 275c0a687e92..60e020a7b852 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -1219,4 +1219,3 @@ impl PyLazyFrame { Ok(out.into()) } } - diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 13d1ec655e34..848c629a9e82 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -7085,6 +7085,7 @@ def join( .collect(_eager=True) ) + @unstable() def join_where( self, other: DataFrame, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 927b75c08295..d08bd7a0462c 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4561,6 +4561,7 @@ def join( ) ) + @unstable() def join_where( self, other: LazyFrame, @@ -4593,7 +4594,7 @@ def join_where( pyexprs = parse_into_list_of_expressions(*predicates) return self._from_pyldf( - self._ldf.inequality_join( + self._ldf.join_where( other._ldf, pyexprs, suffix, diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 84b822a25583..7a0108eeb6db 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -174,6 +174,7 @@ def test_join_between() -> None: left = pl.DataFrame( { "id": [0, 1, 2, 3, 4, 5], + "group": [0, 0, 0, 1, 1, 1], "time": [ datetime(2024, 8, 26, 15, 34, 30), datetime(2024, 8, 26, 15, 35, 30), @@ -187,6 +188,7 @@ def test_join_between() -> None: right = pl.DataFrame( { "id": [0, 1, 2], + "group": [0, 1, 1], "start_time": [ datetime(2024, 8, 26, 15, 34, 0), datetime(2024, 8, 26, 15, 35, 0), @@ -214,6 +216,64 @@ def test_join_between() -> None: ) assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") == pl.col("group"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "INNER JOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter( + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") == pl.col("group_right"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") != pl.col("group"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "IEJOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter( + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") != pl.col("group_right"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr: if op == "<":