Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable casting between Utf8/LargeUtf8 and Binary/LargeBinary #3542

Merged
merged 6 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 82 additions & 69 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,16 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(_, Decimal256(_, _)) => false,
(Struct(_), _) => false,
(_, Struct(_)) => false,
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8 || from_type == &LargeUtf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8 || to_type == &LargeUtf8,

(Utf8, LargeUtf8) => true,
(LargeUtf8, Utf8) => true,
(Binary, LargeBinary) => true,
(LargeBinary, Binary) => true,
(Utf8,
Binary
| LargeBinary
| Date32
| Date64
| Time32(TimeUnit::Second)
Expand All @@ -168,7 +171,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
) => true,
(Utf8, _) => DataType::is_numeric(to_type) && to_type != &Float16,
(LargeUtf8,
LargeBinary
Binary
| LargeBinary
| Date32
| Date64
| Time32(TimeUnit::Second)
Expand Down Expand Up @@ -1075,7 +1079,8 @@ pub fn cast_with_options(
Float16 => cast_numeric_to_bool::<Float16Type>(array),
Float32 => cast_numeric_to_bool::<Float32Type>(array),
Float64 => cast_numeric_to_bool::<Float64Type>(array),
Utf8 => cast_utf8_to_boolean(array, cast_options),
Utf8 => cast_utf8_to_boolean::<i32>(array, cast_options),
LargeUtf8 => cast_utf8_to_boolean::<i64>(array, cast_options),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
Expand All @@ -1102,13 +1107,22 @@ pub fn cast_with_options(
.collect::<StringArray>(),
))
}
LargeUtf8 => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(
array
.iter()
.map(|value| value.map(|value| if value { "1" } else { "0" }))
.collect::<LargeStringArray>(),
))
}
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(Utf8, _) => match to_type {
LargeUtf8 => cast_str_container::<i32, i64>(&**array),
LargeUtf8 => cast_byte_container::<Utf8Type, LargeUtf8Type, str>(&**array),
UInt8 => cast_string_to_numeric::<UInt8Type, i32>(array, cast_options),
UInt16 => cast_string_to_numeric::<UInt16Type, i32>(array, cast_options),
UInt32 => cast_string_to_numeric::<UInt32Type, i32>(array, cast_options),
Expand All @@ -1121,7 +1135,11 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i32>(array, cast_options),
Date32 => cast_string_to_date32::<i32>(&**array, cast_options),
Date64 => cast_string_to_date64::<i32>(&**array, cast_options),
Binary => cast_string_to_binary(array),
Binary => Ok(Arc::new(BinaryArray::from(as_string_array(array).clone()))),
LargeBinary => {
let binary = BinaryArray::from(as_string_array(array).clone());
cast_byte_container::<BinaryType, LargeBinaryType, [u8]>(&binary)
}
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i32>(&**array, cast_options)
}
Expand All @@ -1143,7 +1161,7 @@ pub fn cast_with_options(
))),
},
(_, Utf8) => match from_type {
LargeUtf8 => cast_str_container::<i64, i32>(&**array),
LargeUtf8 => cast_byte_container::<LargeUtf8Type, Utf8Type, str>(&**array),
UInt8 => cast_numeric_to_string::<UInt8Type, i32>(array),
UInt16 => cast_numeric_to_string::<UInt16Type, i32>(array),
UInt32 => cast_numeric_to_string::<UInt32Type, i32>(array),
Expand Down Expand Up @@ -1270,7 +1288,14 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i64>(array, cast_options),
Date32 => cast_string_to_date32::<i64>(&**array, cast_options),
Date64 => cast_string_to_date64::<i64>(&**array, cast_options),
LargeBinary => cast_string_to_binary(array),
Binary => {
let large_binary =
LargeBinaryArray::from(as_largestring_array(array).clone());
cast_byte_container::<LargeBinaryType, BinaryType, [u8]>(&large_binary)
}
LargeBinary => Ok(Arc::new(LargeBinaryArray::from(
as_largestring_array(array).clone(),
))),
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i64>(&**array, cast_options)
}
Expand All @@ -1291,7 +1316,22 @@ pub fn cast_with_options(
from_type, to_type,
))),
},

(Binary, _) => match to_type {
LargeBinary => {
cast_byte_container::<BinaryType, LargeBinaryType, [u8]>(&**array)
}
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
(LargeBinary, _) => match to_type {
Binary => cast_byte_container::<LargeBinaryType, BinaryType, [u8]>(&**array),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},
// start numeric casts
(UInt8, UInt16) => {
cast_numeric_arrays::<UInt8Type, UInt16Type>(array, cast_options)
Expand Down Expand Up @@ -2007,41 +2047,6 @@ pub fn cast_with_options(
}
}

/// Cast to string array to binary array
fn cast_string_to_binary(array: &ArrayRef) -> Result<ArrayRef, ArrowError> {
let from_type = array.data_type();
match *from_type {
DataType::Utf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::Binary)
.build_unchecked()
};

Ok(Arc::new(BinaryArray::from(data)) as ArrayRef)
}
DataType::LargeUtf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::LargeBinary)
.build_unchecked()
};

Ok(Arc::new(LargeBinaryArray::from(data)) as ArrayRef)
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"{:?} cannot be converted to binary array",
from_type
))),
}
}

