From 534649969b45fb3b965910a243d870b313fe0873 Mon Sep 17 00:00:00 2001 From: sundyli <543950155@qq.com> Date: Fri, 29 Apr 2022 08:29:16 +0800 Subject: [PATCH] fix(functions): use drop guard to ensure the states dropped --- .../src/aggregates/aggregator_common.rs | 58 +++++++++++-------- .../aggregator/aggregator_single_key.rs | 23 ++++---- .../processors/transforms/aggregator/mod.rs | 6 +- .../transforms/transform_aggregator.rs | 4 +- 4 files changed, 50 insertions(+), 41 deletions(-) diff --git a/common/functions/src/aggregates/aggregator_common.rs b/common/functions/src/aggregates/aggregator_common.rs index e51b550e9be4..172bea24c693 100644 --- a/common/functions/src/aggregates/aggregator_common.rs +++ b/common/functions/src/aggregates/aggregator_common.rs @@ -23,6 +23,8 @@ use common_exception::ErrorCode; use common_exception::Result; use super::AggregateFunctionFactory; +use super::AggregateFunctionRef; +use super::StateAddr; pub fn assert_unary_params(name: D, actual: usize) -> Result<()> { if actual != 1 { @@ -78,6 +80,33 @@ pub fn assert_variadic_arguments( Ok(()) } +struct EvalAggr { + addr: StateAddr, + _arena: Bump, + func: AggregateFunctionRef, +} + +impl EvalAggr { + fn new(func: AggregateFunctionRef) -> Self { + let _arena = Bump::new(); + let place = _arena.alloc_layout(func.state_layout()); + let addr = place.into(); + func.init_state(addr); + + Self { _arena, func, addr } + } +} + +impl Drop for EvalAggr { + fn drop(&mut self) { + if self.func.need_manual_drop_state() { + unsafe { + self.func.drop_state(self.addr); + } + } + } +} + pub fn eval_aggr( name: &str, params: Vec, @@ -91,28 +120,9 @@ pub fn eval_aggr( let func = factory.get(name, params, arguments)?; let data_type = func.return_type()?; - let arena = Bump::new(); - let place = arena.alloc_layout(func.state_layout()); - let addr = place.into(); - func.init_state(addr); - - let f = func.clone(); - // we need a temporary function to catch the errors - let apply = || -> Result { - func.accumulate(addr, &cols, None, rows)?; - let mut builder = data_type.create_mutable(1024); - func.merge_result(addr, builder.as_mut())?; - - Ok(builder.to_column()) - }; - - let result = apply(); - - if f.need_manual_drop_state() { - unsafe { - f.drop_state(addr); - } - } - drop(arena); - result + let eval = EvalAggr::new(func.clone()); + func.accumulate(eval.addr, &cols, None, rows)?; + let mut builder = data_type.create_mutable(1024); + func.merge_result(eval.addr, builder.as_mut())?; + Ok(builder.to_column()) } diff --git a/query/src/pipelines/new/processors/transforms/aggregator/aggregator_single_key.rs b/query/src/pipelines/new/processors/transforms/aggregator/aggregator_single_key.rs index f791c474cda1..eb742891abba 100644 --- a/query/src/pipelines/new/processors/transforms/aggregator/aggregator_single_key.rs +++ b/query/src/pipelines/new/processors/transforms/aggregator/aggregator_single_key.rs @@ -35,16 +35,15 @@ use common_functions::aggregates::StateAddr; use crate::pipelines::new::processors::transforms::transform_aggregator::Aggregator; use crate::pipelines::new::processors::AggregatorParams; -pub type FinalSingleKeyAggregator = SingleKeyAggregator; -pub type PartialSingleKeyAggregator = SingleKeyAggregator; +pub type FinalSingleStateAggregator = SingleStateAggregator; +pub type PartialSingleStateAggregator = SingleStateAggregator; /// SELECT COUNT | SUM FROM table; -#[allow(dead_code)] -pub struct SingleKeyAggregator { +pub struct SingleStateAggregator { funcs: Vec, arg_names: Vec>, schema: DataSchemaRef, - arena: Bump, + _arena: Bump, places: Vec, // used for deserialization only, so we can reuse it during the loop temp_places: Vec, @@ -52,7 +51,7 @@ pub struct SingleKeyAggregator { states_dropped: bool, } -impl SingleKeyAggregator { +impl SingleStateAggregator { pub fn try_create(params: &Arc) -> Result { let arena = Bump::new(); let (layout, offsets_aggregate_states) = @@ -76,7 +75,7 @@ impl SingleKeyAggregator { let temp_places = get_places(); Ok(Self { - arena, + _arena: arena, places, funcs: params.aggregate_functions.clone(), arg_names: params.aggregate_functions_arguments_name.clone(), @@ -106,8 +105,8 @@ impl SingleKeyAggregator { } } -impl Aggregator for SingleKeyAggregator { - const NAME: &'static str = "FinalSingleKeyAggregator"; +impl Aggregator for SingleStateAggregator { + const NAME: &'static str = "FinalSingleStateAggregator"; fn consume(&mut self, block: DataBlock) -> Result<()> { for (index, func) in self.funcs.iter().enumerate() { @@ -156,8 +155,8 @@ impl Aggregator for SingleKeyAggregator { } } -impl Aggregator for SingleKeyAggregator { - const NAME: &'static str = "PartialSingleKeyAggregator"; +impl Aggregator for SingleStateAggregator { + const NAME: &'static str = "PartialSingleStateAggregator"; fn consume(&mut self, block: DataBlock) -> Result<()> { let rows = block.num_rows(); @@ -197,7 +196,7 @@ impl Aggregator for SingleKeyAggregator { } } -impl Drop for SingleKeyAggregator { +impl Drop for SingleStateAggregator { fn drop(&mut self) { self.drop_states(); } diff --git a/query/src/pipelines/new/processors/transforms/aggregator/mod.rs b/query/src/pipelines/new/processors/transforms/aggregator/mod.rs index b320f51f0e19..402299b47997 100644 --- a/query/src/pipelines/new/processors/transforms/aggregator/mod.rs +++ b/query/src/pipelines/new/processors/transforms/aggregator/mod.rs @@ -33,6 +33,6 @@ pub use aggregator_partial::KeysU8PartialAggregator; pub use aggregator_partial::PartialAggregator; pub use aggregator_partial::SerializerPartialAggregator; pub use aggregator_partial::SingleStringPartialAggregator; -pub use aggregator_single_key::FinalSingleKeyAggregator; -pub use aggregator_single_key::PartialSingleKeyAggregator; -pub use aggregator_single_key::SingleKeyAggregator; +pub use aggregator_single_key::FinalSingleStateAggregator; +pub use aggregator_single_key::PartialSingleStateAggregator; +pub use aggregator_single_key::SingleStateAggregator; diff --git a/query/src/pipelines/new/processors/transforms/transform_aggregator.rs b/query/src/pipelines/new/processors/transforms/transform_aggregator.rs index 5f780f50a31f..802ebee2f0c6 100644 --- a/query/src/pipelines/new/processors/transforms/transform_aggregator.rs +++ b/query/src/pipelines/new/processors/transforms/transform_aggregator.rs @@ -41,7 +41,7 @@ impl TransformAggregator { return AggregatorTransform::create( input_port, output_port, - FinalSingleKeyAggregator::try_create(&aggregator_params)?, + FinalSingleStateAggregator::try_create(&aggregator_params)?, ); } @@ -124,7 +124,7 @@ impl TransformAggregator { return AggregatorTransform::create( input_port, output_port, - PartialSingleKeyAggregator::try_create(&aggregator_params)?, + PartialSingleStateAggregator::try_create(&aggregator_params)?, ); }