From d4028cb08d72af25362d31f29ca83ad026fd1053 Mon Sep 17 00:00:00 2001 From: GengTeng Date: Fri, 25 Oct 2024 15:24:56 +0800 Subject: [PATCH] feat: add support for converting LargeUtf8/LargeString in try_cast_to (#71) --- .../src/schema_cast/lists_cast.rs | 13 +-- .../src/schema_cast/record_convert.rs | 97 +++++++++++++++++-- .../src/schema_cast/struct_cast.rs | 7 +- 3 files changed, 100 insertions(+), 17 deletions(-) diff --git a/datafusion-federation/src/schema_cast/lists_cast.rs b/datafusion-federation/src/schema_cast/lists_cast.rs index 9a63b28..8c07d99 100644 --- a/datafusion-federation/src/schema_cast/lists_cast.rs +++ b/datafusion-federation/src/schema_cast/lists_cast.rs @@ -1,4 +1,5 @@ use arrow_json::ReaderBuilder; +use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait}; use datafusion::arrow::{ array::{ Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array, @@ -193,13 +194,13 @@ macro_rules! cast_string_to_fixed_size_list_array { }}; } -pub(crate) fn cast_string_to_list( +pub(crate) fn cast_string_to_list( array: &dyn Array, list_item_field: &FieldRef, ) -> Result { let string_array = array .as_any() - .downcast_ref::() + .downcast_ref::>() .ok_or_else(|| { ArrowError::CastError( "Failed to decode value: unable to downcast to StringArray".to_string(), @@ -297,13 +298,13 @@ pub(crate) fn cast_string_to_list( } } -pub(crate) fn cast_string_to_large_list( +pub(crate) fn cast_string_to_large_list( array: &dyn Array, list_item_field: &FieldRef, ) -> Result { let string_array = array .as_any() - .downcast_ref::() + .downcast_ref::>() .ok_or_else(|| { ArrowError::CastError( "Failed to decode value: unable to downcast to StringArray".to_string(), @@ -401,14 +402,14 @@ pub(crate) fn cast_string_to_large_list( } } -pub(crate) fn cast_string_to_fixed_size_list( +pub(crate) fn cast_string_to_fixed_size_list( array: &dyn Array, list_item_field: &FieldRef, value_length: i32, ) -> Result { let string_array = array .as_any() - .downcast_ref::() + .downcast_ref::>() .ok_or_else(|| { ArrowError::CastError( "Failed to decode value: unable to downcast to StringArray".to_string(), diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs index a20401a..b2b2e0a 100644 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -68,25 +68,47 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res match (record_batch_col.data_type(), expected_field.data_type()) { (DataType::Utf8, DataType::List(item_type)) => { - cast_string_to_list(&Arc::clone(record_batch_col), item_type) + cast_string_to_list::(&Arc::clone(record_batch_col), item_type) .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } (DataType::Utf8, DataType::LargeList(item_type)) => { - cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) + cast_string_to_large_list::(&Arc::clone(record_batch_col), item_type) .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { - cast_string_to_fixed_size_list( + cast_string_to_fixed_size_list::( &Arc::clone(record_batch_col), item_type, *value_length, ) .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } - (DataType::Utf8, DataType::Struct(_)) => { - cast_string_to_struct(&Arc::clone(record_batch_col), expected_field.clone()) + (DataType::Utf8, DataType::Struct(_)) => cast_string_to_struct::( + &Arc::clone(record_batch_col), + expected_field.clone(), + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + (DataType::LargeUtf8, DataType::List(item_type)) => { + cast_string_to_list::(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::LargeUtf8, DataType::LargeList(item_type)) => { + cast_string_to_large_list::(&Arc::clone(record_batch_col), item_type) .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } + (DataType::LargeUtf8, DataType::FixedSizeList(item_type, value_length)) => { + cast_string_to_fixed_size_list::( + &Arc::clone(record_batch_col), + item_type, + *value_length, + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::LargeUtf8, DataType::Struct(_)) => cast_string_to_struct::( + &Arc::clone(record_batch_col), + expected_field.clone(), + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), ( DataType::Interval(IntervalUnit::MonthDayNano), DataType::Interval(IntervalUnit::YearMonth), @@ -109,12 +131,13 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res #[cfg(test)] mod test { + use super::*; + use datafusion::arrow::array::LargeStringArray; use datafusion::arrow::{ array::{Int32Array, StringArray}, datatypes::{DataType, Field, Schema, TimeUnit}, }; - - use super::*; + use datafusion::assert_batches_eq; fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -151,6 +174,64 @@ mod test { #[test] fn test_string_to_timestamp_conversion() { let result = try_cast_to(batch_input(), to_schema()).expect("converted"); - assert_eq!(3, result.num_rows()); + let expected = vec![ + "+---+-----+---------------------+", + "| a | b | c |", + "+---+-----+---------------------+", + "| 1 | foo | 2024-01-13T03:18:09 |", + "| 2 | bar | 2024-01-13T03:18:09 |", + "| 3 | baz | 2024-01-13T03:18:09 |", + "+---+-----+---------------------+", + ]; + + assert_batches_eq!(expected, &[result]); + } + + fn large_string_from_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::LargeUtf8, false), + Field::new("c", DataType::LargeUtf8, false), + ])) + } + + fn large_string_to_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::LargeUtf8, false), + Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), + ])) + } + + fn large_string_batch_input() -> RecordBatch { + RecordBatch::try_new( + large_string_from_schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(LargeStringArray::from(vec!["foo", "bar", "baz"])), + Arc::new(LargeStringArray::from(vec![ + "2024-01-13 03:18:09.000000", + "2024-01-13 03:18:09", + "2024-01-13 03:18:09.000", + ])), + ], + ) + .expect("record batch should not panic") + } + + #[test] + fn test_large_string_to_timestamp_conversion() { + let result = + try_cast_to(large_string_batch_input(), large_string_to_schema()).expect("converted"); + let expected = vec![ + "+---+-----+---------------------+", + "| a | b | c |", + "+---+-----+---------------------+", + "| 1 | foo | 2024-01-13T03:18:09 |", + "| 2 | bar | 2024-01-13T03:18:09 |", + "| 3 | baz | 2024-01-13T03:18:09 |", + "+---+-----+---------------------+", + ]; + assert_batches_eq!(expected, &[result]); } } diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs index 85ad369..e9c3206 100644 --- a/datafusion-federation/src/schema_cast/struct_cast.rs +++ b/datafusion-federation/src/schema_cast/struct_cast.rs @@ -1,6 +1,7 @@ use arrow_json::ReaderBuilder; +use datafusion::arrow::array::{GenericStringArray, OffsetSizeTrait}; use datafusion::arrow::{ - array::{Array, ArrayRef, StringArray}, + array::{Array, ArrayRef}, datatypes::Field, error::ArrowError, }; @@ -8,13 +9,13 @@ use std::sync::Arc; pub type Result = std::result::Result; -pub(crate) fn cast_string_to_struct( +pub(crate) fn cast_string_to_struct( array: &dyn Array, struct_field: Arc, ) -> Result { let string_array = array .as_any() - .downcast_ref::() + .downcast_ref::>() .ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?; let mut decoder = ReaderBuilder::new_with_field(struct_field)