Skip to content

Commit

Permalink
simplify the code by removing the i32/i64 specializaton via traits an…
Browse files Browse the repository at this point in the history
…d just implement two versions of the function
  • Loading branch information
andygrove committed Apr 27, 2024
1 parent d9b1fb6 commit 317ea08
Showing 1 changed file with 104 additions and 157 deletions.
261 changes: 104 additions & 157 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,11 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult<Option<i16>
}

fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> {
let mut accum = CastStringToInt32::default();
do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?;
Ok(accum.result)
Ok(do_cast_string_to_i32(str, eval_mode, "INT", i32::MIN)?.map(|n| n as i32))
}

fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> {
let mut accum = CastStringToInt64::default();
do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?;
Ok(accum.result)
do_cast_string_to_i64(str, eval_mode, "BIGINT", i64::MIN)
}

fn cast_string_to_int_with_range_check(
Expand All @@ -267,188 +263,116 @@ fn cast_string_to_int_with_range_check(
min: i32,
max: i32,
) -> CometResult<Option<i32>> {
let mut accum = CastStringToInt32::default();
do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?;
match accum.result {
match do_cast_string_to_i32(str, eval_mode, type_name, i32::MIN)? {
None => Ok(None),
Some(v) if v >= min && v <= max => Ok(Some(v)),
Some(v) if v >= min && v <= max => Ok(Some(v as i32)),
_ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
_ => Ok(None),
}
}

/// We support parsing strings to i32 and i64 to match Spark's logic. Support for i8 and i16 is
/// implemented by first parsing as i32 and then downcasting. The CastStringToInt trait is
/// introduced so that we can have the parsing logic delegate either to an i32 or i64 accumulator
/// and avoid the need to use macros here.
trait CastStringToInt {
fn accumulate(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
digit: u32,
) -> CometResult<()>;

fn reset(&mut self);

fn finish(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
negative: bool,
) -> CometResult<()>;
}
struct CastStringToInt32 {
negative: bool,
result: Option<i32>,
radix: i32,
}
fn do_cast_string_to_i32(
str: &str,
eval_mode: EvalMode,
type_name: &str,
min_value: i32,
) -> CometResult<Option<i32>> {
let chars: Vec<char> = str.chars().collect();
let mut i = 0;
let mut end = chars.len();

impl Default for CastStringToInt32 {
fn default() -> Self {
Self {
negative: false,
result: Some(0),
radix: 10,
}
// skip leading whitespace
while i < end && chars[i].is_whitespace() {
i += 1;
}
}

impl CastStringToInt for CastStringToInt32 {
fn accumulate(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
digit: u32,
) -> CometResult<()> {
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix),
// then result * 10 will definitely be smaller than minValue, and we can stop
if let Some(r) = self.result {
let stop_value = i32::MIN / self.radix;
if r < stop_value {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
}
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop
let v = self.result.unwrap_or(0) * self.radix;
match v.checked_sub(digit as i32) {
Some(x) if x <= 0 => self.result = Some(x),
_ => {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
}
Ok(())
// skip trailing whitespace
while end > i && chars[end - 1].is_whitespace() {
end -= 1;
}
fn reset(&mut self) {
self.result = None;

// check for empty string
if i == end {
return none_or_err(eval_mode, type_name, str);
}

fn finish(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
negative: bool,
) -> CometResult<()> {
if !negative {
if let Some(r) = self.result {
let negated = r.checked_neg().unwrap_or(-1);
if negated < 0 {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(negated);
}
// skip + or -
let negative = chars[i] == '-';
if negative || chars[i] == '+' {
i += 1;
if i == end {
return none_or_err(eval_mode, type_name, str);
}
Ok(())
}
}

struct CastStringToInt64 {
negative: bool,
result: Option<i64>,
radix: i64,
}
let mut result = 0;
let radix = 10;
let stop_value = min_value / radix;
while i < end {
let b = chars[i];
i += 1;

impl Default for CastStringToInt64 {
fn default() -> Self {
Self {
negative: false,
result: Some(0),
radix: 10,
if b == '.' && eval_mode == EvalMode::Legacy {
// truncate decimal in legacy mode
break;
}
}
}

impl CastStringToInt for CastStringToInt64 {
fn accumulate(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
digit: u32,
) -> CometResult<()> {
let digit = if b.is_ascii_digit() {
(b as u32) - ('0' as u32)
} else {
return none_or_err(eval_mode, type_name, str);
};

// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix),
// then result * 10 will definitely be smaller than minValue, and we can stop
if let Some(r) = self.result {
let stop_value = i64::MIN / self.radix;
if r < stop_value {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
if result < stop_value {
return none_or_err(eval_mode, type_name, str);
}

// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop
let v = self.result.unwrap_or(0) * self.radix;
match v.checked_sub(digit as i64) {
Some(x) if x <= 0 => self.result = Some(x),
let v = result * radix;
match v.checked_sub(digit as i32) {
Some(x) if x <= 0 => result = x,
_ => {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
}
Ok(())
}

fn reset(&mut self) {
self.result = None;
// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well-formed.
while i < end {
let b = chars[i];
if !b.is_ascii_digit() {
return none_or_err(eval_mode, type_name, str);
}
i += 1;
}

fn finish(
&mut self,
eval_mode: EvalMode,
type_name: &str,
str: &str,
negative: bool,
) -> CometResult<()> {
if !negative {
if let Some(r) = self.result {
let negated = r.checked_neg().unwrap_or(-1);
if negated < 0 {
self.reset();
return none_or_err(eval_mode, type_name, str);
}
self.result = Some(negated);
if !negative {
if let Some(x) = result.checked_neg() {
if x < 0 {
return none_or_err(eval_mode, type_name, str);
}
result = x;
} else {
return none_or_err(eval_mode, type_name, str);
}
Ok(())
}

Ok(Some(result))
}

fn do_cast_string_to_int(
accumulator: &mut dyn CastStringToInt,
/// This is a copy of do_cast_string_to_i32 but with the type changed to i64
fn do_cast_string_to_i64(
str: &str,
eval_mode: EvalMode,
type_name: &str,
) -> CometResult<()> {
min_value: i64,
) -> CometResult<Option<i64>> {
let chars: Vec<char> = str.chars().collect();
let mut i = 0;
let mut end = chars.len();
Expand All @@ -465,7 +389,6 @@ fn do_cast_string_to_int(

// check for empty string
if i == end {
accumulator.reset();
return none_or_err(eval_mode, type_name, str);
}

Expand All @@ -474,11 +397,13 @@ fn do_cast_string_to_int(
if negative || chars[i] == '+' {
i += 1;
if i == end {
accumulator.reset();
return none_or_err(eval_mode, type_name, str);
}
}

let mut result = 0;
let radix = 10;
let stop_value = min_value / radix;
while i < end {
let b = chars[i];
i += 1;
Expand All @@ -491,11 +416,25 @@ fn do_cast_string_to_int(
let digit = if b.is_ascii_digit() {
(b as u32) - ('0' as u32)
} else {
accumulator.reset();
return none_or_err(eval_mode, type_name, str);
};

accumulator.accumulate(eval_mode, type_name, str, digit)?;
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix),
// then result * 10 will definitely be smaller than minValue, and we can stop
if result < stop_value {
return none_or_err(eval_mode, type_name, str);
}

// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop
let v = result * radix;
match v.checked_sub(digit as i64) {
Some(x) if x <= 0 => result = x,
_ => {
return none_or_err(eval_mode, type_name, str);
}
}
}

// This is the case when we've encountered a decimal separator. The fractional
Expand All @@ -504,22 +443,30 @@ fn do_cast_string_to_int(
while i < end {
let b = chars[i];
if !b.is_ascii_digit() {
accumulator.reset();
return none_or_err(eval_mode, type_name, str);
}
i += 1;
}

accumulator.finish(eval_mode, type_name, str, negative)?;
if !negative {
if let Some(x) = result.checked_neg() {
if x < 0 {
return none_or_err(eval_mode, type_name, str);
}
result = x;
} else {
return none_or_err(eval_mode, type_name, str);
}
}

Ok(())
Ok(Some(result))
}

/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode
fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<()> {
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<Option<T>> {
match eval_mode {
EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
_ => Ok(()),
_ => Ok(None),
}
}

Expand Down

0 comments on commit 317ea08

Please sign in to comment.