From 56609a835f1d4f6acfa199f3e265bf971c6d6b4d Mon Sep 17 00:00:00 2001 From: Kun FAN Date: Wed, 20 Apr 2022 19:05:53 +0800 Subject: [PATCH] feat(functions): support aggregate function retention --- .../src/aggregates/aggregate_retention.rs | 224 ++++++++++++++++++ common/functions/src/aggregates/aggregator.rs | 3 + common/functions/src/aggregates/mod.rs | 2 + .../aggregate-retention.md | 60 +++++ ...2_0000_function_aggregate_retention.result | 3 + .../02_0000_function_aggregate_retention.sql | 13 + 6 files changed, 305 insertions(+) create mode 100644 common/functions/src/aggregates/aggregate_retention.rs create mode 100644 docs/doc/30-reference/20-functions/50-aggregate-functions/aggregate-retention.md create mode 100644 tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.result create mode 100644 tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.sql diff --git a/common/functions/src/aggregates/aggregate_retention.rs b/common/functions/src/aggregates/aggregate_retention.rs new file mode 100644 index 0000000000000..4b8a814d0c0aa --- /dev/null +++ b/common/functions/src/aggregates/aggregate_retention.rs @@ -0,0 +1,224 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::Layout; +use std::fmt; +use std::sync::Arc; + +use bytes::BytesMut; +use common_datavalues::prelude::*; +use common_exception::ErrorCode; +use common_exception::Result; +use common_io::prelude::*; +use serde_json::json; +use serde_json::Value as JsonValue; + +use super::aggregate_function::AggregateFunction; +use super::aggregate_function::AggregateFunctionRef; +use super::aggregate_function_factory::AggregateFunctionDescription; +use super::StateAddr; +use crate::aggregates::aggregator_common::assert_variadic_arguments; + +struct AggregateRetentionState { + pub events: u32, +} + +impl AggregateRetentionState { + #[inline(always)] + fn add(&mut self, event: u8) { + self.events |= 1 << event; + } + + fn merge(&mut self, other: &Self) { + self.events |= other.events; + } + + fn serialize(&self, writer: &mut BytesMut) -> Result<()> { + serialize_into_buf(writer, &self.events) + } + + fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> { + self.events = deserialize_from_slice(reader)?; + Ok(()) + } +} + +#[derive(Clone)] +pub struct AggregateRetentionFunction { + display_name: String, + events_size: u8, + _arguments: Vec, +} + +impl AggregateFunction for AggregateRetentionFunction { + fn name(&self) -> &str { + "AggregateRetentionFunction" + } + + fn return_type(&self) -> Result { + Ok(JsonValue::to_data_type()) + } + + fn init_state(&self, place: StateAddr) { + place.write(|| AggregateRetentionState { events: 0 }); + } + + fn state_layout(&self) -> std::alloc::Layout { + Layout::new::() + } + + fn accumulate( + &self, + place: StateAddr, + columns: &[common_datavalues::ColumnRef], + _validity: Option<&common_arrow::arrow::bitmap::Bitmap>, + input_rows: usize, + ) -> Result<()> { + let state = place.get::(); + let new_columns: Vec<&BooleanColumn> = columns + .iter() + .map(|column| Series::check_get(column).unwrap()) + .collect(); + for i in 0..input_rows { + for j in 0..self.events_size { + if new_columns[j as usize].get_data(i) { + state.add(j); + } + } + } + Ok(()) + } + + fn accumulate_row( + &self, + place: StateAddr, + columns: &[common_datavalues::ColumnRef], + row: usize, + ) -> Result<()> { + let state = place.get::(); + let new_columns: Vec<&BooleanColumn> = columns + .iter() + .map(|column| Series::check_get(column).unwrap()) + .collect(); + for j in 0..self.events_size { + if new_columns[j as usize].get_data(row) { + state.add(j); + } + } + Ok(()) + } + + fn serialize(&self, place: StateAddr, writer: &mut BytesMut) -> Result<()> { + let state = place.get::(); + state.serialize(writer) + } + + fn deserialize(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + let state = place.get::(); + state.deserialize(reader) + } + + fn merge(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + let rhs = rhs.get::(); + let state = place.get::(); + state.merge(rhs); + Ok(()) + } + + #[allow(unused_mut)] + fn merge_result( + &self, + place: StateAddr, + array: &mut dyn common_datavalues::MutableColumn, + ) -> Result<()> { + let state = place.get::(); + let builder: &mut MutableObjectColumn = Series::check_get_mutable_column(array)?; + let mut vec: Vec = vec![0; self.events_size as usize]; + if state.events & 1 == 1 { + vec[0] = 1; + for i in 1..self.events_size { + if state.events & (1 << i) != 0 { + vec[i as usize] = 1; + } + } + } + builder.append_value(json!(vec)); + Ok(()) + } + + fn accumulate_keys( + &self, + places: &[StateAddr], + offset: usize, + columns: &[common_datavalues::ColumnRef], + _input_rows: usize, + ) -> Result<()> { + let new_columns: Vec<&BooleanColumn> = columns + .iter() + .map(|column| Series::check_get(column).unwrap()) + .collect(); + for (row, place) in places.iter().enumerate() { + let place = place.next(offset); + let state = place.get::(); + for j in 0..self.events_size { + if new_columns[j as usize].get_data(row) { + state.add(j); + } + } + } + Ok(()) + } +} + +impl fmt::Display for AggregateRetentionFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.display_name) + } +} + +impl AggregateRetentionFunction { + pub fn try_create( + display_name: &str, + arguments: Vec, + ) -> Result { + Ok(Arc::new(Self { + display_name: display_name.to_owned(), + events_size: arguments.len() as u8, + _arguments: arguments, + })) + } +} + +pub fn try_create_aggregate_retention_function( + display_name: &str, + _params: Vec, + arguments: Vec, +) -> Result { + assert_variadic_arguments(display_name, arguments.len(), (1, 32))?; + + for argument in arguments.iter() { + let data_type = argument.data_type(); + if data_type.data_type_id() != TypeID::Boolean { + return Err(ErrorCode::BadArguments( + "The arguments of AggregateRetention should be an expression which returns a Boolean result" + )); + } + } + + AggregateRetentionFunction::try_create(display_name, arguments) +} + +pub fn aggregate_retention_function_desc() -> AggregateFunctionDescription { + AggregateFunctionDescription::creator(Box::new(try_create_aggregate_retention_function)) +} diff --git a/common/functions/src/aggregates/aggregator.rs b/common/functions/src/aggregates/aggregator.rs index 66dec7f2cc3bb..ed3d077145bfb 100644 --- a/common/functions/src/aggregates/aggregator.rs +++ b/common/functions/src/aggregates/aggregator.rs @@ -25,6 +25,7 @@ use super::aggregate_window_funnel::aggregate_window_funnel_function_desc; use super::AggregateCountFunction; use super::AggregateFunctionFactory; use super::AggregateIfCombinator; +use crate::aggregates::aggregate_retention::aggregate_retention_function_desc; use crate::aggregates::aggregate_sum::aggregate_sum_function_desc; pub struct Aggregators; @@ -50,6 +51,8 @@ impl Aggregators { factory.register("window_funnel", aggregate_window_funnel_function_desc()); factory.register("uniq", AggregateDistinctCombinator::uniq_desc()); + + factory.register("retention", aggregate_retention_function_desc()); } pub fn register_combinator(factory: &mut AggregateFunctionFactory) { diff --git a/common/functions/src/aggregates/mod.rs b/common/functions/src/aggregates/mod.rs index faff0f9480cb7..98b7df67872d2 100644 --- a/common/functions/src/aggregates/mod.rs +++ b/common/functions/src/aggregates/mod.rs @@ -34,6 +34,7 @@ mod aggregate_combinator_if; mod aggregate_covariance; mod aggregate_min_max; mod aggregate_null_result; +mod aggregate_retention; mod aggregate_scalar_state; mod aggregate_stddev_pop; mod aggregate_window_funnel; @@ -53,6 +54,7 @@ pub use aggregate_function_state::StateAddr; pub use aggregate_function_state::StateAddrs; pub use aggregate_min_max::AggregateMinMaxFunction; pub use aggregate_null_result::AggregateNullResultFunction; +pub use aggregate_retention::AggregateRetentionFunction; pub use aggregate_stddev_pop::AggregateStddevPopFunction; pub use aggregate_sum::AggregateSumFunction; pub use aggregate_window_funnel::AggregateWindowFunnelFunction; diff --git a/docs/doc/30-reference/20-functions/50-aggregate-functions/aggregate-retention.md b/docs/doc/30-reference/20-functions/50-aggregate-functions/aggregate-retention.md new file mode 100644 index 0000000000000..1b83952a1b785 --- /dev/null +++ b/docs/doc/30-reference/20-functions/50-aggregate-functions/aggregate-retention.md @@ -0,0 +1,60 @@ +--- +title: RETENTION +--- + +Aggregate function + +The RETENTION() function takes as arguments a set of conditions from 1 to 32 arguments of type UInt8 that indicate whether a certain condition was met for the event. + +Any condition can be specified as an argument (as in WHERE). + +The conditions, except the first, apply in pairs: the result of the second will be true if the first and second are true, of the third if the first and third are true, etc. + +## Syntax + +``` +RETENTION(cond1, cond2, ..., cond32); +``` + +## Arguments + +| Arguments | Description | +| ----------- | ----------- | +| cond | An expression that returns a Boolean result | + +## Return Type + +The array of 1 or 0. + +## Examples + +``` +CREATE TABLE retention_test(date DATE, uid INT) ENGINE = Memory; +INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80); +INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50); +INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60); +``` + +``` +SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-07' GROUP BY uid); ++------+------+ +| r1 | r2 | ++------+------+ +| 80 | 50 | ++------+------+ + +SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-08' GROUP BY uid); ++------+------+ +| r1 | r2 | ++------+------+ +| 80 | 60 | ++------+------+ + +SELECT sum(get(r, 0)::TINYINT) as r1, sum(get(r, 1)::TINYINT) as r2, sum(get(r, 2)::TINYINT) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test GROUP BY uid); ++------+------+------+ +| r1 | r2 | r3 | ++------+------+------+ +| 80 | 50 | 60 | ++------+------+------+ + +``` \ No newline at end of file diff --git a/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.result b/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.result new file mode 100644 index 0000000000000..e0dce2baba7f5 --- /dev/null +++ b/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.result @@ -0,0 +1,3 @@ +80 50 +80 60 +80 50 60 diff --git a/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.sql b/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.sql new file mode 100644 index 0000000000000..d495620fabadb --- /dev/null +++ b/tests/suites/0_stateless/02_function/02_0000_function_aggregate_retention.sql @@ -0,0 +1,13 @@ +DROP TABLE IF EXISTS retention_test; + +CREATE TABLE retention_test(date DATE, uid INT)ENGINE = Memory; + +INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80); +INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50); +INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60); + +SELECT sum(r[0]::TINYINT) as r1, sum(r[1]::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-07' GROUP BY uid); +SELECT sum(r[0]::TINYINT) as r1, sum(r[1]::TINYINT) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date = '2018-08-06' or date = '2018-08-08' GROUP BY uid); +SELECT sum(r[0]::TINYINT) as r1, sum(r[1]::TINYINT) as r2, sum(r[2]::TINYINT) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test GROUP BY uid); + +DROP TABLE retention_test;