Skip to content

Commit

Permalink
For CastOptions.safe as false case, applying optimized casting
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 27, 2023
1 parent f37503a commit 124205b
Showing 1 changed file with 54 additions and 21 deletions.
75 changes: 54 additions & 21 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3400,34 +3400,67 @@ fn cast_binary_to_generic_string<I, O>(
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
I: OffsetSizeTrait,
O: OffsetSizeTrait,
I: OffsetSizeTrait + ToPrimitive,
O: OffsetSizeTrait + NumCast,
{
let array = array
.as_any()
.downcast_ref::<GenericByteArray<GenericBinaryType<I>>>()
.unwrap();
Ok(Arc::new(
array

if !cast_options.safe {
let offsets = array.value_offsets();
let values = array.value_data();

// We only need to validate that all values are valid UTF-8
let validated = std::str::from_utf8(values)
.map_err(|_| ArrowError::CastError("Invalid UTF-8 sequence".to_string()))?;

let mut offset_builder = BufferBuilder::<O>::new(offsets.len());
offsets
.iter()
.map(|maybe_value| match maybe_value {
Some(value) => {
let result = std::str::from_utf8(value);
if cast_options.safe {
Ok(result.ok())
} else {
Some(result.map_err(|_| {
ArrowError::CastError(
"Cannot cast binary to string".to_string(),
)
}))
.transpose()
}
.try_for_each::<_, Result<_, ArrowError>>(|offset| {
if !validated.is_char_boundary(offset.as_usize()) {
return Err(ArrowError::CastError(
"Invalid UTF-8 sequence".to_string(),
));
}
None => Ok(None),
})
.collect::<Result<GenericByteArray<GenericStringType<O>>, _>>()?,
))

let offset = <O as NumCast>::from(*offset).ok_or_else(|| {
ArrowError::ComputeError(format!(
"{}Binary array too large to cast to {}String array",
I::PREFIX,
O::PREFIX
))
})?;
offset_builder.append(offset);
Ok(())
})?;

let offset_buffer = offset_builder.finish();

let builder = ArrayData::builder(GenericStringArray::<O>::DATA_TYPE)
.offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_buffer(array.data().buffers()[1].clone())
.null_bit_buffer(array.data().null_buffer().cloned());

// SAFETY:
// Validated UTF-8 above
Ok(Arc::new(GenericStringArray::<O>::from(unsafe {
builder.build_unchecked()
})))
} else {
Ok(Arc::new(
array
.iter()
.map(|maybe_value| {
maybe_value.and_then(|value| std::str::from_utf8(value).ok())
})
.collect::<GenericByteArray<GenericStringType<O>>>(),
))
}
}

/// Helper function to cast from one `ByteArrayType` to another and vice versa.
Expand Down

0 comments on commit 124205b

Please sign in to comment.