Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan LATERAL subqueries #11456

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ pub struct PlannerContext {
ctes: HashMap<String, Arc<LogicalPlan>>,
/// The query schema of the outer query plan, used to resolve the columns in subquery
outer_query_schema: Option<DFSchemaRef>,
/// The joined schemas of all FROM clauses planned so far. When planning LATERAL
/// FROM clauses, this should become a suffix of the `outer_query_schema`.
outer_from_schema: Option<DFSchemaRef>,
}

impl Default for PlannerContext {
Expand All @@ -150,6 +153,7 @@ impl PlannerContext {
prepare_param_data_types: Arc::new(vec![]),
ctes: HashMap::new(),
outer_query_schema: None,
outer_from_schema: None,
}
}

Expand Down Expand Up @@ -177,6 +181,29 @@ impl PlannerContext {
schema
}

// return a clone of the outer FROM schema
pub fn outer_from_schema(&self) -> Option<Arc<DFSchema>> {
self.outer_from_schema.clone()
}

/// sets the outer FROM schema, returning the existing one, if any
pub fn set_outer_from_schema(
&mut self,
mut schema: Option<DFSchemaRef>,
) -> Option<DFSchemaRef> {
std::mem::swap(&mut self.outer_from_schema, &mut schema);
schema
}

/// extends the FROM schema, returning the existing one, if any
pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> {
self.outer_from_schema = match self.outer_from_schema.as_ref() {
Some(from_schema) => Some(Arc::new(from_schema.join(schema)?)),
Copy link
Contributor Author

@aalexandrov aalexandrov Jul 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment.

Actually, looking at the Postgres behavior it seems that keeping join here is better (the j1_id column reference is then reported as ambiguous):

SELECT j1_string, j2_string FROM
  j1 AS x JOIN j1 as y USING(j1_string)
  LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);

None => Some(Arc::clone(schema)),
};
Ok(())
}

/// Return the types of parameters (`$1`, `$2`, etc) if known
pub fn prepare_param_data_types(&self) -> &[DataType] {
&self.prepare_param_data_types
Expand Down
49 changes: 45 additions & 4 deletions datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{not_impl_err, Column, Result};
use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins};
use std::collections::HashSet;

impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expand All @@ -27,10 +27,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
t: TableWithJoins,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let mut left = self.create_relation(t.relation, planner_context)?;
for join in t.joins.into_iter() {
let mut left = if is_lateral(&t.relation) {
self.create_relation_subquery(t.relation, planner_context)?
} else {
self.create_relation(t.relation, planner_context)?
};
let old_outer_from_schema = planner_context.outer_from_schema();
for join in t.joins {
planner_context.extend_outer_from_schema(left.schema())?;
left = self.parse_relation_join(left, join, planner_context)?;
}
planner_context.set_outer_from_schema(old_outer_from_schema);
Ok(left)
}

Expand All @@ -40,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
join: Join,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let right = self.create_relation(join.relation, planner_context)?;
let right = if is_lateral_join(&join)? {
self.create_relation_subquery(join.relation, planner_context)?
} else {
self.create_relation(join.relation, planner_context)?
};
match join.join_operator {
JoinOperator::LeftOuter(constraint) => {
self.parse_join(left, right, constraint, JoinType::Left, planner_context)
Expand Down Expand Up @@ -144,3 +155,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
}

/// Return `true` iff the given [`TableFactor`] is lateral.
pub(crate) fn is_lateral(factor: &TableFactor) -> bool {
match factor {
TableFactor::Derived { lateral, .. } => *lateral,
TableFactor::Function { lateral, .. } => *lateral,
_ => false,
}
}

/// Return `true` iff the given [`Join`] is lateral.
pub(crate) fn is_lateral_join(join: &Join) -> Result<bool> {
let is_lateral_syntax = is_lateral(&join.relation);
let is_apply_syntax = match join.join_operator {
JoinOperator::FullOuter(..)
| JoinOperator::RightOuter(..)
| JoinOperator::RightAnti(..)
| JoinOperator::RightSemi(..)
if is_lateral_syntax =>
{
return not_impl_err!(
"LATERAL syntax is not supported for \
FULL OUTER and RIGHT [OUTER | ANTI | SEMI] joins"
);
}
JoinOperator::CrossApply | JoinOperator::OuterApply => true,
_ => false,
};
Ok(is_lateral_syntax || is_apply_syntax)
}
51 changes: 51 additions & 0 deletions datafusion/sql/src/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference};
use datafusion_expr::builder::subquery_alias;
use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion_expr::{Subquery, SubqueryAlias};
use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};

mod join;
Expand Down Expand Up @@ -153,6 +157,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(optimized_plan)
}
}

