Skip to content

Commit

Permalink
fix(functions): use drop guard to ensure the states dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Apr 29, 2022
1 parent 0e63755 commit 5346499
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 41 deletions.
58 changes: 34 additions & 24 deletions common/functions/src/aggregates/aggregator_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use common_exception::ErrorCode;
use common_exception::Result;

use super::AggregateFunctionFactory;
use super::AggregateFunctionRef;
use super::StateAddr;

pub fn assert_unary_params<D: Display>(name: D, actual: usize) -> Result<()> {
if actual != 1 {
Expand Down Expand Up @@ -78,6 +80,33 @@ pub fn assert_variadic_arguments<D: Display>(
Ok(())
}

struct EvalAggr {
addr: StateAddr,
_arena: Bump,
func: AggregateFunctionRef,
}

impl EvalAggr {
fn new(func: AggregateFunctionRef) -> Self {
let _arena = Bump::new();
let place = _arena.alloc_layout(func.state_layout());
let addr = place.into();
func.init_state(addr);

Self { _arena, func, addr }
}
}

impl Drop for EvalAggr {
fn drop(&mut self) {
if self.func.need_manual_drop_state() {
unsafe {
self.func.drop_state(self.addr);
}
}
}
}

pub fn eval_aggr(
name: &str,
params: Vec<DataValue>,
Expand All @@ -91,28 +120,9 @@ pub fn eval_aggr(
let func = factory.get(name, params, arguments)?;
let data_type = func.return_type()?;

let arena = Bump::new();
let place = arena.alloc_layout(func.state_layout());
let addr = place.into();
func.init_state(addr);

let f = func.clone();
// we need a temporary function to catch the errors
let apply = || -> Result<ColumnRef> {
func.accumulate(addr, &cols, None, rows)?;
let mut builder = data_type.create_mutable(1024);
func.merge_result(addr, builder.as_mut())?;

Ok(builder.to_column())
};

let result = apply();

if f.need_manual_drop_state() {
unsafe {
f.drop_state(addr);
}
}
drop(arena);
result
let eval = EvalAggr::new(func.clone());
func.accumulate(eval.addr, &cols, None, rows)?;
let mut builder = data_type.create_mutable(1024);
func.merge_result(eval.addr, builder.as_mut())?;
Ok(builder.to_column())
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,23 @@ use common_functions::aggregates::StateAddr;
use crate::pipelines::new::processors::transforms::transform_aggregator::Aggregator;
use crate::pipelines::new::processors::AggregatorParams;

pub type FinalSingleKeyAggregator = SingleKeyAggregator<true>;
pub type PartialSingleKeyAggregator = SingleKeyAggregator<false>;
pub type FinalSingleStateAggregator = SingleStateAggregator<true>;
pub type PartialSingleStateAggregator = SingleStateAggregator<false>;

/// SELECT COUNT | SUM FROM table;
#[allow(dead_code)]
pub struct SingleKeyAggregator<const FINAL: bool> {
pub struct SingleStateAggregator<const FINAL: bool> {
funcs: Vec<AggregateFunctionRef>,
arg_names: Vec<Vec<String>>,
schema: DataSchemaRef,
arena: Bump,
_arena: Bump,
places: Vec<StateAddr>,
// used for deserialization only, so we can reuse it during the loop
temp_places: Vec<StateAddr>,
is_finished: bool,
states_dropped: bool,
}

impl<const FINAL: bool> SingleKeyAggregator<FINAL> {
impl<const FINAL: bool> SingleStateAggregator<FINAL> {
pub fn try_create(params: &Arc<AggregatorParams>) -> Result<Self> {
let arena = Bump::new();
let (layout, offsets_aggregate_states) =
Expand All @@ -76,7 +75,7 @@ impl<const FINAL: bool> SingleKeyAggregator<FINAL> {
let temp_places = get_places();

Ok(Self {
arena,
_arena: arena,
places,
funcs: params.aggregate_functions.clone(),
arg_names: params.aggregate_functions_arguments_name.clone(),
Expand Down Expand Up @@ -106,8 +105,8 @@ impl<const FINAL: bool> SingleKeyAggregator<FINAL> {
}
}

impl Aggregator for SingleKeyAggregator<true> {
const NAME: &'static str = "FinalSingleKeyAggregator";
impl Aggregator for SingleStateAggregator<true> {
const NAME: &'static str = "FinalSingleStateAggregator";

fn consume(&mut self, block: DataBlock) -> Result<()> {
for (index, func) in self.funcs.iter().enumerate() {
Expand Down Expand Up @@ -156,8 +155,8 @@ impl Aggregator for SingleKeyAggregator<true> {
}
}

impl Aggregator for SingleKeyAggregator<false> {
const NAME: &'static str = "PartialSingleKeyAggregator";
impl Aggregator for SingleStateAggregator<false> {
const NAME: &'static str = "PartialSingleStateAggregator";

fn consume(&mut self, block: DataBlock) -> Result<()> {
let rows = block.num_rows();
Expand Down Expand Up @@ -197,7 +196,7 @@ impl Aggregator for SingleKeyAggregator<false> {
}
}

impl<const FINAL: bool> Drop for SingleKeyAggregator<FINAL> {
impl<const FINAL: bool> Drop for SingleStateAggregator<FINAL> {
fn drop(&mut self) {
self.drop_states();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ pub use aggregator_partial::KeysU8PartialAggregator;
pub use aggregator_partial::PartialAggregator;
pub use aggregator_partial::SerializerPartialAggregator;
pub use aggregator_partial::SingleStringPartialAggregator;
pub use aggregator_single_key::FinalSingleKeyAggregator;
pub use aggregator_single_key::PartialSingleKeyAggregator;
pub use aggregator_single_key::SingleKeyAggregator;
pub use aggregator_single_key::FinalSingleStateAggregator;
pub use aggregator_single_key::PartialSingleStateAggregator;
pub use aggregator_single_key::SingleStateAggregator;
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl TransformAggregator {
return AggregatorTransform::create(
input_port,
output_port,
FinalSingleKeyAggregator::try_create(&aggregator_params)?,
FinalSingleStateAggregator::try_create(&aggregator_params)?,
);
}

Expand Down Expand Up @@ -124,7 +124,7 @@ impl TransformAggregator {
return AggregatorTransform::create(
input_port,
output_port,
PartialSingleKeyAggregator::try_create(&aggregator_params)?,
PartialSingleStateAggregator::try_create(&aggregator_params)?,
);
}

Expand Down

0 comments on commit 5346499

Please sign in to comment.