diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index 34b736d65861..4a920f3ee43e 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -268,7 +268,7 @@ where let keys = self.keys_builder.finish(); let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(DataType::Utf8)); + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); let builder = keys .into_data() @@ -285,7 +285,7 @@ where let keys = self.keys_builder.finish_cloned(); let data_type = - DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(DataType::Utf8)); + DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); let builder = keys .into_data() diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 8b8244a7c9ac..aa6697a7170d 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -3303,6 +3303,8 @@ fn cast_to_dictionary( ), Utf8 => pack_string_to_dictionary::(array, cast_options), LargeUtf8 => pack_string_to_dictionary::(array, cast_options), + Binary => pack_binary_to_dictionary::(array, cast_options), + LargeBinary => pack_binary_to_dictionary::(array, cast_options), _ => Err(ArrowError::CastError(format!( "Unsupported output type for dictionary packing: {:?}", dict_value_type @@ -3366,6 +3368,30 @@ where Ok(Arc::new(b.finish())) } +// Packs the data as a BinaryDictionaryArray, if possible, with the +// key types of K +fn pack_binary_to_dictionary( + array: &ArrayRef, + cast_options: &CastOptions, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let cast_values = cast_with_options(array, &DataType::Binary, cast_options)?; + let values = cast_values.as_any().downcast_ref::().unwrap(); + let mut b = BinaryDictionaryBuilder::::with_capacity(values.len(), 1024, 1024); + + // copy each element one at a time + for i in 0..values.len() { + if values.is_null(i) { + b.append_null(); + } else { + b.append(values.value(i))?; + } + } + Ok(Arc::new(b.finish())) +} + /// Helper function that takes a primitive array and casts to a (generic) list array. fn cast_primitive_to_list( array: &ArrayRef, diff --git a/arrow/tests/array_cast.rs b/arrow/tests/array_cast.rs index be37a7636b63..91d2da9985b5 100644 --- a/arrow/tests/array_cast.rs +++ b/arrow/tests/array_cast.rs @@ -411,7 +411,9 @@ fn get_all_types() -> Vec { ), Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::Int16), Box::new(DataType::Binary)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Binary)), Decimal128(38, 0), Dictionary(Box::new(DataType::Int8), Box::new(Decimal128(38, 0))), Dictionary(Box::new(DataType::Int16), Box::new(Decimal128(38, 0))),