diff --git a/datafusion/core/src/execution/memory_manager.rs b/datafusion/core/src/execution/memory_manager.rs index 48d4ca3c3d32..e7148b06606c 100644 --- a/datafusion/core/src/execution/memory_manager.rs +++ b/datafusion/core/src/execution/memory_manager.rs @@ -178,10 +178,8 @@ pub trait MemoryConsumer: Send + Sync { self.id(), ); - let can_grow_directly = self - .memory_manager() - .can_grow_directly(required, current) - .await; + let can_grow_directly = + self.memory_manager().can_grow_directly(required, current); if !can_grow_directly { debug!( "Failed to grow memory of {} directly from consumer {}, spilling first ...", @@ -334,7 +332,7 @@ impl MemoryManager { } /// Grow memory attempt from a consumer, return if we could grant that much to it - async fn can_grow_directly(&self, required: usize, current: usize) -> bool { + fn can_grow_directly(&self, required: usize, current: usize) -> bool { let num_rqt = self.requesters.lock().len(); let mut rqt_current_used = self.requesters_total.lock(); let mut rqt_max = self.max_mem_for_requesters(); diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 43e75e352010..6ce58592d83b 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -348,7 +348,7 @@ impl ExecutionPlan for AggregateExec { context: Arc, ) -> Result { let batch_size = context.session_config().batch_size(); - let input = self.input.execute(partition, context)?; + let input = self.input.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -369,6 +369,8 @@ impl ExecutionPlan for AggregateExec { input, baseline_metrics, batch_size, + context, + partition, )?)) } else { Ok(Box::pin(GroupedHashAggregateStream::new( @@ -689,7 +691,8 @@ fn evaluate_group_by( #[cfg(test)] mod tests { - use crate::execution::context::TaskContext; + use crate::execution::context::{SessionConfig, TaskContext}; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::from_slice::FromSlice; use crate::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -700,7 +703,7 @@ mod tests { use crate::{assert_batches_sorted_eq, physical_plan::common}; use arrow::array::{Float64Array, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::error::Result as ArrowResult; + use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{lit, Count}; @@ -1081,6 +1084,63 @@ mod tests { check_grouping_sets(input).await } + #[tokio::test] + async fn test_oom() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + let input_schema = input.schema(); + + let session_ctx = SessionContext::with_config_rt( + SessionConfig::default(), + Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) + .unwrap(), + ), + ); + let task_ctx = session_ctx.task_ctx(); + + let groups = PhysicalGroupBy { + expr: vec![(col("a", &input_schema)?, "a".to_string())], + null_expr: vec![], + groups: vec![vec![false]], + }; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + col("b", &input_schema)?, + "AVG(b)".to_string(), + DataType::Float64, + ))]; + + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates, + input, + input_schema.clone(), + )?); + + let err = common::collect(partial_aggregate.execute(0, task_ctx.clone())?) + .await + .unwrap_err(); + + // error root cause traversal is a bit complicated, see #4172. + if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err { + if let Some(err) = err.downcast_ref::() { + assert!( + matches!(err, DataFusionError::ResourcesExhausted(_)), + "Wrong inner error type: {}", + err, + ); + } else { + panic!("Wrong arrow error type: {err}") + } + } else { + panic!("Wrong outer error type: {err}") + } + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index aefc6571b068..c6658b2a6ee5 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -22,12 +22,14 @@ use std::task::{Context, Poll}; use std::vec; use ahash::RandomState; -use futures::{ - ready, - stream::{Stream, StreamExt}, -}; +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::stream::{Stream, StreamExt}; use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::execution::memory_manager::ConsumerType; +use crate::execution::{MemoryConsumer, MemoryConsumerId, MemoryManager}; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, PhysicalGroupBy, @@ -45,13 +47,13 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, }; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_row::accessor::RowAccessor; use datafusion_row::layout::RowLayout; use datafusion_row::reader::{read_row, RowReader}; use datafusion_row::writer::{write_row, RowWriter}; use datafusion_row::{MutableRecordBatch, RowType}; -use hashbrown::raw::RawTable; +use hashbrown::raw::{Bucket, RawTable}; /// Grouping aggregate with row-format aggregation states inside. /// @@ -70,6 +72,16 @@ use hashbrown::raw::RawTable; /// [Compact]: datafusion_row::layout::RowType::Compact /// [WordAligned]: datafusion_row::layout::RowType::WordAligned pub(crate) struct GroupedHashAggregateStreamV2 { + stream: BoxStream<'static, ArrowResult>, + schema: SchemaRef, +} + +/// Actual implementation of [`GroupedHashAggregateStreamV2`]. +/// +/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem +/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with +/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`]. +struct GroupedHashAggregateStreamV2Inner { schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, @@ -102,6 +114,7 @@ fn aggr_state_schema(aggr_expr: &[Arc]) -> Result impl GroupedHashAggregateStreamV2 { /// Create a new GroupedRowHashAggregateStream + #[allow(clippy::too_many_arguments)] pub fn new( mode: AggregateMode, schema: SchemaRef, @@ -110,6 +123,8 @@ impl GroupedHashAggregateStreamV2 { input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, batch_size: usize, + context: Arc, + partition: usize, ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); @@ -125,10 +140,24 @@ impl GroupedHashAggregateStreamV2 { let aggr_schema = aggr_state_schema(&aggr_expr)?; let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned)); + + let aggr_state = AggregationState { + memory_consumer: AggregationStateMemoryConsumer { + id: MemoryConsumerId::new(partition), + memory_manager: Arc::clone(&context.runtime_env().memory_manager), + used: 0, + }, + map: RawTable::with_capacity(0), + group_states: Vec::with_capacity(0), + }; + context + .runtime_env() + .register_requester(aggr_state.memory_consumer.id()); + timer.done(); - Ok(Self { - schema, + let inner = GroupedHashAggregateStreamV2Inner { + schema: Arc::clone(&schema), mode, input, group_by, @@ -138,11 +167,87 @@ impl GroupedHashAggregateStreamV2 { aggr_layout, baseline_metrics, aggregate_expressions, - aggr_state: Default::default(), + aggr_state, random_state: Default::default(), batch_size, row_group_skip_position: 0, - }) + }; + + let stream = futures::stream::unfold(inner, |mut this| async move { + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + + loop { + let result: ArrowResult> = + match this.input.next().await { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = group_aggregate_batch( + &this.mode, + &this.random_state, + &this.group_by, + &mut this.accumulators, + &this.group_schema, + this.aggr_layout.clone(), + batch, + &mut this.aggr_state, + &this.aggregate_expressions, + ); + + timer.done(); + + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + let result = match result { + Ok(allocated) => { + this.aggr_state.memory_consumer.alloc(allocated).await + } + Err(e) => Err(e), + }; + + match result { + Ok(()) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } + } + Some(Err(e)) => Err(e), + None => { + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = create_batch_from_map( + &this.mode, + &this.group_schema, + &this.aggr_schema, + this.batch_size, + this.row_group_skip_position, + &mut this.aggr_state, + &mut this.accumulators, + &this.schema, + ); + + timer.done(); + result + } + }; + + this.row_group_skip_position += this.batch_size; + match result { + Ok(Some(result)) => { + return Some(( + Ok(result.record_output(&this.baseline_metrics)), + this, + )); + } + Ok(None) => return None, + Err(error) => return Some((Err(error), this)), + } + } + }); + + // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. + let stream = stream.fuse(); + let stream = Box::pin(stream); + + Ok(Self { schema, stream }) } } @@ -154,63 +259,7 @@ impl Stream for GroupedHashAggregateStreamV2 { cx: &mut Context<'_>, ) -> Poll> { let this = &mut *self; - - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - - loop { - let result: ArrowResult> = - match ready!(this.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = group_aggregate_batch( - &this.mode, - &this.random_state, - &this.group_by, - &mut this.accumulators, - &this.group_schema, - this.aggr_layout.clone(), - batch, - &mut this.aggr_state, - &this.aggregate_expressions, - ); - - timer.done(); - - match result { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), - } - } - Some(Err(e)) => Err(e), - None => { - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = create_batch_from_map( - &this.mode, - &this.group_schema, - &this.aggr_schema, - this.batch_size, - this.row_group_skip_position, - &mut this.aggr_state, - &mut this.accumulators, - &this.schema, - ); - - timer.done(); - result - } - }; - - this.row_group_skip_position += this.batch_size; - match result { - Ok(Some(result)) => { - return Poll::Ready(Some(Ok( - result.record_output(&this.baseline_metrics) - ))) - } - Ok(None) => return Poll::Ready(None), - Err(error) => return Poll::Ready(Some(Err(error))), - } - } + this.stream.poll_next_unpin(cx) } } @@ -220,6 +269,10 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 { } } +/// Perform group-by aggregation for the given [`RecordBatch`]. +/// +/// If successfull, this returns the additional number of bytes that were allocated during this process. +/// /// TODO: Make this a member function of [`GroupedHashAggregateStreamV2`] #[allow(clippy::too_many_arguments)] fn group_aggregate_batch( @@ -232,10 +285,15 @@ fn group_aggregate_batch( batch: RecordBatch, aggr_state: &mut AggregationState, aggregate_expressions: &[Vec>], -) -> Result<()> { +) -> Result { // evaluate the grouping expressions let grouping_by_values = evaluate_group_by(grouping_set, &batch)?; + let AggregationState { + map, group_states, .. + } = aggr_state; + let mut allocated = 0usize; + for group_values in grouping_by_values { let group_rows: Vec> = create_group_rows(group_values, group_schema); @@ -256,8 +314,6 @@ fn group_aggregate_batch( create_row_hashes(&group_rows, random_state, &mut batch_hashes)?; for (row, hash) in batch_hashes.into_iter().enumerate() { - let AggregationState { map, group_states } = aggr_state; - let entry = map.get_mut(hash, |(_hash, group_idx)| { // verify that a group that we are inserting with hash is // actually the same key value as the group in @@ -270,11 +326,15 @@ fn group_aggregate_batch( // Existing entry for this group value Some((_hash, group_idx)) => { let group_state = &mut group_states[*group_idx]; + // 1.3 if group_state.indices.is_empty() { groups_with_rows.push(*group_idx); }; - group_state.indices.push(row as u32); // remember this row + + group_state + .indices + .push_accounted(row as u32, &mut allocated); // remember this row } // 1.2 Need to create new entry None => { @@ -285,11 +345,25 @@ fn group_aggregate_batch( indices: vec![row as u32], // 1.3 }; let group_idx = group_states.len(); - group_states.push(group_state); - groups_with_rows.push(group_idx); + + // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by + // `group_states` (see allocation down below) + allocated += (std::mem::size_of::() + * group_state.group_by_values.capacity()) + + (std::mem::size_of::() + * group_state.aggregation_buffer.capacity()) + + (std::mem::size_of::() * group_state.indices.capacity()); // for hasher function, use precomputed hash value - map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut allocated, + ); + + group_states.push_accounted(group_state, &mut allocated); + + groups_with_rows.push(group_idx); } }; } @@ -299,7 +373,7 @@ fn group_aggregate_batch( let mut offsets = vec![0]; let mut offset_so_far = 0; for group_idx in groups_with_rows.iter() { - let indices = &aggr_state.group_states[*group_idx].indices; + let indices = &group_states[*group_idx].indices; batch_indices.append_slice(indices); offset_so_far += indices.len(); offsets.push(offset_so_far); @@ -334,7 +408,7 @@ fn group_aggregate_batch( .iter() .zip(offsets.windows(2)) .try_for_each(|(group_idx, offsets)| { - let group_state = &mut aggr_state.group_states[*group_idx]; + let group_state = &mut group_states[*group_idx]; // 2.2 accumulators .iter_mut() @@ -374,7 +448,7 @@ fn group_aggregate_batch( })?; } - Ok(()) + Ok(allocated) } /// The state that is built for each output group. @@ -392,8 +466,9 @@ struct RowGroupState { } /// The state of all the groups -#[derive(Default)] struct AggregationState { + memory_consumer: AggregationStateMemoryConsumer, + /// Logically maps group values to an index in `group_states` /// /// Uses the raw API of hashbrown to avoid actually storing the @@ -418,6 +493,130 @@ impl std::fmt::Debug for AggregationState { } } +/// Accounting data structure for memory usage. +struct AggregationStateMemoryConsumer { + /// Consumer ID. + id: MemoryConsumerId, + + /// Linked memory manager. + memory_manager: Arc, + + /// Currently used size in bytes. + used: usize, +} + +#[async_trait] +impl MemoryConsumer for AggregationStateMemoryConsumer { + fn name(&self) -> String { + "AggregationState".to_owned() + } + + fn id(&self) -> &crate::execution::MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + Arc::clone(&self.memory_manager) + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } + + async fn spill(&self) -> Result { + Err(DataFusionError::ResourcesExhausted( + "Cannot spill AggregationState".to_owned(), + )) + } + + fn mem_used(&self) -> usize { + self.used + } +} + +impl AggregationStateMemoryConsumer { + async fn alloc(&mut self, bytes: usize) -> Result<()> { + self.try_grow(bytes).await?; + self.used = self.used.checked_add(bytes).expect("overflow"); + Ok(()) + } +} + +impl Drop for AggregationStateMemoryConsumer { + fn drop(&mut self) { + self.memory_manager + .drop_consumer(self.id(), self.mem_used()); + } +} + +trait VecAllocExt { + type T; + + fn push_accounted(&mut self, x: Self::T, accounting: &mut usize); +} + +impl VecAllocExt for Vec { + type T = T; + + fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) { + if self.capacity() == self.len() { + // allocate more + + // growth factor: 2, but at least 2 elements + let bump_elements = (self.capacity() * 2).max(2); + let bump_size = std::mem::size_of::() * bump_elements; + self.reserve(bump_elements); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + } + + self.push(x); + } +} + +trait RawTableAllocExt { + type T; + + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) -> Bucket; +} + +impl RawTableAllocExt for RawTable { + type T = T; + + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) -> Bucket { + let hash = hasher(&x); + + match self.try_insert_no_grow(hash, x) { + Ok(bucket) => bucket, + Err(x) => { + // need to request more memory + + let bump_elements = (self.capacity() * 2).max(16); + let bump_size = bump_elements * std::mem::size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + + self.reserve(bump_elements, hasher); + + // still need to insert the element since first try failed + // Note: cannot use `.expect` here because `T` may not implement `Debug` + match self.try_insert_no_grow(hash, x) { + Ok(bucket) => bucket, + Err(_) => panic!("just grew the container"), + } + } + } + } +} + /// Create grouping rows fn create_group_rows(arrays: Vec, schema: &Schema) -> Vec> { let mut writer = RowWriter::new(schema, RowType::Compact);