Skip to content

Commit

Permalink
feat(query): add error_or function (#14980)
Browse files Browse the repository at this point in the history
* feat(query): add error_or function

* feat(query): add error_or function

* feat(query): add error_or function

* feat(query): add error_or function
  • Loading branch information
sundy-li authored Mar 17, 2024
1 parent adeec64 commit fdacd61
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 82 deletions.
180 changes: 116 additions & 64 deletions src/query/expression/src/evaluator.rs

Large diffs are not rendered by default.

25 changes: 19 additions & 6 deletions src/query/expression/src/filter/selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::filter::SelectExpr;
use crate::filter::SelectOp;
use crate::types::DataType;
use crate::EvalContext;
use crate::EvaluateOptions;
use crate::Evaluator;
use crate::Expr;
use crate::Scalar;
Expand Down Expand Up @@ -282,7 +283,8 @@ impl<'a> Selector<'a> {
*mutable_false_idx + count,
&select_strategy,
);
let children = self.evaluator.get_children(exprs, selection)?;
let mut eval_options = EvaluateOptions::new(selection);
let children = self.evaluator.get_children(exprs, &mut eval_options)?;
let (left_value, left_data_type) = children[0].clone();
let (right_value, right_data_type) = children[1].clone();
let left_data_type = self
Expand Down Expand Up @@ -332,7 +334,11 @@ impl<'a> Selector<'a> {
*mutable_false_idx + count,
&select_strategy,
);
let result = self.evaluator.eval_if(args, generics, None, selection)?;
let mut eval_options = EvaluateOptions::new(selection);

let result = self
.evaluator
.eval_if(args, generics, None, &mut eval_options)?;
let data_type = self
.evaluator
.remove_generics_data_type(generics, &function.signature.return_type);
Expand Down Expand Up @@ -366,9 +372,12 @@ impl<'a> Selector<'a> {
*mutable_false_idx + count,
&select_strategy,
);
let mut eval_options = EvaluateOptions::new(selection)
.with_suppress_error(function.signature.name == "is_not_error");

let args = args
.iter()
.map(|expr| self.evaluator.partial_run(expr, None, selection))
.map(|expr| self.evaluator.partial_run(expr, None, &mut eval_options))
.collect::<Result<Vec<_>>>()?;
assert!(
args.iter()
Expand All @@ -385,6 +394,7 @@ impl<'a> Selector<'a> {
validity: None,
errors: None,
func_ctx: self.evaluator.func_ctx(),
suppress_error: eval_options.suppress_error,
};
let (_, eval) = function.eval.as_scalar().unwrap();
let result = (eval)(cols_ref.as_slice(), &mut ctx);
Expand Down Expand Up @@ -422,7 +432,8 @@ impl<'a> Selector<'a> {
*mutable_false_idx + count,
&select_strategy,
);
let value = self.evaluator.get_select_child(expr, selection)?.0;
let mut eval_options = EvaluateOptions::new(selection);
let value = self.evaluator.get_select_child(expr, &mut eval_options)?.0;
let result = if *is_try {
self.evaluator
.run_try_cast(*span, expr.data_type(), dest_type, value)?
Expand All @@ -433,7 +444,7 @@ impl<'a> Selector<'a> {
dest_type,
value,
None,
selection,
&mut eval_options,
)?
};
self.select_value(
Expand Down Expand Up @@ -461,9 +472,11 @@ impl<'a> Selector<'a> {
*mutable_false_idx + count,
&select_strategy,
);
let mut eval_options = EvaluateOptions::new(selection);

let args = args
.iter()
.map(|expr| self.evaluator.partial_run(expr, None, selection))
.map(|expr| self.evaluator.partial_run(expr, None, &mut eval_options))
.collect::<Result<Vec<_>>>()?;
assert!(
args.iter()
Expand Down
4 changes: 4 additions & 0 deletions src/query/expression/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ pub struct EvalContext<'a> {
/// default value in nullable's inner column.
pub validity: Option<Bitmap>,
pub errors: Option<(MutableBitmap, String)>,
pub suppress_error: bool,
}

/// `FunctionID` is a unique identifier for a function in the registry. It's used to
Expand Down Expand Up @@ -564,6 +565,9 @@ impl<'a> EvalContext<'a> {
func_name: &str,
selection: Option<&[u32]>,
) -> Result<()> {
if self.suppress_error {
return Ok(());
}
match &self.errors {
Some((valids, error)) => {
let first_error_row = if let Some(selection) = selection {
Expand Down
12 changes: 12 additions & 0 deletions src/query/functions/src/scalars/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,16 @@ pub fn register(registry: &mut FunctionRegistry) {
ValueRef::Scalar(Some(_)) => Value::Scalar(true),
},
);

registry.register_1_arg_core::<GenericType<0>, BooleanType, _, _>(
"is_not_error",
|_, _| FunctionDomain::Full,
|arg, ctx| match ctx.errors.take() {
Some((bitmap, _)) => match arg {
ValueRef::Column(_) => Value::Column(bitmap.into()),
ValueRef::Scalar(_) => Value::Scalar(bitmap.get(0)),
},
None => Value::Scalar(true),
},
);
}
1 change: 1 addition & 0 deletions src/query/functions/src/scalars/decimal/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ fn convert_to_decimal_domain(
func_ctx,
validity: None,
errors: None,
suppress_error: false,
};
let dest_size = dest_type.size();
let res = convert_to_decimal(&value.as_ref(), &mut ctx, &from_type, dest_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,7 @@ Functions overloads:
1 is_float(Variant NULL) :: Boolean NULL
0 is_integer(Variant) :: Boolean
1 is_integer(Variant NULL) :: Boolean NULL
0 is_not_error(T0) :: Boolean
0 is_not_null(NULL) :: Boolean
1 is_not_null(T0 NULL) :: Boolean
0 is_null_value(Variant) :: Boolean
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/src/evaluator/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ pub fn apply_cse(
fn count_expressions(expr: &Expr, counter: &mut HashMap<Expr, usize>) {
match expr {
Expr::FunctionCall { function, .. } if function.signature.name == "if" => {}
Expr::FunctionCall { function, .. } if function.signature.name == "is_not_error" => {}
Expr::FunctionCall { args, .. } | Expr::LambdaFunctionCall { args, .. } => {
let entry = counter.entry(expr.clone()).or_insert(0);
*entry += 1;
Expand Down
76 changes: 68 additions & 8 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,8 @@ impl<'a> TypeChecker<'a> {
"nvl",
"nvl2",
"is_null",
"is_error",
"error_or",
"coalesce",
"last_query_id",
"array_sort",
Expand Down Expand Up @@ -2597,6 +2599,59 @@ impl<'a> TypeChecker<'a> {
.await,
)
}
("is_error", &[arg_x]) => {
// Rewrite is_error(x) to not(is_not_error(x))
Some(
self.resolve_unary_op(span, &UnaryOperator::Not, &Expr::FunctionCall {
span,
func: ASTFunctionCall {
distinct: false,
name: Identifier {
name: "is_not_error".to_string(),
quote: None,
span,
},
args: vec![arg_x.clone()],
params: vec![],
window: None,
lambda: None,
},
})
.await,
)
}
("error_or", args) => {
// error_or(arg0, arg1, ..., argN) is essentially
// if(is_not_error(arg0), arg0, is_not_error(arg1), arg1, ..., argN)
let mut new_args = Vec::with_capacity(args.len() * 2 + 1);

for arg in args.iter() {
let is_not_error = Expr::FunctionCall {
span,
func: ASTFunctionCall {
distinct: false,
name: Identifier {
name: "is_not_error".to_string(),
quote: None,
span,
},
args: vec![(*arg).clone()],
params: vec![],
window: None,
lambda: None,
},
};
new_args.push(is_not_error);
new_args.push((*arg).clone());
}
new_args.push(Expr::Literal {
span,
lit: Literal::Null,
});

let args_ref: Vec<&Expr> = new_args.iter().collect();
Some(self.resolve_function(span, "if", vec![], &args_ref).await)
}
("coalesce", args) => {
// coalesce(arg0, arg1, ..., argN) is essentially
// if(is_not_null(arg0), assume_not_null(arg0), is_not_null(arg1), assume_not_null(arg1), ..., argN)
Expand Down Expand Up @@ -2641,14 +2696,19 @@ impl<'a> TypeChecker<'a> {
span,
lit: Literal::Null,
});
new_args.push(Expr::Literal {
span,
lit: Literal::Null,
});
new_args.push(Expr::Literal {
span,
lit: Literal::Null,
});

// coalesce(all_null) => null
if new_args.len() == 1 {
new_args.push(Expr::Literal {
span,
lit: Literal::Null,
});
new_args.push(Expr::Literal {
span,
lit: Literal::Null,
});
}

let args_ref: Vec<&Expr> = new_args.iter().collect();
Some(self.resolve_function(span, "if", vec![], &args_ref).await)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ AggregateFinal
├── estimated rows: 0.00
├── EvalScalar
│ ├── output columns: [t.id (#0), de (#8)]
│ ├── expressions: [if(CAST(is_not_null(sum(tb.de) (#7)) AS Boolean NULL), CAST(assume_not_null(sum(tb.de) (#7)) AS Int64 NULL), true, 0, NULL, NULL, NULL)]
│ ├── expressions: [if(CAST(is_not_null(sum(tb.de) (#7)) AS Boolean NULL), CAST(assume_not_null(sum(tb.de) (#7)) AS Int64 NULL), true, 0, NULL)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [sum(tb.de) (#7), t.id (#0)]
Expand Down Expand Up @@ -174,7 +174,7 @@ AggregateFinal
│ │ ├── estimated rows: 0.00
│ │ └── EvalScalar
│ │ ├── output columns: [t2.sid (#1), sum_arg_0 (#4)]
│ │ ├── expressions: [if(CAST(is_not_null(t3.val (#2)) AS Boolean NULL), CAST(assume_not_null(t3.val (#2)) AS Int32 NULL), true, 0, NULL, NULL, NULL)]
│ │ ├── expressions: [if(CAST(is_not_null(t3.val (#2)) AS Boolean NULL), CAST(assume_not_null(t3.val (#2)) AS Int32 NULL), true, 0, NULL)]
│ │ ├── estimated rows: 0.00
│ │ └── TableScan
│ │ ├── table: default.default.t2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ AggregateFinal
├── estimated rows: 0.00
├── EvalScalar
│ ├── output columns: [t.id (#0), de (#8)]
│ ├── expressions: [if(CAST(is_not_null(sum(tb.de) (#7)) AS Boolean NULL), CAST(assume_not_null(sum(tb.de) (#7)) AS Int64 NULL), true, 0, NULL, NULL, NULL)]
│ ├── expressions: [if(CAST(is_not_null(sum(tb.de) (#7)) AS Boolean NULL), CAST(assume_not_null(sum(tb.de) (#7)) AS Int64 NULL), true, 0, NULL)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [sum(tb.de) (#7), t.id (#0)]
Expand Down Expand Up @@ -166,7 +166,7 @@ AggregateFinal
│ │ ├── estimated rows: 0.00
│ │ └── EvalScalar
│ │ ├── output columns: [t2.sid (#1), sum_arg_0 (#4)]
│ │ ├── expressions: [if(CAST(is_not_null(t3.val (#2)) AS Boolean NULL), CAST(assume_not_null(t3.val (#2)) AS Int32 NULL), true, 0, NULL, NULL, NULL)]
│ │ ├── expressions: [if(CAST(is_not_null(t3.val (#2)) AS Boolean NULL), CAST(assume_not_null(t3.val (#2)) AS Int32 NULL), true, 0, NULL)]
│ │ ├── estimated rows: 0.00
│ │ └── TableScan
│ │ ├── table: default.default.t2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
query BBB
select is_error(from_base64('aj')), is_not_error(from_base64('ac')), is_error(3);
----
1 0 0

query T
select error_or(from_base64('aak') , from_base64('aaj'), from_base64('MzQz'));
----
333433

0 comments on commit fdacd61

Please sign in to comment.