Skip to content

Commit

Permalink
Use ArrayFormatter in cast kernel (#4668)
Browse files Browse the repository at this point in the history
* Use ArrayFormatter in cast kernel

* Add test

* Clippy
  • Loading branch information
tustvold authored Aug 9, 2023
1 parent 5023ea8 commit cefb8c1
Showing 1 changed file with 31 additions and 55 deletions.
86 changes: 31 additions & 55 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use chrono::{NaiveTime, Offset, TimeZone, Utc};
use std::cmp::Ordering;
use std::sync::Arc;

use crate::display::{array_value_to_string, ArrayFormatter, FormatOptions};
use crate::display::{ArrayFormatter, FormatOptions};
use crate::parse::{
parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
string_to_datetime, Parser,
Expand Down Expand Up @@ -622,21 +622,6 @@ where
Ok(Arc::new(array))
}

// cast the List array to Utf8 array
macro_rules! cast_list_to_string {
($ARRAY:expr, $SIZE:ident) => {{
let mut value_builder: GenericStringBuilder<$SIZE> = GenericStringBuilder::new();
for i in 0..$ARRAY.len() {
if $ARRAY.is_null(i) {
value_builder.append_null();
} else {
value_builder.append_value(array_value_to_string($ARRAY, i)?);
}
}
Ok(Arc::new(value_builder.finish()))
}};
}

fn make_timestamp_array(
array: &PrimitiveArray<Int64Type>,
unit: TimeUnit,
Expand Down Expand Up @@ -800,8 +785,8 @@ pub fn cast_with_options(
}
}
(List(_) | LargeList(_), _) => match to_type {
Utf8 => cast_list_to_string!(array, i32),
LargeUtf8 => cast_list_to_string!(array, i64),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(
"Cannot cast list to non-list data types".to_string(),
)),
Expand Down Expand Up @@ -924,8 +909,8 @@ pub fn cast_with_options(
x as f64 / 10_f64.powi(*scale as i32)
})
}
Utf8 => value_to_string::<i32>(array, Some(&cast_options.format_options)),
LargeUtf8 => value_to_string::<i64>(array, Some(&cast_options.format_options)),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported"
Expand Down Expand Up @@ -993,8 +978,8 @@ pub fn cast_with_options(
x.to_f64().unwrap() / 10_f64.powi(*scale as i32)
})
}
Utf8 => value_to_string::<i32>(array, Some(&cast_options.format_options)),
LargeUtf8 => value_to_string::<i64>(array, Some(&cast_options.format_options)),
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported"
Expand Down Expand Up @@ -1215,24 +1200,8 @@ pub fn cast_with_options(
Float16 => cast_bool_to_numeric::<Float16Type>(array, cast_options),
Float32 => cast_bool_to_numeric::<Float32Type>(array, cast_options),
Float64 => cast_bool_to_numeric::<Float64Type>(array, cast_options),
Utf8 => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(
array
.iter()
.map(|value| value.map(|value| if value { "true" } else { "false" }))
.collect::<StringArray>(),
))
}
LargeUtf8 => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(
array
.iter()
.map(|value| value.map(|value| if value { "true" } else { "false" }))
.collect::<LargeStringArray>(),
))
}
Utf8 => value_to_string::<i32>(array, cast_options),
LargeUtf8 => value_to_string::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
"Casting from {from_type:?} to {to_type:?} not supported",
))),
Expand Down Expand Up @@ -1374,8 +1343,8 @@ pub fn cast_with_options(
"Casting from {from_type:?} to {to_type:?} not supported",
))),
},
(from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::<i64>(array, Some(&cast_options.format_options)),
(from_type, Utf8) if from_type.is_primitive() => value_to_string::<i32>(array, Some(&cast_options.format_options)),
(from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::<i64>(array, cast_options),
(from_type, Utf8) if from_type.is_primitive() => value_to_string::<i32>(array, cast_options),
// start numeric casts
(UInt8, UInt16) => {
cast_numeric_arrays::<UInt8Type, UInt16Type>(array, cast_options)
Expand Down Expand Up @@ -2461,14 +2430,10 @@ where

fn value_to_string<O: OffsetSizeTrait>(
array: &dyn Array,
options: Option<&FormatOptions>,
options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let mut builder = GenericStringBuilder::<O>::new();
let mut fmt_options = &FormatOptions::default();
if let Some(fmt_opts) = options {
fmt_options = fmt_opts;
};
let formatter = ArrayFormatter::try_new(array, fmt_options)?;
let formatter = ArrayFormatter::try_new(array, &options.format_options)?;
let nulls = array.nulls();
for i in 0..array.len() {
match nulls.map(|x| x.is_null(i)).unwrap_or_default() {
Expand Down Expand Up @@ -7369,14 +7334,10 @@ mod tests {

/// Print the `DictionaryArray` `array` as a vector of strings
fn array_to_strings(array: &ArrayRef) -> Vec<String> {
let options = FormatOptions::new().with_null("null");
let formatter = ArrayFormatter::try_new(array.as_ref(), &options).unwrap();
(0..array.len())
.map(|i| {
if array.is_null(i) {
"null".to_string()
} else {
array_value_to_string(array, i).expect("Convert array to String")
}
})
.map(|i| formatter.value(i).to_string())
.collect()
}

Expand Down Expand Up @@ -8989,4 +8950,19 @@ mod tests {
fn test_const_options() {
assert!(CAST_OPTIONS.safe)
}

#[test]
fn test_list_format_options() {
let options = CastOptions {
safe: false,
format_options: FormatOptions::default().with_null("null"),
};
let array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), Some(1), Some(2)]),
Some(vec![Some(0), None, Some(2)]),
]);
let a = cast_with_options(&array, &DataType::Utf8, &options).unwrap();
let r: Vec<_> = a.as_string::<i32>().iter().map(|x| x.unwrap()).collect();
assert_eq!(r, &["[0, 1, 2]", "[0, null, 2]"]);
}
}

0 comments on commit cefb8c1

Please sign in to comment.