diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index aa6697a7170d..b86562c1bdde 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -151,13 +151,16 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (_, Decimal256(_, _)) => false, (Struct(_), _) => false, (_, Struct(_)) => false, - (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8, - (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, + (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8 || from_type == &LargeUtf8, + (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8 || to_type == &LargeUtf8, (Utf8, LargeUtf8) => true, (LargeUtf8, Utf8) => true, + (Binary, LargeBinary) => true, + (LargeBinary, Binary) => true, (Utf8, Binary + | LargeBinary | Date32 | Date64 | Time32(TimeUnit::Second) @@ -168,7 +171,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { ) => true, (Utf8, _) => DataType::is_numeric(to_type) && to_type != &Float16, (LargeUtf8, - LargeBinary + Binary + | LargeBinary | Date32 | Date64 | Time32(TimeUnit::Second) @@ -1075,7 +1079,8 @@ pub fn cast_with_options( Float16 => cast_numeric_to_bool::(array), Float32 => cast_numeric_to_bool::(array), Float64 => cast_numeric_to_bool::(array), - Utf8 => cast_utf8_to_boolean(array, cast_options), + Utf8 => cast_utf8_to_boolean::(array, cast_options), + LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -1102,13 +1107,22 @@ pub fn cast_with_options( .collect::(), )) } + LargeUtf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(Arc::new( + array + .iter() + .map(|value| value.map(|value| if value { "1" } else { "0" })) + .collect::(), + )) + } _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, ))), }, (Utf8, _) => match to_type { - LargeUtf8 => cast_str_container::(&**array), + LargeUtf8 => cast_byte_container::(&**array), UInt8 => cast_string_to_numeric::(array, cast_options), UInt16 => cast_string_to_numeric::(array, cast_options), UInt32 => cast_string_to_numeric::(array, cast_options), @@ -1121,7 +1135,8 @@ pub fn cast_with_options( Float64 => cast_string_to_numeric::(array, cast_options), Date32 => cast_string_to_date32::(&**array, cast_options), Date64 => cast_string_to_date64::(&**array, cast_options), - Binary => cast_string_to_binary(array), + Binary => cast_byte_container::(array), + LargeBinary => cast_byte_container::(array), Time32(TimeUnit::Second) => { cast_string_to_time32second::(&**array, cast_options) } @@ -1143,7 +1158,7 @@ pub fn cast_with_options( ))), }, (_, Utf8) => match from_type { - LargeUtf8 => cast_str_container::(&**array), + LargeUtf8 => cast_byte_container::(&**array), UInt8 => cast_numeric_to_string::(array), UInt16 => cast_numeric_to_string::(array), UInt32 => cast_numeric_to_string::(array), @@ -1270,7 +1285,8 @@ pub fn cast_with_options( Float64 => cast_string_to_numeric::(array, cast_options), Date32 => cast_string_to_date32::(&**array, cast_options), Date64 => cast_string_to_date64::(&**array, cast_options), - LargeBinary => cast_string_to_binary(array), + Binary => cast_byte_container::(array), + LargeBinary => cast_byte_container::(array), Time32(TimeUnit::Second) => { cast_string_to_time32second::(&**array, cast_options) } @@ -1291,7 +1307,20 @@ pub fn cast_with_options( from_type, to_type, ))), }, - + (Binary, _) => match to_type { + LargeBinary => cast_byte_container::(&**array), + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (LargeBinary, _) => match to_type { + Binary => cast_byte_container::(&**array), + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, // start numeric casts (UInt8, UInt16) => { cast_numeric_arrays::(array, cast_options) @@ -2007,41 +2036,6 @@ pub fn cast_with_options( } } -/// Cast to string array to binary array -fn cast_string_to_binary(array: &ArrayRef) -> Result { - let from_type = array.data_type(); - match *from_type { - DataType::Utf8 => { - let data = unsafe { - array - .data() - .clone() - .into_builder() - .data_type(DataType::Binary) - .build_unchecked() - }; - - Ok(Arc::new(BinaryArray::from(data)) as ArrayRef) - } - DataType::LargeUtf8 => { - let data = unsafe { - array - .data() - .clone() - .into_builder() - .data_type(DataType::LargeBinary) - .build_unchecked() - }; - - Ok(Arc::new(LargeBinaryArray::from(data)) as ArrayRef) - } - _ => Err(ArrowError::InvalidArgumentError(format!( - "{:?} cannot be converted to binary array", - from_type - ))), - } -} - /// Get the time unit as a multiple of a second const fn time_unit_multiple(unit: &TimeUnit) -> i64 { match unit { @@ -2843,11 +2837,17 @@ fn cast_string_to_timestamp_ns( } /// Casts Utf8 to Boolean -fn cast_utf8_to_boolean( +fn cast_utf8_to_boolean( from: &ArrayRef, cast_options: &CastOptions, -) -> Result { - let array = as_string_array(from); +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); let output_array = array .iter() @@ -2861,7 +2861,7 @@ fn cast_utf8_to_boolean( invalid_value => match cast_options.safe { true => Ok(None), false => Err(ArrowError::CastError(format!( - "Cannot cast string '{}' to value of Boolean type", + "Cannot cast value '{}' to value of Boolean type", invalid_value, ))), }, @@ -3460,39 +3460,41 @@ fn cast_list_inner( Ok(Arc::new(list) as ArrayRef) } -/// Helper function to cast from `Utf8` to `LargeUtf8` and vice versa. If the `LargeUtf8` is too large for -/// a `Utf8` array it will return an Error. -fn cast_str_container( - array: &dyn Array, -) -> Result +/// Helper function to cast from one `ByteArrayType` to another and vice versa. +/// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error. +fn cast_byte_container(array: &dyn Array) -> Result where - OffsetSizeFrom: OffsetSizeTrait + ToPrimitive, - OffsetSizeTo: OffsetSizeTrait + NumCast + ArrowNativeType, + FROM: ByteArrayType, + TO: ByteArrayType, + FROM::Offset: OffsetSizeTrait + ToPrimitive, + TO::Offset: OffsetSizeTrait + NumCast + ArrowNativeType, { let data = array.data(); - assert_eq!( - data.data_type(), - &GenericStringArray::::DATA_TYPE - ); + assert_eq!(data.data_type(), &FROM::DATA_TYPE); let str_values_buf = data.buffers()[1].clone(); - let offsets = data.buffers()[0].typed_data::(); + let offsets = data.buffers()[0].typed_data::(); - let mut offset_builder = BufferBuilder::::new(offsets.len()); + let mut offset_builder = BufferBuilder::::new(offsets.len()); offsets .iter() .try_for_each::<_, Result<_, ArrowError>>(|offset| { - let offset = OffsetSizeTo::from(*offset).ok_or_else(|| { - ArrowError::ComputeError( - "large-utf8 array too large to cast to utf8-array".into(), - ) - })?; + let offset = <::Offset as NumCast>::from(*offset) + .ok_or_else(|| { + ArrowError::ComputeError(format!( + "{}{} array too large to cast to {}{} array", + FROM::Offset::PREFIX, + FROM::PREFIX, + TO::Offset::PREFIX, + TO::PREFIX + )) + })?; offset_builder.append(offset); Ok(()) })?; let offset_buffer = offset_builder.finish(); - let dtype = GenericStringArray::::DATA_TYPE; + let dtype = TO::DATA_TYPE; let builder = ArrayData::builder(dtype) .offset(array.offset()) @@ -3503,9 +3505,7 @@ where let array_data = unsafe { builder.build_unchecked() }; - Ok(Arc::new(GenericStringArray::::from( - array_data, - ))) + Ok(Arc::new(GenericByteArray::::from(array_data))) } /// Cast the container type of List/Largelist array but not the inner types. diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index 91d2da9985b5..7936ecb1e01e 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -411,9 +411,13 @@ fn get_all_types() -> Vec { ), Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::LargeUtf8)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Binary)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::LargeBinary)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::LargeUtf8)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::LargeBinary)), Decimal128(38, 0), Dictionary(Box::new(DataType::Int8), Box::new(Decimal128(38, 0))), Dictionary(Box::new(DataType::Int16), Box::new(Decimal128(38, 0))),