Skip to content

Commit

Permalink
feat: add support for converting LargeUtf8/LargeString in try_cast_to…
Browse files Browse the repository at this point in the history
… function
  • Loading branch information
gengteng committed Oct 15, 2024
1 parent e1084a9 commit 9ae73a2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 16 deletions.
13 changes: 7 additions & 6 deletions datafusion-federation/src/schema_cast/lists_cast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow_json::ReaderBuilder;
use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait};
use datafusion::arrow::{
array::{
Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array,
Expand Down Expand Up @@ -193,13 +194,13 @@ macro_rules! cast_string_to_fixed_size_list_array {
}};
}

pub(crate) fn cast_string_to_list(
pub(crate) fn cast_string_to_list<StringOffsetSize: OffsetSizeTrait>(
array: &dyn Array,
list_item_field: &FieldRef,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.downcast_ref::<GenericStringArray<StringOffsetSize>>()
.ok_or_else(|| {
ArrowError::CastError(
"Failed to decode value: unable to downcast to StringArray".to_string(),
Expand Down Expand Up @@ -297,13 +298,13 @@ pub(crate) fn cast_string_to_list(
}
}

pub(crate) fn cast_string_to_large_list(
pub(crate) fn cast_string_to_large_list<StringOffsetSize: OffsetSizeTrait>(
array: &dyn Array,
list_item_field: &FieldRef,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.downcast_ref::<GenericStringArray<StringOffsetSize>>()
.ok_or_else(|| {
ArrowError::CastError(
"Failed to decode value: unable to downcast to StringArray".to_string(),
Expand Down Expand Up @@ -401,14 +402,14 @@ pub(crate) fn cast_string_to_large_list(
}
}

pub(crate) fn cast_string_to_fixed_size_list(
pub(crate) fn cast_string_to_fixed_size_list<StringOffsetSize: OffsetSizeTrait>(
array: &dyn Array,
list_item_field: &FieldRef,
value_length: i32,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.downcast_ref::<GenericStringArray<StringOffsetSize>>()
.ok_or_else(|| {
ArrowError::CastError(
"Failed to decode value: unable to downcast to StringArray".to_string(),
Expand Down
75 changes: 68 additions & 7 deletions datafusion-federation/src/schema_cast/record_convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,47 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res

match (record_batch_col.data_type(), expected_field.data_type()) {
(DataType::Utf8, DataType::List(item_type)) => {
cast_string_to_list(&Arc::clone(record_batch_col), item_type)
cast_string_to_list::<i32>(&Arc::clone(record_batch_col), item_type)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::Utf8, DataType::LargeList(item_type)) => {
cast_string_to_large_list(&Arc::clone(record_batch_col), item_type)
cast_string_to_large_list::<i32>(&Arc::clone(record_batch_col), item_type)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => {
cast_string_to_fixed_size_list(
cast_string_to_fixed_size_list::<i32>(
&Arc::clone(record_batch_col),
item_type,
*value_length,
)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::Utf8, DataType::Struct(_)) => {
cast_string_to_struct(&Arc::clone(record_batch_col), expected_field.clone())
(DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::<i32>(
&Arc::clone(record_batch_col),
expected_field.clone(),
)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
(DataType::LargeUtf8, DataType::List(item_type)) => {
cast_string_to_list::<i64>(&Arc::clone(record_batch_col), item_type)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::LargeUtf8, DataType::LargeList(item_type)) => {
cast_string_to_large_list::<i64>(&Arc::clone(record_batch_col), item_type)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => {
cast_string_to_fixed_size_list::<i64>(
&Arc::clone(record_batch_col),
item_type,
*value_length,
)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e })
}
(DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::<i64>(
&Arc::clone(record_batch_col),
expected_field.clone(),
)
.map_err(|e| Error::UnableToConvertRecordBatch { source: e }),
(
DataType::Interval(IntervalUnit::MonthDayNano),
DataType::Interval(IntervalUnit::YearMonth),
Expand All @@ -109,13 +131,13 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res

#[cfg(test)]
mod test {
use super::*;
use datafusion::arrow::array::LargeStringArray;
use datafusion::arrow::{
array::{Int32Array, StringArray},
datatypes::{DataType, Field, Schema, TimeUnit},
};

use super::*;

fn schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Expand Down Expand Up @@ -153,4 +175,43 @@ mod test {
let result = try_cast_to(batch_input(), to_schema()).expect("converted");
assert_eq!(3, result.num_rows());
}

fn large_string_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::LargeUtf8, false),
Field::new("c", DataType::LargeUtf8, false),
]))
}

fn large_string_to_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::LargeUtf8, false),
Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false),
]))
}

fn large_string_batch_input() -> RecordBatch {
RecordBatch::try_new(
large_string_schema(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])),
Arc::new(LargeStringArray::from(vec![
"2024-01-13 03:18:09.000000",
"2024-01-13 03:18:09",
"2024-01-13 03:18:09.000",
])),
],
)
.expect("record batch should not panic")
}

#[test]
fn test_large_string_to_timestamp_conversion() {
let result =
try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted");
assert_eq!(3, result.num_rows());
}
}
7 changes: 4 additions & 3 deletions datafusion-federation/src/schema_cast/struct_cast.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use arrow_json::ReaderBuilder;
use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait};
use datafusion::arrow::{
array::{Array, ArrayRef, StringArray},
array::{Array, ArrayRef},
datatypes::Field,
error::ArrowError,
};
use std::sync::Arc;

pub type Result<T, E = crate::schema_cast::record_convert::Error> = std::result::Result<T, E>;

pub(crate) fn cast_string_to_struct(
pub(crate) fn cast_string_to_struct<StringOffsetSize: OffsetSizeTrait>(
array: &dyn Array,
struct_field: Arc<Field>,
) -> Result<ArrayRef, ArrowError> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.downcast_ref::<GenericStringArray<StringOffsetSize>>()
.ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?;

let mut decoder = ReaderBuilder::new_with_field(struct_field)
Expand Down

0 comments on commit 9ae73a2

Please sign in to comment.