Skip to content

Commit

Permalink
fix(stream): handle 0.0 case of approx percentile, and fix computat…
Browse files Browse the repository at this point in the history
…ion of `quantile_count` (#18546)
  • Loading branch information
kwannoel authored Sep 16, 2024
1 parent 9413b28 commit 9167768
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 4 deletions.
10 changes: 9 additions & 1 deletion src/expr/impl/src/aggregate/approx_percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::ops::Range;

use bytes::{Buf, Bytes};
use risingwave_common::array::*;
use risingwave_common::bail;
use risingwave_common::row::Row;
use risingwave_common::types::*;
use risingwave_common_estimate_size::EstimateSize;
Expand All @@ -38,6 +39,12 @@ fn build(agg: &AggCall) -> Result<Box<dyn AggregateFunction>> {
.literal()
.map(|x| (*x.as_float64()).into())
.unwrap();
if relative_error <= 0.0 || relative_error >= 1.0 {
bail!(
"relative_error must be in the range (0, 1), got {}",
relative_error
)
}
let base = (1.0 + relative_error) / (1.0 - relative_error);
Ok(Box::new(ApproxPercentile { quantile, base }))
}
Expand Down Expand Up @@ -156,7 +163,8 @@ impl AggregateFunction for ApproxPercentile {
// approximate quantile bucket on the fly.
async fn get_result(&self, state: &AggregateState) -> Result<Datum> {
let state = state.downcast_ref::<State>();
let quantile_count = (state.count as f64 * self.quantile).floor() as u64;
let quantile_count =
((state.count.saturating_sub(1)) as f64 * self.quantile).floor() as u64;
let mut acc_count = 0;
for (bucket_id, count) in state.neg_buckets.iter().rev() {
acc_count += count;
Expand Down
20 changes: 19 additions & 1 deletion src/frontend/planner_test/tests/testdata/input/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1100,4 +1100,22 @@
expected_outputs:
- logical_plan
- batch_plan
- stream_plan
- stream_plan
- name: test approx percentile with invalid relative_error
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5, 0.0) WITHIN GROUP (order by v1) from t;
expected_outputs:
- binder_error
- name: test approx percentile with invalid relative_error 0.0
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5, 0.0) WITHIN GROUP (order by v1) from t;
expected_outputs:
- binder_error
- name: test approx percentile with invalid relative_error 1.0 with group by.
sql: |
CREATE TABLE t (v1 int, v2 int);
SELECT approx_percentile(0.0, 1.0) WITHIN GROUP (order by v1) from t group by v2;
expected_outputs:
- binder_error
27 changes: 27 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2198,3 +2198,30 @@
└─StreamLocalApproxPercentile { percentile_col: $expr1, quantile: 0.5:Float64, relative_error: 0.01:Float64 }
└─StreamProject { exprs: [t.v1::Float64 as $expr1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: test approx percentile with invalid relative_error
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5, 0.0) WITHIN GROUP (order by v1) from t;
binder_error: |
Failed to bind expression: approx_percentile(0.5, 0.0) WITHIN GROUP (ORDER BY v1)
Caused by:
relative_error=0 does not satisfy 0.0 < relative_error < 1.0
- name: test approx percentile with invalid relative_error 0.0
sql: |
CREATE TABLE t (v1 int);
SELECT approx_percentile(0.5, 0.0) WITHIN GROUP (order by v1) from t;
binder_error: |
Failed to bind expression: approx_percentile(0.5, 0.0) WITHIN GROUP (ORDER BY v1)
Caused by:
relative_error=0 does not satisfy 0.0 < relative_error < 1.0
- name: test approx percentile with invalid relative_error 1.0 with group by.
sql: |
CREATE TABLE t (v1 int, v2 int);
SELECT approx_percentile(0.0, 1.0) WITHIN GROUP (order by v1) from t group by v2;
binder_error: |
Failed to bind expression: approx_percentile(0.0, 1.0) WITHIN GROUP (ORDER BY v1)
Caused by:
relative_error=1 does not satisfy 0.0 < relative_error < 1.0
13 changes: 12 additions & 1 deletion src/frontend/src/binder/expr/function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

use itertools::Itertools;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_common::{bail, bail_not_implemented};
use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind};
use risingwave_sqlparser::ast::{self, FunctionArgExpr};

Expand Down Expand Up @@ -158,6 +158,17 @@ impl Binder {
2 => {
let relative_error = &mut direct_args[1];
decimal_to_float64(relative_error, kind)?;
if let Some(relative_error) = relative_error.as_literal()
&& let Some(relative_error) = relative_error.get_data()
{
let relative_error = relative_error.as_float64().0;
if relative_error <= 0.0 || relative_error >= 1.0 {
bail!(
"relative_error={} does not satisfy 0.0 < relative_error < 1.0",
relative_error,
)
}
}
}
1 => {
let relative_error: ExprImpl = Literal::new(
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/approx_percentile/global_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl BucketTableCache {
}

pub fn get_output(&self, row_count: i64, quantile: f64, base: f64) -> Datum {
let quantile_count = (row_count as f64 * quantile).floor() as i64;
let quantile_count = ((row_count - 1) as f64 * quantile).floor() as i64;
let mut acc_count = 0;
for (bucket_id, count) in self.neg_buckets.iter().rev() {
acc_count += count;
Expand Down

0 comments on commit 9167768

Please sign in to comment.