From cefb8c1bbb2807fbb420e62f108676eeb80ec198 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 9 Aug 2023 22:30:12 +0100 Subject: [PATCH] Use ArrayFormatter in cast kernel (#4668) * Use ArrayFormatter in cast kernel * Add test * Clippy --- arrow-cast/src/cast.rs | 86 +++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 55 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index c730452a8da5..c7fd082de2e6 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -41,7 +41,7 @@ use chrono::{NaiveTime, Offset, TimeZone, Utc}; use std::cmp::Ordering; use std::sync::Arc; -use crate::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use crate::display::{ArrayFormatter, FormatOptions}; use crate::parse::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, string_to_datetime, Parser, @@ -622,21 +622,6 @@ where Ok(Arc::new(array)) } -// cast the List array to Utf8 array -macro_rules! cast_list_to_string { - ($ARRAY:expr, $SIZE:ident) => {{ - let mut value_builder: GenericStringBuilder<$SIZE> = GenericStringBuilder::new(); - for i in 0..$ARRAY.len() { - if $ARRAY.is_null(i) { - value_builder.append_null(); - } else { - value_builder.append_value(array_value_to_string($ARRAY, i)?); - } - } - Ok(Arc::new(value_builder.finish())) - }}; -} - fn make_timestamp_array( array: &PrimitiveArray, unit: TimeUnit, @@ -800,8 +785,8 @@ pub fn cast_with_options( } } (List(_) | LargeList(_), _) => match to_type { - Utf8 => cast_list_to_string!(array, i32), - LargeUtf8 => cast_list_to_string!(array, i64), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError( "Cannot cast list to non-list data types".to_string(), )), @@ -924,8 +909,8 @@ pub fn cast_with_options( x as f64 / 10_f64.powi(*scale as i32) }) } - Utf8 => value_to_string::(array, Some(&cast_options.format_options)), - LargeUtf8 => value_to_string::(array, Some(&cast_options.format_options)), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported" @@ -993,8 +978,8 @@ pub fn cast_with_options( x.to_f64().unwrap() / 10_f64.powi(*scale as i32) }) } - Utf8 => value_to_string::(array, Some(&cast_options.format_options)), - LargeUtf8 => value_to_string::(array, Some(&cast_options.format_options)), + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported" @@ -1215,24 +1200,8 @@ pub fn cast_with_options( Float16 => cast_bool_to_numeric::(array, cast_options), Float32 => cast_bool_to_numeric::(array, cast_options), Float64 => cast_bool_to_numeric::(array, cast_options), - Utf8 => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|value| value.map(|value| if value { "true" } else { "false" })) - .collect::(), - )) - } - LargeUtf8 => { - let array = array.as_any().downcast_ref::().unwrap(); - Ok(Arc::new( - array - .iter() - .map(|value| value.map(|value| if value { "true" } else { "false" })) - .collect::(), - )) - } + Utf8 => value_to_string::(array, cast_options), + LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {from_type:?} to {to_type:?} not supported", ))), @@ -1374,8 +1343,8 @@ pub fn cast_with_options( "Casting from {from_type:?} to {to_type:?} not supported", ))), }, - (from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::(array, Some(&cast_options.format_options)), - (from_type, Utf8) if from_type.is_primitive() => value_to_string::(array, Some(&cast_options.format_options)), + (from_type, LargeUtf8) if from_type.is_primitive() => value_to_string::(array, cast_options), + (from_type, Utf8) if from_type.is_primitive() => value_to_string::(array, cast_options), // start numeric casts (UInt8, UInt16) => { cast_numeric_arrays::(array, cast_options) @@ -2461,14 +2430,10 @@ where fn value_to_string( array: &dyn Array, - options: Option<&FormatOptions>, + options: &CastOptions, ) -> Result { let mut builder = GenericStringBuilder::::new(); - let mut fmt_options = &FormatOptions::default(); - if let Some(fmt_opts) = options { - fmt_options = fmt_opts; - }; - let formatter = ArrayFormatter::try_new(array, fmt_options)?; + let formatter = ArrayFormatter::try_new(array, &options.format_options)?; let nulls = array.nulls(); for i in 0..array.len() { match nulls.map(|x| x.is_null(i)).unwrap_or_default() { @@ -7369,14 +7334,10 @@ mod tests { /// Print the `DictionaryArray` `array` as a vector of strings fn array_to_strings(array: &ArrayRef) -> Vec { + let options = FormatOptions::new().with_null("null"); + let formatter = ArrayFormatter::try_new(array.as_ref(), &options).unwrap(); (0..array.len()) - .map(|i| { - if array.is_null(i) { - "null".to_string() - } else { - array_value_to_string(array, i).expect("Convert array to String") - } - }) + .map(|i| formatter.value(i).to_string()) .collect() } @@ -8989,4 +8950,19 @@ mod tests { fn test_const_options() { assert!(CAST_OPTIONS.safe) } + + #[test] + fn test_list_format_options() { + let options = CastOptions { + safe: false, + format_options: FormatOptions::default().with_null("null"), + }; + let array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![Some(0), None, Some(2)]), + ]); + let a = cast_with_options(&array, &DataType::Utf8, &options).unwrap(); + let r: Vec<_> = a.as_string::().iter().map(|x| x.unwrap()).collect(); + assert_eq!(r, &["[0, 1, 2]", "[0, null, 2]"]); + } }