Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type_ids in Union datatype #1703

Merged
merged 5 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arrow/src/array/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef,
DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef,
DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef,
DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef,
DataType::FixedSizeList(_, _) => {
Arc::new(FixedSizeListArray::from(data)) as ArrayRef
}
Expand Down Expand Up @@ -535,7 +535,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef {
DataType::Map(field, _keys_sorted) => {
new_null_list_array::<i32>(data_type, field.data_type(), length)
}
DataType::Union(_, _) => {
DataType::Union(_, _, _) => {
unimplemented!("Creating null Union array not yet supported")
}
DataType::Dictionary(key, value) => {
Expand Down
35 changes: 24 additions & 11 deletions arrow/src/array/array_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use std::any::Any;
/// ];
///
/// let array = UnionArray::try_new(
/// &vec![0, 1],
/// type_id_buffer,
/// Some(value_offsets_buffer),
/// children,
Expand Down Expand Up @@ -90,6 +91,7 @@ use std::any::Any;
/// ];
///
/// let array = UnionArray::try_new(
/// &vec![0, 1],
/// type_id_buffer,
/// None,
/// children,
Expand Down Expand Up @@ -135,6 +137,7 @@ impl UnionArray {
/// `i8` and `i32` values respectively. `Buffer` objects are untyped and no attempt is made
/// to ensure that the data provided is valid.
pub unsafe fn new_unchecked(
field_type_ids: &[i8],
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
Expand All @@ -149,10 +152,14 @@ impl UnionArray {
UnionMode::Sparse
};

let builder = ArrayData::builder(DataType::Union(field_types, mode))
.add_buffer(type_ids)
.child_data(field_values.into_iter().map(|a| a.data().clone()).collect())
.len(len);
let builder = ArrayData::builder(DataType::Union(
field_types,
Vec::from(field_type_ids),
mode,
))
.add_buffer(type_ids)
.child_data(field_values.into_iter().map(|a| a.data().clone()).collect())
.len(len);

let data = match value_offsets {
Some(b) => builder.add_buffer(b).build_unchecked(),
Expand All @@ -163,6 +170,7 @@ impl UnionArray {

/// Attempts to create a new `UnionArray`, validating the inputs provided.
pub fn try_new(
field_type_ids: &[i8],
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
Expand Down Expand Up @@ -209,8 +217,9 @@ impl UnionArray {

// Unsafe Justification: arguments were validated above (and
// re-revalidated as part of data().validate() below)
let new_self =
unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays) };
let new_self = unsafe {
Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays)
};
new_self.data().validate()?;

Ok(new_self)
Expand Down Expand Up @@ -269,7 +278,7 @@ impl UnionArray {
/// Returns the names of the types in the union.
pub fn type_names(&self) -> Vec<&str> {
match self.data.data_type() {
DataType::Union(fields, _) => fields
DataType::Union(fields, _, _) => fields
.iter()
.map(|f| f.name().as_str())
.collect::<Vec<&str>>(),
Expand All @@ -280,7 +289,7 @@ impl UnionArray {
/// Returns whether the `UnionArray` is dense (or sparse if `false`).
fn is_dense(&self) -> bool {
match self.data.data_type() {
DataType::Union(_, mode) => mode == &UnionMode::Dense,
DataType::Union(_, _, mode) => mode == &UnionMode::Dense,
_ => unreachable!("Union array's data type is not a union!"),
}
}
Expand Down Expand Up @@ -626,9 +635,13 @@ mod tests {
Arc::new(float_array),
),
];
let array =
UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer), children)
.unwrap();
let array = UnionArray::try_new(
&[0, 1, 2],
type_id_buffer,
Some(value_offsets_buffer),
children,
)
.unwrap();

// Check type ids
assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]);
Expand Down
4 changes: 3 additions & 1 deletion arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2168,7 +2168,9 @@ impl UnionBuilder {
});
let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect();

UnionArray::try_new(type_id_buffer, value_offsets_buffer, children)
let type_ids: Vec<i8> = (0_i8..children.len() as i8).collect();

UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, children)
}
}

Expand Down
20 changes: 12 additions & 8 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
DataType::Union(_, mode) => {
DataType::Union(_, _, mode) => {
let type_ids = MutableBuffer::new(capacity * mem::size_of::<i8>());
match mode {
UnionMode::Sparse => [type_ids, empty_buffer],
Expand All @@ -220,7 +220,7 @@ pub(crate) fn into_buffers(
| DataType::Binary
| DataType::LargeUtf8
| DataType::LargeBinary => vec![buffer1.into(), buffer2.into()],
DataType::Union(_, mode) => {
DataType::Union(_, _, mode) => {
match mode {
// Based on Union's DataTypeLayout
UnionMode::Sparse => vec![buffer1.into()],
Expand Down Expand Up @@ -581,7 +581,7 @@ impl ArrayData {
DataType::Map(field, _) => {
vec![Self::new_empty(field.data_type())]
}
DataType::Union(fields, _) => fields
DataType::Union(fields, _, _) => fields
.iter()
.map(|field| Self::new_empty(field.data_type()))
.collect(),
Expand Down Expand Up @@ -854,7 +854,7 @@ impl ArrayData {
}
Ok(())
}
DataType::Union(fields, mode) => {
DataType::Union(fields, _, mode) => {
self.validate_num_child_data(fields.len())?;

for (i, field) in fields.iter().enumerate() {
Expand Down Expand Up @@ -1002,7 +1002,7 @@ impl ArrayData {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len + child.offset)
}
DataType::Union(_, _) => {
DataType::Union(_, _, _) => {
// Validate Union Array as part of implementing new Union semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
Expand Down Expand Up @@ -1279,7 +1279,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all in child data
DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::<i32>()),
DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data,
DataType::Union(_, mode) => {
DataType::Union(_, _, mode) => {
let type_ids = BufferSpec::FixedWidth {
byte_width: size_of::<i8>(),
};
Expand Down Expand Up @@ -2431,6 +2431,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true), // data is int32
],
vec![0, 1],
UnionMode::Sparse,
),
2,
Expand Down Expand Up @@ -2462,6 +2463,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
vec![0, 1],
UnionMode::Sparse,
),
2,
Expand Down Expand Up @@ -2489,6 +2491,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
vec![0, 1],
UnionMode::Dense,
),
2,
Expand Down Expand Up @@ -2519,6 +2522,7 @@ mod tests {
Field::new("field1", DataType::Int32, true),
Field::new("field2", DataType::Int64, true),
],
vec![0, 1],
UnionMode::Dense,
),
2,
Expand Down Expand Up @@ -2631,8 +2635,8 @@ mod tests {
#[test]
fn test_into_buffers() {
let data_types = vec![
DataType::Union(vec![], UnionMode::Dense),
DataType::Union(vec![], UnionMode::Sparse),
DataType::Union(vec![], vec![], UnionMode::Dense),
DataType::Union(vec![], vec![], UnionMode::Sparse),
];

for data_type in data_types {
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ fn equal_values(
fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Int16 => {
Expand Down
26 changes: 21 additions & 5 deletions arrow/src/array/equal/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode};

use super::equal_range;

#[allow(clippy::too_many_arguments)]
fn equal_dense(
lhs: &ArrayData,
rhs: &ArrayData,
lhs_type_ids: &[i8],
rhs_type_ids: &[i8],
lhs_offsets: &[i32],
rhs_offsets: &[i32],
lhs_field_type_ids: &[i8],
rhs_field_type_ids: &[i8],
) -> bool {
let offsets = lhs_offsets.iter().zip(rhs_offsets.iter());

Expand All @@ -34,8 +37,16 @@ fn equal_dense(
.zip(rhs_type_ids.iter())
.zip(offsets)
.all(|((l_type_id, r_type_id), (l_offset, r_offset))| {
let lhs_values = &lhs.child_data()[*l_type_id as usize];
let rhs_values = &rhs.child_data()[*r_type_id as usize];
let lhs_child_index = lhs_field_type_ids
.iter()
.position(|r| r == l_type_id)
.unwrap();
let rhs_child_index = rhs_field_type_ids
.iter()
.position(|r| r == r_type_id)
.unwrap();
let lhs_values = &lhs.child_data()[lhs_child_index];
let rhs_values = &rhs.child_data()[rhs_child_index];

equal_range(
lhs_values,
Expand Down Expand Up @@ -76,7 +87,10 @@ pub(super) fn union_equal(
let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len];

match (lhs.data_type(), rhs.data_type()) {
(DataType::Union(_, UnionMode::Dense), DataType::Union(_, UnionMode::Dense)) => {
(
DataType::Union(_, lhs_type_ids, UnionMode::Dense),
DataType::Union(_, rhs_type_ids, UnionMode::Dense),
) => {
let lhs_offsets = lhs.buffer::<i32>(1);
let rhs_offsets = rhs.buffer::<i32>(1);

Expand All @@ -91,11 +105,13 @@ pub(super) fn union_equal(
rhs_type_id_range,
lhs_offsets_range,
rhs_offsets_range,
lhs_type_ids,
rhs_type_ids,
)
}
(
DataType::Union(_, UnionMode::Sparse),
DataType::Union(_, UnionMode::Sparse),
DataType::Union(_, _, UnionMode::Sparse),
DataType::Union(_, _, UnionMode::Sparse),
) => {
lhs_type_id_range == rhs_type_id_range
&& equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/array/equal/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub(super) fn equal_nulls(
#[inline]
pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
let equal_type = match (lhs.data_type(), rhs.data_type()) {
(DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => {
(DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, r_mode)) => {
l_fields == r_fields && l_mode == r_mode
}
(DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => {
Expand Down
6 changes: 3 additions & 3 deletions arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => primitive::build_extend::<f16>(array),
DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
DataType::Union(_, mode) => match mode {
DataType::Union(_, _, mode) => match mode {
UnionMode::Sparse => union::build_extend_sparse(array),
UnionMode::Dense => union::build_extend_dense(array),
},
Expand Down Expand Up @@ -325,7 +325,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::Float16 => primitive::extend_nulls::<f16>,
DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
DataType::Union(_, mode) => match mode {
DataType::Union(_, _, mode) => match mode {
UnionMode::Sparse => union::extend_nulls_sparse,
UnionMode::Dense => union::extend_nulls_dense,
},
Expand Down Expand Up @@ -524,7 +524,7 @@ impl<'a> MutableArrayData<'a> {
.collect::<Vec<_>>();
vec![MutableArrayData::new(childs, use_nulls, array_capacity)]
}
DataType::Union(fields, _) => (0..fields.len())
DataType::Union(fields, _, _) => (0..fields.len())
.map(|i| {
let child_arrays = arrays
.iter()
Expand Down
1 change: 1 addition & 0 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4776,6 +4776,7 @@ mod tests {
Field::new("f1", DataType::Int32, false),
Field::new("f2", DataType::Utf8, true),
],
vec![0, 1],
UnionMode::Dense,
),
Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
Expand Down
25 changes: 10 additions & 15 deletions arrow/src/datatypes/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ pub enum DataType {
LargeList(Box<Field>),
/// A nested datatype that contains a number of sub-fields.
Struct(Vec<Field>),
/// A nested datatype that can represent slots of differing types.
Union(Vec<Field>, UnionMode),
/// A nested datatype that can represent slots of differing types. Components:
///
/// 1. [`Field`] for each possible child type the Union can hold
/// 2. The corresponding `type_id` used to identify which Field
/// 3. The type of union (Sparse or Dense)
Union(Vec<Field>, Vec<i8>, UnionMode),
/// A dictionary encoded array (`key_type`, `value_type`), where
/// each array element is an index of `key_type` into an
/// associated dictionary of `value_type`.
Expand Down Expand Up @@ -516,24 +520,15 @@ impl DataType {
.as_array()
.unwrap()
.iter()
.map(|t| t.as_i64().unwrap())
.map(|t| t.as_i64().unwrap() as i8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could call t.as_i8()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no as_i8() API for Value. 😢

.collect::<Vec<_>>();

let default_fields = type_ids
.iter()
.map(|t| {
Field::new("", DataType::Boolean, true).with_metadata(
Some(
[("type_id".to_string(), t.to_string())]
.iter()
.cloned()
.collect(),
),
)
})
.map(|_| default_field.clone())
.collect::<Vec<_>>();

Ok(DataType::Union(default_fields, union_mode))
Ok(DataType::Union(default_fields, type_ids, union_mode))
} else {
Err(ArrowError::ParseError(
"Expecting a typeIds for union ".to_string(),
Expand Down Expand Up @@ -581,7 +576,7 @@ impl DataType {
json!({"name": "fixedsizebinary", "byteWidth": byte_width})
}
DataType::Struct(_) => json!({"name": "struct"}),
DataType::Union(_, _) => json!({"name": "union"}),
DataType::Union(_, _, _) => json!({"name": "union"}),
DataType::List(_) => json!({ "name": "list"}),
DataType::LargeList(_) => json!({ "name": "largelist"}),
DataType::FixedSizeList(_, length) => {
Expand Down
Loading