diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 3a2f9a7b73c8..192201541d3a 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -674,9 +674,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.2" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" [[package]] name = "base64-simd" @@ -1948,9 +1948,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "libmimalloc-sys" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4ac0e912c8ef1b735e92369695618dc5b1819f5a7bf3f167301a3ba1cea515e" +checksum = "25d058a81af0d1c22d7a1c948576bee6d673f7af3c0f35564abd6c81122f513d" dependencies = [ "cc", "libc", @@ -2026,9 +2026,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mimalloc" -version = "0.1.37" +version = "0.1.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2894987a3459f3ffb755608bd82188f8ed00d0ae077f1edea29c068d639d98" +checksum = "972e5f23f6716f62665760b0f4cbf592576a80c7b879ba9beaafc0e558894127" dependencies = [ "libmimalloc-sys", ] @@ -2070,14 +2070,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +checksum = "abbbc55ad7b13aac85f9401c796dcda1b864e07fcad40ad47792eaa8932ea502" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "cfg-if", "libc", - "static_assertions", ] [[package]] @@ -2393,9 +2392,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cc1b0bf1727a77a54b6654e7b5f1af8604923edc8b81885f8ec92f9e3f0a05" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -2575,9 +2574,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" dependencies = [ "aho-corasick", "memchr", @@ -2587,9 +2586,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" dependencies = [ "aho-corasick", "memchr", @@ -2598,15 +2597,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" -version = "0.11.19" +version = "0.11.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20b9b67e2ca7dd9e9f9285b759de30ff538aab981abaaf7bc9bd90b84a0126c3" +checksum = "3e9ad3fe7488d7e34558a2033d45a0c90b72d97b4f80705666fea71472e2e6a1" dependencies = [ "base64", "bytes", @@ -2701,9 +2700,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.8" +version = "0.38.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" +checksum = "9bfe0f2582b4931a45d1fa608f8a8722e8b3c7ac54dd6d5f3b3212791fedef49" dependencies = [ "bitflags 2.4.0", "errno", @@ -2873,18 +2872,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be9b6f69f1dfd54c3b568ffa45c310d6973a5e5148fd40cf515acaf38cf5bc31" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc59dfdcbad1437773485e0367fea4b090a2e0a16d9ffc46af47764536a298ec" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", @@ -3175,9 +3174,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb39ee79a6d8de55f48f2293a830e040392f1c5f16e336bdd1788cd0aadce07" +checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" dependencies = [ "deranged", "serde", @@ -3193,9 +3192,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "733d258752e9303d392b94b75230d07b0b9c489350c69b851fc6c065fde3e8f9" +checksum = "1a942f44339478ef67935ab2bbaec2fb0322496cf3cbe84b261e06ac3814c572" dependencies = [ "time-core", ] @@ -3423,9 +3422,9 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 5a32349c65e9..70e25ed8c7fb 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -380,6 +380,10 @@ config_namespace! { /// repartitioning to increase parallelism to leverage more CPU cores pub enable_round_robin_repartition: bool, default = true + /// When set to true, the optimizer will attempt to perform limit operations + /// during aggregations, if possible + pub enable_topk_aggregation: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index bfa00b2a9fb3..80d8e5cab197 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -83,7 +83,7 @@ parking_lot = "0.12" parquet = { workspace = true } percent-encoding = "2.2.0" pin-project-lite = "^0.2.7" -rand = "0.8" +rand = { version = "0.8", features = ["small_rng"] } sqlparser = { workspace = true } tempfile = "3" tokio = { version = "1.28", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } @@ -105,6 +105,8 @@ env_logger = "0.10" half = "2.2.1" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } +ptree = "0.4.0" +rand_distr = "0.4.3" regex = "1.5.4" rstest = "0.18.0" rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } @@ -157,3 +159,7 @@ name = "sql_query_with_io" [[bench]] harness = false name = "sort" + +[[bench]] +harness = false +name = "topk_aggregate" diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs new file mode 100644 index 000000000000..f50a8ec047da --- /dev/null +++ b/datafusion/core/benches/topk_aggregate.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow::{datatypes::Schema, record_batch::RecordBatch}; +use arrow_array::builder::{Int64Builder, StringBuilder}; +use arrow_schema::{DataType, Field, SchemaRef}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion::{datasource::MemTable, error::Result}; +use datafusion_common::DataFusionError; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::TaskContext; +use rand_distr::Distribution; +use rand_distr::{Normal, Pareto}; +use std::sync::Arc; +use tokio::runtime::Runtime; + +async fn create_context( + limit: usize, + partition_cnt: i32, + sample_cnt: i32, + asc: bool, + use_topk: bool, +) -> Result<(Arc, Arc)> { + let (schema, parts) = make_data(partition_cnt, sample_cnt, asc).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + let sql = format!("select trace_id, max(timestamp_ms) from traces group by trace_id order by max(timestamp_ms) desc limit {limit};"); + let df = ctx.sql(sql.as_str()).await?; + let physical_plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + + Ok((physical_plan, ctx.task_ctx())) +} + +fn run(plan: Arc, ctx: Arc, asc: bool) { + let rt = Runtime::new().unwrap(); + criterion::black_box( + rt.block_on(async { aggregate(plan.clone(), ctx.clone(), asc).await }), + ) + .unwrap(); +} + +async fn aggregate( + plan: Arc, + ctx: Arc, + asc: bool, +) -> Result<()> { + let batches = collect(plan, ctx).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + let actual = format!("{}", pretty_format_batches(&batches)?); + let expected_asc = r#" ++----------------------------------+--------------------------+ +| trace_id | MAX(traces.timestamp_ms) | ++----------------------------------+--------------------------+ +| 5868861a23ed31355efc5200eb80fe74 | 16909009999999 | +| 4040e64656804c3d77320d7a0e7eb1f0 | 16909009999998 | +| 02801bbe533190a9f8713d75222f445d | 16909009999997 | +| 9e31b3b5a620de32b68fefa5aeea57f1 | 16909009999996 | +| 2d88a860e9bd1cfaa632d8e7caeaa934 | 16909009999995 | +| a47edcef8364ab6f191dd9103e51c171 | 16909009999994 | +| 36a3fa2ccfbf8e00337f0b1254384db6 | 16909009999993 | +| 0756be84f57369012e10de18b57d8a2f | 16909009999992 | +| d4d6bf9845fa5897710e3a8db81d5907 | 16909009999991 | +| 3c2cc1abe728a66b61e14880b53482a0 | 16909009999990 | ++----------------------------------+--------------------------+ + "# + .trim(); + if asc { + assert_eq!(actual.trim(), expected_asc); + } + + Ok(()) +} + +fn make_data( + partition_cnt: i32, + sample_cnt: i32, + asc: bool, +) -> Result<(Arc, Vec>), DataFusionError> { + use rand::Rng; + use rand::SeedableRng; + + // constants observed from trace data + let simultaneous_group_cnt = 2000; + let fitted_shape = 12f64; + let fitted_scale = 5f64; + let mean = 0.1; + let stddev = 1.1; + let pareto = Pareto::new(fitted_scale, fitted_shape).unwrap(); + let normal = Normal::new(mean, stddev).unwrap(); + let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); + + // populate data + let schema = test_schema(); + let mut partitions = vec![]; + let mut cur_time = 16909000000000i64; + for _ in 0..partition_cnt { + let mut id_builder = StringBuilder::new(); + let mut ts_builder = Int64Builder::new(); + let gen_id = |rng: &mut rand::rngs::SmallRng| { + rng.gen::<[u8; 16]>() + .iter() + .map(|b| format!("{:02x}", b)) + .collect::() + }; + let gen_sample_cnt = + |mut rng: &mut rand::rngs::SmallRng| pareto.sample(&mut rng).ceil() as u32; + let mut group_ids = (0..simultaneous_group_cnt) + .map(|_| gen_id(&mut rng)) + .collect::>(); + let mut group_sample_cnts = (0..simultaneous_group_cnt) + .map(|_| gen_sample_cnt(&mut rng)) + .collect::>(); + for _ in 0..sample_cnt { + let random_index = rng.gen_range(0..simultaneous_group_cnt); + let trace_id = &mut group_ids[random_index]; + let sample_cnt = &mut group_sample_cnts[random_index]; + *sample_cnt -= 1; + if *sample_cnt == 0 { + *trace_id = gen_id(&mut rng); + *sample_cnt = gen_sample_cnt(&mut rng); + } + + id_builder.append_value(trace_id); + ts_builder.append_value(cur_time); + + if asc { + cur_time += 1; + } else { + let samp: f64 = normal.sample(&mut rng); + let samp = samp.round(); + cur_time += samp as i64; + } + } + + // convert to MemTable + let id_col = Arc::new(id_builder.finish()); + let ts_col = Arc::new(ts_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col, ts_col])?; + partitions.push(vec![batch]); + } + Ok((schema, partitions)) +} + +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) +} + +fn criterion_benchmark(c: &mut Criterion) { + let limit = 10; + let partitions = 10; + let samples = 1_000_000; + + let rt = Runtime::new().unwrap(); + let topk_real = rt.block_on(async { + create_context(limit, partitions, samples, false, true) + .await + .unwrap() + }); + let topk_asc = rt.block_on(async { + create_context(limit, partitions, samples, true, true) + .await + .unwrap() + }); + let real = rt.block_on(async { + create_context(limit, partitions, samples, false, false) + .await + .unwrap() + }); + let asc = rt.block_on(async { + create_context(limit, partitions, samples, true, false) + .await + .unwrap() + }); + + c.bench_function( + format!("aggregate {} time-series rows", partitions * samples).as_str(), + |b| b.iter(|| run(real.0.clone(), real.1.clone(), false)), + ); + + c.bench_function( + format!("aggregate {} worst-case rows", partitions * samples).as_str(), + |b| b.iter(|| run(asc.0.clone(), asc.1.clone(), true)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} time-series rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_real.0.clone(), topk_real.1.clone(), false)), + ); + + c.bench_function( + format!( + "top k={limit} aggregate {} worst-case rows", + partitions * samples + ) + .as_str(), + |b| b.iter(|| run(topk_asc.0.clone(), topk_asc.1.clone(), true)), + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index f74d4ea0c9a6..ad950e26bbfa 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -33,6 +33,7 @@ pub mod repartition; pub mod replace_with_order_preserving_variants; pub mod sort_enforcement; mod sort_pushdown; +pub mod topk_aggregation; mod utils; #[cfg(test)] diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 3f6698c6cf46..d629eb0c8e0a 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -28,6 +28,7 @@ use crate::physical_optimizer::join_selection::JoinSelection; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::repartition::Repartition; use crate::physical_optimizer::sort_enforcement::EnforceSorting; +use crate::physical_optimizer::topk_aggregation::TopKAggregation; use crate::{error::Result, physical_plan::ExecutionPlan}; /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which @@ -101,6 +102,11 @@ impl PhysicalOptimizer { // diagnostic error message when this happens. It makes no changes to the // given query plan; i.e. it only acts as a final gatekeeping rule. Arc::new(PipelineChecker::new()), + // The aggregation limiter will try to find situations where the accumulator count + // is not tied to the cardinality, i.e. when the output of the aggregation is passed + // into an `order by max(x) limit y`. In this case it will copy the limit value down + // to the aggregation, allowing it to use only y number of accumulators. + Arc::new(TopKAggregation::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs new file mode 100644 index 000000000000..f862675bf205 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that detects aggregate operations that could use a limited bucket count + +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::aggregates::AggregateExec; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::ExecutionPlan; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalSortExpr; +use std::sync::Arc; + +/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed +pub struct TopKAggregation {} + +impl TopKAggregation { + /// Create a new `LimitAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + order: &PhysicalSortExpr, + limit: usize, + ) -> Option> { + // ensure the sort direction matches aggregate function + let (field, desc) = aggr.get_minmax_desc()?; + if desc != order.options.descending { + return None; + } + let group_key = aggr.group_expr().expr().first()?; + let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; + if !kt.is_primitive() && kt != DataType::Utf8 { + return None; + } + if aggr.filter_expr.iter().any(|e| e.is_some()) { + return None; + } + + // ensure the sort is on the same field as the aggregate output + let col = order.expr.as_any().downcast_ref::()?; + if col.name() != field.name() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let mut new_aggr = AggregateExec::try_new( + aggr.mode, + aggr.group_by.clone(), + aggr.aggr_expr.clone(), + aggr.filter_expr.clone(), + aggr.order_by_expr.clone(), + aggr.input.clone(), + aggr.input_schema.clone(), + ) + .expect("Unable to copy Aggregate!"); + new_aggr.limit = Some(limit); + Some(Arc::new(new_aggr)) + } + + fn transform_sort(plan: Arc) -> Option> { + let sort = plan.as_any().downcast_ref::()?; + + let children = sort.children(); + let child = children.first()?; + let order = sort.output_ordering()?; + let order = order.first()?; + let limit = sort.fetch()?; + + let is_cardinality_preserving = |plan: Arc| { + plan.as_any() + .downcast_ref::() + .is_some() + || plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + }; + + let mut cardinality_preserved = true; + let mut closure = |plan: Arc| { + if !cardinality_preserved { + return Ok(Transformed::No(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + // either we run into an Aggregate and transform it + match Self::transform_agg(aggr, order, limit) { + None => cardinality_preserved = false, + Some(plan) => return Ok(Transformed::Yes(plan)), + } + } else { + // or we continue down whitelisted nodes of other types + if !is_cardinality_preserving(plan.clone()) { + cardinality_preserved = false; + } + } + Ok(Transformed::No(plan)) + }; + let child = transform_down_mut(child.clone(), &mut closure).ok()?; + let sort = SortExec::new(sort.expr().to_vec(), child) + .with_fetch(sort.fetch()) + .with_preserve_partitioning(sort.preserve_partitioning()); + Some(Arc::new(sort)) + } +} + +fn transform_down_mut( + me: Arc, + op: &mut F, +) -> Result> +where + F: FnMut(Arc) -> Result>>, +{ + let after_op = op(me)?.into(); + after_op.map_children(|node| transform_down_mut(node, op)) +} + +impl Default for TopKAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for TopKAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + let plan = if config.optimizer.enable_topk_aggregation { + plan.transform_down(&|plan| { + Ok( + if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { + Transformed::Yes(plan) + } else { + Transformed::No(plan) + }, + ) + })? + } else { + plan + }; + Ok(plan) + } + + fn name(&self) -> &str { + "LimitAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +// see `aggregate.slt` for tests diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs index 46f372b6ad28..f10f83dfe3c8 100644 --- a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs @@ -20,7 +20,7 @@ use arrow_schema::SchemaRef; use datafusion_common::Result; use datafusion_physical_expr::EmitTo; -mod primitive; +pub(crate) mod primitive; use primitive::GroupValuesPrimitive; mod row; diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs index 7b8691c67fdd..d7989fb8c4c5 100644 --- a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs +++ b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs @@ -31,20 +31,20 @@ use hashbrown::raw::RawTable; use std::sync::Arc; /// A trait to allow hashing of floating point numbers -trait HashValue { - fn hash(self, state: &RandomState) -> u64; +pub(crate) trait HashValue { + fn hash(&self, state: &RandomState) -> u64; } macro_rules! hash_integer { ($($t:ty),+) => { $(impl HashValue for $t { #[cfg(not(feature = "force_hash_collisions"))] - fn hash(self, state: &RandomState) -> u64 { + fn hash(&self, state: &RandomState) -> u64 { state.hash_one(self) } #[cfg(feature = "force_hash_collisions")] - fn hash(self, _state: &RandomState) -> u64 { + fn hash(&self, _state: &RandomState) -> u64 { 0 } })+ @@ -57,12 +57,12 @@ macro_rules! hash_float { ($($t:ty),+) => { $(impl HashValue for $t { #[cfg(not(feature = "force_hash_collisions"))] - fn hash(self, state: &RandomState) -> u64 { + fn hash(&self, state: &RandomState) -> u64 { state.hash_one(self.to_bits()) } #[cfg(feature = "force_hash_collisions")] - fn hash(self, _state: &RandomState) -> u64 { + fn hash(&self, _state: &RandomState) -> u64 { 0 } })+ diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 78ef5e37b239..14350ce1bba7 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -50,10 +50,14 @@ mod group_values; mod no_grouping; mod order; mod row_hash; +mod topk; +mod topk_stream; +use crate::physical_plan::aggregates::topk_stream::GroupedTopKAggregateStream; pub use datafusion_expr::AggregateFunction; use datafusion_physical_expr::aggregate::is_order_sensitive; pub use datafusion_physical_expr::expressions::create_aggregate_expr; +use datafusion_physical_expr::expressions::{Max, Min}; use datafusion_physical_expr::utils::{ get_finer_ordering, ordering_satisfy_requirement_concrete, }; @@ -228,14 +232,16 @@ impl PartialEq for PhysicalGroupBy { enum StreamType { AggregateStream(AggregateStream), - GroupedHashAggregateStream(GroupedHashAggregateStream), + GroupedHash(GroupedHashAggregateStream), + GroupedPriorityQueue(GroupedTopKAggregateStream), } impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), - StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), + StreamType::GroupedHash(stream) => Box::pin(stream), + StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } } @@ -265,6 +271,8 @@ pub struct AggregateExec { pub(crate) filter_expr: Vec>>, /// (ORDER BY clause) expression for each aggregate expression pub(crate) order_by_expr: Vec>, + /// Set if the output of this aggregation is truncated by a upstream sort/limit clause + pub(crate) limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub(crate) input: Arc, /// Schema after the aggregate is applied @@ -669,6 +677,7 @@ impl AggregateExec { metrics: ExecutionPlanMetricsSet::new(), aggregation_ordering, required_input_ordering, + limit: None, }) } @@ -717,14 +726,35 @@ impl AggregateExec { partition: usize, context: Arc, ) -> Result { + // no group by at all if self.group_by.expr.is_empty() { - Ok(StreamType::AggregateStream(AggregateStream::new( + return Ok(StreamType::AggregateStream(AggregateStream::new( self, context, partition, - )?)) + )?)); + } + + // grouping by an expression that has a sort/limit upstream + if let Some(limit) = self.limit { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, limit)?, + )); + } + + // grouping by something else and we need to just materialize all results + Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( + self, context, partition, + )?)) + } + + /// Finds the DataType and SortDirection for this Aggregate, if there is one + pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + let agg_expr = self.aggr_expr.as_slice().first()?; + if let Some(max) = agg_expr.as_any().downcast_ref::() { + Some((max.field().ok()?, true)) + } else if let Some(min) = agg_expr.as_any().downcast_ref::() { + Some((min.field().ok()?, false)) } else { - Ok(StreamType::GroupedHashAggregateStream( - GroupedHashAggregateStream::new(self, context, partition)?, - )) + None } } } @@ -793,6 +823,9 @@ impl DisplayAs for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; + if let Some(limit) = self.limit { + write!(f, ", lim=[{limit}]")?; + } if let Some(aggregation_ordering) = &self.aggregation_ordering { write!(f, ", ordering_mode={:?}", aggregation_ordering.mode)?; @@ -900,7 +933,7 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(AggregateExec::try_new( + let mut me = AggregateExec::try_new( self.mode, self.group_by.clone(), self.aggr_expr.clone(), @@ -908,7 +941,9 @@ impl ExecutionPlan for AggregateExec { self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), - )?)) + )?; + me.limit = self.limit; + Ok(Arc::new(me)) } fn execute( @@ -1115,7 +1150,7 @@ fn evaluate( } /// Evaluates expressions against a record batch. -fn evaluate_many( +pub(crate) fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { @@ -1138,7 +1173,17 @@ fn evaluate_optional( .collect::>>() } -fn evaluate_group_by( +/// Evaluate a group by expression against a `RecordBatch` +/// +/// Arguments: +/// `group_by`: the expression to evaluate +/// `batch`: the `RecordBatch` to evaluate against +/// +/// Returns: A Vec of Vecs of Array of results +/// The outer Vect appears to be for grouping sets +/// The inner Vect contains the results per expression +/// The inner-inner Array contains the results per row +pub(crate) fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { @@ -1798,10 +1843,10 @@ mod tests { assert!(matches!(stream, StreamType::AggregateStream(_))); } 1 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } 2 => { - assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + assert!(matches!(stream, StreamType::GroupedHash(_))); } _ => panic!("Unknown version: {version}"), } diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 4613a2e46443..d034bd669e55 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -56,7 +56,7 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; -/// Hash based Grouping Aggregator +/// HashTable based Grouping Aggregator /// /// # Design Goals /// @@ -145,7 +145,7 @@ pub(crate) struct GroupedHashAggregateStream { /// accumulator. If present, only those rows for which the filter /// evaluate to true should be included in the aggregate results. /// - /// For example, for an aggregate like `SUM(x FILTER x > 100)`, + /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, /// the filter expression is `x > 100`. filter_expressions: Vec>>, @@ -266,7 +266,7 @@ impl GroupedHashAggregateStream { /// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. -fn create_group_accumulator( +pub(crate) fn create_group_accumulator( agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { diff --git a/datafusion/core/src/physical_plan/aggregates/topk/hash_table.rs b/datafusion/core/src/physical_plan/aggregates/topk/hash_table.rs new file mode 100644 index 000000000000..ad6a1d1db79b --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/topk/hash_table.rs @@ -0,0 +1,434 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index + +use crate::physical_plan::aggregates::group_values::primitive::HashValue; +use crate::physical_plan::aggregates::topk::heap::Comparable; +use ahash::RandomState; +use arrow::datatypes::i256; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{ + downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray, +}; +use arrow_schema::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use half::f16; +use hashbrown::raw::RawTable; +use std::fmt::Debug; +use std::sync::Arc; + +/// A "type alias" for Keys which are stored in our map +pub trait KeyType: Clone + Comparable + Debug {} + +impl KeyType for T where T: Clone + Comparable + Debug {} + +/// An entry in our hash table that: +/// 1. memoizes the hash +/// 2. contains the key (ID) +/// 3. contains the value (heap_idx - an index into the corresponding heap) +pub struct HashTableItem { + hash: u64, + pub id: ID, + pub heap_idx: usize, +} + +/// A custom wrapper around `hashbrown::RawTable` that: +/// 1. limits the number of entries to the top K +/// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing +/// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash +/// 4. Catches resize events to allow the corresponding heap to update it's indexes +struct TopKHashTable { + map: RawTable>, + limit: usize, +} + +/// An interface to hide the generic type signature of TopKHashTable behind arrow arrays +pub trait ArrowHashTable { + fn set_batch(&mut self, ids: ArrayRef); + fn len(&self) -> usize; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide valid indexes + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide a valid index + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; + fn drain(&mut self) -> (ArrayRef, Vec); + + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the caller must provide valid indexes + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + map: &mut Vec<(usize, usize)>, + ) -> (usize, bool); +} + +// An implementation of ArrowHashTable for String keys +pub struct StringHashTable { + owned: ArrayRef, + map: TopKHashTable>, + rnd: RandomState, +} + +// An implementation of ArrowHashTable for any `ArrowPrimitiveType` key +struct PrimitiveHashTable +where + Option<::Native>: Comparable, +{ + owned: ArrayRef, + map: TopKHashTable>, + rnd: RandomState, +} + +impl StringHashTable { + pub fn new(limit: usize) -> Self { + let vals: Vec<&str> = Vec::new(); + let owned = Arc::new(StringArray::from(vals)); + Self { + owned, + map: TopKHashTable::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl ArrowHashTable for StringHashTable { + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec) { + let (ids, heap_idxs) = self.map.drain(); + let ids = Arc::new(StringArray::from(ids)); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("StringArray required"); + let id = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash = self.rnd.hash_one(id); + if let Some(map_idx) = self + .map + .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) + { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let id = id.map(|id| id.to_string()); + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +impl PrimitiveHashTable +where + Option<::Native>: Comparable, + Option<::Native>: HashValue, +{ + pub fn new(limit: usize) -> Self { + let owned = Arc::new(PrimitiveArray::::builder(0).finish()); + Self { + owned, + map: TopKHashTable::new(limit, limit * 10), + rnd: ahash::RandomState::default(), + } + } +} + +impl ArrowHashTable for PrimitiveHashTable +where + Option<::Native>: Comparable, + Option<::Native>: HashValue, +{ + fn set_batch(&mut self, ids: ArrayRef) { + self.owned = ids; + } + + fn len(&self) -> usize { + self.map.len() + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + self.map.update_heap_idx(mapper); + } + + unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + self.map.heap_idx_at(map_idx) + } + + fn drain(&mut self) -> (ArrayRef, Vec) { + let (ids, heap_idxs) = self.map.drain(); + let mut builder: PrimitiveBuilder = PrimitiveArray::builder(ids.len()); + for id in ids.into_iter() { + match id { + None => builder.append_null(), + Some(id) => builder.append_value(id), + } + } + let ids = Arc::new(builder.finish()); + (ids, heap_idxs) + } + + unsafe fn find_or_insert( + &mut self, + row_idx: usize, + replace_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> (usize, bool) { + let ids = self.owned.as_primitive::(); + let id: Option = if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + }; + + let hash: u64 = id.hash(&self.rnd); + if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { + return (map_idx, false); + } + + // we're full and this is a better value, so remove the worst + let heap_idx = self.map.remove_if_full(replace_idx); + + // add the new group + let map_idx = self.map.insert(hash, id, heap_idx, mapper); + (map_idx, true) + } +} + +impl TopKHashTable { + pub fn new(limit: usize, capacity: usize) -> Self { + Self { + map: RawTable::with_capacity(capacity), + limit, + } + } + + pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option { + let bucket = self.map.find(hash, |mi| eq(&mi.id))?; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: getting the index of a bucket we just found + let idx = unsafe { self.map.bucket_index(&bucket) }; + Some(idx) + } + + pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + let bucket = unsafe { self.map.bucket(map_idx) }; + bucket.as_ref().heap_idx + } + + pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { + if self.map.len() >= self.limit { + self.map.erase(self.map.bucket(replace_idx)); + 0 // if full, always replace top node + } else { + self.map.len() // if we're not full, always append to end + } + } + + unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + for (m, h) in mapper { + self.map.bucket(*m).as_mut().heap_idx = *h + } + } + + pub fn insert( + &mut self, + hash: u64, + id: ID, + heap_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) -> usize { + let mi = HashTableItem::new(hash, id, heap_idx); + let bucket = self.map.try_insert_no_grow(hash, mi); + let bucket = match bucket { + Ok(bucket) => bucket, + Err(new_item) => { + let bucket = self.map.insert(hash, new_item, |mi| mi.hash); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: we're getting indexes of buckets, not dereferencing them + unsafe { + for bucket in self.map.iter() { + let heap_idx = bucket.as_ref().heap_idx; + let map_idx = self.map.bucket_index(&bucket); + mapper.push((heap_idx, map_idx)); + } + } + bucket + } + }; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: we're getting indexes of buckets, not dereferencing them + unsafe { self.map.bucket_index(&bucket) } + } + + pub fn len(&self) -> usize { + self.map.len() + } + + pub fn drain(&mut self) -> (Vec, Vec) { + self.map.drain().map(|mi| (mi.id, mi.heap_idx)).unzip() + } +} + +impl HashTableItem { + pub fn new(hash: u64, id: ID, heap_idx: usize) -> Self { + Self { hash, id, heap_idx } + } +} + +#[allow(dead_code)] +#[cfg(test)] +fn map_print(map: &RawTable>) { + use itertools::Itertools; + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: iterator is safe as long as we don't hold onto it past this stack frame + unsafe { + let mut indexes = vec![]; + for mi in map.iter() { + let mi = mi.as_ref(); + println!("id={:?} heap_idx={}", mi.id, mi.heap_idx); + indexes.push(mi.heap_idx); + } + let indexes: Vec<_> = indexes.iter().unique().collect(); + if indexes.len() != map.len() { + panic!("{} indexes and {} keys", indexes.len(), map.len()); + } + } +} + +impl HashValue for Option { + fn hash(&self, state: &RandomState) -> u64 { + state.hash_one(self) + } +} + +macro_rules! hash_float { + ($($t:ty),+) => { + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +macro_rules! has_integer { + ($($t:ty),+) => { + $(impl HashValue for Option<$t> { + fn hash(&self, state: &RandomState) -> u64 { + self.map(|me| me.hash(state)).unwrap_or(0) + } + })+ + }; +} + +has_integer!(i8, i16, i32, i64, i128, i256); +has_integer!(u8, u16, u32, u64); +hash_float!(f16, f32, f64); + +pub fn new_hash_table(limit: usize, kt: DataType) -> Result> { + macro_rules! downcast_helper { + ($kt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) + }; + } + + downcast_primitive! { + kt => (downcast_helper, kt), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + _ => {} + } + + Err(DataFusionError::Execution(format!( + "Can't create HashTable for type: {kt:?}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + + #[test] + fn should_resize_properly() -> Result<()> { + let mut map = TopKHashTable::>::new(5, 3); + for (idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() { + let mut mapper = vec![]; + map.insert(idx as u64, Some(id.to_string()), idx, &mut mapper); + if idx == 3 { + assert_eq!( + mapper, + vec![(0, 0), (1, 1), (2, 2), (3, 3)], + "Pass {idx} resized incorrectly!" + ); + } else { + assert_eq!(mapper, vec![], "Pass {idx} resized!"); + } + } + + let (ids, indexes) = map.drain(); + assert_eq!( + format!("{:?}", ids), + r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# + ); + assert_eq!(indexes, vec![0, 1, 2, 3, 4]); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/topk/heap.rs b/datafusion/core/src/physical_plan/aggregates/topk/heap.rs new file mode 100644 index 000000000000..5719947aed6a --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/topk/heap.rs @@ -0,0 +1,638 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A custom binary heap implementation for performant top K aggregation + +use arrow::datatypes::i256; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_physical_expr::aggregate::utils::adjust_output_array; +use half::f16; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +/// A custom version of `Ord` that only exists to we can implement it for the Values in our heap +pub trait Comparable { + fn comp(&self, other: &Self) -> Ordering; +} + +impl Comparable for Option { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } +} + +/// A "type alias" for Values which are stored in our heap +pub trait ValueType: Comparable + Clone + Debug {} + +impl ValueType for T where T: Comparable + Clone + Debug {} + +/// An entry in our heap, which contains both the value and a index into an external HashTable +struct HeapItem { + val: VAL, + map_idx: usize, +} + +/// A custom heap implementation that allows several things that couldn't be achieved with +/// `collections::BinaryHeap`: +/// 1. It allows values to be updated at arbitrary positions (when group values change) +/// 2. It can be either a min or max heap +/// 3. It can use our `HeapItem` type & `Comparable` trait +/// 4. It is specialized to grow to a certain limit, then always replace without grow & shrink +struct TopKHeap { + desc: bool, + len: usize, + capacity: usize, + heap: Vec>>, +} + +/// An interface to hide the generic type signature of TopKHeap behind arrow arrays +pub trait ArrowHeap { + fn set_batch(&mut self, vals: ArrayRef); + fn is_worse(&self, idx: usize) -> bool; + fn worst_map_idx(&self) -> usize; + fn renumber(&mut self, heap_to_map: &[(usize, usize)]); + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ); + fn take_all(&mut self, heap_idxs: Vec) -> ArrayRef; +} + +/// An implementation of `ArrowHeap` that deals with primitive values +pub struct PrimitiveHeap +where + ::Native: Comparable, +{ + batch: ArrayRef, + heap: TopKHeap, + desc: bool, + data_type: DataType, +} + +impl PrimitiveHeap +where + ::Native: Comparable, +{ + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let owned: ArrayRef = Arc::new(PrimitiveArray::::builder(0).finish()); + Self { + batch: owned, + heap: TopKHeap::new(limit, desc), + desc, + data_type, + } + } +} + +impl ArrowHeap for PrimitiveHeap +where + ::Native: Comparable, +{ + fn set_batch(&mut self, vals: ArrayRef) { + self.batch = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + (!self.desc && new_val > *worst_val) || (self.desc && new_val < *worst_val) + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + self.heap.renumber(heap_to_map); + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + self.heap.append_or_replace(new_val, map_idx, map); + } + + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let vals = self.batch.as_primitive::(); + let new_val = vals.value(row_idx); + self.heap.replace_if_better(heap_idx, new_val, map); + } + + fn take_all(&mut self, heap_idxs: Vec) -> ArrayRef { + let vals = self.heap.take_all(heap_idxs); + let vals = Arc::new(PrimitiveArray::::from_iter_values(vals)); + adjust_output_array(&self.data_type, vals).expect("Type is incorrect") + } +} + +impl TopKHeap { + pub fn new(limit: usize, desc: bool) -> Self { + Self { + desc, + capacity: limit, + len: 0, + heap: (0..=limit).map(|_| None).collect::>(), + } + } + + pub fn worst_val(&self) -> Option<&VAL> { + let root = self.heap.first()?; + let hi = match root { + None => return None, + Some(hi) => hi, + }; + Some(&hi.val) + } + + pub fn worst_map_idx(&self) -> usize { + self.heap[0].as_ref().map(|hi| hi.map_idx).unwrap_or(0) + } + + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.len + } + + #[allow(dead_code)] + pub fn is_full(&self) -> bool { + self.len >= self.capacity + } + + pub fn append_or_replace( + &mut self, + new_val: VAL, + map_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + if self.is_full() { + self.replace_root(new_val, map_idx, map); + } else { + self.append(new_val, map_idx, map); + } + } + + fn append(&mut self, new_val: VAL, map_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let hi = HeapItem::new(new_val, map_idx); + self.heap[self.len] = Some(hi); + self.heapify_up(self.len, mapper); + self.len += 1; + } + + pub fn take_all(&mut self, indexes: Vec) -> Vec { + let res = indexes + .into_iter() + .map(|i| { + let hi: HeapItem = self.heap[i].take().expect("No heap item"); + hi.val + }) + .collect(); + self.len = 0; + res + } + + fn replace_root( + &mut self, + new_val: VAL, + map_idx: usize, + mapper: &mut Vec<(usize, usize)>, + ) { + let hi = self.heap[0].as_mut().expect("No root"); + hi.val = new_val; + hi.map_idx = map_idx; + self.heapify_down(0, mapper); + } + + pub fn replace_if_better( + &mut self, + heap_idx: usize, + new_val: VAL, + mapper: &mut Vec<(usize, usize)>, + ) { + let existing = self.heap[heap_idx].as_mut().expect("Missing heap item"); + if (!self.desc && new_val.comp(&existing.val) != Ordering::Less) + || (self.desc && new_val.comp(&existing.val) != Ordering::Greater) + { + return; + } + existing.val = new_val; + self.heapify_down(heap_idx, mapper); + } + + pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { + for (heap_idx, map_idx) in heap_to_map.iter() { + if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { + hi.map_idx = *map_idx; + } + } + } + + fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { + let desc = self.desc; + while idx != 0 { + let parent_idx = (idx - 1) / 2; + let node = self.heap[idx].as_ref().expect("No heap item"); + let parent = self.heap[parent_idx].as_ref().expect("No heap item"); + if (!desc && node.val.comp(&parent.val) != Ordering::Greater) + || (desc && node.val.comp(&parent.val) != Ordering::Less) + { + return; + } + self.swap(idx, parent_idx, mapper); + idx = parent_idx; + } + } + + fn swap(&mut self, a_idx: usize, b_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let a_hi = self.heap[a_idx].take().expect("Missing heap entry"); + let b_hi = self.heap[b_idx].take().expect("Missing heap entry"); + + mapper.push((a_hi.map_idx, b_idx)); + mapper.push((b_hi.map_idx, a_idx)); + + self.heap[a_idx] = Some(b_hi); + self.heap[b_idx] = Some(a_hi); + } + + fn heapify_down(&mut self, node_idx: usize, mapper: &mut Vec<(usize, usize)>) { + let left_child = node_idx * 2 + 1; + let desc = self.desc; + let entry = self.heap.get(node_idx).expect("Missing node!"); + let entry = entry.as_ref().expect("Missing node!"); + let mut best_idx = node_idx; + let mut best_val = &entry.val; + for child_idx in left_child..=left_child + 1 { + if let Some(Some(child)) = self.heap.get(child_idx) { + if (!desc && child.val.comp(best_val) == Ordering::Greater) + || (desc && child.val.comp(best_val) == Ordering::Less) + { + best_val = &child.val; + best_idx = child_idx; + } + } + } + if best_val.comp(&entry.val) != Ordering::Equal { + self.swap(best_idx, node_idx, mapper); + self.heapify_down(best_idx, mapper); + } + } + + #[cfg(test)] + fn _tree_print(&self, idx: usize, builder: &mut ptree::TreeBuilder) -> bool { + let hi = self.heap.get(idx); + let hi = match hi { + None => return true, + Some(hi) => hi, + }; + let mut valid = true; + if let Some(hi) = hi { + let label = format!("val={:?} idx={}, bucket={}", hi.val, idx, hi.map_idx); + builder.begin_child(label); + valid &= self._tree_print(idx * 2 + 1, builder); // left + valid &= self._tree_print(idx * 2 + 2, builder); // right + builder.end_child(); + if idx != 0 { + let parent_idx = (idx - 1) / 2; + let parent = self.heap[parent_idx].as_ref().expect("Missing parent"); + if (!self.desc && hi.val.comp(&parent.val) == Ordering::Greater) + || (self.desc && hi.val.comp(&parent.val) == Ordering::Less) + { + return false; + } + } + } + valid + } + + #[allow(dead_code)] + #[cfg(test)] + pub fn tree_print(&self) -> String { + let mut builder = ptree::TreeBuilder::new("BinaryHeap".to_string()); + let valid = self._tree_print(0, &mut builder); + let mut actual = Vec::new(); + ptree::write_tree(&builder.build(), &mut actual).unwrap(); + let res = String::from_utf8(actual).unwrap(); + if !valid { + println!("{res}"); + panic!("Heap invariant violated"); + } + res + } +} + +impl HeapItem { + pub fn new(val: VAL, buk_idx: usize) -> Self { + Self { + val, + map_idx: buk_idx, + } + } +} + +impl Debug for HeapItem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("bucket=")?; + self.map_idx.fmt(f)?; + f.write_str(" val=")?; + self.val.fmt(f)?; + f.write_str("\n")?; + Ok(()) + } +} + +impl Eq for HeapItem {} + +impl PartialEq for HeapItem { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl PartialOrd for HeapItem { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapItem { + fn cmp(&self, other: &Self) -> Ordering { + let res = self.val.comp(&other.val); + if res != Ordering::Equal { + return res; + } + self.map_idx.cmp(&other.map_idx) + } +} + +macro_rules! compare_float { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + match (self, other) { + (Some(me), Some(other)) => me.total_cmp(other), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + } + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.total_cmp(other) + } + })+ + }; +} + +macro_rules! compare_integer { + ($($t:ty),+) => { + $(impl Comparable for Option<$t> { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + + $(impl Comparable for $t { + fn comp(&self, other: &Self) -> Ordering { + self.cmp(other) + } + })+ + }; +} + +compare_integer!(i8, i16, i32, i64, i128, i256); +compare_integer!(u8, u16, u32, u64); +compare_float!(f16, f32, f64); + +pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result> { + macro_rules! downcast_helper { + ($vt:ty, $d:ident) => { + return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) + }; + } + + downcast_primitive! { + vt => (downcast_helper, vt), + _ => {} + } + + Err(DataFusionError::Execution(format!( + "Can't group type: {vt:?}" + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + + #[test] + fn should_append() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + heap.append_or_replace(1, 1, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=1 idx=0, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } + + #[test] + fn should_heapify_up() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + assert_eq!(map, vec![]); + + heap.append_or_replace(2, 2, &mut map); + assert_eq!(map, vec![(2, 0), (1, 1)]); + + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=2 + └─ val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } + + #[test] + fn should_heapify_down() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(3, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + heap.append_or_replace(3, 3, &mut map); + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=3 idx=0, bucket=3 + ├─ val=1 idx=1, bucket=1 + └─ val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let mut map = vec![]; + heap.append_or_replace(0, 0, &mut map); + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=2 + ├─ val=1 idx=1, bucket=1 + └─ val=0 idx=2, bucket=0 + "#; + assert_eq!(actual.trim(), expected.trim()); + assert_eq!(map, vec![(2, 0), (0, 2)]); + + Ok(()) + } + + #[test] + fn should_replace() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(4, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + heap.append_or_replace(3, 3, &mut map); + heap.append_or_replace(4, 4, &mut map); + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=4 idx=0, bucket=4 + ├─ val=3 idx=1, bucket=3 + │ └─ val=1 idx=3, bucket=1 + └─ val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let mut map = vec![]; + heap.replace_if_better(1, 0, &mut map); + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=4 idx=0, bucket=4 + ├─ val=1 idx=1, bucket=1 + │ └─ val=0 idx=3, bucket=3 + └─ val=2 idx=2, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + assert_eq!(map, vec![(1, 1), (3, 3)]); + + Ok(()) + } + + #[test] + fn should_find_worst() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=2 + └─ val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + assert_eq!(heap.worst_val(), Some(&2)); + assert_eq!(heap.worst_map_idx(), 2); + + Ok(()) + } + + #[test] + fn should_take_all() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=2 + └─ val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let vals = heap.take_all(vec![1, 0]); + assert_eq!(vals, vec![1, 2]); + assert_eq!(heap.len(), 0); + + Ok(()) + } + + #[test] + fn should_renumber() -> Result<()> { + let mut map = vec![]; + let mut heap = TopKHeap::new(10, false); + + heap.append_or_replace(1, 1, &mut map); + heap.append_or_replace(2, 2, &mut map); + + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=2 + └─ val=1 idx=1, bucket=1 + "#; + assert_eq!(actual.trim(), expected.trim()); + + let numbers = vec![(0, 1), (1, 2)]; + heap.renumber(numbers.as_slice()); + let actual = heap.tree_print(); + let expected = r#" +BinaryHeap +└─ val=2 idx=0, bucket=1 + └─ val=1 idx=1, bucket=2 + "#; + assert_eq!(actual.trim(), expected.trim()); + + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/topk/mod.rs b/datafusion/core/src/physical_plan/aggregates/topk/mod.rs new file mode 100644 index 000000000000..c6a0f40cc817 --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/topk/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TopK functionality for aggregates + +pub mod hash_table; +pub mod heap; +pub mod priority_map; diff --git a/datafusion/core/src/physical_plan/aggregates/topk/priority_map.rs b/datafusion/core/src/physical_plan/aggregates/topk/priority_map.rs new file mode 100644 index 000000000000..53f553304f1f --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/topk/priority_map.rs @@ -0,0 +1,383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` + +use crate::physical_plan::aggregates::topk::hash_table::{ + new_hash_table, ArrowHashTable, +}; +use crate::physical_plan::aggregates::topk::heap::{new_heap, ArrowHeap}; +use arrow_array::ArrayRef; +use arrow_schema::DataType; +use datafusion_common::Result; + +/// A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` +pub struct PriorityMap { + map: Box, + heap: Box, + capacity: usize, + mapper: Vec<(usize, usize)>, +} + +// JUSTIFICATION +// Benefit: ~15% speedup + required to index into RawTable from binary heap +// Soundness: it is only accessed by one thread at a time, and indexes are kept up to date +unsafe impl Send for PriorityMap {} + +impl PriorityMap { + pub fn new( + key_type: DataType, + val_type: DataType, + capacity: usize, + descending: bool, + ) -> Result { + Ok(Self { + map: new_hash_table(capacity, key_type)?, + heap: new_heap(capacity, descending, val_type)?, + capacity, + mapper: Vec::with_capacity(capacity), + }) + } + + pub fn set_batch(&mut self, ids: ArrayRef, vals: ArrayRef) { + self.map.set_batch(ids); + self.heap.set_batch(vals); + } + + pub fn insert(&mut self, row_idx: usize) -> Result<()> { + assert!(self.map.len() <= self.capacity, "Overflow"); + + // if we're full, and the new val is worse than all our values, just bail + if self.heap.is_worse(row_idx) { + return Ok(()); + } + let map = &mut self.mapper; + + // handle new groups we haven't seen yet + map.clear(); + let replace_idx = self.heap.worst_map_idx(); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: replace_idx kept valid during resizes + let (map_idx, did_insert) = + unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + if did_insert { + self.heap.renumber(map); + map.clear(); + self.heap.insert(row_idx, map_idx, map); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the map was created on the line above, so all the indexes should be valid + unsafe { self.map.update_heap_idx(map) }; + return Ok(()); + }; + + // this is a value for an existing group + map.clear(); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: map_idx was just found, so it is valid + let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; + self.heap.replace_if_better(heap_idx, row_idx, map); + // JUSTIFICATION + // Benefit: ~15% speedup + required to index into RawTable from binary heap + // Soundness: the index map was just built, so it will be valid + unsafe { self.map.update_heap_idx(map) }; + + Ok(()) + } + + pub fn emit(&mut self) -> Result> { + let (ids, heap_idxs) = self.map.drain(); + let vals = self.heap.take_all(heap_idxs); + Ok(vec![ids, vals]) + } + + pub fn is_empty(&self) -> bool { + self.map.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::{Int64Array, RecordBatch, StringArray}; + use arrow_schema::Field; + use arrow_schema::Schema; + use arrow_schema::{DataType, SchemaRef}; + use std::sync::Arc; + + #[test] + fn should_append() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_higher_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_ignore_lower_same_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["2", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "2"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 2 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_lower_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![2, 1])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_accept_higher_for_group() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec!["1", "1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 2 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_handle_null_ids() -> Result<()> { + let ids: ArrayRef = Arc::new(StringArray::from(vec![Some("1"), None, None])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); + let mut agg = PriorityMap::new(DataType::Utf8, DataType::Int64, 2, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + agg.insert(2)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema(), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | +| | 3 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } +} diff --git a/datafusion/core/src/physical_plan/aggregates/topk_stream.rs b/datafusion/core/src/physical_plan/aggregates/topk_stream.rs new file mode 100644 index 000000000000..de1c02885d0c --- /dev/null +++ b/datafusion/core/src/physical_plan/aggregates/topk_stream.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A memory-conscious aggregation implementation that limits group buckets to a fixed number + +use crate::physical_plan::aggregates::topk::priority_map::PriorityMap; +use crate::physical_plan::aggregates::{ + aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, + PhysicalGroupBy, +}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use arrow::util::pretty::print_batches; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{Stream, StreamExt}; +use log::{trace, Level}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct GroupedTopKAggregateStream { + partition: usize, + row_count: usize, + started: bool, + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec>>, + group_by: PhysicalGroupBy, + priority_map: PriorityMap, +} + +impl GroupedTopKAggregateStream { + pub fn new( + aggr: &AggregateExec, + context: Arc, + partition: usize, + limit: usize, + ) -> Result { + let agg_schema = Arc::clone(&aggr.schema); + let group_by = aggr.group_by.clone(); + let input = aggr.input.execute(partition, Arc::clone(&context))?; + let aggregate_arguments = + aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; + let (val_field, desc) = aggr + .get_minmax_desc() + .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; + + let (expr, _) = &aggr.group_expr().expr()[0]; + let kt = expr.data_type(&aggr.input().schema())?; + let vt = val_field.data_type().clone(); + + let priority_map = PriorityMap::new(kt, vt, limit, desc)?; + + Ok(GroupedTopKAggregateStream { + partition, + started: false, + row_count: 0, + schema: agg_schema, + input, + aggregate_arguments, + group_by, + priority_map, + }) + } +} + +impl RecordBatchStream for GroupedTopKAggregateStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl GroupedTopKAggregateStream { + fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { + let len = ids.len(); + self.priority_map.set_batch(ids, vals.clone()); + + let has_nulls = vals.null_count() > 0; + for row_idx in 0..len { + if has_nulls && vals.is_null(row_idx) { + continue; + } + self.priority_map.insert(row_idx)?; + } + Ok(()) + } +} + +impl Stream for GroupedTopKAggregateStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { + match res { + // got a batch, convert to rows and append to our TreeMap + Some(Ok(batch)) => { + self.started = true; + trace!( + "partition {} has {} rows and got batch with {} rows", + self.partition, + self.row_count, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 { + print_batches(&[batch.clone()])?; + } + self.row_count += batch.num_rows(); + let batches = &[batch]; + let group_by_values = + evaluate_group_by(&self.group_by, batches.first().unwrap())?; + assert_eq!( + group_by_values.len(), + 1, + "Exactly 1 group value required" + ); + assert_eq!( + group_by_values[0].len(), + 1, + "Exactly 1 group value required" + ); + let group_by_values = group_by_values[0][0].clone(); + let input_values = evaluate_many( + &self.aggregate_arguments, + batches.first().unwrap(), + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + let input_values = input_values[0][0].clone(); + + // iterate over each column of group_by values + (*self).intern(group_by_values, input_values)?; + } + // inner is done, emit all rows and switch to producing output + None => { + if self.priority_map.is_empty() { + trace!("partition {} emit None", self.partition); + return Poll::Ready(None); + } + let cols = self.priority_map.emit()?; + let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + trace!( + "partition {} emit batch with {} rows", + self.partition, + batch.num_rows() + ); + if log::log_enabled!(Level::Trace) { + print_batches(&[batch.clone()])?; + } + return Poll::Ready(Some(Ok(batch))); + } + // inner had error, return to caller + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + } + } + Poll::Pending + } +} diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index ebb8fca930cd..862d2275afc2 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{DataType, TimeUnit}; #[tokio::test] async fn group_by_date_trunc() -> Result<()> { @@ -68,6 +70,95 @@ async fn group_by_date_trunc() -> Result<()> { Ok(()) } +#[tokio::test] +async fn group_by_limit() -> Result<()> { + let tmp_dir = TempDir::new()?; + let ctx = create_groupby_context(&tmp_dir).await?; + + let sql = "SELECT trace_id, MAX(ts) from traces group by trace_id order by MAX(ts) desc limit 4"; + let dataframe = ctx.sql(sql).await?; + + // ensure we see `lim=[4]` + let physical_plan = dataframe.create_physical_plan().await?; + let mut expected_physical_plan = r#" +GlobalLimitExec: skip=0, fetch=4 + SortExec: fetch=4, expr=[MAX(traces.ts)@1 DESC] + AggregateExec: mode=Single, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.ts)], lim=[4] + "#.trim().to_string(); + let actual_phys_plan = + format_plan(physical_plan.clone(), &mut expected_physical_plan); + assert_eq!(actual_phys_plan, expected_physical_plan); + + let batches = collect(physical_plan, ctx.task_ctx()).await?; + let expected = r#" ++----------+----------------------+ +| trace_id | MAX(traces.ts) | ++----------+----------------------+ +| 9 | 2020-12-01T00:00:18Z | +| 8 | 2020-12-01T00:00:17Z | +| 7 | 2020-12-01T00:00:16Z | +| 6 | 2020-12-01T00:00:15Z | ++----------+----------------------+ +"# + .trim(); + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + + Ok(()) +} + +fn format_plan( + physical_plan: Arc, + expected_phys_plan: &mut String, +) -> String { + let actual_phys_plan = displayable(physical_plan.as_ref()).indent(true).to_string(); + let last_line = actual_phys_plan + .as_str() + .lines() + .last() + .expect("Plan should not be empty"); + + expected_phys_plan.push('\n'); + expected_phys_plan.push_str(last_line); + expected_phys_plan.push('\n'); + actual_phys_plan +} + +async fn create_groupby_context(tmp_dir: &TempDir) -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + false, + ), + ])); + + // generate a file + let filename = "traces.csv"; + let file_path = tmp_dir.path().join(filename); + let mut file = File::create(file_path)?; + + // generate some data + for trace_id in 0..10 { + for ts in 0..10 { + let ts = trace_id + ts; + let data = format!("\"{trace_id}\",2020-12-01T00:00:{ts:02}.000Z\n"); + file.write_all(data.as_bytes())?; + } + } + + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(cfg); + ctx.register_csv( + "traces", + tmp_dir.path().to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + ) + .await?; + Ok(ctx) +} + #[tokio::test] async fn group_by_dictionary() { async fn run_test_case() { diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 3ed7905e294d..e1af67071260 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -21,7 +21,10 @@ use crate::{AggregateExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; use arrow_array::cast::AsArray; -use arrow_array::types::Decimal128Type; +use arrow_array::types::{ + Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; use arrow_schema::{DataType, Field}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -126,6 +129,30 @@ pub fn adjust_output_array( .as_primitive::() .clone() .with_precision_and_scale(*p, *s)?, + ) as ArrayRef, + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(arrow_schema::TimeUnit::Second, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), ), // no adjustment needed for other arrays _ => array, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4e201317e16b..5dcd66a5ab44 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2282,7 +2282,7 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict); 4 5 -# bool aggregtion +# bool aggregation statement ok CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3); @@ -2312,7 +2312,150 @@ false true NULL +# TopK aggregation +statement ok +CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES +(NULL, 0), +('a', NULL), +('a', 1), +('b', 0), +('c', 1), +('c', 2), +('b', 3); + +statement ok +set datafusion.optimizer.enable_topk_aggregation = false; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +NULL 0 +b 0 +c 1 +a 1 + +statement ok +set datafusion.optimizer.enable_topk_aggregation = true; +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MIN(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MIN(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 +----SortExec: fetch=4, expr=[MIN(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 +----SortExec: fetch=4, expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: traces.trace_id ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 +----SortExec: fetch=4, expr=[trace_id@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +NULL 0 +b 0 +c 1 +a 1 # # regr_*() tests diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index ad9b2be40e9e..44b67c78ed27 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -250,4 +250,5 @@ physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after PipelineChecker SAME TEXT AS ABOVE +physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 5db305105f53..0bb30dc0bd70 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -182,6 +182,7 @@ datafusion.explain.physical_plan_only false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.bounded_order_preserving_variants false datafusion.optimizer.enable_round_robin_repartition true +datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 datafusion.optimizer.max_passes 3 diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 0d3abeac9fbf..c92b72e8b323 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -76,6 +76,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | | datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | | datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | | datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | | datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. |