Skip to content

Commit

Permalink
Minor: Remove some clone in TypeCoercion (apache#10203)
Browse files Browse the repository at this point in the history
* Remove some clone in TypeCoercion

* Less clone

* less copy
  • Loading branch information
alamb authored Apr 24, 2024
1 parent 4edbdd7 commit deebda7
Showing 1 changed file with 37 additions and 52 deletions.
89 changes: 37 additions & 52 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
))))
}
Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
&expr,
*expr,
&self.schema,
)?))),
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::Like(Like {
negated,
Expand Down Expand Up @@ -308,15 +308,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
args,
&self.schema,
fun.signature(),
)?;
let new_expr = coerce_arguments_for_fun(
new_expr.as_slice(),
&self.schema,
&fun,
)?;
let new_expr =
coerce_arguments_for_fun(new_expr, &self.schema, &fun)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(fun, new_expr),
)))
Expand All @@ -336,7 +333,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
AggregateFunctionDefinition::BuiltIn(fun) => {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
&args,
args,
&self.schema,
&fun.signature(),
)?;
Expand All @@ -353,7 +350,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
AggregateFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
args,
&self.schema,
fun.signature(),
)?;
Expand Down Expand Up @@ -387,7 +384,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
&args,
args,
&self.schema,
&fun.signature(),
)?
Expand Down Expand Up @@ -454,12 +451,12 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarVa
/// Downstream code uses this signal to treat these values as *unbounded*.
fn coerce_scalar_range_aware(
target_type: &DataType,
value: &ScalarValue,
value: ScalarValue,
) -> Result<ScalarValue> {
coerce_scalar(target_type, value).or_else(|err| {
coerce_scalar(target_type, &value).or_else(|err| {
// If type coercion fails, check if the largest type in family works:
if let Some(largest_type) = get_widest_type_in_family(target_type) {
coerce_scalar(largest_type, value).map_or_else(
coerce_scalar(largest_type, &value).map_or_else(
|_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
|_| ScalarValue::try_from(target_type),
)
Expand All @@ -484,7 +481,7 @@ fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
/// Coerces the given (window frame) `bound` to `target_type`.
fn coerce_frame_bound(
target_type: &DataType,
bound: &WindowFrameBound,
bound: WindowFrameBound,
) -> Result<WindowFrameBound> {
match bound {
WindowFrameBound::Preceding(v) => {
Expand Down Expand Up @@ -530,31 +527,30 @@ fn coerce_window_frame(
}
WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64,
};
window_frame.start_bound =
coerce_frame_bound(target_type, &window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(target_type, &window_frame.end_bound)?;
window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?;
Ok(window_frame)
}

// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr> {
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
cast_expr(expr, &DataType::Boolean, schema)
expr.cast_to(&DataType::Boolean, schema)
}

/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
/// See the module level documentation for more detail on coercion.
fn coerce_arguments_for_signature(
expressions: &[Expr],
expressions: Vec<Expr>,
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(vec![]);
return Ok(expressions);
}

let current_types = expressions
Expand All @@ -565,58 +561,47 @@ fn coerce_arguments_for_signature(
let new_types = data_types(&current_types, signature)?;

expressions
.iter()
.into_iter()
.enumerate()
.map(|(i, expr)| cast_expr(expr, &new_types[i], schema))
.collect::<Result<Vec<_>>>()
.map(|(i, expr)| expr.cast_to(&new_types[i], schema))
.collect()
}

fn coerce_arguments_for_fun(
expressions: &[Expr],
expressions: Vec<Expr>,
schema: &DFSchema,
fun: &Arc<ScalarUDF>,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(vec![]);
}
let mut expressions: Vec<Expr> = expressions.to_vec();

// Cast Fixedsizelist to List for array functions
if fun.name() == "make_array" {
expressions = expressions
expressions
.into_iter()
.map(|expr| {
let data_type = expr.get_type(schema).unwrap();
if let DataType::FixedSizeList(field, _) = data_type {
let field = field.as_ref().clone();
let to_type = DataType::List(Arc::new(field));
let to_type = DataType::List(field.clone());
expr.cast_to(&to_type, schema)
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;
.collect()
} else {
Ok(expressions)
}

Ok(expressions)
}

/// Cast `expr` to the specified type, if possible
fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
expr.clone().cast_to(to_type, schema)
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
fn coerce_agg_exprs_for_signature(
agg_fun: &AggregateFunction,
input_exprs: &[Expr],
input_exprs: Vec<Expr>,
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if input_exprs.is_empty() {
return Ok(vec![]);
return Ok(input_exprs);
}
let current_types = input_exprs
.iter()
Expand All @@ -627,10 +612,10 @@ fn coerce_agg_exprs_for_signature(
type_coercion::aggregates::coerce_types(agg_fun, &current_types, signature)?;

input_exprs
.iter()
.into_iter()
.enumerate()
.map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema))
.collect::<Result<Vec<_>>>()
.map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
.collect()
}

fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
Expand Down

0 comments on commit deebda7

Please sign in to comment.