Skip to content

Commit

Permalink
Enable casting between Utf8/LargeUtf8 and Binary/LargeBinary (#3542)
Browse files Browse the repository at this point in the history
* Enable casting between Utf8/LargeUtf8 and Binary/LargeBinary

* For review

* Add native bound restrict

* Use From for Utf8 -> Binary and LargeUtf8 -> LargeBinary.

* Restrict the input and output native types to be the same.
  • Loading branch information
viirya authored Jan 25, 2023
1 parent d938cd9 commit bf21ad9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 69 deletions.
151 changes: 82 additions & 69 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,16 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(_, Decimal256(_, _)) => false,
(Struct(_), _) => false,
(_, Struct(_)) => false,
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8 || from_type == &LargeUtf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8 || to_type == &LargeUtf8,

(Utf8, LargeUtf8) => true,
(LargeUtf8, Utf8) => true,
(Binary, LargeBinary) => true,
(LargeBinary, Binary) => true,
(Utf8,
Binary
| LargeBinary
| Date32
| Date64
| Time32(TimeUnit::Second)
Expand All @@ -168,7 +171,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
) => true,
(Utf8, _) => DataType::is_numeric(to_type) && to_type != &Float16,
(LargeUtf8,
LargeBinary
Binary
| LargeBinary
| Date32
| Date64
| Time32(TimeUnit::Second)
Expand Down Expand Up @@ -1075,7 +1079,8 @@ pub fn cast_with_options(
Float16 => cast_numeric_to_bool::<Float16Type>(array),
Float32 => cast_numeric_to_bool::<Float32Type>(array),
Float64 => cast_numeric_to_bool::<Float64Type>(array),
Utf8 => cast_utf8_to_boolean(array, cast_options),
Utf8 => cast_utf8_to_boolean::<i32>(array, cast_options),
LargeUtf8 => cast_utf8_to_boolean::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
Expand All @@ -1102,13 +1107,22 @@ pub fn cast_with_options(
.collect::<StringArray>(),
))
}
LargeUtf8 => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(
array
.iter()
.map(|value| value.map(|value| if value { "1" } else { "0" }))
.collect::<LargeStringArray>(),
))
}
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(Utf8, _) => match to_type {
LargeUtf8 => cast_str_container::<i32, i64>(&**array),
LargeUtf8 => cast_byte_container::<Utf8Type, LargeUtf8Type, str>(&**array),
UInt8 => cast_string_to_numeric::<UInt8Type, i32>(array, cast_options),
UInt16 => cast_string_to_numeric::<UInt16Type, i32>(array, cast_options),
UInt32 => cast_string_to_numeric::<UInt32Type, i32>(array, cast_options),
Expand All @@ -1121,7 +1135,11 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i32>(array, cast_options),
Date32 => cast_string_to_date32::<i32>(&**array, cast_options),
Date64 => cast_string_to_date64::<i32>(&**array, cast_options),
Binary => cast_string_to_binary(array),
Binary => Ok(Arc::new(BinaryArray::from(as_string_array(array).clone()))),
LargeBinary => {
let binary = BinaryArray::from(as_string_array(array).clone());
cast_byte_container::<BinaryType, LargeBinaryType, [u8]>(&binary)
}
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i32>(&**array, cast_options)
}
Expand All @@ -1143,7 +1161,7 @@ pub fn cast_with_options(
))),
},
(_, Utf8) => match from_type {
LargeUtf8 => cast_str_container::<i64, i32>(&**array),
LargeUtf8 => cast_byte_container::<LargeUtf8Type, Utf8Type, str>(&**array),
UInt8 => cast_numeric_to_string::<UInt8Type, i32>(array),
UInt16 => cast_numeric_to_string::<UInt16Type, i32>(array),
UInt32 => cast_numeric_to_string::<UInt32Type, i32>(array),
Expand Down Expand Up @@ -1270,7 +1288,14 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i64>(array, cast_options),
Date32 => cast_string_to_date32::<i64>(&**array, cast_options),
Date64 => cast_string_to_date64::<i64>(&**array, cast_options),
LargeBinary => cast_string_to_binary(array),
Binary => {
let large_binary =
LargeBinaryArray::from(as_largestring_array(array).clone());
cast_byte_container::<LargeBinaryType, BinaryType, [u8]>(&large_binary)
}
LargeBinary => Ok(Arc::new(LargeBinaryArray::from(
as_largestring_array(array).clone(),
))),
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i64>(&**array, cast_options)
}
Expand All @@ -1291,7 +1316,22 @@ pub fn cast_with_options(
from_type, to_type,
))),
},

(Binary, _) => match to_type {
LargeBinary => {
cast_byte_container::<BinaryType, LargeBinaryType, [u8]>(&**array)
}
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(LargeBinary, _) => match to_type {
Binary => cast_byte_container::<LargeBinaryType, BinaryType, [u8]>(&**array),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
// start numeric casts
(UInt8, UInt16) => {
cast_numeric_arrays::<UInt8Type, UInt16Type>(array, cast_options)
Expand Down Expand Up @@ -2007,41 +2047,6 @@ pub fn cast_with_options(
}
}

/// Cast to string array to binary array
fn cast_string_to_binary(array: &ArrayRef) -> Result<ArrayRef, ArrowError> {
let from_type = array.data_type();
match *from_type {
DataType::Utf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::Binary)
.build_unchecked()
};

Ok(Arc::new(BinaryArray::from(data)) as ArrayRef)
}
DataType::LargeUtf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::LargeBinary)
.build_unchecked()
};

