From 6542fcb4f50ec3ab626d426ad9342c06da372303 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 16 Apr 2021 16:55:29 +0100 Subject: [PATCH] ARROW-12426: [Rust] Fix concatentation of arrow dictionaries Signed-off-by: Raphael Taylor-Davies --- rust/arrow/src/array/transform/mod.rs | 119 +++++++++++++++++--- rust/arrow/src/array/transform/primitive.rs | 15 +++ rust/arrow/src/compute/kernels/concat.rs | 47 ++++++++ 3 files changed, 167 insertions(+), 14 deletions(-) diff --git a/rust/arrow/src/array/transform/mod.rs b/rust/arrow/src/array/transform/mod.rs index 4dc7b56d1c37c..6083077493366 100644 --- a/rust/arrow/src/array/transform/mod.rs +++ b/rust/arrow/src/array/transform/mod.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util}; +use crate::{ + buffer::MutableBuffer, + datatypes::DataType, + error::{ArrowError, Result}, + util::bit_util, +}; use super::{ data::{into_buffers, new_buffers}, @@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> { } } +/// Builds an extend that adds `offset` to the source primitive +/// Additionally validates that `max` fits into the +/// the underlying primitive returning None if not +fn build_extend_dictionary( + array: &ArrayData, + offset: usize, + max: usize, +) -> Option { + use crate::datatypes::*; + use std::convert::TryInto; + + match array.data_type() { + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => { + let _: u8 = max.try_into().ok()?; + let offset: u8 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt16 => { + let _: u16 = max.try_into().ok()?; + let offset: u16 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt32 => { + let _: u32 = max.try_into().ok()?; + let offset: u32 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt64 => { + let _: u64 = max.try_into().ok()?; + let offset: u64 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int8 => { + let _: i8 = max.try_into().ok()?; + let offset: i8 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int16 => { + let _: i16 = max.try_into().ok()?; + let offset: i16 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int32 => { + let _: i32 = max.try_into().ok()?; + let offset: i32 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int64 => { + let _: i64 = max.try_into().ok()?; + let offset: i64 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + _ => unreachable!(), + }, + _ => None, + } +} + fn build_extend(array: &ArrayData) -> Extend { use crate::datatypes::*; match array.data_type() { @@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend { } DataType::List(_) => list::build_extend::(array), DataType::LargeList(_) => list::build_extend::(array), - DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { - DataType::UInt8 => primitive::build_extend::(array), - DataType::UInt16 => primitive::build_extend::(array), - DataType::UInt32 => primitive::build_extend::(array), - DataType::UInt64 => primitive::build_extend::(array), - DataType::Int8 => primitive::build_extend::(array), - DataType::Int16 => primitive::build_extend::(array), - DataType::Int32 => primitive::build_extend::(array), - DataType::Int64 => primitive::build_extend::(array), - _ => unreachable!(), - }, + DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), DataType::Float16 => unreachable!(), @@ -339,7 +393,28 @@ impl<'a> MutableArrayData<'a> { }; let dictionary = match &data_type { - DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()), + DataType::Dictionary(_, _) => match arrays.len() { + 0 => unreachable!(), + 1 => Some(arrays[0].child_data()[0].clone()), + _ => { + // Concat dictionaries together + let dictionaries: Vec<_> = + arrays.iter().map(|array| &array.child_data()[0]).collect(); + let lengths: Vec<_> = dictionaries + .iter() + .map(|dictionary| dictionary.len()) + .collect(); + let capacity = lengths.iter().sum(); + + let mut mutable = MutableArrayData::new(dictionaries, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + Some(mutable.freeze()) + } + } _ => None, }; @@ -353,7 +428,23 @@ impl<'a> MutableArrayData<'a> { let null_bytes = bit_util::ceil(capacity, 8); let null_buffer = MutableBuffer::from_len_zeroed(null_bytes); - let extend_values = arrays.iter().map(|array| build_extend(array)).collect(); + let extend_values = match &data_type { + DataType::Dictionary(_, _) => { + let mut next_offset = 0; + let extend_values: Result> = arrays + .iter() + .map(|array| { + let offset = next_offset; + next_offset += array.child_data()[0].len(); + Ok(build_extend_dictionary(array, offset, next_offset) + .ok_or(ArrowError::DictionaryKeyOverflowError)?) + }) + .collect(); + + extend_values.expect("") + } + _ => arrays.iter().map(|array| build_extend(array)).collect(), + }; let data = _MutableArrayData { data_type: data_type.clone(), diff --git a/rust/arrow/src/array/transform/primitive.rs b/rust/arrow/src/array/transform/primitive.rs index 032bb4a877940..4c765c0c0d958 100644 --- a/rust/arrow/src/array/transform/primitive.rs +++ b/rust/arrow/src/array/transform/primitive.rs @@ -16,6 +16,7 @@ // under the License. use std::mem::size_of; +use std::ops::Add; use crate::{array::ArrayData, datatypes::ArrowNativeType}; @@ -32,6 +33,20 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ) } +pub(super) fn build_extend_with_offset(array: &ArrayData, offset: T) -> Extend +where + T: ArrowNativeType + Add, +{ + let values = array.buffer::(0); + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + mutable + .buffer1 + .extend(values[start..start + len].iter().map(|x| *x + offset)); + }, + ) +} + pub(super) fn extend_nulls( mutable: &mut _MutableArrayData, len: usize, diff --git a/rust/arrow/src/compute/kernels/concat.rs b/rust/arrow/src/compute/kernels/concat.rs index 32880286a7247..40fb3026f8450 100644 --- a/rust/arrow/src/compute/kernels/concat.rs +++ b/rust/arrow/src/compute/kernels/concat.rs @@ -384,4 +384,51 @@ mod tests { Ok(()) } + + fn collect_string_dictionary(dictionary: &DictionaryArray) -> Vec> { + let values = dictionary.values(); + let values = values.as_any().downcast_ref::().unwrap(); + + (0..dictionary.len()) + .map(move |i| { + match dictionary.keys().is_valid(i) { + true => { + let key = dictionary.keys().value(i); + Some(values.value(key as _).to_string()) + } + false => None + } + }) + .collect() + } + + #[test] + fn test_string_dictionary_array() -> Result<()> { + let input_1: DictionaryArray = + vec!["hello", "A", "B", "hello", "hello", "C"] + .into_iter() + .collect(); + let input_2: DictionaryArray = + vec!["hello", "E", "E", "hello", "F", "E"] + .into_iter() + .collect(); + + let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); + let concat = concat + .as_any() + .downcast_ref::>() + .unwrap(); + + let concat_collected = collect_string_dictionary(concat); + let input_1_collected = collect_string_dictionary(&input_1); + let input_2_collected = collect_string_dictionary(&input_2); + let expected: Vec<_> = input_1_collected + .into_iter() + .chain(input_2_collected.into_iter()) + .collect(); + + assert_eq!(concat_collected, expected); + + Ok(()) + } }