pub(crate) fn create_relation_subquery(
&self,
subquery: TableFactor,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
// At this point for a syntacitally valid query the outer_from_schema is
// guaranteed to be set, so the `.unwrap()` call will never panic. This
// is the case because we only call this method for lateral table
// factors, and those can never be the first factor in a FROM list. This
// means we arrived here through the `for` loop in `plan_from_tables` or
// the `for` loop in `plan_table_with_joins`.
let old_from_schema = planner_context
.set_outer_from_schema(None)
.unwrap_or_else(|| Arc::new(DFSchema::empty()));
let new_query_schema = match planner_context.outer_query_schema() {
Some(old_query_schema) => {
let mut new_query_schema = old_from_schema.as_ref().clone();
new_query_schema.merge(old_query_schema);
Some(Arc::new(new_query_schema))
}
None => Some(Arc::clone(&old_from_schema)),
};
let old_query_schema = planner_context.set_outer_query_schema(new_query_schema);

let plan = self.create_relation(subquery, planner_context)?;
let outer_ref_columns = plan.all_out_ref_exprs();

planner_context.set_outer_query_schema(old_query_schema);
planner_context.set_outer_from_schema(Some(old_from_schema));

match plan {
LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => {
subquery_alias(
LogicalPlan::Subquery(Subquery {
subquery: input,
outer_ref_columns,
}),
alias,
)
}
plan => Ok(LogicalPlan::Subquery(Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
})),
}
}
}

fn optimize_subquery_sort(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down
31 changes: 21 additions & 10 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,19 +496,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
match from.len() {
0 => Ok(LogicalPlanBuilder::empty(true).build()?),
1 => {
let from = from.remove(0);
self.plan_table_with_joins(from, planner_context)
let input = from.remove(0);
self.plan_table_with_joins(input, planner_context)
}
_ => {
let mut plans = from
.into_iter()
.map(|t| self.plan_table_with_joins(t, planner_context));

let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?);

for right in plans {
left = left.cross_join(right?)?;
let mut from = from.into_iter();

let mut left = LogicalPlanBuilder::from({
let input = from.next().unwrap();
self.plan_table_with_joins(input, planner_context)?
});
let old_outer_from_schema = {
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema)
};
for input in from {
// Join `input` with the current result (`left`).
let right = self.plan_table_with_joins(input, planner_context)?;
left = left.cross_join(right)?;
// Update the outer FROM schema.
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema);
}
planner_context.set_outer_from_schema(old_outer_from_schema);

Ok(left.build()?)
}
}
Expand Down
108 changes: 108 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3103,6 +3103,114 @@ fn join_on_complex_condition() {
quick_test(sql, expected);
}

#[test]
fn lateral_constant() {
let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2";
let expected = "Projection: *\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: Int64(1)\
\n EmptyRelation";
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join() {
let sql = "SELECT j1_string, j2_string FROM
j1, \
LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2";
let expected = "Projection: j1.j1_string, j2.j2_string\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: *\
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join_referencing_join_rhs() {
let sql = "SELECT * FROM\
\n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\
\n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;";
let expected = "Projection: *\
\n CrossJoin:\
\n Inner Join: Filter: j1.j1_id = j2.j2_id\
\n TableScan: j1\
\n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\
\n TableScan: j2\
\n TableScan: j3\
\n SubqueryAlias: j4\
\n Subquery:\
\n Projection: *\
\n Filter: j3.j3_string = outer_ref(j2.j2_string)\
\n TableScan: j3";
quick_test(sql, expected);
}

#[test]
fn lateral_comma_join_with_shadowing() {
// The j1_id on line 3 references the (closest) j1 definition from line 2.
let sql = "\
SELECT * FROM j1, LATERAL (\
SELECT * FROM j1, LATERAL (\
SELECT * FROM j2 WHERE j1_id = j2_id\
) as j2\
) as j2;";
let expected = "Projection: *\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: *\
\n CrossJoin:\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: *\
\n Filter: outer_ref(j1.j1_id) = j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_left_join() {
let sql = "SELECT j1_string, j2_string FROM \
j1 \
LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);";
let expected = "Projection: j1.j1_string, j2.j2_string\
\n Left Join: Filter: Boolean(true)\
\n TableScan: j1\
\n SubqueryAlias: j2\
\n Subquery:\
\n Projection: *\
\n Filter: outer_ref(j1.j1_id) < j2.j2_id\
\n TableScan: j2";
quick_test(sql, expected);
}

#[test]
fn lateral_nested_left_join() {
let sql = "SELECT * FROM
j1, \
(j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))";
let expected = "Projection: *\
\n CrossJoin:\
\n TableScan: j1\
\n Left Join: Filter: Boolean(true)\
\n TableScan: j2\
\n SubqueryAlias: j3\
\n Subquery:\
\n Projection: *\
\n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\
\n TableScan: j3";
quick_test(sql, expected);
}

#[test]
fn hive_aggregate_with_filter() -> Result<()> {
let dialect = &HiveDialect {};
Expand Down
Loading