/// Get the time unit as a multiple of a second
const fn time_unit_multiple(unit: &TimeUnit) -> i64 {
match unit {
Expand Down Expand Up @@ -2843,11 +2848,17 @@ fn cast_string_to_timestamp_ns<Offset: OffsetSizeTrait>(
}

/// Casts Utf8 to Boolean
fn cast_utf8_to_boolean(
fn cast_utf8_to_boolean<OffsetSize>(
from: &ArrayRef,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = as_string_array(from);
) -> Result<ArrayRef, ArrowError>
where
OffsetSize: OffsetSizeTrait,
{
let array = from
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.unwrap();

let output_array = array
.iter()
Expand All @@ -2861,7 +2872,7 @@ fn cast_utf8_to_boolean(
invalid_value => match cast_options.safe {
true => Ok(None),
false => Err(ArrowError::CastError(format!(
"Cannot cast string '{}' to value of Boolean type",
"Cannot cast value '{}' to value of Boolean type",
invalid_value,
))),
},
Expand Down Expand Up @@ -3447,39 +3458,43 @@ fn cast_list_inner<OffsetSize: OffsetSizeTrait>(
Ok(Arc::new(list) as ArrayRef)
}

/// Helper function to cast from `Utf8` to `LargeUtf8` and vice versa. If the `LargeUtf8` is too large for
/// a `Utf8` array it will return an Error.
fn cast_str_container<OffsetSizeFrom, OffsetSizeTo>(
/// Helper function to cast from one `ByteArrayType` to another and vice versa.
/// If the target one (e.g., `LargeUtf8`) is too large for the source array it will return an Error.
fn cast_byte_container<FROM, TO, N: ?Sized>(
array: &dyn Array,
) -> Result<ArrayRef, ArrowError>
where
OffsetSizeFrom: OffsetSizeTrait + ToPrimitive,
OffsetSizeTo: OffsetSizeTrait + NumCast + ArrowNativeType,
FROM: ByteArrayType<Native = N>,
TO: ByteArrayType<Native = N>,
FROM::Offset: OffsetSizeTrait + ToPrimitive,
TO::Offset: OffsetSizeTrait + NumCast,
{
let data = array.data();
assert_eq!(
data.data_type(),
&GenericStringArray::<OffsetSizeFrom>::DATA_TYPE
);
assert_eq!(data.data_type(), &FROM::DATA_TYPE);
let str_values_buf = data.buffers()[1].clone();
let offsets = data.buffers()[0].typed_data::<OffsetSizeFrom>();
let offsets = data.buffers()[0].typed_data::<FROM::Offset>();

let mut offset_builder = BufferBuilder::<OffsetSizeTo>::new(offsets.len());
let mut offset_builder = BufferBuilder::<TO::Offset>::new(offsets.len());
offsets
.iter()
.try_for_each::<_, Result<_, ArrowError>>(|offset| {
let offset = OffsetSizeTo::from(*offset).ok_or_else(|| {
ArrowError::ComputeError(
"large-utf8 array too large to cast to utf8-array".into(),
)
})?;
let offset = <<TO as ByteArrayType>::Offset as NumCast>::from(*offset)
.ok_or_else(|| {
ArrowError::ComputeError(format!(
"{}{} array too large to cast to {}{} array",
FROM::Offset::PREFIX,
FROM::PREFIX,
TO::Offset::PREFIX,
TO::PREFIX
))
})?;
offset_builder.append(offset);
Ok(())
})?;

let offset_buffer = offset_builder.finish();

let dtype = GenericStringArray::<OffsetSizeTo>::DATA_TYPE;
let dtype = TO::DATA_TYPE;

let builder = ArrayData::builder(dtype)
.offset(array.offset())
Expand All @@ -3490,9 +3505,7 @@ where

let array_data = unsafe { builder.build_unchecked() };

Ok(Arc::new(GenericStringArray::<OffsetSizeTo>::from(
array_data,
)))
Ok(Arc::new(GenericByteArray::<TO>::from(array_data)))
}

/// Cast the container type of List/Largelist array but not the inner types.
Expand Down Expand Up @@ -4813,7 +4826,7 @@ mod tests {
Ok(_) => panic!("expected error"),
Err(e) => {
assert!(e.to_string().contains(
"Cast error: Cannot cast string 'invalid' to value of Boolean type"
"Cast error: Cannot cast value 'invalid' to value of Boolean type"
))
}
}
Expand Down
2 changes: 2 additions & 0 deletions arrow/tests/array_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ fn get_all_types() -> Vec<DataType> {
vec![
Dictionary(Box::new(key_type.clone()), Box::new(Int32)),
Dictionary(Box::new(key_type.clone()), Box::new(Utf8)),
Dictionary(Box::new(key_type.clone()), Box::new(LargeUtf8)),
Dictionary(Box::new(key_type.clone()), Box::new(Binary)),
Dictionary(Box::new(key_type.clone()), Box::new(LargeBinary)),
Dictionary(Box::new(key_type.clone()), Box::new(Decimal128(38, 0))),
Dictionary(Box::new(key_type), Box::new(Decimal256(76, 0))),
]
Expand Down