diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 60426e5b3c4d..22e99a44c326 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::StringDictionaryBuilder; +use crate::builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}; +use crate::cast::as_primitive_array; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ @@ -394,6 +395,44 @@ impl DictionaryArray { // Offsets were valid before and verified length is greater than or equal Self::from(unsafe { builder.build_unchecked() }) } + + /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating + /// its keys and values if the underlying data buffer is not shared by others. + pub fn into_primitive_dict_builder( + self, + ) -> Result, Self> + where + V: ArrowPrimitiveType, + { + if !self.value_type().is_primitive() { + return Err(self); + } + + let key_array = as_primitive_array::(self.keys()).clone(); + let value_array = as_primitive_array::(self.values()).clone(); + + drop(self.data); + drop(self.keys); + drop(self.values); + + let key_builder = key_array.into_builder(); + let value_builder = value_array.into_builder(); + + match (key_builder, value_builder) { + (Ok(key_builder), Ok(value_builder)) => Ok(unsafe { + PrimitiveDictionaryBuilder::new_from_builders(key_builder, value_builder) + }), + (Err(key_array), Ok(mut value_builder)) => { + Err(Self::try_new(&key_array, &value_builder.finish()).unwrap()) + } + (Ok(mut key_builder), Err(value_array)) => { + Err(Self::try_new(&key_builder.finish(), &value_array).unwrap()) + } + (Err(key_array), Err(value_array)) => { + Err(Self::try_new(&key_array, &value_array).unwrap()) + } + } + } } /// Constructs a `DictionaryArray` from an array data reference. @@ -644,11 +683,13 @@ where mod tests { use super::*; use crate::builder::PrimitiveDictionaryBuilder; + use crate::cast::as_dictionary_array; use crate::types::{ Float32Type, Int16Type, Int32Type, Int8Type, UInt32Type, UInt8Type, }; use crate::{Float32Array, Int16Array, Int32Array, Int8Array}; use arrow_buffer::{Buffer, ToByteSlice}; + use std::sync::Arc; #[test] fn test_dictionary_array() { @@ -930,4 +971,58 @@ mod tests { let a = DictionaryArray::::from_iter(["32"]); let _ = DictionaryArray::::from(a.into_data()); } + + #[test] + fn test_into_primitive_dict_builder() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let boxed: ArrayRef = Arc::new(dict_array); + let col: DictionaryArray = as_dictionary_array(&boxed).clone(); + + drop(boxed); + drop(keys); + drop(values); + + let mut builder = col.into_primitive_dict_builder::().unwrap(); + + let slice = builder.values_slice_mut(); + assert_eq!(slice, &[10, 12, 15]); + + slice[0] = 4; + slice[1] = 2; + slice[2] = 1; + + let values = Int32Array::from_iter_values([4_i32, 2, 1]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let new_array = builder.finish(); + assert_eq!(expected, new_array); + } + + #[test] + fn test_into_primitive_dict_builder_cloned_array() { + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let dict_array = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let boxed: ArrayRef = Arc::new(dict_array); + + let col: DictionaryArray = + DictionaryArray::::from(boxed.data().clone()); + let err = col.into_primitive_dict_builder::(); + + let returned = err.unwrap_err(); + + let values = Int32Array::from_iter_values([10_i32, 12, 15]); + let keys = Int8Array::from_iter_values([1_i8, 0, 2, 0]); + + let expected = DictionaryArray::::try_new(&keys, &values).unwrap(); + assert_eq!(expected, returned); + } } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index 742c09d8cc26..9f410994114f 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -118,7 +118,7 @@ where /// # Panics /// /// This method panics if `keys_builder` or `values_builder` is not empty. - pub fn new_from_builders( + pub fn new_from_empty_builders( keys_builder: PrimitiveBuilder, values_builder: PrimitiveBuilder, ) -> Self { @@ -133,6 +133,30 @@ where } } + /// Creates a new `PrimitiveDictionaryBuilder` from existing `PrimitiveBuilder`s of keys and values. + /// + /// # Safety + /// + /// caller must ensure that the passed in builders are valid for DictionaryArray. + pub unsafe fn new_from_builders( + keys_builder: PrimitiveBuilder, + values_builder: PrimitiveBuilder, + ) -> Self { + let keys = keys_builder.values_slice(); + let values = values_builder.values_slice(); + let mut map = HashMap::with_capacity(values.len()); + + keys.iter().zip(values.iter()).for_each(|(key, value)| { + map.insert(Value(*value), K::Native::to_usize(*key).unwrap()); + }); + + Self { + keys_builder, + values_builder, + map, + } + } + /// Creates a new `PrimitiveDictionaryBuilder` with the provided capacities /// /// `keys_capacity`: the number of keys, i.e. length of array to build @@ -276,6 +300,16 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + + /// Returns the current dictionary values buffer as a slice + pub fn values_slice(&self) -> &[V::Native] { + self.values_builder.values_slice() + } + + /// Returns the current dictionary values buffer as a mutable slice + pub fn values_slice_mut(&mut self) -> &mut [V::Native] { + self.values_builder.values_slice_mut() + } } impl Extend> @@ -357,7 +391,7 @@ mod tests { let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); let mut builder = - PrimitiveDictionaryBuilder::::new_from_builders( + PrimitiveDictionaryBuilder::::new_from_empty_builders( keys_builder, values_builder, );