From 35e34d4d7581c150acb2016477a6115e0de9987c Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Wed, 18 Jan 2023 02:53:25 -0500 Subject: [PATCH] Fix column indices in EnforceDistribution optimizer in Partial AggregateMode (#4878) (#4959) --- .../physical_optimizer/dist_enforcement.rs | 25 +++++--- datafusion/core/tests/sql/joins.rs | 58 +++++++++++++++++++ 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index aa8b07569cd5..cc94ad14a01f 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -431,11 +431,22 @@ fn reorder_aggregate_keys( None }; if let Some(partial_agg) = new_partial_agg { - let mut new_group_exprs = vec![]; - for idx in positions.into_iter() { - new_group_exprs.push(group_by.expr()[idx].clone()); - } - let new_group_by = PhysicalGroupBy::new_single(new_group_exprs); + // Build new group expressions that correspond to the output of partial_agg + let new_final_group: Vec> = + partial_agg.output_group_expr(); + let new_group_by = PhysicalGroupBy::new_single( + new_final_group + .iter() + .enumerate() + .map(|(i, expr)| { + ( + expr.clone(), + partial_agg.group_expr().expr()[i].1.clone(), + ) + }) + .collect(), + ); + let new_final_agg = Arc::new(AggregateExec::try_new( AggregateMode::FinalPartitioned, new_group_by, @@ -1494,7 +1505,7 @@ mod tests { let expected = &[ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)", "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -2057,7 +2068,7 @@ mod tests { "SortExec: [b3@1 ASC,a3@0 ASC]", "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", - "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]", + "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)", "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]", diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 1de20c29cd07..db5c706d3353 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2810,3 +2810,61 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()> { + // Regression test for GH #4873 + let col1 = Arc::new(StringArray::from(vec![ + "A", "A", "A", "A", "A", "A", "A", "A", "BB", "BB", "BB", "BB", + ])) as ArrayRef; + + let col2 = + Arc::new(UInt64Array::from(vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) as ArrayRef; + + let col3 = + Arc::new(UInt64Array::from(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Utf8, true), + Field::new("col2", DataType::UInt64, true), + Field::new("col3", DataType::UInt64, true), + ])) as SchemaRef; + + let batch = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap(); + let mem_table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + + // Create context and register table + let ctx = SessionContext::new(); + ctx.register_table("tbl", Arc::new(mem_table)).unwrap(); + + let sql = "select col1, col2, coalesce(sum_col3, 0) as sum_col3 \ + from (select distinct col2 from tbl) AS q1 \ + cross join (select distinct col1 from tbl) AS q2 \ + left outer join (SELECT col1, col2, sum(col3) as sum_col3 FROM tbl GROUP BY col1, col2) AS q3 \ + USING(col2, col1) \ + ORDER BY col1, col2"; + + let expected = vec![ + "+------+------+----------+", + "| col1 | col2 | sum_col3 |", + "+------+------+----------+", + "| A | 1 | 2 |", + "| A | 2 | 2 |", + "| A | 3 | 2 |", + "| A | 4 | 2 |", + "| A | 5 | 0 |", + "| A | 6 | 0 |", + "| BB | 1 | 0 |", + "| BB | 2 | 0 |", + "| BB | 3 | 0 |", + "| BB | 4 | 0 |", + "| BB | 5 | 2 |", + "| BB | 6 | 2 |", + "+------+------+----------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +}