Skip to content

Commit

Permalink
ARROW-12426: [Rust] Fix concatentation of arrow dictionaries
Browse files Browse the repository at this point in the history
Signed-off-by: Raphael Taylor-Davies <r.taylordavies@googlemail.com>
  • Loading branch information
tustvold committed Apr 16, 2021
1 parent 715cb57 commit 6542fcb
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 14 deletions.
119 changes: 105 additions & 14 deletions rust/arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util};
use crate::{
buffer::MutableBuffer,
datatypes::DataType,
error::{ArrowError, Result},
util::bit_util,
};

use super::{
data::{into_buffers, new_buffers},
Expand Down Expand Up @@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> {
}
}

/// Builds an extend that adds `offset` to the source primitive
/// Additionally validates that `max` fits into the
/// the underlying primitive returning None if not
fn build_extend_dictionary(
array: &ArrayData,
offset: usize,
max: usize,
) -> Option<Extend> {
use crate::datatypes::*;
use std::convert::TryInto;

match array.data_type() {
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
DataType::UInt8 => {
let _: u8 = max.try_into().ok()?;
let offset: u8 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt16 => {
let _: u16 = max.try_into().ok()?;
let offset: u16 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt32 => {
let _: u32 = max.try_into().ok()?;
let offset: u32 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt64 => {
let _: u64 = max.try_into().ok()?;
let offset: u64 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int8 => {
let _: i8 = max.try_into().ok()?;
let offset: i8 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int16 => {
let _: i16 = max.try_into().ok()?;
let offset: i16 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int32 => {
let _: i32 = max.try_into().ok()?;
let offset: i32 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int64 => {
let _: i64 = max.try_into().ok()?;
let offset: i64 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
_ => unreachable!(),
},
_ => None,
}
}

fn build_extend(array: &ArrayData) -> Extend {
use crate::datatypes::*;
match array.data_type() {
Expand Down Expand Up @@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend {
}
DataType::List(_) => list::build_extend::<i32>(array),
DataType::LargeList(_) => list::build_extend::<i64>(array),
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
DataType::UInt8 => primitive::build_extend::<u8>(array),
DataType::UInt16 => primitive::build_extend::<u16>(array),
DataType::UInt32 => primitive::build_extend::<u32>(array),
DataType::UInt64 => primitive::build_extend::<u64>(array),
DataType::Int8 => primitive::build_extend::<i8>(array),
DataType::Int16 => primitive::build_extend::<i16>(array),
DataType::Int32 => primitive::build_extend::<i32>(array),
DataType::Int64 => primitive::build_extend::<i64>(array),
_ => unreachable!(),
},
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => unreachable!(),
Expand Down Expand Up @@ -339,7 +393,28 @@ impl<'a> MutableArrayData<'a> {
};

let dictionary = match &data_type {
DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()),
DataType::Dictionary(_, _) => match arrays.len() {
0 => unreachable!(),
1 => Some(arrays[0].child_data()[0].clone()),
_ => {
// Concat dictionaries together
let dictionaries: Vec<_> =
arrays.iter().map(|array| &array.child_data()[0]).collect();
let lengths: Vec<_> = dictionaries
.iter()
.map(|dictionary| dictionary.len())
.collect();
let capacity = lengths.iter().sum();

let mut mutable = MutableArrayData::new(dictionaries, false, capacity);

for (i, len) in lengths.iter().enumerate() {
mutable.extend(i, 0, *len)
}

Some(mutable.freeze())
}
}
_ => None,
};

Expand All @@ -353,7 +428,23 @@ impl<'a> MutableArrayData<'a> {
let null_bytes = bit_util::ceil(capacity, 8);
let null_buffer = MutableBuffer::from_len_zeroed(null_bytes);

let extend_values = arrays.iter().map(|array| build_extend(array)).collect();
let extend_values = match &data_type {
DataType::Dictionary(_, _) => {
let mut next_offset = 0;
let extend_values: Result<Vec<_>> = arrays
.iter()
.map(|array| {
let offset = next_offset;
next_offset += array.child_data()[0].len();
Ok(build_extend_dictionary(array, offset, next_offset)
.ok_or(ArrowError::DictionaryKeyOverflowError)?)
})
.collect();

extend_values.expect("")
}
_ => arrays.iter().map(|array| build_extend(array)).collect(),
};

let data = _MutableArrayData {
data_type: data_type.clone(),
Expand Down
15 changes: 15 additions & 0 deletions rust/arrow/src/array/transform/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::mem::size_of;
use std::ops::Add;

use crate::{array::ArrayData, datatypes::ArrowNativeType};

Expand All @@ -32,6 +33,20 @@ pub(super) fn build_extend<T: ArrowNativeType>(array: &ArrayData) -> Extend {
)
}

pub(super) fn build_extend_with_offset<T>(array: &ArrayData, offset: T) -> Extend
where
T: ArrowNativeType + Add<Output = T>,
{
let values = array.buffer::<T>(0);
Box::new(
move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| {
mutable
.buffer1
.extend(values[start..start + len].iter().map(|x| *x + offset));
},
)
}

pub(super) fn extend_nulls<T: ArrowNativeType>(
mutable: &mut _MutableArrayData,
len: usize,
Expand Down
47 changes: 47 additions & 0 deletions rust/arrow/src/compute/kernels/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,51 @@ mod tests {

Ok(())
}

fn collect_string_dictionary(dictionary: &DictionaryArray<Int32Type>) -> Vec<Option<String>> {
let values = dictionary.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();

(0..dictionary.len())
.map(move |i| {
match dictionary.keys().is_valid(i) {
true => {
let key = dictionary.keys().value(i);
Some(values.value(key as _).to_string())
}
false => None
}
})
.collect()
}

#[test]
fn test_string_dictionary_array() -> Result<()> {
let input_1: DictionaryArray<Int32Type> =
vec!["hello", "A", "B", "hello", "hello", "C"]
.into_iter()
.collect();
let input_2: DictionaryArray<Int32Type> =
vec!["hello", "E", "E", "hello", "F", "E"]
.into_iter()
.collect();

let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let concat = concat
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();

let concat_collected = collect_string_dictionary(concat);
let input_1_collected = collect_string_dictionary(&input_1);
let input_2_collected = collect_string_dictionary(&input_2);
let expected: Vec<_> = input_1_collected
.into_iter()
.chain(input_2_collected.into_iter())
.collect();

assert_eq!(concat_collected, expected);

Ok(())
}
}

0 comments on commit 6542fcb

Please sign in to comment.