From 9aa67b272b6e19005b1f374e7810c59410f4137f Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Mon, 25 Sep 2023 06:04:18 -0700 Subject: [PATCH] GH-36795: [C#] Implement support for dense and sparse unions (#36797) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes are included in this PR? Support dense and sparse unions in the C# implementation. Adds Archery support for C# unions. ### Are these changes tested? Yes ### Are there any user-facing changes? Unions are now supported in the C# implementation. **This PR includes breaking changes to public APIs.** The public APIs for the UnionArray and UnionType were changed fairly substantially. As these were previously not implemented properly, the impact of the changes ought to be minimal. The ChunkedArray and Column classes were changed to hold IArrowArrays instead of Arrays. To accomodate this, a constructor was added which may introduce ambiguity in calling code. This could be avoided by changing the overloaded constructor to instead be a factory method. This didn't seem worthwhile but could be reconsidered. The metadata version was finally increased to V5.   * Closes: #36795 Authored-by: Curt Hagenlocher Signed-off-by: David Li --- csharp/src/Apache.Arrow/Arrays/Array.cs | 13 +-- .../Arrays/ArrayDataConcatenator.cs | 62 ++++++++++- .../Arrays/ArrayDataTypeComparer.cs | 12 +- .../Apache.Arrow/Arrays/ArrowArrayFactory.cs | 16 ++- .../Apache.Arrow/Arrays/DenseUnionArray.cs | 52 +++++++++ .../Arrays/PrimitiveArrayBuilder.cs | 3 + .../Apache.Arrow/Arrays/SparseUnionArray.cs | 46 ++++++++ csharp/src/Apache.Arrow/Arrays/UnionArray.cs | 77 ++++++++++--- .../src/Apache.Arrow/C/CArrowArrayImporter.cs | 38 +++++++ .../Apache.Arrow/C/CArrowSchemaExporter.cs | 18 +++ .../Apache.Arrow/C/CArrowSchemaImporter.cs | 56 +++++++--- csharp/src/Apache.Arrow/ChunkedArray.cs | 30 +++-- csharp/src/Apache.Arrow/Column.cs | 24 ++-- .../Extensions/FlatbufExtensions.cs | 10 ++ .../Apache.Arrow/Interfaces/IArrowArray.cs | 4 - .../Ipc/ArrowReaderImplementation.cs | 75 ++++++++----- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 19 +++- .../Ipc/ArrowTypeFlatbufferBuilder.cs | 14 ++- .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 4 + csharp/src/Apache.Arrow/Table.cs | 4 +- csharp/src/Apache.Arrow/Types/UnionType.cs | 11 +- .../IntegrationCommand.cs | 63 ++++++++++- .../Apache.Arrow.IntegrationTest/JsonFile.cs | 4 + .../Apache.Arrow.Tests/ArrayTypeComparer.cs | 19 +++- .../ArrowArrayConcatenatorTests.cs | 104 +++++++++++++++++- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 19 ++++ .../CDataInterfacePythonTests.cs | 36 ++++-- csharp/test/Apache.Arrow.Tests/ColumnTests.cs | 2 +- csharp/test/Apache.Arrow.Tests/TableTests.cs | 10 +- csharp/test/Apache.Arrow.Tests/TestData.cs | 64 +++++++++++ dev/archery/archery/integration/datagen.py | 3 +- docs/source/status.rst | 4 +- 32 files changed, 797 insertions(+), 119 deletions(-) create mode 100644 csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs create mode 100644 csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs diff --git a/csharp/src/Apache.Arrow/Arrays/Array.cs b/csharp/src/Apache.Arrow/Arrays/Array.cs index a453b0807267f..0838134b19c6d 100644 --- a/csharp/src/Apache.Arrow/Arrays/Array.cs +++ b/csharp/src/Apache.Arrow/Arrays/Array.cs @@ -62,16 +62,7 @@ internal static void Accept(T array, IArrowArrayVisitor visitor) public Array Slice(int offset, int length) { - if (offset > Length) - { - throw new ArgumentException($"Offset {offset} cannot be greater than Length {Length} for Array.Slice"); - } - - length = Math.Min(Data.Length - offset, length); - offset += Data.Offset; - - ArrayData newData = Data.Slice(offset, length); - return ArrowArrayFactory.BuildArray(newData) as Array; + return ArrowArrayFactory.Slice(this, offset, length) as Array; } public void Dispose() @@ -88,4 +79,4 @@ protected virtual void Dispose(bool disposing) } } } -} \ No newline at end of file +} diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs index 8859ecd7f05b9..806defdc7ce66 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs @@ -49,7 +49,8 @@ private class ArrayDataConcatenationVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { public ArrayData Result { get; private set; } private readonly IReadOnlyList _arrayDataList; @@ -123,6 +124,33 @@ public void Visit(StructType type) Result = new ArrayData(type, _arrayDataList[0].Length, _arrayDataList[0].NullCount, 0, _arrayDataList[0].Buffers, children); } + public void Visit(UnionType type) + { + int bufferCount = type.Mode switch + { + UnionMode.Sparse => 1, + UnionMode.Dense => 2, + _ => throw new InvalidOperationException("TODO"), + }; + + CheckData(type, bufferCount); + List children = new List(type.Fields.Count); + + for (int i = 0; i < type.Fields.Count; i++) + { + children.Add(Concatenate(SelectChildren(i), _allocator)); + } + + ArrowBuffer[] buffers = new ArrowBuffer[bufferCount]; + buffers[0] = ConcatenateUnionTypeBuffer(); + if (bufferCount > 1) + { + buffers[1] = ConcatenateUnionOffsetBuffer(); + } + + Result = new ArrayData(type, _totalLength, _totalNullCount, 0, buffers, children); + } + public void Visit(IArrowType type) { throw new NotImplementedException($"Concatenation for {type.Name} is not supported yet."); @@ -231,6 +259,38 @@ private ArrowBuffer ConcatenateOffsetBuffer() return builder.Build(_allocator); } + private ArrowBuffer ConcatenateUnionTypeBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + + foreach (ArrayData arrayData in _arrayDataList) + { + builder.Append(arrayData.Buffers[0]); + } + + return builder.Build(_allocator); + } + + private ArrowBuffer ConcatenateUnionOffsetBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + int baseOffset = 0; + + foreach (ArrayData arrayData in _arrayDataList) + { + ReadOnlySpan span = arrayData.Buffers[1].Span.CastTo(); + foreach (int offset in span) + { + builder.Append(baseOffset + offset); + } + + // The next offset must start from the current last offset. + baseOffset += span[arrayData.Length]; + } + + return builder.Build(_allocator); + } + private List SelectChildren(int index) { var children = new List(_arrayDataList.Count); diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs index 8a6bfed29abb6..6b54ec1edb573 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs @@ -27,7 +27,8 @@ internal sealed class ArrayDataTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; private bool _dataTypeMatch; @@ -122,6 +123,15 @@ public void Visit(StructType actualType) } } + public void Visit(UnionType actualType) + { + if (_expectedType is UnionType expectedType + && CompareNested(expectedType, actualType)) + { + _dataTypeMatch = true; + } + } + private static bool CompareNested(NestedType expectedType, NestedType actualType) { if (expectedType.Fields.Count != actualType.Fields.Count) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index f82037bff47b1..aa407203d1858 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -62,7 +62,7 @@ public static IArrowArray BuildArray(ArrayData data) case ArrowTypeId.Struct: return new StructArray(data); case ArrowTypeId.Union: - return new UnionArray(data); + return UnionArray.Create(data); case ArrowTypeId.Date64: return new Date64Array(data); case ArrowTypeId.Date32: @@ -91,5 +91,19 @@ public static IArrowArray BuildArray(ArrayData data) throw new NotSupportedException($"An ArrowArray cannot be built for type {data.DataType.TypeId}."); } } + + public static IArrowArray Slice(IArrowArray array, int offset, int length) + { + if (offset > array.Length) + { + throw new ArgumentException($"Offset {offset} cannot be greater than Length {array.Length} for Array.Slice"); + } + + length = Math.Min(array.Data.Length - offset, length); + offset += array.Data.Offset; + + ArrayData newData = array.Data.Slice(offset, length); + return BuildArray(newData); + } } } diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs new file mode 100644 index 0000000000000..1aacbe11f08b9 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Apache.Arrow +{ + public class DenseUnionArray : UnionArray + { + public ArrowBuffer ValueOffsetBuffer => Data.Buffers[1]; + + public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo(); + + public DenseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + ArrowBuffer valuesOffsetBuffer, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds, valuesOffsetBuffer }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Dense, Type.Mode); + } + + public DenseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Dense, Type.Mode); + data.EnsureBufferCount(2); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs index a50d4b52c3257..67fe46633c18f 100644 --- a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs +++ b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs @@ -137,6 +137,9 @@ public TBuilder Append(T value) return Instance; } + public TBuilder Append(T? value) => + (value == null) ? AppendNull() : Append(value.Value); + public TBuilder Append(ReadOnlySpan span) { int len = ValueBuffer.Length; diff --git a/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs new file mode 100644 index 0000000000000..b79c44c979e47 --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System.Collections.Generic; +using System.Linq; + +namespace Apache.Arrow +{ + public class SparseUnionArray : UnionArray + { + public SparseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Sparse, Type.Mode); + } + + public SparseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Sparse, Type.Mode); + data.EnsureBufferCount(1); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs index 8bccea2b59e31..0a7ae288fd0c5 100644 --- a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs @@ -15,37 +15,88 @@ using Apache.Arrow.Types; using System; +using System.Collections.Generic; +using System.Threading; namespace Apache.Arrow { - public class UnionArray: Array + public abstract class UnionArray : IArrowArray { - public UnionType Type => Data.DataType as UnionType; + protected IReadOnlyList _fields; - public UnionMode Mode => Type.Mode; + public IReadOnlyList Fields => + LazyInitializer.EnsureInitialized(ref _fields, () => InitializeFields()); + + public ArrayData Data { get; } - public ArrowBuffer TypeBuffer => Data.Buffers[1]; + public UnionType Type => (UnionType)Data.DataType; - public ArrowBuffer ValueOffsetBuffer => Data.Buffers[2]; + public UnionMode Mode => Type.Mode; + + public ArrowBuffer TypeBuffer => Data.Buffers[0]; public ReadOnlySpan TypeIds => TypeBuffer.Span; - public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo().Slice(0, Length + 1); + public int Length => Data.Length; + + public int Offset => Data.Offset; - public UnionArray(ArrayData data) - : base(data) + public int NullCount => Data.NullCount; + + public bool IsValid(int index) => NullCount == 0 || Fields[TypeIds[index]].IsValid(index); + + public bool IsNull(int index) => !IsValid(index); + + protected UnionArray(ArrayData data) { + Data = data; data.EnsureDataType(ArrowTypeId.Union); - data.EnsureBufferCount(3); } - public IArrowArray GetChild(int index) + public static UnionArray Create(ArrayData data) { - // TODO: Implement - throw new NotImplementedException(); + return ((UnionType)data.DataType).Mode switch + { + UnionMode.Dense => new DenseUnionArray(data), + UnionMode.Sparse => new SparseUnionArray(data), + _ => throw new InvalidOperationException("unknown union mode in array creation") + }; } - public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + public void Accept(IArrowArrayVisitor visitor) => Array.Accept(this, visitor); + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + Data.Dispose(); + } + } + + protected static void ValidateMode(UnionMode expected, UnionMode actual) + { + if (expected != actual) + { + throw new ArgumentException( + $"Specified union mode <{actual}> does not match expected mode <{expected}>", + "Mode"); + } + } + + private IReadOnlyList InitializeFields() + { + IArrowArray[] result = new IArrowArray[Data.Children.Length]; + for (int i = 0; i < Data.Children.Length; i++) + { + result[i] = ArrowArrayFactory.BuildArray(Data.Children[i]); + } + return result; + } } } diff --git a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs index 9b7bcb7abe5a5..da1b0f31b8f08 100644 --- a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs @@ -170,6 +170,15 @@ private ArrayData GetAsArrayData(CArrowArray* cArray, IArrowType type) buffers = new ArrowBuffer[] { ImportValidityBuffer(cArray) }; break; case ArrowTypeId.Union: + UnionType unionType = (UnionType)type; + children = ProcessStructChildren(cArray, unionType.Fields); + buffers = unionType.Mode switch + { + UnionMode.Dense => ImportDenseUnionBuffers(cArray), + UnionMode.Sparse => ImportSparseUnionBuffers(cArray), + _ => throw new InvalidOperationException("unknown union mode in import") + }; ; + break; case ArrowTypeId.Map: break; case ArrowTypeId.Null: @@ -286,6 +295,35 @@ private ArrowBuffer[] ImportFixedSizeListBuffers(CArrowArray* cArray) return buffers; } + private ArrowBuffer[] ImportDenseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 2) + { + throw new InvalidOperationException("Dense union arrays are expected to have exactly two children"); + } + int length = checked((int)cArray->length); + int offsetsLength = length * 4; + + ArrowBuffer[] buffers = new ArrowBuffer[2]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, length)); + buffers[1] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[1], 0, offsetsLength)); + + return buffers; + } + + private ArrowBuffer[] ImportSparseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 1) + { + throw new InvalidOperationException("Sparse union arrays are expected to have exactly one child"); + } + + ArrowBuffer[] buffers = new ArrowBuffer[1]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, checked((int)cArray->length))); + + return buffers; + } + private ArrowBuffer[] ImportFixedWidthBuffers(CArrowArray* cArray, int bitWidth) { if (cArray->n_buffers != 2) diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index 66142da331ac8..c1a12362a942a 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -124,6 +124,23 @@ public static unsafe void ExportSchema(Schema schema, CArrowSchema* out_schema) _ => throw new InvalidDataException($"Unsupported time unit for export: {unit}"), }; + private static string FormatUnion(UnionType unionType) + { + StringBuilder builder = new StringBuilder(); + builder.Append(unionType.Mode switch + { + UnionMode.Sparse => "+us:", + UnionMode.Dense => "+ud:", + _ => throw new InvalidDataException($"Unsupported union mode for export: {unionType.Mode}"), + }); + for (int i = 0; i < unionType.TypeIds.Length; i++) + { + if (i > 0) { builder.Append(','); } + builder.Append(unionType.TypeIds[i]); + } + return builder.ToString(); + } + private static string GetFormat(IArrowType datatype) { switch (datatype) @@ -170,6 +187,7 @@ private static string GetFormat(IArrowType datatype) case FixedSizeListType fixedListType: return $"+w:{fixedListType.ListSize}"; case StructType _: return "+s"; + case UnionType u: return FormatUnion(u); // Dictionary case DictionaryType dictionaryType: return GetFormat(dictionaryType.IndexType); diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs index 2a750d5e8250d..42c8cdd5ef548 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs @@ -184,21 +184,7 @@ public ArrowType GetAsType() } else if (format == "+s") { - var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; - - for (int i = 0; i < _cSchema->n_children; i++) - { - if (_cSchema->GetChild(i) == null) - { - throw new InvalidDataException("Expected struct type child to be non-null."); - } - child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); - } - - - List childFields = child_schemas.Select(schema => schema.GetAsField()).ToList(); - - return new StructType(childFields); + return new StructType(ParseChildren("struct")); } else if (format.StartsWith("+w:")) { @@ -265,6 +251,30 @@ public ArrowType GetAsType() return new FixedSizeBinaryType(width); } + // Unions + if (format.StartsWith("+ud:") || format.StartsWith("+us:")) + { + UnionMode unionMode = format[2] == 'd' ? UnionMode.Dense : UnionMode.Sparse; + List typeIds = new List(); + int pos = 4; + do + { + int next = format.IndexOf(',', pos); + if (next < 0) { next = format.Length; } + + int code; + if (!int.TryParse(format.Substring(pos, next - pos), out code)) + { + throw new InvalidDataException($"Invalid type code for union import: {format.Substring(pos, next - pos)}"); + } + typeIds.Add(code); + + pos = next + 1; + } while (pos < format.Length); + + return new UnionType(ParseChildren("union"), typeIds, unionMode); + } + return format switch { // Primitives @@ -324,6 +334,22 @@ public Schema GetAsSchema() } } + private List ParseChildren(string typeName) + { + var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; + + for (int i = 0; i < _cSchema->n_children; i++) + { + if (_cSchema->GetChild(i) == null) + { + throw new InvalidDataException($"Expected {typeName} type child to be non-null."); + } + child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); + } + + return child_schemas.Select(schema => schema.GetAsField()).ToList(); + } + private unsafe static IReadOnlyDictionary GetMetadata(byte* metadata) { if (metadata == null) diff --git a/csharp/src/Apache.Arrow/ChunkedArray.cs b/csharp/src/Apache.Arrow/ChunkedArray.cs index 5f25acfe04a2f..f5909f5adfe48 100644 --- a/csharp/src/Apache.Arrow/ChunkedArray.cs +++ b/csharp/src/Apache.Arrow/ChunkedArray.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using Apache.Arrow; using Apache.Arrow.Types; namespace Apache.Arrow @@ -25,7 +24,7 @@ namespace Apache.Arrow /// public class ChunkedArray { - private IList Arrays { get; } + private IList Arrays { get; } public IArrowType DataType { get; } public long Length { get; } public long NullCount { get; } @@ -35,9 +34,16 @@ public int ArrayCount get => Arrays.Count; } - public Array Array(int index) => Arrays[index]; + public Array Array(int index) => Arrays[index] as Array; + + public IArrowArray ArrowArray(int index) => Arrays[index]; public ChunkedArray(IList arrays) + : this(Cast(arrays)) + { + } + + public ChunkedArray(IList arrays) { Arrays = arrays ?? throw new ArgumentNullException(nameof(arrays)); if (arrays.Count < 1) @@ -45,14 +51,14 @@ public ChunkedArray(IList arrays) throw new ArgumentException($"Count must be at least 1. Got {arrays.Count} instead"); } DataType = arrays[0].Data.DataType; - foreach (Array array in arrays) + foreach (IArrowArray array in arrays) { Length += array.Length; NullCount += array.NullCount; } } - public ChunkedArray(Array array) : this(new[] { array }) { } + public ChunkedArray(Array array) : this(new IArrowArray[] { array }) { } public ChunkedArray Slice(long offset, long length) { @@ -69,10 +75,10 @@ public ChunkedArray Slice(long offset, long length) curArrayIndex++; } - IList newArrays = new List(); + IList newArrays = new List(); while (curArrayIndex < numArrays && length > 0) { - newArrays.Add(Arrays[curArrayIndex].Slice((int)offset, + newArrays.Add(ArrowArrayFactory.Slice(Arrays[curArrayIndex], (int)offset, length > Arrays[curArrayIndex].Length ? Arrays[curArrayIndex].Length : (int)length)); length -= Arrays[curArrayIndex].Length - offset; offset = 0; @@ -86,6 +92,16 @@ public ChunkedArray Slice(long offset) return Slice(offset, Length - offset); } + private static IArrowArray[] Cast(IList arrays) + { + IArrowArray[] arrowArrays = new IArrowArray[arrays.Count]; + for (int i = 0; i < arrays.Count; i++) + { + arrowArrays[i] = arrays[i]; + } + return arrowArrays; + } + // TODO: Flatten for Structs } } diff --git a/csharp/src/Apache.Arrow/Column.cs b/csharp/src/Apache.Arrow/Column.cs index 4eaf9a559e75d..0709b9142cafd 100644 --- a/csharp/src/Apache.Arrow/Column.cs +++ b/csharp/src/Apache.Arrow/Column.cs @@ -28,19 +28,23 @@ public class Column public ChunkedArray Data { get; } public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) + { + } + + public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) { - Data = new ChunkedArray(arrays); - Field = field; - if (!ValidateArrayDataTypes()) - { - throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); - } } - private Column(Field field, ChunkedArray arrays) + private Column(Field field, ChunkedArray data, bool doValidation = false) { + Data = data; Field = field; - Data = arrays; + if (doValidation && !ValidateArrayDataTypes()) + { + throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); + } } public long Length => Data.Length; @@ -64,12 +68,12 @@ private bool ValidateArrayDataTypes() for (int i = 0; i < Data.ArrayCount; i++) { - if (Data.Array(i).Data.DataType.TypeId != Field.DataType.TypeId) + if (Data.ArrowArray(i).Data.DataType.TypeId != Field.DataType.TypeId) { return false; } - Data.Array(i).Data.DataType.Accept(dataTypeComparer); + Data.ArrowArray(i).Data.DataType.Accept(dataTypeComparer); if (!dataTypeComparer.DataTypeMatch) { diff --git a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs index d2a70bca9e4ec..35c5b3e55157d 100644 --- a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs +++ b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs @@ -80,6 +80,16 @@ public static Types.TimeUnit ToArrow(this Flatbuf.TimeUnit unit) throw new ArgumentException($"Unexpected Flatbuf TimeUnit", nameof(unit)); } } + + public static Types.UnionMode ToArrow(this Flatbuf.UnionMode mode) + { + return mode switch + { + Flatbuf.UnionMode.Dense => Types.UnionMode.Dense, + Flatbuf.UnionMode.Sparse => Types.UnionMode.Sparse, + _ => throw new ArgumentException($"Unsupported Flatbuf UnionMode", nameof(mode)), + }; + } } } diff --git a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs index 50fbc3af6dd72..9bcee36ef4eaf 100644 --- a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs +++ b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs @@ -32,9 +32,5 @@ public interface IArrowArray : IDisposable ArrayData Data { get; } void Accept(IArrowArrayVisitor visitor); - - //IArrowArray Slice(int offset); - - //IArrowArray Slice(int offset, int length); } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index c9c1b21673316..d3115da52cc6c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -116,11 +116,11 @@ protected RecordBatch CreateArrowObjectFromMessage( break; case Flatbuf.MessageHeader.DictionaryBatch: Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; - ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); + ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; - List arrays = BuildArrays(Schema, bodyByteBuffer, rb); + List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb); return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); default: // NOTE: Skip unsupported message type @@ -136,7 +136,11 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); } - private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) + private void ReadDictionaryBatch( + MetadataVersion version, + Flatbuf.DictionaryBatch dictionaryBatch, + ByteBuffer bodyByteBuffer, + IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; IArrowType valueType = DictionaryMemo.GetDictionaryType(id); @@ -149,7 +153,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu Field valueField = new Field("dummy", valueType, true); var schema = new Schema(new[] { valueField }, default); - IList arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + IList arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value); if (arrays.Count != 1) { @@ -167,6 +171,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu } private List BuildArrays( + MetadataVersion version, Schema schema, ByteBuffer messageBuffer, Flatbuf.RecordBatch recordBatchMessage) @@ -187,8 +192,8 @@ private List BuildArrays( Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode; ArrayData arrayData = field.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); arrays.Add(ArrowArrayFactory.BuildArray(arrayData)); } while (recordBatchEnumerator.MoveNextNode()); @@ -225,6 +230,7 @@ private IBufferCreator GetBufferCreator(BodyCompression? compression) } private ArrayData LoadPrimitiveField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -245,31 +251,44 @@ private ArrayData LoadPrimitiveField( throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message } - if (field.DataType.TypeId == ArrowTypeId.Null) + int buffers; + switch (field.DataType.TypeId) { - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); - } - - ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); - if (!recordBatchEnumerator.MoveNextBuffer()) - { - throw new Exception("Unable to move to the next buffer."); + case ArrowTypeId.Null: + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); + case ArrowTypeId.Union: + if (version < MetadataVersion.V5) + { + if (fieldNullCount > 0) + { + if (recordBatchEnumerator.CurrentBuffer.Length > 0) + { + // With older metadata we can get a validity bitmap. Fixing up union data is hard, + // so we will just quit. + throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + } + } + recordBatchEnumerator.MoveNextBuffer(); + } + buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1; + break; + case ArrowTypeId.Struct: + case ArrowTypeId.FixedSizeList: + buffers = 1; + break; + default: + buffers = 2; + break; } - ArrowBuffer[] arrowBuff; - if (field.DataType.TypeId == ArrowTypeId.Struct || field.DataType.TypeId == ArrowTypeId.FixedSizeList) + ArrowBuffer[] arrowBuff = new ArrowBuffer[buffers]; + for (int i = 0; i < buffers; i++) { - arrowBuff = new[] { nullArrowBuffer }; - } - else - { - ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); + arrowBuff[i] = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); recordBatchEnumerator.MoveNextBuffer(); - - arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer }; } - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -282,6 +301,7 @@ private ArrayData LoadPrimitiveField( } private ArrayData LoadVariableField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -316,7 +336,7 @@ private ArrayData LoadVariableField( } ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -329,6 +349,7 @@ private ArrayData LoadVariableField( } private ArrayData[] GetChildren( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, ByteBuffer bodyData, @@ -345,8 +366,8 @@ private ArrayData[] GetChildren( Field childField = type.Fields[index]; ArrayData child = childField.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); children[index] = child; } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index a5d8db3f509d7..2b3815af71142 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -55,6 +55,7 @@ internal class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -156,6 +157,22 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + _buffers.Add(CreateBuffer(array.TypeBuffer)); + + ArrowBuffer? offsets = (array as DenseUnionArray)?.ValueOffsetBuffer; + if (offsets != null) + { + _buffers.Add(CreateBuffer(offsets.Value)); + } + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(this); + } + } + public void Visit(DictionaryArray array) { // Dictionary is serialized separately in Dictionary serialization. @@ -218,7 +235,7 @@ public void Visit(IArrowArray array) private readonly bool _leaveOpen; private readonly IpcOptions _options; - private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; + private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V5; private static readonly byte[] s_padding = new byte[64]; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index 203aa72d93ea3..b11467538dd04 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -120,7 +120,9 @@ public void Visit(FixedSizeListType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + Result = FieldType.Build( + Flatbuf.Type.Union, + Flatbuf.Union.CreateUnion(Builder, ToFlatBuffer(type.Mode), Flatbuf.Union.CreateTypeIdsVector(Builder, type.TypeIds))); } public void Visit(StringType type) @@ -279,5 +281,15 @@ private static Flatbuf.TimeUnit ToFlatBuffer(TimeUnit unit) return result; } + + private static Flatbuf.UnionMode ToFlatBuffer(Types.UnionMode mode) + { + return mode switch + { + Types.UnionMode.Dense => Flatbuf.UnionMode.Dense, + Types.UnionMode.Sparse => Flatbuf.UnionMode.Sparse, + _ => throw new ArgumentException($"unsupported union mode <{mode}>", nameof(mode)), + }; + } } } diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 8ca69b61165bf..6249063ba81f4 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -203,6 +203,10 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c case Flatbuf.Type.Struct_: Debug.Assert(childFields != null); return new Types.StructType(childFields); + case Flatbuf.Type.Union: + Debug.Assert(childFields != null); + Flatbuf.Union unionMetadata = field.Type().Value; + return new Types.UnionType(childFields, unionMetadata.GetTypeIdsArray(), unionMetadata.Mode.ToArrow()); default: throw new InvalidDataException($"Arrow primitive '{field.TypeType}' is unsupported."); } diff --git a/csharp/src/Apache.Arrow/Table.cs b/csharp/src/Apache.Arrow/Table.cs index 0b9f31557bec8..939ec23f54ff2 100644 --- a/csharp/src/Apache.Arrow/Table.cs +++ b/csharp/src/Apache.Arrow/Table.cs @@ -37,10 +37,10 @@ public static Table TableFromRecordBatches(Schema schema, IList rec List columns = new List(nColumns); for (int icol = 0; icol < nColumns; icol++) { - List columnArrays = new List(nBatches); + List columnArrays = new List(nBatches); for (int jj = 0; jj < nBatches; jj++) { - columnArrays.Add(recordBatches[jj].Column(icol) as Array); + columnArrays.Add(recordBatches[jj].Column(icol)); } columns.Add(new Column(schema.GetFieldByIndex(icol), columnArrays)); } diff --git a/csharp/src/Apache.Arrow/Types/UnionType.cs b/csharp/src/Apache.Arrow/Types/UnionType.cs index 293271018aa26..23fa3b45ab278 100644 --- a/csharp/src/Apache.Arrow/Types/UnionType.cs +++ b/csharp/src/Apache.Arrow/Types/UnionType.cs @@ -24,20 +24,21 @@ public enum UnionMode Dense } - public sealed class UnionType : ArrowType + public sealed class UnionType : NestedType { public override ArrowTypeId TypeId => ArrowTypeId.Union; public override string Name => "union"; public UnionMode Mode { get; } - - public IEnumerable TypeCodes { get; } + + public int[] TypeIds { get; } public UnionType( - IEnumerable fields, IEnumerable typeCodes, + IEnumerable fields, IEnumerable typeIds, UnionMode mode = UnionMode.Sparse) + : base(fields.ToArray()) { - TypeCodes = typeCodes.ToList(); + TypeIds = typeIds.ToArray(); Mode = mode; } diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs index abf7451e5e98c..1e76ee505a516 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs @@ -128,7 +128,7 @@ private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch jsonRecordB for (int i = 0; i < jsonRecordBatch.Columns.Count; i++) { JsonFieldData data = jsonRecordBatch.Columns[i]; - Field field = schema.GetFieldByName(data.Name); + Field field = schema.FieldsList[i]; ArrayCreator creator = new ArrayCreator(data); field.DataType.Accept(creator); arrays.Add(creator.Array); @@ -188,6 +188,7 @@ private static IArrowType ToArrowType(JsonArrowType type, Field[] children) "list" => ToListArrowType(type, children), "fixedsizelist" => ToFixedSizeListArrowType(type, children), "struct" => ToStructArrowType(type, children), + "union" => ToUnionArrowType(type, children), "null" => NullType.Default, _ => throw new NotSupportedException($"JsonArrowType not supported: {type.Name}") }; @@ -281,6 +282,17 @@ private static IArrowType ToStructArrowType(JsonArrowType type, Field[] children return new StructType(children); } + private static IArrowType ToUnionArrowType(JsonArrowType type, Field[] children) + { + UnionMode mode = type.Mode switch + { + "SPARSE" => UnionMode.Sparse, + "DENSE" => UnionMode.Dense, + _ => throw new NotSupportedException($"Union mode not supported: {type.Mode}"), + }; + return new UnionType(children, type.TypeIds, mode); + } + private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, @@ -306,6 +318,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor { private JsonFieldData JsonFieldData { get; set; } @@ -556,6 +569,43 @@ public void Visit(StructType type) Array = new StructArray(arrayData); } + public void Visit(UnionType type) + { + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Dense) + { + buffers = new ArrowBuffer[2]; + buffers[1] = GetOffsetBuffer(); + } + else + { + buffers = new ArrowBuffer[1]; + } + buffers[0] = GetTypeIdBuffer(); + + ArrayData[] children = GetChildren(type); + + int nullCount = 0; + ArrayData arrayData = new ArrayData(type, JsonFieldData.Count, nullCount, 0, buffers, children); + Array = UnionArray.Create(arrayData); + } + + private ArrayData[] GetChildren(NestedType type) + { + ArrayData[] children = new ArrayData[type.Fields.Count]; + + var data = JsonFieldData; + for (int i = 0; i < children.Length; i++) + { + JsonFieldData = data.Children[i]; + type.Fields[i].DataType.Accept(this); + children[i] = Array.Data; + } + JsonFieldData = data; + + return children; + } + private static byte[] ConvertHexStringToByteArray(string hexString) { byte[] data = new byte[hexString.Length / 2]; @@ -619,11 +669,22 @@ private void GenerateLongArray(Func valueOffsets = new ArrowBuffer.Builder(JsonFieldData.Offset.Length); valueOffsets.AppendRange(JsonFieldData.Offset); return valueOffsets.Build(default); } + private ArrowBuffer GetTypeIdBuffer() + { + ArrowBuffer.Builder typeIds = new ArrowBuffer.Builder(JsonFieldData.TypeId.Length); + for (int i = 0; i < JsonFieldData.TypeId.Length; i++) + { + typeIds.Append(checked((byte)JsonFieldData.TypeId[i])); + } + return typeIds.Build(default); + } + private ArrowBuffer GetValidityBuffer(out int nullCount) { if (JsonFieldData.Validity == null) diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs index f0f63d3e19b8c..112eeabcb9931 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs @@ -71,6 +71,10 @@ public class JsonArrowType // FixedSizeList fields public int ListSize { get; set; } + // union fields + public string Mode { get; set; } + public int[] TypeIds { get; set; } + [JsonExtensionData] public Dictionary ExtensionData { get; set; } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs index 77584aefb1bf4..c8bcc3cee0f99 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs @@ -28,7 +28,8 @@ public class ArrayTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; @@ -114,6 +115,22 @@ public void Visit(StructType actualType) CompareNested(expectedType, actualType); } + public void Visit(UnionType actualType) + { + Assert.IsAssignableFrom(_expectedType); + UnionType expectedType = (UnionType)_expectedType; + + Assert.Equal(expectedType.Mode, actualType.Mode); + + Assert.Equal(expectedType.TypeIds.Length, actualType.TypeIds.Length); + for (int i = 0; i < expectedType.TypeIds.Length; i++) + { + Assert.Equal(expectedType.TypeIds[i], actualType.TypeIds[i]); + } + + CompareNested(expectedType, actualType); + } + private static void CompareNested(NestedType expectedType, NestedType actualType) { Assert.Equal(expectedType.Fields.Count, actualType.Fields.Count); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index 36cffe7eb4da1..f5a2c345e2ae6 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -77,6 +77,22 @@ private static IEnumerable, IArrowArray>> GenerateTestDa new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() }), new FixedSizeListType(Int32Type.Default, 1), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new[] { 0, 1 }, + UnionMode.Sparse + ), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new[] { 0, 1 }, + UnionMode.Dense + ), }; foreach (IArrowType type in targetTypes) @@ -119,7 +135,8 @@ private class TestDataGenerator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private List> _baseData; @@ -392,6 +409,91 @@ public void Visit(StructType type) ExpectedArray = new StructArray(type, 3, new List { resultStringArray, resultInt32Array }, nullBitmapBuffer, 1); } + public void Visit(UnionType type) + { + bool isDense = type.Mode == UnionMode.Dense; + + StringArray.Builder stringResultBuilder = new StringArray.Builder().Reserve(_baseDataTotalElementCount); + Int32Array.Builder intResultBuilder = new Int32Array.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder typeResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder offsetResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + int resultNullCount = 0; + + for (int i = 0; i < _baseDataListCount; i++) + { + List dataList = _baseData[i]; + StringArray.Builder stringBuilder = new StringArray.Builder().Reserve(dataList.Count); + Int32Array.Builder intBuilder = new Int32Array.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + int nullCount = 0; + + for (int j = 0; j < dataList.Count; j++) + { + byte index = (byte)Math.Max(j % 3, 1); + int? intValue = (index == 1) ? dataList[j] : null; + string stringValue = (index == 1) ? null : dataList[j]?.ToString(); + typeBuilder.Append(index); + + if (isDense) + { + if (index == 0) + { + offsetBuilder.Append(stringBuilder.Length); + offsetResultBuilder.Append(stringResultBuilder.Length); + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + } + else + { + offsetBuilder.Append(intBuilder.Length); + offsetResultBuilder.Append(intResultBuilder.Length); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + } + else + { + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + + if (dataList[j] == null) + { + nullCount++; + resultNullCount++; + } + } + + ArrowBuffer[] buffers; + if (isDense) + { + buffers = new[] { typeBuilder.Build(), offsetBuilder.Build() }; + } + else + { + buffers = new[] { typeBuilder.Build() }; + } + TestTargetArrayList.Add(UnionArray.Create(new ArrayData( + type, dataList.Count, nullCount, 0, buffers, + new[] { stringBuilder.Build().Data, intBuilder.Build().Data }))); + } + + ArrowBuffer[] resultBuffers; + if (isDense) + { + resultBuffers = new[] { typeResultBuilder.Build(), offsetResultBuilder.Build() }; + } + else + { + resultBuffers = new[] { typeResultBuilder.Build() }; + } + ExpectedArray = UnionArray.Create(new ArrayData( + type, _baseDataTotalElementCount, resultNullCount, 0, resultBuffers, + new[] { stringResultBuilder.Build().Data, intResultBuilder.Build().Data })); + } public void Visit(IArrowType type) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index e588eab51e1fc..8b41763a70ac8 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -91,6 +91,7 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -151,6 +152,24 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + Assert.IsAssignableFrom(_expectedArray); + UnionArray expectedArray = (UnionArray)_expectedArray; + + Assert.Equal(expectedArray.Mode, array.Mode); + Assert.Equal(expectedArray.Length, array.Length); + Assert.Equal(expectedArray.NullCount, array.NullCount); + Assert.Equal(expectedArray.Offset, array.Offset); + Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length); + Assert.Equal(expectedArray.Fields.Count, array.Fields.Count); + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i], _strictCompare)); + } + } + public void Visit(DictionaryArray array) { Assert.IsAssignableFrom(_expectedArray); diff --git a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs index 29b1b9e7db74a..f28b89a9cd17e 100644 --- a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs +++ b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs @@ -112,6 +112,9 @@ private static Schema GetTestSchema() .Field(f => f.Name("dict_string_ordered").DataType(new DictionaryType(Int32Type.Default, StringType.Default, true)).Nullable(false)) .Field(f => f.Name("list_dict_string").DataType(new ListType(new DictionaryType(Int32Type.Default, StringType.Default, false))).Nullable(false)) + .Field(f => f.Name("dense_union").DataType(new UnionType(new[] { new Field("i64", Int64Type.Default, false), new Field("f32", FloatType.Default, true), }, new[] { 0, 1 }, UnionMode.Dense))) + .Field(f => f.Name("sparse_union").DataType(new UnionType(new[] { new Field("i32", Int32Type.Default, true), new Field("f64", DoubleType.Default, false), }, new[] { 0, 1 }, UnionMode.Sparse))) + // Checking wider characters. .Field(f => f.Name("hello 你好 😄").DataType(BooleanType.Default).Nullable(true)) @@ -172,6 +175,9 @@ private static IEnumerable GetPythonFields() yield return pa.field("dict_string_ordered", pa.dictionary(pa.int32(), pa.utf8(), true), false); yield return pa.field("list_dict_string", pa.list_(pa.dictionary(pa.int32(), pa.utf8(), false)), false); + yield return pa.field("dense_union", pa.dense_union(List(pa.field("i64", pa.int64(), false), pa.field("f32", pa.float32(), true)))); + yield return pa.field("sparse_union", pa.sparse_union(List(pa.field("i32", pa.int32(), true), pa.field("f64", pa.float64(), false)))); + yield return pa.field("hello 你好 😄", pa.bool_(), true); } } @@ -485,22 +491,29 @@ public unsafe void ImportRecordBatch() pa.array(List(0.0, 1.4, 2.5, 3.6, 4.7)), pa.array(new PyObject[] { List(1, 2), List(3, 4), PyObject.None, PyObject.None, List(5, 4, 3) }), pa.StructArray.from_arrays( - new PyList(new PyObject[] - { + List( List(10, 9, null, null, null), List("banana", "apple", "orange", "cherry", "grape"), - List(null, 4.3, -9, 123.456, 0), - }), + List(null, 4.3, -9, 123.456, 0) + ), new[] { "fld1", "fld2", "fld3" }), pa.DictionaryArray.from_arrays( pa.array(List(1, 0, 1, 1, null)), - pa.array(List("foo", "bar")) - ), + pa.array(List("foo", "bar"))), pa.FixedSizeListArray.from_arrays( pa.array(List(1, 2, 3, 4, null, 6, 7, null, null, null)), 2), + pa.UnionArray.from_dense( + pa.array(List(0, 1, 1, 0, 0), type: "int8"), + pa.array(List(0, 0, 1, 1, 2), type: "int32"), + List( + pa.array(List(1, 4, null)), + pa.array(List("two", "three")) + ), + /* field name */ List("i32", "s"), + /* type codes */ List(3, 2)), }), - new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8" }); + new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8", "col9" }); dynamic batch = table.to_batches()[0]; @@ -568,6 +581,10 @@ public unsafe void ImportRecordBatch() Assert.Equal(new long[] { 1, 2, 3, 4, 0, 6, 7, 0, 0, 0 }, col8a.Values.ToArray()); Assert.True(col8a.IsValid(3)); Assert.False(col8a.IsValid(9)); + + UnionArray col9 = (UnionArray)recordBatch.Column("col9"); + Assert.Equal(5, col9.Length); + Assert.True(col9 is DenseUnionArray); } [SkippableFact] @@ -789,6 +806,11 @@ private static PyObject List(params string[] values) return new PyList(values.Select(i => i == null ? PyObject.None : new PyString(i)).ToArray()); } + private static PyObject List(params PyObject[] values) + { + return new PyList(values); + } + sealed class TestArrayStream : IArrowArrayStream { private readonly RecordBatch[] _batches; diff --git a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs index b90c681622d5f..2d867b79176aa 100644 --- a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs @@ -39,7 +39,7 @@ public void TestColumn() Array intArrayCopy = MakeIntArray(10); Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); - Column column = new Column(field, new[] { intArray, intArrayCopy }); + Column column = new Column(field, new IArrowArray[] { intArray, intArrayCopy }); Assert.True(column.Name == field.Name); Assert.True(column.Field == field); diff --git a/csharp/test/Apache.Arrow.Tests/TableTests.cs b/csharp/test/Apache.Arrow.Tests/TableTests.cs index b4c4b1faed190..8b07a38c1b8c0 100644 --- a/csharp/test/Apache.Arrow.Tests/TableTests.cs +++ b/csharp/test/Apache.Arrow.Tests/TableTests.cs @@ -30,7 +30,7 @@ public static Table MakeTableWithOneColumnOfTwoIntArrays(int lengthOfEachArray) Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); Schema s0 = new Schema.Builder().Field(field).Build(); - Column column = new Column(field, new List { intArray, intArrayCopy }); + Column column = new Column(field, new List { intArray, intArrayCopy }); Table table = new Table(s0, new List { column }); return table; } @@ -60,7 +60,7 @@ public void TestTableFromRecordBatches() Table table1 = Table.TableFromRecordBatches(recordBatch1.Schema, recordBatches); Assert.Equal(20, table1.RowCount); - Assert.Equal(24, table1.ColumnCount); + Assert.Equal(26, table1.ColumnCount); FixedSizeBinaryType type = new FixedSizeBinaryType(17); Field newField1 = new Field(type.Name, type, false); @@ -86,13 +86,13 @@ public void TestTableAddRemoveAndSetColumn() Array nonEqualLengthIntArray = ColumnTests.MakeIntArray(10); Field field1 = new Field.Builder().Name("f1").DataType(Int32Type.Default).Build(); - Column nonEqualLengthColumn = new Column(field1, new[] { nonEqualLengthIntArray}); + Column nonEqualLengthColumn = new Column(field1, new IArrowArray[] { nonEqualLengthIntArray }); Assert.Throws(() => table.InsertColumn(-1, nonEqualLengthColumn)); Assert.Throws(() => table.InsertColumn(1, nonEqualLengthColumn)); Array equalLengthIntArray = ColumnTests.MakeIntArray(20); Field field2 = new Field.Builder().Name("f2").DataType(Int32Type.Default).Build(); - Column equalLengthColumn = new Column(field2, new[] { equalLengthIntArray}); + Column equalLengthColumn = new Column(field2, new IArrowArray[] { equalLengthIntArray }); Column existingColumn = table.Column(0); Table newTable = table.InsertColumn(0, equalLengthColumn); @@ -118,7 +118,7 @@ public void TestBuildFromRecordBatch() RecordBatch batch = TestData.CreateSampleRecordBatch(schema, 10); Table table = Table.TableFromRecordBatches(schema, new[] { batch }); - Assert.NotNull(table.Column(0).Data.Array(0) as Int64Array); + Assert.NotNull(table.Column(0).Data.ArrowArray(0) as Int64Array); } } diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 41507311f6a04..9e2061e3428a9 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -60,6 +60,8 @@ public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount builder.Field(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); builder.Field(CreateField(new FixedSizeBinaryType(16), i)); builder.Field(CreateField(new FixedSizeListType(Int32Type.Default, 3), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Sparse), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Dense), -i)); } //builder.Field(CreateField(HalfFloatType.Default)); @@ -125,6 +127,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -315,6 +318,67 @@ public void Visit(StructType type) Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } + public void Visit(UnionType type) + { + int[] lengths = new int[type.Fields.Count]; + if (type.Mode == UnionMode.Sparse) + { + for (int i = 0; i < lengths.Length; i++) + { + lengths[i] = Length; + } + } + else + { + int totalLength = Length; + int oneLength = Length / lengths.Length; + for (int i = 1; i < lengths.Length; i++) + { + lengths[i] = oneLength; + totalLength -= oneLength; + } + lengths[0] = totalLength; + } + + ArrayData[] childArrays = new ArrayData[type.Fields.Count]; + for (int i = 0; i < childArrays.Length; i++) + { + childArrays[i] = CreateArray(type.Fields[i], lengths[i]).Data; + } + + ArrowBuffer.Builder typeIdBuilder = new ArrowBuffer.Builder(Length); + byte index = 0; + for (int i = 0; i < Length; i++) + { + typeIdBuilder.Append(index); + index++; + if (index == lengths.Length) + { + index = 0; + } + } + + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Sparse) + { + buffers = new ArrowBuffer[1]; + } + else + { + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(Length); + for (int i = 0; i < Length; i++) + { + offsetBuilder.Append(i / lengths.Length); + } + + buffers = new ArrowBuffer[2]; + buffers[1] = offsetBuilder.Build(); + } + buffers[0] = typeIdBuilder.Build(); + + Array = UnionArray.Create(new ArrayData(type, Length, 0, 0, buffers, childArrays)); + } + public void Visit(DictionaryType type) { Int32Array.Builder indicesBuilder = new Int32Array.Builder().Reserve(Length); diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 5ac32da56a8de..299881c4b613a 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1833,8 +1833,7 @@ def _temp_path(): .skip_tester('C#') .skip_tester('JS'), - generate_unions_case() - .skip_tester('C#'), + generate_unions_case(), generate_custom_metadata_case() .skip_tester('C#'), diff --git a/docs/source/status.rst b/docs/source/status.rst index 36c29fcdc4da6..6314fd4c8d31f 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -83,9 +83,9 @@ Data Types +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Map | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Dense Union | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +| Dense Union | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Sparse Union | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +| Sparse Union | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+