Skip to content

Commit

Permalink
feat: add support for TRIM() (#211)
Browse files Browse the repository at this point in the history
* feat: add support for `TRIM()`

* test(trim): improves test cases
  • Loading branch information
tkzt authored Jul 21, 2024
1 parent 5fcf1eb commit 0d3e732
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 47 deletions.
21 changes: 21 additions & 0 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
self.visit_column_agg_expr(expr)?;
self.visit_column_agg_expr(in_expr)?;
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
self.visit_column_agg_expr(expr)?;
if let Some(trim_what_expr) = trim_what_expr {
self.visit_column_agg_expr(trim_what_expr)?;
}
}
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
Expand Down Expand Up @@ -365,6 +375,17 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
self.validate_having_orderby(in_expr)?;
Ok(())
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
self.validate_having_orderby(expr)?;
if let Some(trim_what_expr) = trim_what_expr {
self.validate_having_orderby(trim_what_expr)?;
}
Ok(())
}
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
Expand Down
15 changes: 15 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
expr: Box::new(self.bind_expr(expr)?),
in_expr: Box::new(self.bind_expr(r#in)?),
}),
Expr::Trim {
expr,
trim_what,
trim_where,
} => {
let mut trim_what_expr = None;
if let Some(trim_what) = trim_what {
trim_what_expr = Some(Box::new(self.bind_expr(trim_what)?))
}
Ok(ScalarExpression::Trim {
expr: Box::new(self.bind_expr(expr)?),
trim_what_expr,
trim_where: *trim_where,
})
}
Expr::Subquery(subquery) => {
let (sub_query, column) = self.bind_subquery(subquery)?;
let (expr, sub_query) = if !self.context.is_step(&QueryBindStep::Where) {
Expand Down
49 changes: 48 additions & 1 deletion src/expression/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::types::value::{DataValue, Utf8Type, ValueRef};
use crate::types::LogicalType;
use itertools::Itertools;
use lazy_static::lazy_static;
use sqlparser::ast::CharLengthUnits;
use regex::Regex;
use sqlparser::ast::{CharLengthUnits, TrimWhereField};
use std::cmp;
use std::cmp::Ordering;
use std::sync::Arc;
Expand Down Expand Up @@ -224,6 +225,52 @@ impl ScalarExpression {
str.find(&pattern).map(|pos| pos as i32 + 1).unwrap_or(0),
))))
}
ScalarExpression::Trim {
expr,
trim_what_expr,
trim_where,
} => {
if let Some(string) = DataValue::clone(expr.eval(tuple, schema)?.as_ref())
.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?
.utf8()
{
let mut trim_what = String::from(" ");
if let Some(trim_what_expr) = trim_what_expr {
trim_what = DataValue::clone(trim_what_expr.eval(tuple, schema)?.as_ref())
.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?
.utf8()
.unwrap_or_default();
}
let trim_regex = match trim_where {
Some(TrimWhereField::Both) | None => Regex::new(&format!(
r"^(?:{0})*([\w\W]*?)(?:{0})*$",
regex::escape(&trim_what)
))
.unwrap(),
Some(TrimWhereField::Leading) => {
Regex::new(&format!(r"^(?:{0})*([\w\W]*?)", regex::escape(&trim_what)))
.unwrap()
}
Some(TrimWhereField::Trailing) => {
Regex::new(&format!(r"([\w\W]*?)(?:{0})*$", regex::escape(&trim_what)))
.unwrap()
}
};
let string_trimmed = trim_regex.replace_all(&string, "$1").to_string();

Ok(Arc::new(DataValue::Utf8 {
value: Some(string_trimmed),
ty: Utf8Type::Variable(None),
unit: CharLengthUnits::Characters,
}))
} else {
Ok(Arc::new(DataValue::Utf8 {
value: None,
ty: Utf8Type::Variable(None),
unit: CharLengthUnits::Characters,
}))
}
}
ScalarExpression::Reference { pos, .. } => {
return Ok(tuple
.values
Expand Down
80 changes: 80 additions & 0 deletions src/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sqlparser::ast::TrimWhereField;
use std::fmt::{Debug, Formatter};
use std::hash::Hash;
use std::sync::Arc;
Expand Down Expand Up @@ -88,6 +89,11 @@ pub enum ScalarExpression {
expr: Box<ScalarExpression>,
in_expr: Box<ScalarExpression>,
},
Trim {
expr: Box<ScalarExpression>,
trim_what_expr: Option<Box<ScalarExpression>>,
trim_where: Option<TrimWhereField>,
},
// Temporary expression used for expression substitution
Empty,
Reference {
Expand Down Expand Up @@ -226,6 +232,16 @@ impl ScalarExpression {
expr.try_reference(output_exprs);
in_expr.try_reference(output_exprs);
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.try_reference(output_exprs);
if let Some(trim_what_expr) = trim_what_expr {
trim_what_expr.try_reference(output_exprs);
}
}
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Constant(_)
| ScalarExpression::ColumnRef(_)
Expand Down Expand Up @@ -379,6 +395,16 @@ impl ScalarExpression {
expr.bind_evaluator()?;
in_expr.bind_evaluator()?;
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.bind_evaluator()?;
if let Some(trim_what_expr) = trim_what_expr {
trim_what_expr.bind_evaluator()?;
}
}
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Constant(_)
| ScalarExpression::ColumnRef(_)
Expand Down Expand Up @@ -477,6 +503,14 @@ impl ScalarExpression {
ScalarExpression::Position { expr, in_expr } => {
expr.has_count_star() || in_expr.has_count_star()
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.has_count_star()
|| trim_what_expr.as_ref().map(|expr| expr.has_count_star()) == Some(true)
}
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Reference { expr, .. } => expr.has_count_star(),
ScalarExpression::Tuple(args) => args.iter().any(Self::has_count_star),
Expand Down Expand Up @@ -558,6 +592,9 @@ impl ScalarExpression {
LogicalType::Varchar(None, CharLengthUnits::Characters)
}
ScalarExpression::Position { .. } => LogicalType::Integer,
ScalarExpression::Trim { .. } => {
LogicalType::Varchar(None, CharLengthUnits::Characters)
}
ScalarExpression::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => {
expr.return_type()
}
Expand Down Expand Up @@ -638,6 +675,16 @@ impl ScalarExpression {
columns_collect(expr, vec, only_column_ref);
columns_collect(in_expr, vec, only_column_ref);
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
columns_collect(expr, vec, only_column_ref);
if let Some(trim_what_expr) = trim_what_expr {
columns_collect(trim_what_expr, vec, only_column_ref);
}
}
ScalarExpression::Constant(_) => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::If {
Expand Down Expand Up @@ -730,6 +777,14 @@ impl ScalarExpression {
ScalarExpression::Position { expr, in_expr } => {
expr.has_agg_call() || in_expr.has_agg_call()
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.has_agg_call()
|| trim_what_expr.as_ref().map(|expr| expr.has_agg_call()) == Some(true)
}
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. })
Expand Down Expand Up @@ -871,6 +926,31 @@ impl ScalarExpression {
in_expr.output_name()
)
}
ScalarExpression::Trim {
expr,
trim_what_expr,
trim_where,
} => {
let trim_what_str = {
trim_what_expr
.as_ref()
.map(|expr| expr.output_name())
.unwrap_or(" ".to_string())
};
let trim_where_str = match trim_where {
Some(TrimWhereField::Both) => format!("both '{}' from", trim_what_str),
Some(TrimWhereField::Leading) => format!("leading '{}' from", trim_what_str),
Some(TrimWhereField::Trailing) => format!("trailing '{}' from", trim_what_str),
None => {
if trim_what_str.is_empty() {
String::new()
} else {
format!("'{}' from", trim_what_str)
}
}
};
format!("trim({} {})", trim_where_str, expr.output_name())
}
ScalarExpression::Reference { expr, .. } => expr.output_name(),
ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args) => {
Expand Down
2 changes: 2 additions & 0 deletions src/expression/range_detacher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ impl<'a> RangeDetacher<'a> {
| ScalarExpression::Between { expr, .. }
| ScalarExpression::SubString { expr, .. } => self.detach(expr),
ScalarExpression::Position { expr, .. } => self.detach(expr),
ScalarExpression::Trim { expr, .. } => self.detach(expr),
ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() {
ScalarExpression::ColumnRef(column) => {
if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) {
Expand All @@ -245,6 +246,7 @@ impl<'a> RangeDetacher<'a> {
| ScalarExpression::Between { .. }
| ScalarExpression::SubString { .. }
| ScalarExpression::Position { .. }
| ScalarExpression::Trim { .. }
| ScalarExpression::Function(_)
| ScalarExpression::If { .. }
| ScalarExpression::IfNull { .. }
Expand Down
21 changes: 21 additions & 0 deletions src/expression/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ impl ScalarExpression {
ScalarExpression::Position { expr, in_expr } => {
expr.exist_column(table_name, col_id) || in_expr.exist_column(table_name, col_id)
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.exist_column(table_name, col_id)
|| trim_what_expr
.as_ref()
.map(|expr| expr.exist_column(table_name, col_id))
== Some(true)
}
ScalarExpression::Constant(_) => false,
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::If {
Expand Down Expand Up @@ -316,6 +327,16 @@ impl ScalarExpression {
expr.constant_calculation()?;
in_expr.constant_calculation()?;
}
ScalarExpression::Trim {
expr,
trim_what_expr,
..
} => {
expr.constant_calculation()?;
if let Some(trim_what_expr) = trim_what_expr {
trim_what_expr.constant_calculation()?;
}
}
ScalarExpression::Tuple(exprs) | ScalarExpression::Coalesce { exprs, .. } => {
for expr in exprs {
expr.constant_calculation()?;
Expand Down
Loading

0 comments on commit 0d3e732

Please sign in to comment.