Ok(Arc::new(LargeBinaryArray::from(data)) as ArrayRef)
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"{:?} cannot be converted to binary array",
from_type
))),
}
}

/// Get the time unit as a multiple of a second
const fn time_unit_multiple(unit: &TimeUnit) -> i64 {
match unit {
Expand Down Expand Up @@ -2843,11 +2848,17 @@ fn cast_string_to_timestamp_ns<Offset: OffsetSizeTrait>(
}

/// Casts Utf8 to Boolean
fn cast_utf8_to_boolean(
fn cast_utf8_to_boolean<OffsetSize>(
from: &ArrayRef,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = as_string_array(from);
) -> Result<ArrayRef, ArrowError>
where
OffsetSize: OffsetSizeTrait,
{
let array = from
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.unwrap();

let output_array = array
.iter()
Expand All @@ -2861,7 +2872,7 @@ fn cast_utf8_to_boolean(
invalid_value => match cast_options.safe {
true => Ok(None),
false => Err(ArrowError::CastError(format!(
"Cannot cast string '{}' to value of Boolean type",
"Cannot cast value '{}' to value of Boolean type",
invalid_value,
))),
},
Expand Down Expand Up @@ -3447,39 +3458,43 @@ fn cast_list_inner<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(list) as ArrayRef)
}

/// Helper function to cast from `Utf8` to `LargeUtf8` and vice versa. If the `LargeUtf8` is too large for
/// a `Utf8` array it will return an Error.
fn cast_str_container<OffsetSizeFrom, OffsetSizeTo>(
/// Helper function to cast from one `ByteArrayType` to another and vice versa.
/// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error.
fn cast_byte_container<FROM, TO, N: ?Sized>(
array: &dyn Array,
) -> Result<ArrayRef, ArrowError>
where
OffsetSizeFrom: OffsetSizeTrait + ToPrimitive,
OffsetSizeTo: OffsetSizeTrait + NumCast + ArrowNativeType,
FROM: ByteArrayType<Native = N>,
TO: ByteArrayType<Native = N>,
FROM::Offset: OffsetSizeTrait + ToPrimitive,
TO::Offset: OffsetSizeTrait + NumCast,
{
let data = array.data();
assert_eq!(
data.data_type(),
&GenericStringArray::<OffsetSizeFrom>::DATA_TYPE
);
assert_eq!(data.data_type(), &FROM::DATA_TYPE);
let str_values_buf = data.buffers()[1].clone();
let offsets = data.buffers()[0].typed_data::<OffsetSizeFrom>();
let offsets = data.buffers()[0].typed_data::<FROM::Offset>();

let mut offset_builder = BufferBuilder::<OffsetSizeTo>::new(offsets.len());
let mut offset_builder = BufferBuilder::<TO::Offset>::new(offsets.len());
offsets
.iter()
.try_for_each::<_, Result<_, ArrowError>>(|offset| {
let offset = OffsetSizeTo::from(*offset).ok_or_else(|| {
ArrowError::ComputeError(
"large-utf8 array too large to cast to utf8-array".into(),
)
})?;
let offset = <<TO as ByteArrayType>::Offset as NumCast>::from(*offset)
.ok_or_else(|| {
ArrowError::ComputeError(format!(
"{}{} array too large to cast to {}{} array",
FROM::Offset::PREFIX,
FROM::PREFIX,
TO::Offset::PREFIX,
TO::PREFIX
))
})?;
offset_builder.append(offset);
Ok(())
})?;

let offset_buffer = offset_builder.finish();

let dtype = GenericStringArray::<OffsetSizeTo>::DATA_TYPE;
let dtype = TO::DATA_TYPE;

let builder = ArrayData::builder(dtype)
.offset(array.offset())
Expand All @@ -3490,9 +3505,7 @@ where

let array_data = unsafe { builder.build_unchecked() };

Ok(Arc::new(GenericStringArray::<OffsetSizeTo>::from(
array_data,
)))
Ok(Arc::new(GenericByteArray::<TO>::from(array_data)))
}

/// Cast the container type of List/Largelist array but not the inner types.
Expand Down Expand Up @@ -4813,7 +4826,7 @@ mod tests {
Ok(_) => panic!("expected error"),
Err(e) => {
assert!(e.to_string().contains(
"Cast error: Cannot cast string 'invalid' to value of Boolean type"
"Cast error: Cannot cast value 'invalid' to value of Boolean type"
))
}
}
Expand Down
2 changes: 2 additions & 0 deletions arrow/tests/array_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ fn get_all_types() -> Vec<DataType> {
vec![
Dictionary(Box::new(key_type.clone()), Box::new(Int32)),
Dictionary(Box::new(key_type.clone()), Box::new(Utf8)),
Dictionary(Box::new(key_type.clone()), Box::new(LargeUtf8)),
Dictionary(Box::new(key_type.clone()), Box::new(Binary)),
Dictionary(Box::new(key_type.clone()), Box::new(LargeBinary)),
Dictionary(Box::new(key_type.clone()), Box::new(Decimal128(38, 0))),
Dictionary(Box::new(key_type), Box::new(Decimal256(76, 0))),
]
Expand Down

0 comments on commit bf21ad9

Please sign in to comment.