From 67d0c2e38011cd883059e3a9fd0ea08088661707 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 23 Jan 2021 06:56:47 -0500 Subject: [PATCH] ARROW-11319: [Rust] [DataFusion] Improve test comparisons to record batch, remove test::format_batch The `test::format_batch` function does not have wide range of type support (e.g. it doesn't support dictionaries) and its output makes tests hard to read / update, in my opinion. This PR consolidates the datafusion tests to use `arrow::util::pretty::pretty_format_batches` both to reduce code duplication as well as increase type support This PR removes the `test::format_batch(&batch);` function and replaces it with `arrow::util::pretty::pretty_format_batches` and some macros. It has no code changes. This change the following benefits: 1. Better type support (I immediately can compare RecordBatches with `Dictionary` types in tests without having to update `format_batch` and https://github.com/apache/arrow/pull/9233 gets simpler) 2. Better readability and error reporting (at least I find the code and diffs easier to understand) 3. Easier test update / review: it is easier to update the diffs (you can copy/paste the test output into the source code) and to review them This is a variant of a strategy that I been using with success in IOx [source link](https://github.com/influxdata/influxdb_iox/blob/main/arrow_deps/src/test_util.rs#L15) and I wanted to contribute it back. An example failure with this PR: ``` ---- physical_plan::hash_join::tests::join_left_one stdout ---- thread 'physical_plan::hash_join::tests::join_left_one' panicked at 'assertion failed: `(left == right)` left: `["+----+----+----+----+", "| a1 | b2 | c1 | c2 |", "+----+----+----+----+", "| 1 | 1 | 7 | 70 |", "| 2 | 2 | 8 | 80 |", "| 2 | 2 | 9 | 80 |", "+----+----+----+----+"]`, right: `["+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | c2 |", "+----+----+----+----+----+", "| 1 | 4 | 7 | 10 | 70 |", "| 2 | 5 | 8 | 20 | 80 |", "| 3 | 7 | 9 | | |", "+----+----+----+----+----+"]`: expected: [ "+----+----+----+----+", "| a1 | b2 | c1 | c2 |", "+----+----+----+----+", "| 1 | 1 | 7 | 70 |", "| 2 | 2 | 8 | 80 |", "| 2 | 2 | 9 | 80 |", "+----+----+----+----+", ] actual: [ "+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | c2 |", "+----+----+----+----+----+", "| 1 | 4 | 7 | 10 | 70 |", "| 2 | 5 | 8 | 20 | 80 |", "| 3 | 7 | 9 | | |", "+----+----+----+----+----+", ] ``` You can copy/paste the output of `actual` directly into the test code for an update. Closes #9264 from alamb/remove_test_format_batch Authored-by: Andrew Lamb Signed-off-by: Andrew Lamb --- rust/datafusion/src/execution/context.rs | 500 +++++++++++------- .../src/physical_plan/hash_aggregate.rs | 37 +- .../datafusion/src/physical_plan/hash_join.rs | 122 +++-- rust/datafusion/src/test/mod.rs | 236 +++------ 4 files changed, 475 insertions(+), 420 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 5a036935ec73a..5600c55521c15 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -595,16 +595,19 @@ impl FunctionRegistry for ExecutionContextState { mod tests { use super::*; - use crate::logical_plan::{col, create_udf, sum}; use crate::physical_plan::functions::ScalarFunctionImplementation; use crate::physical_plan::{collect, collect_partitioned}; use crate::test; use crate::variable::VarType; + use crate::{ + assert_batches_eq, assert_batches_sorted_eq, + logical_plan::{col, create_udf, sum}, + }; use crate::{ datasource::MemTable, logical_plan::create_udaf, physical_plan::expressions::AvgAccumulator, }; - use arrow::array::{ArrayRef, Float64Array, Int32Array, StringArray}; + use arrow::array::{ArrayRef, Float64Array, Int32Array}; use arrow::compute::add; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; @@ -626,10 +629,56 @@ mod tests { for batch in &results { assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 10); - - assert_eq!(field_names(batch), vec!["c1", "c2"]); } + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 3 | 1 |", + "| 3 | 2 |", + "| 3 | 3 |", + "| 3 | 4 |", + "| 3 | 5 |", + "| 3 | 6 |", + "| 3 | 7 |", + "| 3 | 8 |", + "| 3 | 9 |", + "| 3 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) } @@ -650,24 +699,14 @@ mod tests { let results = plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?; - let batch = &results[0]; - assert_eq!(2, batch.num_columns()); - assert_eq!(1, batch.num_rows()); - assert_eq!(field_names(batch), vec!["@@version", "@name"]); - - let version = batch - .column(0) - .as_any() - .downcast_ref::() - .expect("failed to cast version"); - assert_eq!(version.value(0), "system-var-@@version"); - - let name = batch - .column(1) - .as_any() - .downcast_ref::() - .expect("failed to cast name"); - assert_eq!(name.value(0), "user-defined-var-@name"); + let expected = vec![ + "+----------------------+------------------------+", + "| @@version | @name |", + "+----------------------+------------------------+", + "| system-var-@@version | user-defined-var-@name |", + "+----------------------+------------------------+", + ]; + assert_batches_eq!(expected, &results); Ok(()) } @@ -703,6 +742,35 @@ mod tests { assert_eq!(2, num_batches); assert_eq!(20, num_rows); + let results: Vec = results.into_iter().flatten().collect(); + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 10 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 2 | 1 |", + "| 2 | 10 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) } @@ -841,14 +909,56 @@ mod tests { execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; let expected: Vec<&str> = vec![ - "3,1", "3,2", "3,3", "3,4", "3,5", "3,6", "3,7", "3,8", "3,9", "3,10", "2,1", - "2,2", "2,3", "2,4", "2,5", "2,6", "2,7", "2,8", "2,9", "2,10", "1,1", "1,2", - "1,3", "1,4", "1,5", "1,6", "1,7", "1,8", "1,9", "1,10", "0,1", "0,2", "0,3", - "0,4", "0,5", "0,6", "0,7", "0,8", "0,9", "0,10", + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 3 | 1 |", + "| 3 | 2 |", + "| 3 | 3 |", + "| 3 | 4 |", + "| 3 | 5 |", + "| 3 | 6 |", + "| 3 | 7 |", + "| 3 | 8 |", + "| 3 | 9 |", + "| 3 | 10 |", + "| 2 | 1 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "| 2 | 10 |", + "| 1 | 1 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 1 | 10 |", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", ]; - assert_eq!(test::format_batch(batch), expected); + + // Note it is important to NOT use assert_batches_sorted_eq + // here as we are testing the sortedness of the output + assert_batches_eq!(expected, &results); Ok(()) } @@ -871,14 +981,14 @@ mod tests { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]); - - let expected: Vec<&str> = vec!["60,220"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------+---------+", + "| SUM(c1) | SUM(c2) |", + "+---------+---------+", + "| 60 | 220 |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -891,12 +1001,15 @@ mod tests { .unwrap(); assert_eq!(results.len(), 1); - let batch = &results[0]; - let expected: Vec<&str> = vec!["NULL,NULL"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------+---------+", + "| SUM(c1) | SUM(c2) |", + "+---------+---------+", + "| | |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -906,14 +1019,14 @@ mod tests { let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]); - - let expected: Vec<&str> = vec!["1.5,5.5"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------+---------+", + "| AVG(c1) | AVG(c2) |", + "+---------+---------+", + "| 1.5 | 5.5 |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -923,14 +1036,14 @@ mod tests { let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]); - - let expected: Vec<&str> = vec!["3,10"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------+---------+", + "| MAX(c1) | MAX(c2) |", + "+---------+---------+", + "| 3 | 10 |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -940,14 +1053,14 @@ mod tests { let results = execute("SELECT MIN(c1), MIN(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]); - - let expected: Vec<&str> = vec!["0,1"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------+---------+", + "| MIN(c1) | MIN(c2) |", + "+---------+---------+", + "| 0 | 1 |", + "+---------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -957,14 +1070,17 @@ mod tests { let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]); - - let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+----+---------+", + "| c1 | SUM(c2) |", + "+----+---------+", + "| 0 | 55 |", + "| 1 | 55 |", + "| 2 | 55 |", + "| 3 | 55 |", + "+----+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -974,14 +1090,17 @@ mod tests { let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); - - let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+----+---------+", + "| c1 | AVG(c2) |", + "+----+---------+", + "| 0 | 5.5 |", + "| 1 | 5.5 |", + "| 2 | 5.5 |", + "| 3 | 5.5 |", + "+----+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -992,14 +1111,8 @@ mod tests { execute("SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); - - let expected: Vec<&str> = vec![]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec!["++", "||", "++", "++"]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1009,14 +1122,17 @@ mod tests { let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]); - - let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+----+---------+", + "| c1 | MAX(c2) |", + "+----+---------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1026,14 +1142,17 @@ mod tests { let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]); - - let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+----+---------+", + "| c1 | MIN(c2) |", + "+----+---------+", + "| 0 | 1 |", + "| 1 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "+----+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1043,14 +1162,14 @@ mod tests { let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); - - let expected: Vec<&str> = vec!["10,10"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+-----------+-----------+", + "| COUNT(c1) | COUNT(c2) |", + "+-----------+-----------+", + "| 10 | 10 |", + "+-----------+-----------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1059,14 +1178,14 @@ mod tests { let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); - - let expected: Vec<&str> = vec!["40,40"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+-----------+-----------+", + "| COUNT(c1) | COUNT(c2) |", + "+-----------+-----------+", + "| 40 | 40 |", + "+-----------+-----------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1075,14 +1194,17 @@ mod tests { let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]); - - let expected = vec!["0,10", "1,10", "2,10", "3,10"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+----+-----------+", + "| c1 | COUNT(c2) |", + "+----+-----------+", + "| 0 | 10 |", + "| 1 | 10 |", + "| 2 | 10 |", + "| 3 | 10 |", + "+----+-----------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1124,15 +1246,15 @@ mod tests { ).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - - assert_eq!(field_names(batch), vec!["week", "SUM(c2)"]); - - let expected: Vec<&str> = - vec!["2020-12-07T00:00:00,24", "2020-12-14T00:00:00,156"]; - let mut rows = test::format_batch(&batch); - rows.sort(); - assert_eq!(rows, expected); + let expected = vec![ + "+---------------------+---------+", + "| week | SUM(c2) |", + "+---------------------+---------+", + "| 2020-12-07 00:00:00 | 24 |", + "| 2020-12-14 00:00:00 | 156 |", + "+---------------------+---------+", + ]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1221,20 +1343,17 @@ mod tests { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - assert_eq!(batch.num_rows(), 3); - assert_eq!(batch.num_columns(), 10); - let mut result = test::format_batch(&batch); - result.sort_unstable(); - - assert_eq!( - result, - vec![ - "a,3,2,2,2,2,2,2,2,2", - "b,1,1,1,1,1,1,1,1,1", - "c,3,2,2,2,2,2,2,2,2", - ], - ); + let expected = vec! +[ + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "| c | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", +]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1252,19 +1371,16 @@ mod tests { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; assert_eq!(results.len(), 1); - let batch = &results[0]; - assert_eq!(batch.num_rows(), 3); - assert_eq!(batch.num_columns(), 10); - let mut result = test::format_batch(&batch); - result.sort_unstable(); - assert_eq!( - result, - vec![ - "a,5,3,3,3,3,3,3,3,3", - "b,5,4,4,4,4,4,4,4,4", - "c,1,1,1,1,1,1,1,1,1", - ], - ); + let expected = vec![ + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| c_group | COUNT(c_uint64) | COUNT(DISTINCT c_int8) | COUNT(DISTINCT c_int16) | COUNT(DISTINCT c_int32) | COUNT(DISTINCT c_int64) | COUNT(DISTINCT c_uint8) | COUNT(DISTINCT c_uint16) | COUNT(DISTINCT c_uint32) | COUNT(DISTINCT c_uint64) |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", + "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", + "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", + "| c | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", + "+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+", +]; + assert_batches_sorted_eq!(expected, &results); Ok(()) } @@ -1404,8 +1520,14 @@ mod tests { .await?; assert_eq!(results.len(), 1); - assert_eq!(results[0].num_rows(), 1); - assert_eq!(test::format_batch(&results[0]), vec!["10,110,20"]); + let expected = vec![ + "+---------+---------+-----------------+", + "| SUM(c1) | SUM(c2) | COUNT(UInt8(1)) |", + "+---------+---------+-----------------+", + "| 10 | 110 | 20 |", + "+---------+---------+-----------------+", + ]; + assert_batches_eq!(expected, &results); Ok(()) } @@ -1498,11 +1620,19 @@ mod tests { let plan = ctx.create_physical_plan(&plan)?; let result = collect(plan).await?; - let batch = &result[0]; - assert_eq!(3, batch.num_columns()); - assert_eq!(4, batch.num_rows()); - assert_eq!(field_names(batch), vec!["a", "b", "my_add(a,b)"]); + let expected = vec![ + "+-----+-----+-------------+", + "| a | b | my_add(a,b) |", + "+-----+-----+-------------+", + "| 1 | 2 | 3 |", + "| 10 | 12 | 22 |", + "| 10 | 12 | 22 |", + "| 100 | 120 | 220 |", + "+-----+-----+-------------+", + ]; + assert_batches_eq!(expected, &result); + let batch = &result[0]; let a = batch .column(0) .as_any() @@ -1598,18 +1728,15 @@ mod tests { let result = plan_and_collect(&mut ctx, "SELECT MY_AVG(a) FROM t").await?; - let batch = &result[0]; - assert_eq!(1, batch.num_columns()); - assert_eq!(1, batch.num_rows()); + let expected = vec![ + "+-----------+", + "| MY_AVG(a) |", + "+-----------+", + "| 3 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); - let values = batch - .column(0) - .as_any() - .downcast_ref::() - .expect("failed to cast version"); - assert_eq!(values.len(), 1); - // avg(1,2,3,4,5) = 3.0 - assert_eq!(values.value(0), 3.0_f64); Ok(()) } @@ -1662,15 +1789,6 @@ mod tests { collect(physical_plan).await } - fn field_names(result: &RecordBatch) -> Vec { - result - .schema() - .fields() - .iter() - .map(|x| x.name().clone()) - .collect::>() - } - /// Execute SQL and return results async fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new()?; diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 022b5ecea1033..ee23e31d0ba02 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -965,8 +965,8 @@ mod tests { use arrow::array::Float64Array; use super::*; - use crate::physical_plan::common; use crate::physical_plan::expressions::{col, Avg}; + use crate::{assert_batches_sorted_eq, physical_plan::common}; use crate::physical_plan::merge::MergeExec; @@ -1022,9 +1022,16 @@ mod tests { let result = common::collect(partial_aggregate.execute(0).await?).await?; - let mut rows = crate::test::format_batch(&result[0]); - rows.sort(); - assert_eq!(rows, vec!["2,2,2.0", "3,3,7.0", "4,3,11.0"]); + let expected = vec![ + "+---+---------------+-------------+", + "| a | AVG(b)[count] | AVG(b)[sum] |", + "+---+---------------+-------------+", + "| 2 | 2 | 2 |", + "| 3 | 3 | 7 |", + "| 4 | 3 | 11 |", + "+---+---------------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &result); let merge = Arc::new(MergeExec::new(partial_aggregate)); @@ -1049,17 +1056,17 @@ mod tests { assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); - let mut rows = crate::test::format_batch(&batch); - rows.sort(); - - assert_eq!( - rows, - vec![ - "2,1.0", - "3,2.3333333333333335", // 3, (2 + 3 + 2) / 3 - "4,3.6666666666666665" // 4, (3 + 4 + 4) / 3 - ] - ); + let expected = vec![ + "+---+--------------------+", + "| a | AVG(b) |", + "+---+--------------------+", + "| 2 | 1 |", + "| 3 | 2.3333333333333335 |", // 3, (2 + 3 + 2) / 3 + "| 4 | 3.6666666666666665 |", // 4, (3 + 4 + 4) / 3 + "+---+--------------------+", + ]; + + assert_batches_sorted_eq!(&expected, &result); Ok(()) } diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 874b9b2ad6840..08ae684f61ff7 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -838,12 +838,12 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { use crate::{ + assert_batches_sorted_eq, physical_plan::{common, memory::MemoryExec}, - test::{build_table_i32, columns, format_batch}, + test::{build_table_i32, columns}, }; use super::*; - use std::collections::HashSet; use std::sync::Arc; fn build_table( @@ -869,19 +869,6 @@ mod tests { HashJoinExec::try_new(left, right, &on, join_type) } - /// Asserts that the rows are the same, taking into account that their order - /// is irrelevant - fn assert_same_rows(result: &[String], expected: &[&str]) { - // convert to set since row order is irrelevant - let result = result.iter().cloned().collect::>(); - - let expected = expected - .iter() - .map(|s| s.to_string()) - .collect::>(); - assert_eq!(result, expected); - } - #[tokio::test] async fn join_inner_one() -> Result<()> { let left = build_table( @@ -904,10 +891,16 @@ mod tests { let stream = join.execute(0).await?; let batches = common::collect(stream).await?; - let result = format_batch(&batches[0]); - let expected = vec!["2,5,8,20,80", "3,5,9,20,80", "1,4,7,10,70"]; - - assert_same_rows(&result, &expected); + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 5 | 9 | 20 | 80 |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -934,10 +927,17 @@ mod tests { let stream = join.execute(0).await?; let batches = common::collect(stream).await?; - let result = format_batch(&batches[0]); - let expected = vec!["2,5,8,20,5,80", "3,5,9,20,5,80", "1,4,7,10,4,70"]; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; - assert_same_rows(&result, &expected); + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -965,10 +965,17 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let result = format_batch(&batches[0]); - let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; + let expected = vec![ + "+----+----+----+----+", + "| a1 | b2 | c1 | c2 |", + "+----+----+----+----+", + "| 1 | 1 | 7 | 70 |", + "| 2 | 2 | 8 | 80 |", + "| 2 | 2 | 9 | 80 |", + "+----+----+----+----+", + ]; - assert_same_rows(&result, &expected); + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -1004,10 +1011,17 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let result = format_batch(&batches[0]); - let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; + let expected = vec![ + "+----+----+----+----+", + "| a1 | b2 | c1 | c2 |", + "+----+----+----+----+", + "| 1 | 1 | 7 | 70 |", + "| 2 | 2 | 8 | 80 |", + "| 2 | 2 | 9 | 80 |", + "+----+----+----+----+", + ]; - assert_same_rows(&result, &expected); + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -1045,18 +1059,29 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let result = format_batch(&batches[0]); - let expected = vec!["1,4,7,10,70"]; - assert_same_rows(&result, &expected); + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); // second part let stream = join.execute(1).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); - let result = format_batch(&batches[0]); - let expected = vec!["2,5,8,30,90", "3,5,9,30,90"]; + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 90 |", + "| 3 | 5 | 9 | 30 | 90 |", + "+----+----+----+----+----+", + ]; - assert_same_rows(&result, &expected); + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -1083,10 +1108,16 @@ mod tests { let stream = join.execute(0).await?; let batches = common::collect(stream).await?; - let result = format_batch(&batches[0]); - let expected = vec!["1,4,7,10,70", "2,5,8,20,80", "3,7,9,NULL,NULL"]; - - assert_same_rows(&result, &expected); + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 7 | 9 | | |", + "+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); Ok(()) } @@ -1113,10 +1144,17 @@ mod tests { let stream = join.execute(0).await?; let batches = common::collect(stream).await?; - let result = format_batch(&batches[0]); - let expected = vec!["1,7,10,4,70", "2,8,20,5,80", "NULL,NULL,30,6,90"]; - - assert_same_rows(&result, &expected); + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+", + "| | | 30 | 6 | 90 |", + "| 1 | 7 | 10 | 4 | 70 |", + "| 2 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); Ok(()) } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index e589834be5cb9..7628e9f57e75a 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -21,7 +21,7 @@ use crate::datasource::{MemTable, TableProvider}; use crate::error::Result; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use arrow::array::{self, Int32Array}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use std::fs::File; use std::io::prelude::*; @@ -106,137 +106,7 @@ pub fn aggr_test_schema() -> SchemaRef { ])) } -/// Format a batch as csv -pub fn format_batch(batch: &RecordBatch) -> Vec { - let mut rows = vec![]; - for row_index in 0..batch.num_rows() { - let mut s = String::new(); - for column_index in 0..batch.num_columns() { - if column_index > 0 { - s.push(','); - } - let array = batch.column(column_index); - - if array.is_null(row_index) { - s.push_str("NULL"); - continue; - } - - match array.data_type() { - DataType::Utf8 => s.push_str( - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index), - ), - DataType::Int8 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Int16 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Int32 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Int64 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::UInt8 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::UInt16 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::UInt32 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::UInt64 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Float32 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Float64 => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value(row_index) - )), - DataType::Timestamp(TimeUnit::Microsecond, _) => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap() - )), - DataType::Timestamp(TimeUnit::Nanosecond, _) => s.push_str(&format!( - "{:?}", - array - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap() - )), - _ => s.push('?'), - } - } - rows.push(s); - } - rows -} - -/// all tests share a common table +/// some tests share a common table pub fn test_table_scan() -> Result { let schema = Schema::new(vec![ Field::new("a", DataType::UInt32, false), @@ -287,49 +157,71 @@ pub fn columns(schema: &Schema) -> Vec { pub mod user_defined; pub mod variable; -mod tests { - use super::*; - - use arrow::array::{BooleanArray, Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - - #[test] - fn test_format_batch() -> Result<()> { - let array_int32 = Int32Array::from(vec![1000, 2000]); - let array_string = StringArray::from(vec!["bow \u{1F3F9}", "arrow \u{2191}"]); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ]); - - let record_batch = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(array_int32), Arc::new(array_string)], - )?; - - let result = format_batch(&record_batch); - - assert_eq!(result, vec!["1000,bow \u{1F3F9}", "2000,arrow \u{2191}"]); - - Ok(()) - } - - #[test] - fn test_format_batch_unknown() -> Result<()> { - // Use any Array type not yet handled by format_batch(). - let array_bool = BooleanArray::from(vec![false, true]); +/// Compares formatted output of a record batch with an expected +/// vector of strings, with the result of pretty formatting record +/// batches. This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// +/// `assert_batch_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +#[macro_export] +macro_rules! assert_batches_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let expected_lines: Vec = + $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); + + let formatted = arrow::util::pretty::pretty_format_batches($CHUNKS).unwrap(); + + let actual_lines: Vec<&str> = formatted.trim().lines().collect(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); +/// Compares formatted output of a record batch with an expected +/// vector of strings in a way that order does not matter. +/// This is a macro so errors appear on the correct line +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +/// +/// Expects to be called about like this: +/// +/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +#[macro_export] +macro_rules! assert_batches_sorted_eq { + ($EXPECTED_LINES: expr, $CHUNKS: expr) => { + let mut expected_lines: Vec = + $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); + + // sort except for header + footer + let num_lines = expected_lines.len(); + if num_lines > 3 { + expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() + } - let record_batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array_bool)])?; + let formatted = arrow::util::pretty::pretty_format_batches($CHUNKS).unwrap(); + // fix for windows: \r\n --> - let result = format_batch(&record_batch); + let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!(result, vec!["?", "?"]); + // sort except for header + footer + let num_lines = actual_lines.len(); + if num_lines > 3 { + actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() + } - Ok(()) - } + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; }