From 4192c1234d3a0822718c36cea630a211b8f1729c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 May 2022 18:19:15 -0700 Subject: [PATCH 1/5] Store type ids in Union datatype --- arrow/src/array/array.rs | 4 ++-- arrow/src/array/array_union.rs | 35 ++++++++++++++++++--------- arrow/src/array/builder.rs | 4 +++- arrow/src/array/data.rs | 20 +++++++++------- arrow/src/array/equal/mod.rs | 2 +- arrow/src/array/equal/union.rs | 9 ++++--- arrow/src/array/equal/utils.rs | 2 +- arrow/src/array/transform/mod.rs | 6 ++--- arrow/src/compute/kernels/cast.rs | 1 + arrow/src/datatypes/datatype.rs | 19 ++++----------- arrow/src/datatypes/field.rs | 39 +++++++++++++++++++------------ arrow/src/datatypes/mod.rs | 18 +++++--------- arrow/src/ipc/convert.rs | 24 +++++++++++++++---- arrow/src/ipc/reader.rs | 5 ++-- arrow/src/ipc/writer.rs | 7 +++--- arrow/src/util/display.rs | 21 +++++++++-------- arrow/src/util/pretty.rs | 6 ++++- integration-testing/src/lib.rs | 35 ++++----------------------- parquet/src/arrow/arrow_writer.rs | 2 +- parquet/src/arrow/levels.rs | 6 ++--- parquet/src/arrow/schema.rs | 2 +- 21 files changed, 141 insertions(+), 126 deletions(-) diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index 421e60f04ac5..ed99a6b9fad8 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -364,7 +364,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef, DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, - DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef, + DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef, DataType::FixedSizeList(_, _) => { Arc::new(FixedSizeListArray::from(data)) as ArrayRef } @@ -535,7 +535,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { DataType::Map(field, _keys_sorted) => { new_null_list_array::(data_type, field.data_type(), length) } - DataType::Union(_, _) => { + DataType::Union(_, _, _) => { unimplemented!("Creating null Union array not yet supported") } DataType::Dictionary(key, value) => { diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index 5ebf3d2d3c25..5cfab0bbf858 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -58,6 +58,7 @@ use std::any::Any; /// ]; /// /// let array = UnionArray::try_new( +/// &vec![0, 1], /// type_id_buffer, /// Some(value_offsets_buffer), /// children, @@ -90,6 +91,7 @@ use std::any::Any; /// ]; /// /// let array = UnionArray::try_new( +/// &vec![0, 1], /// type_id_buffer, /// None, /// children, @@ -135,6 +137,7 @@ impl UnionArray { /// `i8` and `i32` values respectively. `Buffer` objects are untyped and no attempt is made /// to ensure that the data provided is valid. pub unsafe fn new_unchecked( + field_type_ids: &[i8], type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, @@ -149,10 +152,14 @@ impl UnionArray { UnionMode::Sparse }; - let builder = ArrayData::builder(DataType::Union(field_types, mode)) - .add_buffer(type_ids) - .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) - .len(len); + let builder = ArrayData::builder(DataType::Union( + field_types, + Vec::from(field_type_ids), + mode, + )) + .add_buffer(type_ids) + .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) + .len(len); let data = match value_offsets { Some(b) => builder.add_buffer(b).build_unchecked(), @@ -163,6 +170,7 @@ impl UnionArray { /// Attempts to create a new `UnionArray`, validating the inputs provided. pub fn try_new( + field_type_ids: &[i8], type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, @@ -209,8 +217,9 @@ impl UnionArray { // Unsafe Justification: arguments were validated above (and // re-revalidated as part of data().validate() below) - let new_self = - unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays) }; + let new_self = unsafe { + Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) + }; new_self.data().validate()?; Ok(new_self) @@ -269,7 +278,7 @@ impl UnionArray { /// Returns the names of the types in the union. pub fn type_names(&self) -> Vec<&str> { match self.data.data_type() { - DataType::Union(fields, _) => fields + DataType::Union(fields, _, _) => fields .iter() .map(|f| f.name().as_str()) .collect::>(), @@ -280,7 +289,7 @@ impl UnionArray { /// Returns whether the `UnionArray` is dense (or sparse if `false`). fn is_dense(&self) -> bool { match self.data.data_type() { - DataType::Union(_, mode) => mode == &UnionMode::Dense, + DataType::Union(_, _, mode) => mode == &UnionMode::Dense, _ => unreachable!("Union array's data type is not a union!"), } } @@ -626,9 +635,13 @@ mod tests { Arc::new(float_array), ), ]; - let array = - UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer), children) - .unwrap(); + let array = UnionArray::try_new( + &[0, 1, 2], + type_id_buffer, + Some(value_offsets_buffer), + children, + ) + .unwrap(); // Check type ids assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]); diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index da6d2f1c354b..091a51b15470 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -2168,7 +2168,9 @@ impl UnionBuilder { }); let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect(); - UnionArray::try_new(type_id_buffer, value_offsets_buffer, children) + let type_ids: Vec = (0_i8..children.len() as i8).collect(); + + UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, children) } } diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index c0ecef75d1c0..0d444a777268 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -194,7 +194,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], - DataType::Union(_, mode) => { + DataType::Union(_, _, mode) => { let type_ids = MutableBuffer::new(capacity * mem::size_of::()); match mode { UnionMode::Sparse => [type_ids, empty_buffer], @@ -220,7 +220,7 @@ pub(crate) fn into_buffers( | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()], - DataType::Union(_, mode) => { + DataType::Union(_, _, mode) => { match mode { // Based on Union's DataTypeLayout UnionMode::Sparse => vec![buffer1.into()], @@ -581,7 +581,7 @@ impl ArrayData { DataType::Map(field, _) => { vec![Self::new_empty(field.data_type())] } - DataType::Union(fields, _) => fields + DataType::Union(fields, _, _) => fields .iter() .map(|field| Self::new_empty(field.data_type())) .collect(), @@ -854,7 +854,7 @@ impl ArrayData { } Ok(()) } - DataType::Union(fields, mode) => { + DataType::Union(fields, _, mode) => { self.validate_num_child_data(fields.len())?; for (i, field) in fields.iter().enumerate() { @@ -1002,7 +1002,7 @@ impl ArrayData { let child = &self.child_data[0]; self.validate_offsets_full::(child.len + child.offset) } - DataType::Union(_, _) => { + DataType::Union(_, _, _) => { // Validate Union Array as part of implementing new Union semantics // See comments in `ArrayData::validate()` // https://github.com/apache/arrow-rs/issues/85 @@ -1279,7 +1279,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout { DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all in child data DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::()), DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data, - DataType::Union(_, mode) => { + DataType::Union(_, _, mode) => { let type_ids = BufferSpec::FixedWidth { byte_width: size_of::(), }; @@ -2431,6 +2431,7 @@ mod tests { Field::new("field1", DataType::Int32, true), Field::new("field2", DataType::Int64, true), // data is int32 ], + vec![0, 1], UnionMode::Sparse, ), 2, @@ -2462,6 +2463,7 @@ mod tests { Field::new("field1", DataType::Int32, true), Field::new("field2", DataType::Int64, true), ], + vec![0, 1], UnionMode::Sparse, ), 2, @@ -2489,6 +2491,7 @@ mod tests { Field::new("field1", DataType::Int32, true), Field::new("field2", DataType::Int64, true), ], + vec![0, 1], UnionMode::Dense, ), 2, @@ -2519,6 +2522,7 @@ mod tests { Field::new("field1", DataType::Int32, true), Field::new("field2", DataType::Int64, true), ], + vec![0, 1], UnionMode::Dense, ), 2, @@ -2631,8 +2635,8 @@ mod tests { #[test] fn test_into_buffers() { let data_types = vec![ - DataType::Union(vec![], UnionMode::Dense), - DataType::Union(vec![], UnionMode::Sparse), + DataType::Union(vec![], vec![], UnionMode::Dense), + DataType::Union(vec![], vec![], UnionMode::Sparse), ]; for data_type in data_types { diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 1a6b9f331407..c45b30cccdba 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -193,7 +193,7 @@ fn equal_values( fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) } DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), - DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Dictionary(data_type, _) => match data_type.as_ref() { DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Int16 => { diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs index 021b0a3b7fe7..55132aafd5f9 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow/src/array/equal/union.rs @@ -76,7 +76,10 @@ pub(super) fn union_equal( let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len]; match (lhs.data_type(), rhs.data_type()) { - (DataType::Union(_, UnionMode::Dense), DataType::Union(_, UnionMode::Dense)) => { + ( + DataType::Union(_, _, UnionMode::Dense), + DataType::Union(_, _, UnionMode::Dense), + ) => { let lhs_offsets = lhs.buffer::(1); let rhs_offsets = rhs.buffer::(1); @@ -94,8 +97,8 @@ pub(super) fn union_equal( ) } ( - DataType::Union(_, UnionMode::Sparse), - DataType::Union(_, UnionMode::Sparse), + DataType::Union(_, _, UnionMode::Sparse), + DataType::Union(_, _, UnionMode::Sparse), ) => { lhs_type_id_range == rhs_type_id_range && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index 8875239caf52..fed3933a0893 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -68,7 +68,7 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { let equal_type = match (lhs.data_type(), rhs.data_type()) { - (DataType::Union(l_fields, l_mode), DataType::Union(r_fields, r_mode)) => { + (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, r_mode)) => { l_fields == r_fields && l_mode == r_mode } (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index aa7d417a19e0..586a4fec2710 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -274,7 +274,7 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), DataType::Float16 => primitive::build_extend::(array), DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), - DataType::Union(_, mode) => match mode { + DataType::Union(_, _, mode) => match mode { UnionMode::Sparse => union::build_extend_sparse(array), UnionMode::Dense => union::build_extend_dense(array), }, @@ -325,7 +325,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, DataType::Float16 => primitive::extend_nulls::, DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, - DataType::Union(_, mode) => match mode { + DataType::Union(_, _, mode) => match mode { UnionMode::Sparse => union::extend_nulls_sparse, UnionMode::Dense => union::extend_nulls_dense, }, @@ -524,7 +524,7 @@ impl<'a> MutableArrayData<'a> { .collect::>(); vec![MutableArrayData::new(childs, use_nulls, array_capacity)] } - DataType::Union(fields, _) => (0..fields.len()) + DataType::Union(fields, _, _) => (0..fields.len()) .map(|i| { let child_arrays = arrays .iter() diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 2c0ebb1e20f8..c989cd2fe5c7 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -4776,6 +4776,7 @@ mod tests { Field::new("f1", DataType::Int32, false), Field::new("f2", DataType::Utf8, true), ], + vec![0, 1], UnionMode::Dense, ), Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index c5cc8f017408..9f56a423a109 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -115,7 +115,7 @@ pub enum DataType { /// A nested datatype that contains a number of sub-fields. Struct(Vec), /// A nested datatype that can represent slots of differing types. - Union(Vec, UnionMode), + Union(Vec, Vec, UnionMode), /// A dictionary encoded array (`key_type`, `value_type`), where /// each array element is an index of `key_type` into an /// associated dictionary of `value_type`. @@ -516,24 +516,15 @@ impl DataType { .as_array() .unwrap() .iter() - .map(|t| t.as_i64().unwrap()) + .map(|t| t.as_i64().unwrap() as i8) .collect::>(); let default_fields = type_ids .iter() - .map(|t| { - Field::new("", DataType::Boolean, true).with_metadata( - Some( - [("type_id".to_string(), t.to_string())] - .iter() - .cloned() - .collect(), - ), - ) - }) + .map(|_| default_field.clone()) .collect::>(); - Ok(DataType::Union(default_fields, union_mode)) + Ok(DataType::Union(default_fields, type_ids, union_mode)) } else { Err(ArrowError::ParseError( "Expecting a typeIds for union ".to_string(), @@ -581,7 +572,7 @@ impl DataType { json!({"name": "fixedsizebinary", "byteWidth": byte_width}) } DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_, _) => json!({"name": "union"}), + DataType::Union(_, _, _) => json!({"name": "union"}), DataType::List(_) => json!({ "name": "list"}), DataType::LargeList(_) => json!({ "name": "largelist"}), DataType::FixedSizeList(_, length) => { diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index 6471f1ed7e73..5025d32a4f37 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -168,7 +168,7 @@ impl Field { let mut collected_fields = vec![]; match dt { - DataType::Struct(fields) | DataType::Union(fields, _) => { + DataType::Struct(fields) | DataType::Union(fields, _, _) => { collected_fields.extend(fields.iter().flat_map(|f| f.fields())) } DataType::List(field) @@ -390,18 +390,11 @@ impl Field { } } } - DataType::Union(fields, mode) => match map.get("children") { + DataType::Union(_, type_ids, mode) => match map.get("children") { Some(Value::Array(values)) => { - let mut union_fields: Vec = + let union_fields: Vec = values.iter().map(Field::from).collect::>()?; - fields.iter().zip(union_fields.iter_mut()).for_each( - |(f, union_field)| { - union_field.set_metadata(Some( - f.metadata().unwrap().clone(), - )); - }, - ); - DataType::Union(union_fields, mode) + DataType::Union(union_fields, type_ids, mode) } Some(_) => { return Err(ArrowError::ParseError( @@ -568,18 +561,34 @@ impl Field { )); } }, - DataType::Union(nested_fields, _) => match &from.data_type { - DataType::Union(from_nested_fields, _) => { - for from_field in from_nested_fields { + DataType::Union(nested_fields, type_ids, _) => match &from.data_type { + DataType::Union(from_nested_fields, from_type_ids, _) => { + for (idx, from_field) in from_nested_fields.iter().enumerate() { let mut is_new_field = true; - for self_field in nested_fields.iter_mut() { + let field_type_id = from_type_ids.get(idx).unwrap(); + + for (self_idx, self_field) in nested_fields.iter_mut().enumerate() + { if from_field == self_field { + let self_type_id = type_ids.get(self_idx).unwrap(); + + // If the nested fields in two unions are the same, they must have same + // type id. + if self_type_id != field_type_id { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting type ids in union datatype" + .to_string(), + )); + } + is_new_field = false; break; } } + if is_new_field { nested_fields.push(from_field.clone()); + type_ids.push(*field_type_id); } } } diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index c3015972a8b4..47074633d7e2 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -435,19 +435,10 @@ mod tests { "my_union", DataType::Union( vec![ - Field::new("f1", DataType::Int32, true).with_metadata(Some( - [("type_id".to_string(), "5".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("f2", DataType::Utf8, true).with_metadata(Some( - [("type_id".to_string(), "7".to_string())] - .iter() - .cloned() - .collect(), - )), + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), ], + vec![5, 7], UnionMode::Sparse, ), false, @@ -1444,6 +1435,7 @@ mod tests { Field::new("c11", DataType::Utf8, true), Field::new("c12", DataType::Utf8, true), ], + vec![0, 1], UnionMode::Dense ), false @@ -1455,6 +1447,7 @@ mod tests { Field::new("c12", DataType::Utf8, true), Field::new("c13", DataType::Time64(TimeUnit::Second), true), ], + vec![1, 2], UnionMode::Dense ), false @@ -1468,6 +1461,7 @@ mod tests { Field::new("c12", DataType::Utf8, true), Field::new("c13", DataType::Time64(TimeUnit::Second), true), ], + vec![0, 1, 2], UnionMode::Dense ), false diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index 97ed9ed78829..d5afe877c2d6 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -338,7 +338,12 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT } }; - DataType::Union(fields, union_mode) + let type_ids: Vec = match union.typeIds() { + None => (0_i8..fields.len() as i8).collect(), + Some(ids) => ids.iter().map(|i| i as i8).collect(), + }; + + DataType::Union(fields, type_ids, union_mode) } t => unimplemented!("Type {:?} not supported", t), } @@ -666,7 +671,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - Union(fields, mode) => { + Union(fields, _, mode) => { let mut children = vec![]; for field in fields { children.push(build_field(fbb, field)); @@ -874,6 +879,7 @@ mod tests { DataType::List(Box::new(Field::new( "union", DataType::Union( + vec![], vec![], UnionMode::Sparse, ), @@ -882,6 +888,7 @@ mod tests { false, ), ], + vec![0, 1], UnionMode::Dense, ), false, @@ -889,13 +896,22 @@ mod tests { false, ), ], + vec![0, 1], UnionMode::Sparse, ), false, ), Field::new("struct<>", DataType::Struct(vec![]), true), - Field::new("union<>", DataType::Union(vec![], UnionMode::Dense), true), - Field::new("union<>", DataType::Union(vec![], UnionMode::Sparse), true), + Field::new( + "union<>", + DataType::Union(vec![], vec![], UnionMode::Dense), + true, + ), + Field::new( + "union<>", + DataType::Union(vec![], vec![], UnionMode::Sparse), + true, + ), Field::new_dict( "dictionary", DataType::Dictionary( diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 4a73269e5975..662b384c4b38 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -195,7 +195,7 @@ fn create_array( value_array.clone(), ) } - Union(fields, mode) => { + Union(fields, field_type_ids, mode) => { let union_node = nodes[node_index]; node_index += 1; @@ -234,7 +234,8 @@ fn create_array( children.push((field.clone(), triple.0)); } - let array = UnionArray::try_new(type_ids, value_offsets, children)?; + let array = + UnionArray::try_new(field_type_ids, type_ids, value_offsets, children)?; Arc::new(array) } Null => { diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index c03d5e449537..f61d4ce4c62a 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -221,7 +221,7 @@ impl IpcDataGenerator { write_options, )?; } - DataType::Union(fields, _) => { + DataType::Union(fields, _, _) => { let union = as_union_array(column); for (field, ref column) in fields .iter() @@ -865,7 +865,7 @@ fn write_array_data( // UnionArray does not have a validity buffer if !matches!( array_data.data_type(), - DataType::Null | DataType::Union(_, _) + DataType::Null | DataType::Union(_, _, _) ) { // write null buffer if exists let null_buffer = match array_data.null_buffer() { @@ -1328,7 +1328,8 @@ mod tests { let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]); let union = - UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)]).unwrap(); + UnionArray::try_new(&[0], types, Some(offsets), vec![(dctfield, array)]) + .unwrap(); let schema = Arc::new(Schema::new(vec![Field::new( "union", diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs index b0493b6ce0d3..6da73e4cff67 100644 --- a/arrow/src/util/display.rs +++ b/arrow/src/util/display.rs @@ -396,7 +396,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result union_to_string(column, row, field_vec, mode), + DataType::Union(field_vec, type_ids, mode) => { + union_to_string(column, row, field_vec, type_ids, mode) + } _ => Err(ArrowError::InvalidArgumentError(format!( "Pretty printing not implemented for {:?} type", column.data_type() @@ -409,6 +411,7 @@ fn union_to_string( column: &array::ArrayRef, row: usize, fields: &[Field], + type_ids: &[i8], mode: &UnionMode, ) -> Result { let list = column @@ -420,15 +423,13 @@ fn union_to_string( ) })?; let type_id = list.type_id(row); - let name = fields - .get(type_id as usize) - .ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Repl error: could not get field name for type id: {} in union array.", - type_id, - )) - })? - .name(); + let field_idx = type_ids.iter().position(|t| t == &type_id).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Repl error: could not get field name for type id: {} in union array.", + type_id, + )) + })?; + let name = fields.get(field_idx).unwrap().name(); let value = array_value_to_string( &list.child(type_id), diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs index 3fa2729ba412..124de6127ddd 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow/src/util/pretty.rs @@ -664,6 +664,7 @@ mod tests { Field::new("a", DataType::Int32, false), Field::new("b", DataType::Float64, false), ], + vec![0, 1], UnionMode::Dense, ), false, @@ -704,6 +705,7 @@ mod tests { Field::new("a", DataType::Int32, false), Field::new("b", DataType::Float64, false), ], + vec![0, 1], UnionMode::Sparse, ), false, @@ -746,6 +748,7 @@ mod tests { Field::new("b", DataType::Int32, false), Field::new("c", DataType::Float64, false), ], + vec![0, 1], UnionMode::Dense, ), false, @@ -760,12 +763,13 @@ mod tests { (inner_field.clone(), Arc::new(inner)), ]; - let outer = UnionArray::try_new(type_ids, None, children).unwrap(); + let outer = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap(); let schema = Schema::new(vec![Field::new( "Teamsters", DataType::Union( vec![Field::new("a", DataType::Int32, true), inner_field], + vec![0, 1], UnionMode::Sparse, ), false, diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index c7045993837b..c57ef32bca04 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -632,39 +632,13 @@ fn array_from_json( let array = MapArray::from(array_data); Ok(Arc::new(array)) } - DataType::Union(fields, _) => { - let field_type_ids = fields - .iter() - .enumerate() - .into_iter() - .map(|(idx, f)| { - ( - f.metadata() - .and_then(|m| m.get("type_id")) - .unwrap() - .parse::() - .unwrap(), - idx, - ) - }) - .collect::>(); - + DataType::Union(fields, field_type_ids, _) => { let type_ids = if let Some(type_id) = json_col.type_id { type_id - .iter() - .map(|t| { - if field_type_ids.contains_key(t) { - Ok(*(field_type_ids.get(t).unwrap()) as i8) - } else { - Err(ArrowError::JsonError(format!( - "Unable to find type id {:?}", - t - ))) - } - }) - .collect::>()? } else { - vec![] + return Err(ArrowError::JsonError( + "Cannot find expected type_id in json column".to_string(), + )); }; let offset: Option = json_col.offset.map(|offsets| { @@ -680,6 +654,7 @@ fn array_from_json( } let array = UnionArray::try_new( + field_type_ids, Buffer::from(&type_ids.to_byte_slice()), offset, children, diff --git a/parquet/src/arrow/arrow_writer.rs b/parquet/src/arrow/arrow_writer.rs index 7ddd6443230e..1918c967550a 100644 --- a/parquet/src/arrow/arrow_writer.rs +++ b/parquet/src/arrow/arrow_writer.rs @@ -324,7 +324,7 @@ fn write_leaves( ArrowDataType::Float16 => Err(ParquetError::ArrowError( "Float16 arrays not supported".to_string(), )), - ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _) => { + ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) => { Err(ParquetError::NYI( format!( "Attempting to write an Arrow type {:?} to parquet that is not yet implemented", diff --git a/parquet/src/arrow/levels.rs b/parquet/src/arrow/levels.rs index a1979e591936..be9a5e99323b 100644 --- a/parquet/src/arrow/levels.rs +++ b/parquet/src/arrow/levels.rs @@ -240,7 +240,7 @@ impl LevelInfo { list_level.calculate_array_levels(&child_array, list_field) } DataType::FixedSizeList(_, _) => unimplemented!(), - DataType::Union(_, _) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), } } DataType::Map(map_field, _) => { @@ -310,7 +310,7 @@ impl LevelInfo { }); struct_levels } - DataType::Union(_, _) => unimplemented!(), + DataType::Union(_, _, _) => unimplemented!(), DataType::Dictionary(_, _) => { // Need to check for these cases not implemented in C++: // - "Writing DictionaryArray with nested dictionary type not yet supported" @@ -749,7 +749,7 @@ impl LevelInfo { array_mask, ) } - DataType::FixedSizeList(_, _) | DataType::Union(_, _) => { + DataType::FixedSizeList(_, _) | DataType::Union(_, _, _) => { unimplemented!("Getting offsets not yet implemented") } } diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs index 71184e0b6fae..07c50d11c223 100644 --- a/parquet/src/arrow/schema.rs +++ b/parquet/src/arrow/schema.rs @@ -520,7 +520,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { )) } } - DataType::Union(_, _) => unimplemented!("See ARROW-8817."), + DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."), DataType::Dictionary(_, ref value) => { // Dictionary encoding not handled at the schema level let dict_field = Field::new(name, *value.clone(), field.is_nullable()); From 73c69fc049079cfe22f9b5c68ed9441e1ceea578 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 May 2022 09:59:38 -0700 Subject: [PATCH 2/5] Add doc as suggested and put type ids in ipc --- arrow/src/datatypes/datatype.rs | 6 +++++- arrow/src/ipc/convert.rs | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index 9f56a423a109..a740e8ecc019 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -114,7 +114,11 @@ pub enum DataType { LargeList(Box), /// A nested datatype that contains a number of sub-fields. Struct(Vec), - /// A nested datatype that can represent slots of differing types. + /// A nested datatype that can represent slots of differing types. Components: + /// + /// 1. [`Field`] for each possible child type the Union can hold + /// 2. The corresponding `type_id` used to identify which Field + /// 3. The type of union (Sparse or Dense) Union(Vec, Vec, UnionMode), /// A dictionary encoded array (`key_type`, `value_type`), where /// each array element is an index of `key_type` into an diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index d5afe877c2d6..f8d1d02b7ff2 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -671,7 +671,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - Union(fields, _, mode) => { + Union(fields, type_ids, mode) => { let mut children = vec![]; for field in fields { children.push(build_field(fbb, field)); @@ -682,8 +682,11 @@ pub(crate) fn get_fb_field_type<'a>( UnionMode::Dense => ipc::UnionMode::Dense, }; + let fbb_type_ids = fbb + .create_vector(&type_ids.iter().map(|t| *t as i32).collect::>()); let mut builder = ipc::UnionBuilder::new(fbb); builder.add_mode(union_mode); + builder.add_typeIds(fbb_type_ids); FBFieldType { type_type: ipc::Type::Union, From 66bebf3bcce404b367680425617ae8f43de80ead Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 May 2022 10:11:23 -0700 Subject: [PATCH 3/5] Add test --- arrow/src/ipc/convert.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index f8d1d02b7ff2..c81ea8278c4f 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -915,6 +915,18 @@ mod tests { DataType::Union(vec![], vec![], UnionMode::Sparse), true, ), + Field::new( + "union", + DataType::Union( + vec![ + Field::new("int32", DataType::Int32, true), + Field::new("utf8", DataType::Utf8, true), + ], + vec![2, 3], // non-default type ids + UnionMode::Dense, + ), + true, + ), Field::new_dict( "dictionary", DataType::Dictionary( From 46cfbfc974323b0633905aa3cee619e9a07043e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 May 2022 10:24:50 -0700 Subject: [PATCH 4/5] Fix equal_dense --- arrow/src/array/equal/union.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs index 55132aafd5f9..7ecdfb99f987 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow/src/array/equal/union.rs @@ -26,6 +26,8 @@ fn equal_dense( rhs_type_ids: &[i8], lhs_offsets: &[i32], rhs_offsets: &[i32], + lhs_field_type_ids: &[i8], + rhs_field_type_ids: &[i8], ) -> bool { let offsets = lhs_offsets.iter().zip(rhs_offsets.iter()); @@ -34,8 +36,16 @@ fn equal_dense( .zip(rhs_type_ids.iter()) .zip(offsets) .all(|((l_type_id, r_type_id), (l_offset, r_offset))| { - let lhs_values = &lhs.child_data()[*l_type_id as usize]; - let rhs_values = &rhs.child_data()[*r_type_id as usize]; + let lhs_child_index = lhs_field_type_ids + .iter() + .position(|r| r == l_type_id) + .unwrap(); + let rhs_child_index = rhs_field_type_ids + .iter() + .position(|r| r == r_type_id) + .unwrap(); + let lhs_values = &lhs.child_data()[lhs_child_index]; + let rhs_values = &rhs.child_data()[rhs_child_index]; equal_range( lhs_values, @@ -77,8 +87,8 @@ pub(super) fn union_equal( match (lhs.data_type(), rhs.data_type()) { ( - DataType::Union(_, _, UnionMode::Dense), - DataType::Union(_, _, UnionMode::Dense), + DataType::Union(_, lhs_type_ids, UnionMode::Dense), + DataType::Union(_, rhs_type_ids, UnionMode::Dense), ) => { let lhs_offsets = lhs.buffer::(1); let rhs_offsets = rhs.buffer::(1); @@ -94,6 +104,8 @@ pub(super) fn union_equal( rhs_type_id_range, lhs_offsets_range, rhs_offsets_range, + lhs_type_ids, + rhs_type_ids, ) } ( From 7286c2852d14cdc9d4ba50517beaed07e2b4a290 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 May 2022 13:04:49 -0700 Subject: [PATCH 5/5] Fix clippy --- arrow/src/array/equal/union.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs index 7ecdfb99f987..e8b9d27b6f0f 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow/src/array/equal/union.rs @@ -19,6 +19,7 @@ use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode}; use super::equal_range; +#[allow(clippy::too_many_arguments)] fn equal_dense( lhs: &ArrayData, rhs: &ArrayData,