From 0ff2609e2823e0d663e8338cadfcb3a6cc1dfdff Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Thu, 12 Sep 2024 12:05:49 +0800 Subject: [PATCH 1/8] iinit --- proto/plan_common.proto | 2 + src/batch/src/executor/join/mod.rs | 4 +- src/frontend/src/binder/relation/join.rs | 2 + .../optimizer/plan_node/batch_hash_join.rs | 4 +- .../src/optimizer/plan_node/generic/join.rs | 34 +- .../src/optimizer/plan_node/logical_join.rs | 36 +- src/frontend/src/optimizer/plan_node/mod.rs | 1 + .../optimizer/plan_node/stream_asof_join.rs | 350 ++++++++++++++++++ .../optimizer/plan_node/stream_hash_join.rs | 6 +- .../plan_visitor/cardinality_visitor.rs | 5 + .../rule/apply_join_transpose_rule.rs | 35 +- .../src/optimizer/rule/join_commute_rule.rs | 3 + .../optimizer/rule/translate_apply_rule.rs | 24 +- src/sqlparser/src/ast/query.rs | 16 + src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 17 +- src/stream/src/from_proto/hash_join.rs | 4 +- 17 files changed, 512 insertions(+), 32 deletions(-) create mode 100644 src/frontend/src/optimizer/plan_node/stream_asof_join.rs diff --git a/proto/plan_common.proto b/proto/plan_common.proto index bc2e60503f10..22bda72497a5 100644 --- a/proto/plan_common.proto +++ b/proto/plan_common.proto @@ -139,6 +139,8 @@ enum JoinType { JOIN_TYPE_LEFT_ANTI = 6; JOIN_TYPE_RIGHT_SEMI = 7; JOIN_TYPE_RIGHT_ANTI = 8; + JOIN_TYPE_ASOF_INNER = 9; + JOIN_TYPE_ASOF_LEFT_OUTER = 10; } // https://github.com/tokio-rs/prost/issues/80 diff --git a/src/batch/src/executor/join/mod.rs b/src/batch/src/executor/join/mod.rs index cf2388314d8f..4ac630489a55 100644 --- a/src/batch/src/executor/join/mod.rs +++ b/src/batch/src/executor/join/mod.rs @@ -62,7 +62,9 @@ impl JoinType { PbJoinType::RightSemi => JoinType::RightSemi, PbJoinType::RightAnti => JoinType::RightAnti, PbJoinType::FullOuter => JoinType::FullOuter, - PbJoinType::Unspecified => unreachable!(), + PbJoinType::AsofInner | PbJoinType::AsofLeftOuter | PbJoinType::Unspecified => { + unreachable!() + } } } } diff --git a/src/frontend/src/binder/relation/join.rs b/src/frontend/src/binder/relation/join.rs index d13b683be08b..30bd0a290622 100644 --- a/src/frontend/src/binder/relation/join.rs +++ b/src/frontend/src/binder/relation/join.rs @@ -92,6 +92,8 @@ impl Binder { JoinOperator::FullOuter(constraint) => (constraint, JoinType::FullOuter), // Cross join equals to inner join with with no constraint. JoinOperator::CrossJoin => (JoinConstraint::None, JoinType::Inner), + JoinOperator::AsOfInner(constraint) => (constraint, JoinType::AsofInner), + JoinOperator::AsOfLeft(constraint) => (constraint, JoinType::AsofLeftOuter), }; let right: Relation; let cond: ExprImpl; diff --git a/src/frontend/src/optimizer/plan_node/batch_hash_join.rs b/src/frontend/src/optimizer/plan_node/batch_hash_join.rs index 399817336e4c..bb5bca88d2b1 100644 --- a/src/frontend/src/optimizer/plan_node/batch_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/batch_hash_join.rs @@ -66,7 +66,9 @@ impl BatchHashJoin { // we can not derive the hash distribution from the side where outer join can generate a // NULL row (Distribution::HashShard(_), Distribution::HashShard(_)) => match join.join_type { - JoinType::Unspecified => unreachable!(), + JoinType::AsofInner | JoinType::AsofLeftOuter | JoinType::Unspecified => { + unreachable!() + } JoinType::FullOuter => Distribution::SomeShard, JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => { let l2o = join.l2i_col_mapping().composite(&join.i2o_col_mapping()); diff --git a/src/frontend/src/optimizer/plan_node/generic/join.rs b/src/frontend/src/optimizer/plan_node/generic/join.rs index 105f8bebb32b..f7ce096e73eb 100644 --- a/src/frontend/src/optimizer/plan_node/generic/join.rs +++ b/src/frontend/src/optimizer/plan_node/generic/join.rs @@ -277,7 +277,7 @@ impl GenericPlanNode for Join { .rewrite_functional_dependency_set(right_fd_set) }; let fd_set: FunctionalDependencySet = match self.join_type { - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { let mut fd_set = FunctionalDependencySet::new(full_out_col_num); for i in &self.on.conjunctions { if let Some((col, _)) = i.as_eq_const() { @@ -300,7 +300,7 @@ impl GenericPlanNode for Join { .for_each(|fd| fd_set.add_functional_dependency(fd)); fd_set } - JoinType::LeftOuter => get_new_left_fd_set(left_fd_set), + JoinType::LeftOuter | JoinType::AsofLeftOuter => get_new_left_fd_set(left_fd_set), JoinType::RightOuter => get_new_right_fd_set(right_fd_set), JoinType::FullOuter => FunctionalDependencySet::new(full_out_col_num), JoinType::LeftSemi | JoinType::LeftAnti => left_fd_set, @@ -325,9 +325,12 @@ impl Join { pub fn full_out_col_num(left_len: usize, right_len: usize, join_type: JoinType) -> usize { match join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { - left_len + right_len - } + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => left_len + right_len, JoinType::LeftSemi | JoinType::LeftAnti => left_len, JoinType::RightSemi | JoinType::RightAnti => right_len, JoinType::Unspecified => unreachable!(), @@ -371,7 +374,12 @@ impl Join { let right_len = self.right.schema().len(); match self.join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { ColIndexMapping::identity_or_none(left_len + right_len, left_len) } @@ -389,7 +397,12 @@ impl Join { let right_len = self.right.schema().len(); match self.join_type { - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { ColIndexMapping::with_shift_offset(left_len + right_len, -(left_len as isize)) } JoinType::LeftSemi | JoinType::LeftAnti => ColIndexMapping::empty(left_len, right_len), @@ -445,13 +458,16 @@ impl Join { pub fn add_which_join_key_to_pk(&self) -> EitherOrBoth<(), ()> { match self.join_type { - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { // Theoretically adding either side is ok, but the distribution key of the inner // join derived based on the left side by default, so we choose the left side here // to ensure the pk comprises the distribution key. EitherOrBoth::Left(()) } - JoinType::LeftOuter | JoinType::LeftSemi | JoinType::LeftAnti => EitherOrBoth::Left(()), + JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::AsofLeftOuter => EitherOrBoth::Left(()), JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { EitherOrBoth::Right(()) } diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 2b64b5fd93ad..2f749cbacb19 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -837,14 +837,13 @@ impl PredicatePushdown for LogicalJoin { } impl LogicalJoin { - fn to_stream_hash_join( + fn get_stream_input_for_hash_join( &self, - predicate: EqJoinPredicate, + predicate: &EqJoinPredicate, ctx: &mut ToStreamContext, - ) -> Result { + ) -> Result<(PlanRef, PlanRef)> { use super::stream::prelude::*; - assert!(predicate.has_eq()); let mut right = self.right().to_stream_with_dist_required( &RequiredDist::shard_by_key(self.right().schema().len(), &predicate.right_eq_indexes()), ctx, @@ -888,6 +887,18 @@ impl LogicalJoin { } _ => unreachable!(), } + Ok((left, right)) + } + + fn to_stream_hash_join( + &self, + predicate: EqJoinPredicate, + ctx: &mut ToStreamContext, + ) -> Result { + use super::stream::prelude::*; + + assert!(predicate.has_eq()); + let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; let logical_join = self.clone_with_left_right(left, right); @@ -1259,6 +1270,23 @@ impl LogicalJoin { .to_batch_lookup_join(predicate, logical_join) .expect("Fail to convert to lookup join") .into()) + } + + fn to_stream_asof_join( + &self, + predicate: EqJoinPredicate, + ctx: &mut ToStreamContext, + ) -> Result { + use super::stream::prelude::*; + + assert!(predicate.has_eq()); + let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; + let logical_join = self.clone_with_left_right(left, right); + + + + + } } diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index db1200de2a27..54ad5a6c8f0f 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -897,6 +897,7 @@ mod stream_global_approx_percentile; mod stream_group_topn; mod stream_hash_agg; mod stream_hash_join; +mod stream_asof_join; mod stream_hop_window; mod stream_local_approx_percentile; mod stream_materialize; diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs new file mode 100644 index 000000000000..7331978a703b --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -0,0 +1,350 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use fixedbitset::FixedBitSet; +use itertools::Itertools; +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_pb::plan_common::JoinType; +use risingwave_pb::stream_plan::stream_node::NodeBody; +use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair}; + +use super::generic::Join; +use super::stream::prelude::*; +use super::utils::{childless_record, plan_node_name, watermark_pretty, Distill}; +use super::{ + generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode, +}; +use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputPair}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::utils::IndicesDisplay; +use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; +use crate::optimizer::property::{Distribution, MonotonicityMap}; +use crate::stream_fragmenter::BuildFragmentGraphState; +use crate::utils::ColIndexMappingRewriteExt; + +/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table +/// from inner (right-side) relation and probes with data from outer (left-side) relation to +/// get output rows. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamHashJoin { + pub base: PlanBase, + core: generic::Join, + + /// The join condition must be equivalent to `logical.on`, but separated into equal and + /// non-equal parts to facilitate execution later + eq_join_predicate: EqJoinPredicate, + + /// Whether can optimize for append-only stream. + /// It is true if input of both side is append-only + is_append_only: bool, + + /// `(do_state_cleaning, InequalityInputPair {key_required_larger, key_required_smaller, + /// delta_expression})`. View struct `InequalityInputPair` for details. + inequality_pairs: Vec<(bool, InequalityInputPair)>, +} + +impl StreamHashJoin { + pub fn new(core: generic::Join, eq_join_predicate: EqJoinPredicate) -> Self { + // Inner join won't change the append-only behavior of the stream. The rest might. + let append_only = match core.join_type { + JoinType::Inner => core.left.append_only() && core.right.append_only(), + _ => false, + }; + + let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core); + + // TODO: derive watermarks + let watermark_columns = FixedBitSet::new(); + + // TODO: derive from input + let base = PlanBase::new_stream_with_core( + &core, + dist, + append_only, + false, // TODO(rc): derive EOWC property from input + watermark_columns, + MonotonicityMap::new(), // TODO: derive monotonicity + ); + + Self { + base, + core, + eq_join_predicate, + inequality_pairs, + is_append_only: append_only, + } + } + + /// Get join type + pub fn join_type(&self) -> JoinType { + self.core.join_type + } + + /// Get a reference to the batch hash join's eq join predicate. + pub fn eq_join_predicate(&self) -> &EqJoinPredicate { + &self.eq_join_predicate + } + + pub(super) fn derive_dist( + left: &Distribution, + right: &Distribution, + logical: &generic::Join, + ) -> Distribution { + match (left, right) { + (Distribution::Single, Distribution::Single) => Distribution::Single, + (Distribution::HashShard(_), Distribution::HashShard(_)) => { + // we can not derive the hash distribution from the side where outer join can + // generate a NULL row + match logical.join_type { + JoinType::Unspecified| JoinType::AsofInner + | JoinType::AsofLeftOuter => unreachable!(), + JoinType::FullOuter => Distribution::SomeShard, + JoinType::Inner + | JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + => { + let l2o = logical + .l2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + l2o.rewrite_provided_distribution(left) + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { + let r2o = logical + .r2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + r2o.rewrite_provided_distribution(right) + } + } + } + (_, _) => unreachable!( + "suspicious distribution: left: {:?}, right: {:?}", + left, right + ), + } + } + + /// Convert this hash join to a delta join plan + pub fn into_delta_join(self) -> StreamDeltaJoin { + StreamDeltaJoin::new(self.core, self.eq_join_predicate) + } + + pub fn derive_dist_key_in_join_key(&self) -> Vec { + let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); + let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); + let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); + let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); + + assert_eq!(left_jk_indices.len(), right_jk_indices.len()); + + let mut dk_indices_in_jk = vec![]; + + for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { + for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { + if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { + dk_indices_in_jk.push(dk_idx_in_jk); + break; + } + } + } + + assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); + dk_indices_in_jk + } + + pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> { + &self.inequality_pairs + } +} + +impl Distill for StreamHashJoin { + fn distill<'a>(&self) -> XmlNode<'a> { + let (ljk, rjk) = self + .eq_join_predicate + .eq_indexes() + .first() + .cloned() + .expect("first join key"); + + let name = plan_node_name!("StreamHashJoin", + { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) }, + { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() }, + { "append_only", self.is_append_only }, + ); + let verbose = self.base.ctx().is_explain_verbose(); + let mut vec = Vec::with_capacity(6); + vec.push(("type", Pretty::debug(&self.core.join_type))); + + let concat_schema = self.core.concat_schema(); + vec.push(( + "predicate", + Pretty::debug(&EqJoinPredicateDisplay { + eq_join_predicate: self.eq_join_predicate(), + input_schema: &concat_schema, + }), + )); + + let get_cond = |conjunction_idx| { + Pretty::debug(&ExprDisplay { + expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx], + input_schema: &concat_schema, + }) + }; + if let Some(i) = self.clean_left_state_conjunction_idx { + vec.push(("conditions_to_clean_left_state_table", get_cond(i))); + } + if let Some(i) = self.clean_right_state_conjunction_idx { + vec.push(("conditions_to_clean_right_state_table", get_cond(i))); + } + if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) { + vec.push(("output_watermarks", ow)); + } + + if verbose { + let data = IndicesDisplay::from_join(&self.core, &concat_schema); + vec.push(("output", data)); + } + + childless_record(name, vec) + } +} + +impl PlanTreeNodeBinary for StreamHashJoin { + fn left(&self) -> PlanRef { + self.core.left.clone() + } + + fn right(&self) -> PlanRef { + self.core.right.clone() + } + + fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self { + let mut core = self.core.clone(); + core.left = left; + core.right = right; + Self::new(core, self.eq_join_predicate.clone()) + } +} + +impl_plan_tree_node_for_binary! { StreamHashJoin } + +impl StreamNode for StreamHashJoin { + fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody { + let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); + let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); + let left_jk_indices_prost = left_jk_indices.iter().map(|idx| *idx as i32).collect_vec(); + let right_jk_indices_prost = right_jk_indices.iter().map(|idx| *idx as i32).collect_vec(); + + let dk_indices_in_jk = self.derive_dist_key_in_join_key(); + + let (left_table, left_degree_table, left_deduped_input_pk_indices) = + Join::infer_internal_and_degree_table_catalog( + self.left().plan_base(), + left_jk_indices, + dk_indices_in_jk.clone(), + ); + let (right_table, right_degree_table, right_deduped_input_pk_indices) = + Join::infer_internal_and_degree_table_catalog( + self.right().plan_base(), + right_jk_indices, + dk_indices_in_jk, + ); + + let left_deduped_input_pk_indices = left_deduped_input_pk_indices + .iter() + .map(|idx| *idx as u32) + .collect_vec(); + + let right_deduped_input_pk_indices = right_deduped_input_pk_indices + .iter() + .map(|idx| *idx as u32) + .collect_vec(); + + let (left_table, left_degree_table) = ( + left_table.with_id(state.gen_table_id_wrapped()), + left_degree_table.with_id(state.gen_table_id_wrapped()), + ); + let (right_table, right_degree_table) = ( + right_table.with_id(state.gen_table_id_wrapped()), + right_degree_table.with_id(state.gen_table_id_wrapped()), + ); + + let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect(); + + NodeBody::HashJoin(HashJoinNode { + join_type: self.core.join_type as i32, + left_key: left_jk_indices_prost, + right_key: right_jk_indices_prost, + null_safe: null_safe_prost, + condition: self + .eq_join_predicate + .other_cond() + .as_expr_unless_true() + .map(|x| x.to_expr_proto()), + inequality_pairs: self + .inequality_pairs + .iter() + .map( + |( + do_state_clean, + InequalityInputPair { + key_required_larger, + key_required_smaller, + delta_expression, + }, + )| { + PbInequalityPair { + key_required_larger: *key_required_larger as u32, + key_required_smaller: *key_required_smaller as u32, + clean_state: *do_state_clean, + delta_expression: delta_expression.as_ref().map( + |(delta_type, delta)| DeltaExpression { + delta_type: *delta_type as i32, + delta: Some(delta.to_expr_proto()), + }, + ), + } + }, + ) + .collect_vec(), + left_table: Some(left_table.to_internal_table_prost()), + right_table: Some(right_table.to_internal_table_prost()), + left_degree_table: Some(left_degree_table.to_internal_table_prost()), + right_degree_table: Some(right_degree_table.to_internal_table_prost()), + left_deduped_input_pk_indices, + right_deduped_input_pk_indices, + output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(), + is_append_only: self.is_append_only, + }) + } +} + +impl ExprRewritable for StreamHashJoin { + fn has_rewritable_expr(&self) -> bool { + true + } + + fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { + let mut core = self.core.clone(); + core.rewrite_exprs(r); + Self::new(core, self.eq_join_predicate.rewrite_exprs(r)).into() + } +} + +impl ExprVisitable for StreamHashJoin { + fn visit_exprs(&self, v: &mut dyn ExprVisitor) { + self.core.visit_exprs(v); + } +} diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index cbce1e1caf45..63211d9791ef 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -231,12 +231,14 @@ impl StreamHashJoin { // we can not derive the hash distribution from the side where outer join can // generate a NULL row match logical.join_type { - JoinType::Unspecified => unreachable!(), + JoinType::Unspecified| JoinType::AsofInner + | JoinType::AsofLeftOuter => unreachable!(), JoinType::FullOuter => Distribution::SomeShard, JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi - | JoinType::LeftAnti => { + | JoinType::LeftAnti + => { let l2o = logical .l2i_col_mapping() .composite(&logical.i2o_col_mapping()); diff --git a/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs b/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs index 07459b59b1d5..d167a2e7e536 100644 --- a/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs +++ b/src/frontend/src/optimizer/plan_visitor/cardinality_visitor.rs @@ -171,6 +171,11 @@ impl PlanVisitor for CardinalityVisitor { // TODO: refine the cardinality of full outer join JoinType::FullOuter => Cardinality::unknown(), + + // For each row from one side, we match `0..=1` rows from the other side. + JoinType::AsofInner => left.mul(right.min(0..=1)), + // For each row from left side, we match exactly 1 row from the right side or a `NULL` row`. + JoinType::AsofLeftOuter => left, } } diff --git a/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs index 3da034893623..fc6cbdd47753 100644 --- a/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_join_transpose_rule.rs @@ -130,7 +130,10 @@ impl Rule for ApplyJoinTransposeRule { let (push_left, push_right) = match join.join_type() { // `LeftSemi`, `LeftAnti`, `LeftOuter` can only push to left side if it's right side has // no correlated id. Otherwise push to both sides. - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftOuter => { + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftOuter + | JoinType::AsofLeftOuter => { if !join_right_has_correlated_id { (true, false) } else { @@ -147,7 +150,7 @@ impl Rule for ApplyJoinTransposeRule { } } // `Inner` can push to one side if the other side is not dependent on it. - JoinType::Inner => { + JoinType::Inner | JoinType::AsofInner => { if join_cond_has_correlated_id && !join_right_has_correlated_id && !join_left_has_correlated_id @@ -236,7 +239,12 @@ impl ApplyJoinTransposeRule { JoinType::LeftSemi | JoinType::LeftAnti => { left_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len); d_t1_bit_set.set_range(0..apply_left_len + join_left_len, true); @@ -316,7 +324,12 @@ impl ApplyJoinTransposeRule { JoinType::RightSemi | JoinType::RightAnti => { right_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len); d_t2_bit_set.set_range(0..apply_left_len, true); @@ -456,7 +469,12 @@ impl ApplyJoinTransposeRule { JoinType::RightSemi | JoinType::RightAnti => { right_apply_condition.extend(apply_on); } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let apply_len = apply_left_len + join.schema().len(); let mut d_t1_bit_set = FixedBitSet::with_capacity(apply_len); let mut d_t2_bit_set = FixedBitSet::with_capacity(apply_len); @@ -555,7 +573,12 @@ impl ApplyJoinTransposeRule { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti => { new_join.into() } - JoinType::Inner | JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => { + JoinType::Inner + | JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let mut output_indices_mapping = ColIndexMapping::new( output_indices.iter().map(|x| Some(*x)).collect(), target_size, diff --git a/src/frontend/src/optimizer/rule/join_commute_rule.rs b/src/frontend/src/optimizer/rule/join_commute_rule.rs index 405e28d6825f..55b975ccb971 100644 --- a/src/frontend/src/optimizer/rule/join_commute_rule.rs +++ b/src/frontend/src/optimizer/rule/join_commute_rule.rs @@ -72,6 +72,8 @@ impl Rule for JoinCommuteRule { | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::FullOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter | JoinType::Unspecified => None, } } @@ -116,6 +118,7 @@ impl JoinCommuteRule { JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightSemi => JoinType::LeftSemi, JoinType::RightAnti => JoinType::LeftAnti, + JoinType::AsofInner | JoinType::AsofLeftOuter => unreachable!(), } } } diff --git a/src/frontend/src/optimizer/rule/translate_apply_rule.rs b/src/frontend/src/optimizer/rule/translate_apply_rule.rs index 876ca7d6285b..87ccbd592472 100644 --- a/src/frontend/src/optimizer/rule/translate_apply_rule.rs +++ b/src/frontend/src/optimizer/rule/translate_apply_rule.rs @@ -233,8 +233,9 @@ impl TranslateApplyRule { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - | JoinType::RightOuter => rewrite(join.right(), right_idxs, true), - JoinType::LeftOuter | JoinType::FullOuter => None, + | JoinType::RightOuter + | JoinType::AsofInner => rewrite(join.right(), right_idxs, true), + JoinType::LeftOuter | JoinType::FullOuter | JoinType::AsofLeftOuter => None, JoinType::Unspecified => unreachable!(), } } @@ -246,7 +247,9 @@ impl TranslateApplyRule { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - | JoinType::LeftOuter => rewrite(join.left(), left_idxs, false), + | JoinType::LeftOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => rewrite(join.left(), left_idxs, false), JoinType::RightOuter | JoinType::FullOuter => None, JoinType::Unspecified => unreachable!(), } @@ -258,14 +261,18 @@ impl TranslateApplyRule { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::AsofInner => { let left = rewrite(join.left(), left_idxs, false)?; let right = rewrite(join.right(), right_idxs, true)?; let new_join = LogicalJoin::new(left, right, join.join_type(), Condition::true_cond()); Some(new_join.into()) } - JoinType::LeftOuter | JoinType::RightOuter | JoinType::FullOuter => None, + JoinType::LeftOuter + | JoinType::RightOuter + | JoinType::FullOuter + | JoinType::AsofLeftOuter => None, JoinType::Unspecified => unreachable!(), } } @@ -300,7 +307,12 @@ impl TranslateApplyRule { if !left_idxs.is_empty() && right_idxs.is_empty() { // Deal with multi scalar subqueries match apply.join_type() { - JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftOuter => { + JoinType::Inner + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftOuter + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { let plan = apply.left(); Self::rewrite(&plan, left_idxs, offset, index_mapping, data_types, index) } diff --git a/src/sqlparser/src/ast/query.rs b/src/sqlparser/src/ast/query.rs index b16a3075f90d..be03a5f1133e 100644 --- a/src/sqlparser/src/ast/query.rs +++ b/src/sqlparser/src/ast/query.rs @@ -584,6 +584,20 @@ impl fmt::Display for Join { suffix(constraint) ), JoinOperator::CrossJoin => write!(f, " CROSS JOIN {}", self.relation), + JoinOperator::AsOfInner(constraint) => write!( + f, + " {}ASOF JOIN {}{}", + prefix(constraint), + self.relation, + suffix(constraint) + ), + JoinOperator::AsOfLeft(constraint) => write!( + f, + " {}ASOF LEFT JOIN {}{}", + prefix(constraint), + self.relation, + suffix(constraint) + ), } } } @@ -596,6 +610,8 @@ pub enum JoinOperator { RightOuter(JoinConstraint), FullOuter(JoinConstraint), CrossJoin, + AsOfInner(JoinConstraint), + AsOfLeft(JoinConstraint), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 014d100b1f95..8626df2021bc 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -88,6 +88,7 @@ define_keywords!( AS, ASC, ASENSITIVE, + ASOF, ASYMMETRIC, ASYNC, AT, diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 1ee9b0639216..affae211944f 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4610,7 +4610,13 @@ impl Parser<'_> { join_operator, } } else { - let natural = self.parse_keyword(Keyword::NATURAL); + let (natural, asof) = + match self.parse_one_of_keywords(&[Keyword::NATURAL, Keyword::ASOF]) { + Some(Keyword::NATURAL) => (true, false), + Some(Keyword::ASOF) => (false, true), + Some(_) => unreachable!(), + None => (false, false), + }; let peek_keyword = if let Token::Word(w) = self.peek_token().token { w.keyword } else { @@ -4621,7 +4627,11 @@ impl Parser<'_> { Keyword::INNER | Keyword::JOIN => { let _ = self.parse_keyword(Keyword::INNER); self.expect_keyword(Keyword::JOIN)?; - JoinOperator::Inner + if asof { + JoinOperator::AsOfInner + } else { + JoinOperator::Inner + } } kw @ Keyword::LEFT | kw @ Keyword::RIGHT | kw @ Keyword::FULL => { let _ = self.next_token(); @@ -4640,6 +4650,9 @@ impl Parser<'_> { _ if natural => { return self.expected("a join type after NATURAL"); } + _ if asof => { + return self.expected("a join type after ASOF"); + } _ => break, }; let relation = self.parse_table_factor()?; diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 2d421274cec3..42034b64b0af 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -223,7 +223,9 @@ impl HashKeyDispatcher for HashJoinExecutorDispatcherArgs { }; } match self.join_type_proto { - JoinTypeProto::Unspecified => unreachable!(), + JoinTypeProto::AsofInner + | JoinTypeProto::AsofLeftOuter + | JoinTypeProto::Unspecified => unreachable!(), JoinTypeProto::Inner => build!(Inner), JoinTypeProto::LeftOuter => build!(LeftOuter), JoinTypeProto::RightOuter => build!(RightOuter), From d50157735ac9f32a005f949e2154bc8cf00af4bb Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Tue, 24 Sep 2024 23:50:53 +0800 Subject: [PATCH 2/8] stream node --- proto/stream_plan.proto | 1 + src/common/src/util/stream_graph_visitor.rs | 6 + .../tests/testdata/input/asof_join.yaml | 28 ++ .../tests/testdata/output/asof_join.yaml | 26 ++ .../src/optimizer/plan_node/logical_join.rs | 27 +- src/frontend/src/optimizer/plan_node/mod.rs | 5 +- .../optimizer/plan_node/stream_asof_join.rs | 301 ++++++++++-------- .../optimizer/plan_node/stream_hash_join.rs | 8 +- src/prost/build.rs | 1 + src/sqlparser/src/keywords.rs | 1 + src/sqlparser/src/parser.rs | 33 +- src/sqlparser/tests/testdata/asof_join.yaml | 17 + src/stream/src/from_proto/mod.rs | 2 + 13 files changed, 310 insertions(+), 146 deletions(-) create mode 100644 src/frontend/planner_test/tests/testdata/input/asof_join.yaml create mode 100644 src/frontend/planner_test/tests/testdata/output/asof_join.yaml create mode 100644 src/sqlparser/tests/testdata/asof_join.yaml diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index d6a6ae0ed67e..1006fd17aef7 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -885,6 +885,7 @@ message StreamNode { LocalApproxPercentileNode local_approx_percentile = 144; GlobalApproxPercentileNode global_approx_percentile = 145; RowMergeNode row_merge = 146; + AsOfJoinNode as_of_join = 147; } // The id for the operator. This is local per mview. // TODO: should better be a uint32. diff --git a/src/common/src/util/stream_graph_visitor.rs b/src/common/src/util/stream_graph_visitor.rs index 5b990c018640..a525b459d013 100644 --- a/src/common/src/util/stream_graph_visitor.rs +++ b/src/common/src/util/stream_graph_visitor.rs @@ -269,6 +269,12 @@ pub fn visit_stream_node_tables_inner( always!(node.bucket_state_table, "GlobalApproxPercentileBucketState"); always!(node.count_state_table, "GlobalApproxPercentileCountState"); } + + // AsOf join + NodeBody::AsOfJoin(node) => { + always!(node.left_table, "AsOfJoinLeft"); + always!(node.right_table, "AsOfJoinRight"); + } _ => {} } }; diff --git a/src/frontend/planner_test/tests/testdata/input/asof_join.yaml b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml new file mode 100644 index 000000000000..0ab19b56422f --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml @@ -0,0 +1,28 @@ +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1; + expected_outputs: + - stream_error + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2; + expected_outputs: + - batch_error + - stream_plan + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2; + expected_outputs: + - stream_plan + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; + expected_outputs: + - stream_error diff --git a/src/frontend/planner_test/tests/testdata/output/asof_join.yaml b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml new file mode 100644 index 000000000000..7dd2862065e9 --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml @@ -0,0 +1,26 @@ +# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information. +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1; + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuqual condition' +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2; + stream_plan: |- + StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck } + └─StreamAsOfJoin { type: AsofInner, predicate: t1.v1 = $expr1 AND (t1.v2 > t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] } + ├─StreamExchange { dist: HashShard(t1.v1) } + │ └─StreamTableScan { table: t1, columns: [t1.v1, t1.v2, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard($expr1) } + └─StreamProject { exprs: [t2.v1, t2.v2, ConcatOp(t2.v1, 'a':Varchar) as $expr1, t2._row_id] } + └─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } + batch_error: |- + Not supported: AsOf join in batch query + HINT: AsOf join is only supported in streaming query +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 *2 < t2.v2; + stream_plan: |- + StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck } + └─StreamAsOfJoin { type: AsofLeftOuter, predicate: t1.v1 = t2.v1 AND ($expr1 < t2.v2), output: [t1.v1, t1.v2, t2.v1, t2.v2, t1._row_id, t2._row_id] } + ├─StreamExchange { dist: HashShard(t1.v1) } + │ └─StreamProject { exprs: [t1.v1, t1.v2, (t1.v2 * 2:Int32) as $expr1, t1._row_id] } + │ └─StreamTableScan { table: t1, columns: [t1.v1, t1.v2, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard(t2.v1) } + └─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuqual condition' diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 2f749cbacb19..63b2a6041630 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -33,6 +33,7 @@ use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::generic::DynamicFilter; +use crate::optimizer::plan_node::stream_asof_join::StreamAsOfJoin; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{ BatchHashJoin, BatchLookupJoin, BatchNestedLoopJoin, ColumnPruningContext, EqJoinPredicate, @@ -1276,22 +1277,34 @@ impl LogicalJoin { &self, predicate: EqJoinPredicate, ctx: &mut ToStreamContext, - ) -> Result { + ) -> Result { use super::stream::prelude::*; - + assert!(predicate.has_eq()); let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; + let left_len = left.schema().len(); let logical_join = self.clone_with_left_right(left, right); - + let inequality_desc = + StreamAsOfJoin::get_inequality_desc_from_predicate(predicate.clone(), left_len)?; - - + Ok(StreamAsOfJoin::new( + logical_join.core.clone(), + predicate, + inequality_desc, + )) } } impl ToBatch for LogicalJoin { fn to_batch(&self) -> Result { + if JoinType::AsofInner == self.join_type() || JoinType::AsofLeftOuter == self.join_type() { + return Err(ErrorCode::NotSupported( + "AsOf join in batch query".to_string(), + "AsOf join is only supported in streaming query".to_string(), + ) + .into()); + } let predicate = EqJoinPredicate::create( self.left().schema().len(), self.right().schema().len(), @@ -1348,7 +1361,9 @@ impl ToStream for LogicalJoin { self.on().clone(), ); - if predicate.has_eq() { + if self.join_type() == JoinType::AsofInner || self.join_type() == JoinType::AsofLeftOuter { + self.to_stream_asof_join(predicate, ctx).map(|x| x.into()) + } else if predicate.has_eq() { if !predicate.eq_keys_are_type_aligned() { return Err(ErrorCode::InternalError(format!( "Join eq keys are not aligned for predicate: {predicate:?}" diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index 54ad5a6c8f0f..85c4e3066f7c 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -883,6 +883,7 @@ mod logical_topn; mod logical_union; mod logical_update; mod logical_values; +mod stream_asof_join; mod stream_changelog; mod stream_dedup; mod stream_delta_join; @@ -897,7 +898,6 @@ mod stream_global_approx_percentile; mod stream_group_topn; mod stream_hash_agg; mod stream_hash_join; -mod stream_asof_join; mod stream_hop_window; mod stream_local_approx_percentile; mod stream_materialize; @@ -995,6 +995,7 @@ pub use logical_topn::LogicalTopN; pub use logical_union::LogicalUnion; pub use logical_update::LogicalUpdate; pub use logical_values::LogicalValues; +pub use stream_asof_join::StreamAsOfJoin; pub use stream_cdc_table_scan::StreamCdcTableScan; pub use stream_changelog::StreamChangeLog; pub use stream_dedup::StreamDedup; @@ -1160,6 +1161,7 @@ macro_rules! for_all_plan_nodes { , { Stream, GlobalApproxPercentile } , { Stream, LocalApproxPercentile } , { Stream, RowMerge } + , { Stream, AsOfJoin } } }; } @@ -1289,6 +1291,7 @@ macro_rules! for_stream_plan_nodes { , { Stream, GlobalApproxPercentile } , { Stream, LocalApproxPercentile } , { Stream, RowMerge } + , { Stream, AsOfJoin } } }; } diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs index 7331978a703b..6a6bc7658fa6 100644 --- a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -16,29 +16,32 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use pretty_xmlish::{Pretty, XmlNode}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_pb::plan_common::JoinType; +use risingwave_common::util::sort_util::OrderType; +use risingwave_expr::bail; +use risingwave_pb::expr::expr_node::PbType; +use risingwave_pb::plan_common::{AsOfJoinDesc, AsOfJoinType, JoinType, PbAsOfJoinInequalityType}; use risingwave_pb::stream_plan::stream_node::NodeBody; -use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair}; +use risingwave_pb::stream_plan::AsOfJoinNode; -use super::generic::Join; +use super::generic::GenericPlanNode; use super::stream::prelude::*; -use super::utils::{childless_record, plan_node_name, watermark_pretty, Distill}; -use super::{ - generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode, +use super::utils::{ + childless_record, plan_node_name, watermark_pretty, Distill, TableCatalogBuilder, }; -use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputPair}; +use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode}; +use crate::error::{ErrorCode, Result}; +use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; use crate::optimizer::property::{Distribution, MonotonicityMap}; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; +use crate::TableCatalog; -/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table -/// from inner (right-side) relation and probes with data from outer (left-side) relation to -/// get output rows. +/// [`StreamAsOfJoin`] implements [`super::LogicalJoin`] with hash tables. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct StreamHashJoin { +pub struct StreamAsOfJoin { pub base: PlanBase, core: generic::Join, @@ -46,17 +49,20 @@ pub struct StreamHashJoin { /// non-equal parts to facilitate execution later eq_join_predicate: EqJoinPredicate, - /// Whether can optimize for append-only stream. + /// Whether can optimize for append-only stream. /// It is true if input of both side is append-only is_append_only: bool, - /// `(do_state_cleaning, InequalityInputPair {key_required_larger, key_required_smaller, - /// delta_expression})`. View struct `InequalityInputPair` for details. - inequality_pairs: Vec<(bool, InequalityInputPair)>, + /// inequality description + inequality_desc: AsOfJoinDesc, } -impl StreamHashJoin { - pub fn new(core: generic::Join, eq_join_predicate: EqJoinPredicate) -> Self { +impl StreamAsOfJoin { + pub fn new( + core: generic::Join, + eq_join_predicate: EqJoinPredicate, + inequality_desc: AsOfJoinDesc, + ) -> Self { // Inner join won't change the append-only behavior of the stream. The rest might. let append_only = match core.join_type { JoinType::Inner => core.left.append_only() && core.right.append_only(), @@ -66,7 +72,7 @@ impl StreamHashJoin { let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core); // TODO: derive watermarks - let watermark_columns = FixedBitSet::new(); + let watermark_columns = FixedBitSet::with_capacity(core.schema().len()); // TODO: derive from input let base = PlanBase::new_stream_with_core( @@ -82,8 +88,54 @@ impl StreamHashJoin { base, core, eq_join_predicate, - inequality_pairs, is_append_only: append_only, + inequality_desc, + } + } + + pub fn get_inequality_desc_from_predicate( + predicate: EqJoinPredicate, + left_input_len: usize, + ) -> Result { + if predicate.eq_keys().is_empty() { + Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires at least 1 equal condition".to_string(), + ) + .into()) + } else { + let expr: ExprImpl = predicate.other_cond().clone().into(); + if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() { + if left_input_ref.index() < left_input_len + && right_input_ref.index() >= left_input_len + { + Ok(AsOfJoinDesc { + left_idx: left_input_ref.index() as u32, + right_idx: (right_input_ref.index() - left_input_len) as u32, + inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(), + }) + } else { + bail!("inequal condition from the same side should be push down in optimizer"); + } + } else { + Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires exactly 1 ineuqual condition".to_string(), + ) + .into()) + } + } + } + + fn expr_type_to_comparison_type(expr_type: PbType) -> Result { + match expr_type { + PbType::LessThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLt), + PbType::LessThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeLe), + PbType::GreaterThan => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGt), + PbType::GreaterThanOrEqual => Ok(PbAsOfJoinInequalityType::AsOfInequalityTypeGe), + _ => Err(ErrorCode::InvalidInputSyntax(format!( + "Invalid comparison type: {}", + expr_type.as_str_name() + )) + .into()), } } @@ -104,31 +156,23 @@ impl StreamHashJoin { ) -> Distribution { match (left, right) { (Distribution::Single, Distribution::Single) => Distribution::Single, - (Distribution::HashShard(_), Distribution::HashShard(_)) => { - // we can not derive the hash distribution from the side where outer join can - // generate a NULL row - match logical.join_type { - JoinType::Unspecified| JoinType::AsofInner - | JoinType::AsofLeftOuter => unreachable!(), - JoinType::FullOuter => Distribution::SomeShard, - JoinType::Inner - | JoinType::LeftOuter - | JoinType::LeftSemi - | JoinType::LeftAnti - => { - let l2o = logical - .l2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - l2o.rewrite_provided_distribution(left) - } - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { - let r2o = logical - .r2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - r2o.rewrite_provided_distribution(right) - } + (Distribution::HashShard(_), Distribution::HashShard(_)) => match logical.join_type { + JoinType::Unspecified + | JoinType::FullOuter + | JoinType::Inner + | JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightOuter => unreachable!(), + JoinType::AsofInner | JoinType::AsofLeftOuter => { + let l2o = logical + .l2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + l2o.rewrite_provided_distribution(left) } - } + }, (_, _) => unreachable!( "suspicious distribution: left: {:?}, right: {:?}", left, right @@ -136,11 +180,6 @@ impl StreamHashJoin { } } - /// Convert this hash join to a delta join plan - pub fn into_delta_join(self) -> StreamDeltaJoin { - StreamDeltaJoin::new(self.core, self.eq_join_predicate) - } - pub fn derive_dist_key_in_join_key(&self) -> Vec { let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); @@ -164,12 +203,58 @@ impl StreamHashJoin { dk_indices_in_jk } - pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> { - &self.inequality_pairs + /// Return stream hash join internal table catalog. + pub fn infer_internal_table_catalog( + input: I, + join_key_indices: Vec, + dk_indices_in_jk: Vec, + inequality_key_idx: usize, + ) -> (TableCatalog, Vec) { + let schema = input.schema(); + + let internal_table_dist_keys = dk_indices_in_jk + .iter() + .map(|idx| join_key_indices[*idx]) + .collect_vec(); + + // The pk of AsOf join internal table should be join_key + inequality_key + input_pk. + let join_key_len = join_key_indices.len(); + let mut pk_indices = join_key_indices; + + // dedup the pk in dist key.. + let mut deduped_input_pk_indices = vec![]; + for input_pk_idx in input.stream_key().unwrap() { + if !pk_indices.contains(input_pk_idx) + && !deduped_input_pk_indices.contains(input_pk_idx) + { + deduped_input_pk_indices.push(*input_pk_idx); + } + } + + pk_indices.push(inequality_key_idx); + pk_indices.extend(deduped_input_pk_indices.clone()); + + // Build internal table + let mut internal_table_catalog_builder = TableCatalogBuilder::default(); + let internal_columns_fields = schema.fields().to_vec(); + + internal_columns_fields.iter().for_each(|field| { + internal_table_catalog_builder.add_column(field); + }); + pk_indices.iter().for_each(|idx| { + internal_table_catalog_builder.add_order_column(*idx, OrderType::ascending()) + }); + + internal_table_catalog_builder.set_dist_key_in_pk(dk_indices_in_jk.clone()); + + ( + internal_table_catalog_builder.build(internal_table_dist_keys, join_key_len), + deduped_input_pk_indices, + ) } } -impl Distill for StreamHashJoin { +impl Distill for StreamAsOfJoin { fn distill<'a>(&self) -> XmlNode<'a> { let (ljk, rjk) = self .eq_join_predicate @@ -178,9 +263,8 @@ impl Distill for StreamHashJoin { .cloned() .expect("first join key"); - let name = plan_node_name!("StreamHashJoin", + let name = plan_node_name!("StreamAsOfJoin", { "window", self.left().watermark_columns().contains(ljk) && self.right().watermark_columns().contains(rjk) }, - { "interval", self.clean_left_state_conjunction_idx.is_some() && self.clean_right_state_conjunction_idx.is_some() }, { "append_only", self.is_append_only }, ); let verbose = self.base.ctx().is_explain_verbose(); @@ -196,18 +280,6 @@ impl Distill for StreamHashJoin { }), )); - let get_cond = |conjunction_idx| { - Pretty::debug(&ExprDisplay { - expr: &self.eq_join_predicate().other_cond().conjunctions[conjunction_idx], - input_schema: &concat_schema, - }) - }; - if let Some(i) = self.clean_left_state_conjunction_idx { - vec.push(("conditions_to_clean_left_state_table", get_cond(i))); - } - if let Some(i) = self.clean_right_state_conjunction_idx { - vec.push(("conditions_to_clean_right_state_table", get_cond(i))); - } if let Some(ow) = watermark_pretty(self.base.watermark_columns(), self.schema()) { vec.push(("output_watermarks", ow)); } @@ -221,7 +293,7 @@ impl Distill for StreamHashJoin { } } -impl PlanTreeNodeBinary for StreamHashJoin { +impl PlanTreeNodeBinary for StreamAsOfJoin { fn left(&self) -> PlanRef { self.core.left.clone() } @@ -234,13 +306,13 @@ impl PlanTreeNodeBinary for StreamHashJoin { let mut core = self.core.clone(); core.left = left; core.right = right; - Self::new(core, self.eq_join_predicate.clone()) + Self::new(core, self.eq_join_predicate.clone(), self.inequality_desc) } } -impl_plan_tree_node_for_binary! { StreamHashJoin } +impl_plan_tree_node_for_binary! { StreamAsOfJoin } -impl StreamNode for StreamHashJoin { +impl StreamNode for StreamAsOfJoin { fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> NodeBody { let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); @@ -249,18 +321,18 @@ impl StreamNode for StreamHashJoin { let dk_indices_in_jk = self.derive_dist_key_in_join_key(); - let (left_table, left_degree_table, left_deduped_input_pk_indices) = - Join::infer_internal_and_degree_table_catalog( - self.left().plan_base(), - left_jk_indices, - dk_indices_in_jk.clone(), - ); - let (right_table, right_degree_table, right_deduped_input_pk_indices) = - Join::infer_internal_and_degree_table_catalog( - self.right().plan_base(), - right_jk_indices, - dk_indices_in_jk, - ); + let (left_table, left_deduped_input_pk_indices) = Self::infer_internal_table_catalog( + self.left().plan_base(), + left_jk_indices, + dk_indices_in_jk.clone(), + self.inequality_desc.left_idx as usize, + ); + let (right_table, right_deduped_input_pk_indices) = Self::infer_internal_table_catalog( + self.right().plan_base(), + right_jk_indices, + dk_indices_in_jk, + self.inequality_desc.right_idx as usize, + ); let left_deduped_input_pk_indices = left_deduped_input_pk_indices .iter() @@ -272,66 +344,33 @@ impl StreamNode for StreamHashJoin { .map(|idx| *idx as u32) .collect_vec(); - let (left_table, left_degree_table) = ( - left_table.with_id(state.gen_table_id_wrapped()), - left_degree_table.with_id(state.gen_table_id_wrapped()), - ); - let (right_table, right_degree_table) = ( - right_table.with_id(state.gen_table_id_wrapped()), - right_degree_table.with_id(state.gen_table_id_wrapped()), - ); + let left_table = left_table.with_id(state.gen_table_id_wrapped()); + let right_table = right_table.with_id(state.gen_table_id_wrapped()); let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect(); - NodeBody::HashJoin(HashJoinNode { - join_type: self.core.join_type as i32, + let asof_join_type = match self.core.join_type { + JoinType::AsofInner => AsOfJoinType::Inner, + JoinType::AsofLeftOuter => AsOfJoinType::LeftOuter, + _ => unreachable!(), + }; + + NodeBody::AsOfJoin(AsOfJoinNode { + join_type: asof_join_type.into(), left_key: left_jk_indices_prost, right_key: right_jk_indices_prost, null_safe: null_safe_prost, - condition: self - .eq_join_predicate - .other_cond() - .as_expr_unless_true() - .map(|x| x.to_expr_proto()), - inequality_pairs: self - .inequality_pairs - .iter() - .map( - |( - do_state_clean, - InequalityInputPair { - key_required_larger, - key_required_smaller, - delta_expression, - }, - )| { - PbInequalityPair { - key_required_larger: *key_required_larger as u32, - key_required_smaller: *key_required_smaller as u32, - clean_state: *do_state_clean, - delta_expression: delta_expression.as_ref().map( - |(delta_type, delta)| DeltaExpression { - delta_type: *delta_type as i32, - delta: Some(delta.to_expr_proto()), - }, - ), - } - }, - ) - .collect_vec(), left_table: Some(left_table.to_internal_table_prost()), right_table: Some(right_table.to_internal_table_prost()), - left_degree_table: Some(left_degree_table.to_internal_table_prost()), - right_degree_table: Some(right_degree_table.to_internal_table_prost()), left_deduped_input_pk_indices, right_deduped_input_pk_indices, output_indices: self.core.output_indices.iter().map(|&x| x as u32).collect(), - is_append_only: self.is_append_only, + asof_desc: Some(self.inequality_desc), }) } } -impl ExprRewritable for StreamHashJoin { +impl ExprRewritable for StreamAsOfJoin { fn has_rewritable_expr(&self) -> bool { true } @@ -339,11 +378,17 @@ impl ExprRewritable for StreamHashJoin { fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef { let mut core = self.core.clone(); core.rewrite_exprs(r); - Self::new(core, self.eq_join_predicate.rewrite_exprs(r)).into() + let eq_join_predicate = self.eq_join_predicate.rewrite_exprs(r); + let desc = Self::get_inequality_desc_from_predicate( + eq_join_predicate.clone(), + core.left.schema().len(), + ) + .unwrap(); + Self::new(core, eq_join_predicate, desc).into() } } -impl ExprVisitable for StreamHashJoin { +impl ExprVisitable for StreamAsOfJoin { fn visit_exprs(&self, v: &mut dyn ExprVisitor) { self.core.visit_exprs(v); } diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index 63211d9791ef..4d8951199f74 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -231,14 +231,14 @@ impl StreamHashJoin { // we can not derive the hash distribution from the side where outer join can // generate a NULL row match logical.join_type { - JoinType::Unspecified| JoinType::AsofInner - | JoinType::AsofLeftOuter => unreachable!(), + JoinType::Unspecified | JoinType::AsofInner | JoinType::AsofLeftOuter => { + unreachable!() + } JoinType::FullOuter => Distribution::SomeShard, JoinType::Inner | JoinType::LeftOuter | JoinType::LeftSemi - | JoinType::LeftAnti - => { + | JoinType::LeftAnti => { let l2o = logical .l2i_col_mapping() .composite(&logical.i2o_col_mapping()); diff --git a/src/prost/build.rs b/src/prost/build.rs index 0afbaef2ea73..ee04705ef19e 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -166,6 +166,7 @@ fn main() -> Result<(), Box> { "plan_common.AdditionalCollectionName", "#[derive(Eq, Hash)]", ) + .type_attribute("plan_common.AsOfJoinDesc", "#[derive(Eq, Hash)]") .type_attribute("common.ColumnOrder", "#[derive(Eq, Hash)]") .type_attribute("common.OrderType", "#[derive(Eq, Hash)]") .type_attribute("common.Buffer", "#[derive(Eq)]") diff --git a/src/sqlparser/src/keywords.rs b/src/sqlparser/src/keywords.rs index 8626df2021bc..5a8941483313 100644 --- a/src/sqlparser/src/keywords.rs +++ b/src/sqlparser/src/keywords.rs @@ -616,6 +616,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[ Keyword::LEFT, Keyword::RIGHT, Keyword::NATURAL, + Keyword::ASOF, Keyword::USING, Keyword::CLUSTER, // for MSSQL-specific OUTER APPLY (seems reserved in most dialects) diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index affae211944f..9dd9963363a0 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4634,14 +4634,26 @@ impl Parser<'_> { } } kw @ Keyword::LEFT | kw @ Keyword::RIGHT | kw @ Keyword::FULL => { + let checkpoint = *self; let _ = self.next_token(); let _ = self.parse_keyword(Keyword::OUTER); self.expect_keyword(Keyword::JOIN)?; - match kw { - Keyword::LEFT => JoinOperator::LeftOuter, - Keyword::RIGHT => JoinOperator::RightOuter, - Keyword::FULL => JoinOperator::FullOuter, - _ => unreachable!(), + if asof { + if Keyword::LEFT == kw { + JoinOperator::AsOfLeft + } else { + return self.expected_at( + checkpoint, + "LEFT after ASOF. RIGHT or FULL are not supported", + ); + } + } else { + match kw { + Keyword::LEFT => JoinOperator::LeftOuter, + Keyword::RIGHT => JoinOperator::RightOuter, + Keyword::FULL => JoinOperator::FullOuter, + _ => unreachable!(), + } } } Keyword::OUTER => { @@ -4658,9 +4670,16 @@ impl Parser<'_> { let relation = self.parse_table_factor()?; let join_constraint = self.parse_join_constraint(natural)?; let join_operator = join_operator_type(join_constraint); - if let JoinOperator::Inner(JoinConstraint::None) = join_operator { - return self.expected("join constraint after INNER JOIN"); + let need_constraint = match join_operator { + JoinOperator::Inner(JoinConstraint::None) => Some("INNER JOIN"), + JoinOperator::AsOfInner(JoinConstraint::None) => Some("ASOF INNER JOIN"), + JoinOperator::AsOfLeft(JoinConstraint::None) => Some("ASOF LEFT JOIN"), + _ => None, + }; + if let Some(join_type) = need_constraint { + return self.expected(&format!("join constraint after {join_type}")); } + Join { relation, join_operator, diff --git a/src/sqlparser/tests/testdata/asof_join.yaml b/src/sqlparser/tests/testdata/asof_join.yaml new file mode 100644 index 000000000000..b7ee5b1461b7 --- /dev/null +++ b/src/sqlparser/tests/testdata/asof_join.yaml @@ -0,0 +1,17 @@ +# This file is automatically generated by `src/sqlparser/tests/parser_test.rs`. +- input: SELECT * FROM t1 asof JOIN t2 where t1.v1 = t2.v1 + error_msg: |- + sql parser error: expected join constraint after ASOF INNER JOIN, found: where + LINE 1: SELECT * FROM t1 asof JOIN t2 where t1.v1 = t2.v1 + ^ +- input: SELECT * FROM t1 asof LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + formatted_sql: SELECT * FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 AND t1.v2 > t2.v2 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [Wildcard(None)], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "t1", quote_style: None }]), alias: None, as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "t2", quote_style: None }]), alias: None, as_of: None }, join_operator: AsOfLeft(On(BinaryOp { left: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v1", quote_style: None }]), op: Eq, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v1", quote_style: None }]) }, op: And, right: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v2", quote_style: None }]), op: Gt, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v2", quote_style: None }]) } })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT * FROM t1 asof INNER JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + formatted_sql: SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 AND t1.v2 > t2.v2 + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [Wildcard(None)], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "t1", quote_style: None }]), alias: None, as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "t2", quote_style: None }]), alias: None, as_of: None }, join_operator: AsOfInner(On(BinaryOp { left: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v1", quote_style: None }]), op: Eq, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v1", quote_style: None }]) }, op: And, right: BinaryOp { left: CompoundIdentifier([Ident { value: "t1", quote_style: None }, Ident { value: "v2", quote_style: None }]), op: Gt, right: CompoundIdentifier([Ident { value: "t2", quote_style: None }, Ident { value: "v2", quote_style: None }]) } })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: SELECT * FROM t1 asof RIGHT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + error_msg: |- + sql parser error: expected LEFT after ASOF. RIGHT or FULL are not supported, found: RIGHT + LINE 1: SELECT * FROM t1 asof RIGHT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2 + ^ diff --git a/src/stream/src/from_proto/mod.rs b/src/stream/src/from_proto/mod.rs index 9a51dd10ddfb..5ac5379ca57c 100644 --- a/src/stream/src/from_proto/mod.rs +++ b/src/stream/src/from_proto/mod.rs @@ -67,6 +67,7 @@ use risingwave_storage::StateStore; use self::append_only_dedup::*; use self::approx_percentile::global::*; use self::approx_percentile::local::*; +use self::asof_join::AsOfJoinExecutorBuilder; use self::barrier_recv::*; use self::batch_query::*; use self::cdc_filter::CdcFilterExecutorBuilder; @@ -186,5 +187,6 @@ pub async fn create_executor( NodeBody::GlobalApproxPercentile => GlobalApproxPercentileExecutorBuilder, NodeBody::LocalApproxPercentile => LocalApproxPercentileExecutorBuilder, NodeBody::RowMerge => RowMergeExecutorBuilder, + NodeBody::AsOfJoin => AsOfJoinExecutorBuilder, } } From 4153f2bca20624532305e0f8ecd1cdc9849acfa9 Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Wed, 25 Sep 2024 14:38:08 +0800 Subject: [PATCH 3/8] e2e --- e2e_test/streaming/asof_join.slt | 143 +++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 e2e_test/streaming/asof_join.slt diff --git a/e2e_test/streaming/asof_join.slt b/e2e_test/streaming/asof_join.slt new file mode 100644 index 000000000000..6e35d5aa7d40 --- /dev/null +++ b/e2e_test/streaming/asof_join.slt @@ -0,0 +1,143 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +# asof inner join + +statement ok +create table t1 (v1 int, v2 int, v3 int primary key); + +statement ok +create table t2 (v1 int, v2 int, v3 int primary key); + +statement ok +create materialized view mv1 as SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 <= t2.v2; + +statement ok +insert into t1 values (1, 2, 3); + +statement ok +insert into t2 values (1, 3, 4); + +query III +select * from mv1; +---- +1 2 3 1 3 4 + +statement ok +insert into t2 values (1, 2, 3); + +query III +select * from mv1; +---- +1 2 3 1 2 3 + +statement ok +delete from t1 where v3 = 3; + +query III +select * from mv1; +---- + + +statement ok +insert into t1 values (2, 3, 4); + +statement ok +insert into t2 values (2, 3, 6), (2, 3, 7), (2, 3, 5); + +query III +select * from mv1; +---- +2 3 4 2 3 5 + +statement ok +insert into t2 values (2, 3, 1), (2, 3, 2); + +query III +select * from mv1; +---- +2 3 4 2 3 1 + +statement ok +drop materialized view mv1; + +statement ok +drop table t1; + +statement ok +drop table t2; + + +# asof left join + +statement ok +create table t1 (v1 int, v2 int, v3 int primary key); + +statement ok +create table t2 (v1 int, v2 int, v3 int primary key); + +statement ok +create materialized view mv1 as SELECT t1.v1 t1_v1, t1.v2 t1_v2, t1.v3 t1_v3, t2.v1 t2_v1, t2.v2 t2_v2, t2.v3 t2_v3 FROM t1 ASOF LEFT JOIN t2 ON t1.v1 = t2.v1 and t1.v2 > t2.v2; + +statement ok +insert into t1 values (1, 2, 3); + +statement ok +insert into t2 values (1, 2, 4); + +query III +select * from mv1; +---- +1 2 3 NULL NULL NULL + +statement ok +insert into t2 values (1, 1, 3); + +query III +select * from mv1; +---- +1 2 3 1 1 3 + +statement ok +delete from t1 where v3 = 3; + +query III +select * from mv1; +---- + + +statement ok +insert into t1 values (2, 3, 4); + +statement ok +insert into t2 values (2, 2, 6), (2, 2, 7), (2, 2, 5); + +query III +select * from mv1; +---- +2 3 4 2 2 5 + +statement ok +insert into t2 values (2, 2, 1), (2, 2, 2); + +query III +select * from mv1; +---- +2 3 4 2 2 1 + +statement ok +delete from t2 where v1 = 2; + +query III +select * from mv1; +---- +2 3 4 NULL NULL NULL + +statement ok +drop materialized view mv1; + +statement ok +drop table t1; + +statement ok +drop table t2; From 40b0e8faa0a4ad9222d7c8caf589a8b1e358886c Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Fri, 27 Sep 2024 15:34:44 +0800 Subject: [PATCH 4/8] fix --- .../planner_test/tests/testdata/input/asof_join.yaml | 7 +++++++ .../planner_test/tests/testdata/output/asof_join.yaml | 6 ++++-- src/frontend/src/optimizer/plan_node/logical_join.rs | 8 +++++++- src/frontend/src/optimizer/plan_node/stream_asof_join.rs | 9 +-------- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/asof_join.yaml b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml index 0ab19b56422f..f6ca65716c2e 100644 --- a/src/frontend/planner_test/tests/testdata/input/asof_join.yaml +++ b/src/frontend/planner_test/tests/testdata/input/asof_join.yaml @@ -26,3 +26,10 @@ SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; expected_outputs: - stream_error + +- sql: + CREATE TABLE t1(v1 varchar, v2 int, v3 int); + CREATE TABLE t2(v1 varchar, v2 int, v3 int); + SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2; + expected_outputs: + - stream_error diff --git a/src/frontend/planner_test/tests/testdata/output/asof_join.yaml b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml index 7dd2862065e9..508c9de04f18 100644 --- a/src/frontend/planner_test/tests/testdata/output/asof_join.yaml +++ b/src/frontend/planner_test/tests/testdata/output/asof_join.yaml @@ -1,6 +1,6 @@ # This file is automatically generated. See `src/frontend/planner_test/README.md` for more information. - sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT * FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1; - stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuqual condition' + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition' - sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 || 'a' and t1.v2 > t2.v2; stream_plan: |- StreamMaterialize { columns: [t1_v1, t1_v2, t2_v1, t2_v2, t1._row_id(hidden), t2._row_id(hidden)], stream_key: [t1._row_id, t2._row_id, t1_v1], pk_columns: [t1._row_id, t2._row_id, t1_v1], pk_conflict: NoCheck } @@ -23,4 +23,6 @@ └─StreamExchange { dist: HashShard(t2.v1) } └─StreamTableScan { table: t2, columns: [t2.v1, t2.v2, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } - sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v1 = t2.v1 and t1.v2 < t2.v2 and t1.v3 < t2.v3; - stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuqual condition' + stream_error: 'Invalid input syntax: AsOf join requires exactly 1 ineuquality condition' +- sql: CREATE TABLE t1(v1 varchar, v2 int, v3 int); CREATE TABLE t2(v1 varchar, v2 int, v3 int); SELECT t1.v1 t1_v1, t1.v2 t1_v2, t2.v1 t2_v1, t2.v2 t2_v2 FROM t1 ASOF JOIN t2 ON t1.v2 < t2.v2; + stream_error: 'Invalid input syntax: AsOf join requires at least 1 equal condition' diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 63b2a6041630..07219d5cd071 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -1280,7 +1280,13 @@ impl LogicalJoin { ) -> Result { use super::stream::prelude::*; - assert!(predicate.has_eq()); + if predicate.eq_keys().is_empty() { + return Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires at least 1 equal condition".to_string(), + ) + .into()) + } + let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; let left_len = left.schema().len(); let logical_join = self.clone_with_left_right(left, right); diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs index 6a6bc7658fa6..fa74c8ba6389 100644 --- a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -97,12 +97,6 @@ impl StreamAsOfJoin { predicate: EqJoinPredicate, left_input_len: usize, ) -> Result { - if predicate.eq_keys().is_empty() { - Err(ErrorCode::InvalidInputSyntax( - "AsOf join requires at least 1 equal condition".to_string(), - ) - .into()) - } else { let expr: ExprImpl = predicate.other_cond().clone().into(); if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() { if left_input_ref.index() < left_input_len @@ -118,11 +112,10 @@ impl StreamAsOfJoin { } } else { Err(ErrorCode::InvalidInputSyntax( - "AsOf join requires exactly 1 ineuqual condition".to_string(), + "AsOf join requires exactly 1 ineuquality condition".to_string(), ) .into()) } - } } fn expr_type_to_comparison_type(expr_type: PbType) -> Result { From 45fe957831800377f15276f896aa0158db043dd2 Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Fri, 27 Sep 2024 17:17:47 +0800 Subject: [PATCH 5/8] fmt --- .../src/optimizer/plan_node/logical_join.rs | 2 +- .../optimizer/plan_node/stream_asof_join.rs | 33 +++++++++---------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 07219d5cd071..0f642e3c3e88 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -1284,7 +1284,7 @@ impl LogicalJoin { return Err(ErrorCode::InvalidInputSyntax( "AsOf join requires at least 1 equal condition".to_string(), ) - .into()) + .into()); } let (left, right) = self.get_stream_input_for_hash_join(&predicate, ctx)?; diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs index fa74c8ba6389..3739c0ef8ba9 100644 --- a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -97,25 +97,24 @@ impl StreamAsOfJoin { predicate: EqJoinPredicate, left_input_len: usize, ) -> Result { - let expr: ExprImpl = predicate.other_cond().clone().into(); - if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() { - if left_input_ref.index() < left_input_len - && right_input_ref.index() >= left_input_len - { - Ok(AsOfJoinDesc { - left_idx: left_input_ref.index() as u32, - right_idx: (right_input_ref.index() - left_input_len) as u32, - inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(), - }) - } else { - bail!("inequal condition from the same side should be push down in optimizer"); - } + let expr: ExprImpl = predicate.other_cond().clone().into(); + if let Some((left_input_ref, expr_type, right_input_ref)) = expr.as_comparison_cond() { + if left_input_ref.index() < left_input_len && right_input_ref.index() >= left_input_len + { + Ok(AsOfJoinDesc { + left_idx: left_input_ref.index() as u32, + right_idx: (right_input_ref.index() - left_input_len) as u32, + inequality_type: Self::expr_type_to_comparison_type(expr_type)?.into(), + }) } else { - Err(ErrorCode::InvalidInputSyntax( - "AsOf join requires exactly 1 ineuquality condition".to_string(), - ) - .into()) + bail!("inequal condition from the same side should be push down in optimizer"); } + } else { + Err(ErrorCode::InvalidInputSyntax( + "AsOf join requires exactly 1 ineuquality condition".to_string(), + ) + .into()) + } } fn expr_type_to_comparison_type(expr_type: PbType) -> Result { From f5abfdd4fc057318e0781ae1bebd3c2de8c24daf Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Sun, 29 Sep 2024 13:07:43 +0800 Subject: [PATCH 6/8] try bump node --- .github/workflows/dashboard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dashboard.yml b/.github/workflows/dashboard.yml index 759a9bb83885..64655f1c825b 100644 --- a/.github/workflows/dashboard.yml +++ b/.github/workflows/dashboard.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: - node-version: 18 + node-version: 20 - uses: arduino/setup-protoc@v3 with: version: "3.x" From 878d4828bd51f1f093b7ee0f04b4bcae1b77829a Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Sun, 29 Sep 2024 14:21:59 +0800 Subject: [PATCH 7/8] revert bump --- .github/workflows/dashboard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dashboard.yml b/.github/workflows/dashboard.yml index 64655f1c825b..759a9bb83885 100644 --- a/.github/workflows/dashboard.yml +++ b/.github/workflows/dashboard.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: - node-version: 20 + node-version: 18 - uses: arduino/setup-protoc@v3 with: version: "3.x" From bc8991fea921d37023d6d3a41a16469d72e9a601 Mon Sep 17 00:00:00 2001 From: Yuhao Su Date: Mon, 30 Sep 2024 17:34:40 +0800 Subject: [PATCH 8/8] improve --- src/frontend/src/optimizer/plan_node/mod.rs | 2 + .../optimizer/plan_node/stream_asof_join.rs | 73 ++++----------- .../optimizer/plan_node/stream_delta_join.rs | 2 +- .../optimizer/plan_node/stream_hash_join.rs | 73 +++------------ .../optimizer/plan_node/stream_join_common.rs | 88 +++++++++++++++++++ 5 files changed, 122 insertions(+), 116 deletions(-) create mode 100644 src/frontend/src/optimizer/plan_node/stream_join_common.rs diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index 85c4e3066f7c..0ec266cd2339 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -899,6 +899,7 @@ mod stream_group_topn; mod stream_hash_agg; mod stream_hash_join; mod stream_hop_window; +mod stream_join_common; mod stream_local_approx_percentile; mod stream_materialize; mod stream_now; @@ -1012,6 +1013,7 @@ pub use stream_group_topn::StreamGroupTopN; pub use stream_hash_agg::StreamHashAgg; pub use stream_hash_join::StreamHashJoin; pub use stream_hop_window::StreamHopWindow; +use stream_join_common::StreamJoinCommon; pub use stream_local_approx_percentile::StreamLocalApproxPercentile; pub use stream_materialize::StreamMaterialize; pub use stream_now::StreamNow; diff --git a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs index 3739c0ef8ba9..f24176916860 100644 --- a/src/frontend/src/optimizer/plan_node/stream_asof_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_asof_join.rs @@ -15,7 +15,6 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use pretty_xmlish::{Pretty, XmlNode}; -use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::bail; use risingwave_pb::expr::expr_node::PbType; @@ -28,15 +27,16 @@ use super::stream::prelude::*; use super::utils::{ childless_record, plan_node_name, watermark_pretty, Distill, TableCatalogBuilder, }; -use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode}; +use super::{ + generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamJoinCommon, StreamNode, +}; use crate::error::{ErrorCode, Result}; use crate::expr::{ExprImpl, ExprRewriter, ExprVisitor}; use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; -use crate::optimizer::property::{Distribution, MonotonicityMap}; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; -use crate::utils::ColIndexMappingRewriteExt; use crate::TableCatalog; /// [`StreamAsOfJoin`] implements [`super::LogicalJoin`] with hash tables. @@ -63,13 +63,19 @@ impl StreamAsOfJoin { eq_join_predicate: EqJoinPredicate, inequality_desc: AsOfJoinDesc, ) -> Self { + assert!(core.join_type == JoinType::AsofInner || core.join_type == JoinType::AsofLeftOuter); + // Inner join won't change the append-only behavior of the stream. The rest might. let append_only = match core.join_type { JoinType::Inner => core.left.append_only() && core.right.append_only(), _ => false, }; - let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core); + let dist = StreamJoinCommon::derive_dist( + core.left.distribution(), + core.right.distribution(), + &core, + ); // TODO: derive watermarks let watermark_columns = FixedBitSet::with_capacity(core.schema().len()); @@ -136,66 +142,23 @@ impl StreamAsOfJoin { self.core.join_type } - /// Get a reference to the batch hash join's eq join predicate. + /// Get a reference to the `AsOf` join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } - pub(super) fn derive_dist( - left: &Distribution, - right: &Distribution, - logical: &generic::Join, - ) -> Distribution { - match (left, right) { - (Distribution::Single, Distribution::Single) => Distribution::Single, - (Distribution::HashShard(_), Distribution::HashShard(_)) => match logical.join_type { - JoinType::Unspecified - | JoinType::FullOuter - | JoinType::Inner - | JoinType::LeftOuter - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::RightSemi - | JoinType::RightAnti - | JoinType::RightOuter => unreachable!(), - JoinType::AsofInner | JoinType::AsofLeftOuter => { - let l2o = logical - .l2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - l2o.rewrite_provided_distribution(left) - } - }, - (_, _) => unreachable!( - "suspicious distribution: left: {:?}, right: {:?}", - left, right - ), - } - } - pub fn derive_dist_key_in_join_key(&self) -> Vec { let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); - let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); - let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); - - assert_eq!(left_jk_indices.len(), right_jk_indices.len()); - let mut dk_indices_in_jk = vec![]; - - for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { - for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { - if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { - dk_indices_in_jk.push(dk_idx_in_jk); - break; - } - } - } - - assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); - dk_indices_in_jk + StreamJoinCommon::get_dist_key_in_join_key( + &left_dk_indices, + &right_dk_indices, + self.eq_join_predicate(), + ) } - /// Return stream hash join internal table catalog. + /// Return stream asof join internal table catalog. pub fn infer_internal_table_catalog( input: I, join_key_indices: Vec, diff --git a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs index f53d4331ae61..84592aee1829 100644 --- a/src/frontend/src/optimizer/plan_node/stream_delta_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_delta_join.rs @@ -86,7 +86,7 @@ impl StreamDeltaJoin { } } - /// Get a reference to the batch hash join's eq join predicate. + /// Get a reference to the delta hash join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index 4d8951199f74..0d7863a247d9 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -15,13 +15,13 @@ use fixedbitset::FixedBitSet; use itertools::Itertools; use pretty_xmlish::{Pretty, XmlNode}; -use risingwave_common::util::iter_util::ZipEqFast; use risingwave_pb::plan_common::JoinType; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair}; use super::generic::Join; use super::stream::prelude::*; +use super::stream_join_common::StreamJoinCommon; use super::utils::{childless_record, plan_node_name, watermark_pretty, Distill}; use super::{ generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode, @@ -30,7 +30,7 @@ use crate::expr::{Expr, ExprDisplay, ExprRewriter, ExprVisitor, InequalityInputP use crate::optimizer::plan_node::expr_visitable::ExprVisitable; use crate::optimizer::plan_node::utils::IndicesDisplay; use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay}; -use crate::optimizer::property::{Distribution, MonotonicityMap}; +use crate::optimizer::property::MonotonicityMap; use crate::stream_fragmenter::BuildFragmentGraphState; use crate::utils::ColIndexMappingRewriteExt; @@ -72,7 +72,11 @@ impl StreamHashJoin { _ => false, }; - let dist = Self::derive_dist(core.left.distribution(), core.right.distribution(), &core); + let dist = StreamJoinCommon::derive_dist( + core.left.distribution(), + core.right.distribution(), + &core, + ); let mut inequality_pairs = vec![]; let mut clean_left_state_conjunction_idx = None; @@ -215,50 +219,11 @@ impl StreamHashJoin { self.core.join_type } - /// Get a reference to the batch hash join's eq join predicate. + /// Get a reference to the hash join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } - pub(super) fn derive_dist( - left: &Distribution, - right: &Distribution, - logical: &generic::Join, - ) -> Distribution { - match (left, right) { - (Distribution::Single, Distribution::Single) => Distribution::Single, - (Distribution::HashShard(_), Distribution::HashShard(_)) => { - // we can not derive the hash distribution from the side where outer join can - // generate a NULL row - match logical.join_type { - JoinType::Unspecified | JoinType::AsofInner | JoinType::AsofLeftOuter => { - unreachable!() - } - JoinType::FullOuter => Distribution::SomeShard, - JoinType::Inner - | JoinType::LeftOuter - | JoinType::LeftSemi - | JoinType::LeftAnti => { - let l2o = logical - .l2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - l2o.rewrite_provided_distribution(left) - } - JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { - let r2o = logical - .r2i_col_mapping() - .composite(&logical.i2o_col_mapping()); - r2o.rewrite_provided_distribution(right) - } - } - } - (_, _) => unreachable!( - "suspicious distribution: left: {:?}, right: {:?}", - left, right - ), - } - } - /// Convert this hash join to a delta join plan pub fn into_delta_join(self) -> StreamDeltaJoin { StreamDeltaJoin::new(self.core, self.eq_join_predicate) @@ -267,24 +232,12 @@ impl StreamHashJoin { pub fn derive_dist_key_in_join_key(&self) -> Vec { let left_dk_indices = self.left().distribution().dist_column_indices().to_vec(); let right_dk_indices = self.right().distribution().dist_column_indices().to_vec(); - let left_jk_indices = self.eq_join_predicate.left_eq_indexes(); - let right_jk_indices = self.eq_join_predicate.right_eq_indexes(); - - assert_eq!(left_jk_indices.len(), right_jk_indices.len()); - - let mut dk_indices_in_jk = vec![]; - - for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { - for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { - if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { - dk_indices_in_jk.push(dk_idx_in_jk); - break; - } - } - } - assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); - dk_indices_in_jk + StreamJoinCommon::get_dist_key_in_join_key( + &left_dk_indices, + &right_dk_indices, + self.eq_join_predicate(), + ) } pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> { diff --git a/src/frontend/src/optimizer/plan_node/stream_join_common.rs b/src/frontend/src/optimizer/plan_node/stream_join_common.rs new file mode 100644 index 000000000000..f44ab8291f44 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_join_common.rs @@ -0,0 +1,88 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_pb::plan_common::JoinType; + +use super::{generic, EqJoinPredicate}; +use crate::optimizer::property::Distribution; +use crate::utils::ColIndexMappingRewriteExt; +use crate::PlanRef; + +pub struct StreamJoinCommon; + +impl StreamJoinCommon { + pub(super) fn get_dist_key_in_join_key( + left_dk_indices: &[usize], + right_dk_indices: &[usize], + eq_join_predicate: &EqJoinPredicate, + ) -> Vec { + let left_jk_indices = eq_join_predicate.left_eq_indexes(); + let right_jk_indices = &eq_join_predicate.right_eq_indexes(); + assert_eq!(left_jk_indices.len(), right_jk_indices.len()); + let mut dk_indices_in_jk = vec![]; + for (l_dk_idx, r_dk_idx) in left_dk_indices.iter().zip_eq_fast(right_dk_indices.iter()) { + for dk_idx_in_jk in left_jk_indices.iter().positions(|idx| idx == l_dk_idx) { + if right_jk_indices[dk_idx_in_jk] == *r_dk_idx { + dk_indices_in_jk.push(dk_idx_in_jk); + break; + } + } + } + assert_eq!(dk_indices_in_jk.len(), left_dk_indices.len()); + dk_indices_in_jk + } + + pub(super) fn derive_dist( + left: &Distribution, + right: &Distribution, + logical: &generic::Join, + ) -> Distribution { + match (left, right) { + (Distribution::Single, Distribution::Single) => Distribution::Single, + (Distribution::HashShard(_), Distribution::HashShard(_)) => { + // we can not derive the hash distribution from the side where outer join can + // generate a NULL row + match logical.join_type { + JoinType::Unspecified => { + unreachable!() + } + JoinType::FullOuter => Distribution::SomeShard, + JoinType::Inner + | JoinType::LeftOuter + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::AsofInner + | JoinType::AsofLeftOuter => { + let l2o = logical + .l2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + l2o.rewrite_provided_distribution(left) + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightOuter => { + let r2o = logical + .r2i_col_mapping() + .composite(&logical.i2o_col_mapping()); + r2o.rewrite_provided_distribution(right) + } + } + } + (_, _) => unreachable!( + "suspicious distribution: left: {:?}, right: {:?}", + left, right + ), + } + } +}