From 2139b789d4c44253660a1ac69086876900ceeedb Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Mon, 11 Dec 2023 20:21:15 -0800 Subject: [PATCH 1/2] Add support for dictionaries to Flight implementation --- .../Internal/FlightDataStream.cs | 63 ++++++++++++++----- .../Internal/FlightMessageSerializer.cs | 6 +- .../RecordBatchReaderImplementation.cs | 22 +++++-- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 13 ++-- .../Apache.Arrow.Flight.Tests/FlightTests.cs | 5 ++ 5 files changed, 77 insertions(+), 32 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs index 72c1551be2917..e755b4a26f621 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs @@ -14,13 +14,10 @@ // limitations under the License. using System; -using System.Collections.Generic; using System.IO; -using System.Text; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flatbuf; -using Apache.Arrow.Flight.Protocol; using Apache.Arrow.Ipc; using Google.FlatBuffers; using Google.Protobuf; @@ -36,6 +33,7 @@ internal class FlightDataStream : ArrowStreamWriter private readonly FlightDescriptor _flightDescriptor; private readonly IAsyncStreamWriter _clientStreamWriter; private Protocol.FlightData _currentFlightData; + private ByteString _currentAppMetadata; public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema) : base(new MemoryStream(), schema) @@ -66,29 +64,66 @@ private void ResetStream() this.BaseStream.SetLength(0); } + private void ResetFlightData() + { + _currentFlightData = new Protocol.FlightData(); + } + + private void AddMetadata() + { + if (_currentAppMetadata != null) + { + _currentFlightData.AppMetadata = _currentAppMetadata; + } + } + + private async Task SetFlightDataBodyFromBaseStreamAsync(CancellationToken cancellationToken) + { + BaseStream.Position = 0; + var body = await ByteString.FromStreamAsync(BaseStream, cancellationToken).ConfigureAwait(false); + _currentFlightData.DataBody = body; + } + + private async Task WriteFlightDataAsync() + { + await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); + } + public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata) { + _currentAppMetadata = applicationMetadata; if (!HasWrittenSchema) { await SendSchema().ConfigureAwait(false); } ResetStream(); + ResetFlightData(); - _currentFlightData = new Protocol.FlightData(); + await WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + } - if(applicationMetadata != null) - { - _currentFlightData.AppMetadata = applicationMetadata; - } + public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) + { + await WriteRecordBatchInternalAsync(recordBatch, cancellationToken); - await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false); + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false); + AddMetadata(); + await WriteFlightDataAsync().ConfigureAwait(false); - //Reset stream position - this.BaseStream.Position = 0; - var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false); + HasWrittenDictionaryBatch = false; // force the dictionary to be sent again with the next batch + } - _currentFlightData.DataBody = bodyData; - await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); + private protected override async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) + { + await base.WriteDictionariesAsync(dictionaryMemo, cancellationToken).ConfigureAwait(false); + + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false); + await WriteFlightDataAsync().ConfigureAwait(false); + // Reset the stream for the next dictionary or record batch + ResetStream(); + ResetFlightData(); } private protected override ValueTask WriteMessageAsync(MessageHeader headerType, Offset headerOffset, int bodyLength, CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 47ffe43d2457a..a1b3445df738c 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -15,9 +15,7 @@ using System; using System.Buffers.Binary; -using System.Collections.Generic; using System.IO; -using System.Text; using Apache.Arrow.Ipc; using Google.FlatBuffers; @@ -51,10 +49,8 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) return schema; } - internal static Schema DecodeSchema(ByteBuffer schemaBuffer) + internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo) { - //DictionaryBatch not supported for now - DictionaryMemo dictionaryMemo = null; var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index 22d0bd84fef77..50756cc8f4ee0 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flatbuf; @@ -84,7 +83,7 @@ public override async ValueTask ReadSchemaAsync(CancellationToken cancellationTo // AppMetadata will never be null, but length 0 if empty // Those are skipped - if(_flightDataStream.Current.AppMetadata.Length > 0) + if (_flightDataStream.Current.AppMetadata.Length > 0) { _applicationMetadatas.Add(_flightDataStream.Current.AppMetadata); } @@ -101,7 +100,7 @@ public override async ValueTask ReadSchemaAsync(CancellationToken cancellationTo switch (message.HeaderType) { case MessageHeader.Schema: - _schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + _schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); @@ -117,8 +116,10 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati { await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); - if (moveNextResult) + + // Keep reading dictionary batches until we get a record batch + var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + while (keepGoing) { //AppMetadata will never be null, but length 0 if empty //Those are skipped @@ -135,8 +136,17 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati case MessageHeader.RecordBatch: var body = _flightDataStream.Current.DataBody.Memory; return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null); + case MessageHeader.DictionaryBatch: + var dictionaryBody = _flightDataStream.Current.DataBody.Memory; + CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null); + keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + if (!keepGoing) + { + throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries"); + } + break; default: - throw new NotImplementedException(); + throw new NotImplementedException($"Message type {message.HeaderType} is not implemented."); } } return null; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index ff0b64c09eeb6..22f97ac950b5b 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -579,7 +579,7 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } - private bool HasWrittenDictionaryBatch { get; set; } + protected bool HasWrittenDictionaryBatch { get; set; } private bool HasWrittenStart { get; set; } @@ -668,7 +668,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo); WriteDictionaries(_dictionaryMemo); HasWrittenDictionaryBatch = true; } @@ -707,7 +707,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo); await WriteDictionariesAsync(_dictionaryMemo, cancellationToken).ConfigureAwait(false); HasWrittenDictionaryBatch = true; } @@ -862,7 +862,7 @@ private protected virtual void FinishedWritingDictionary(long bodyLength, long m { } - private protected void WriteDictionaries(DictionaryMemo dictionaryMemo) + private protected virtual void WriteDictionaries(DictionaryMemo dictionaryMemo) { int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; for (int i = 0; i < fieldCount; i++) @@ -886,7 +886,7 @@ private protected void WriteDictionary(long id, IArrowType valueType, IArrowArra FinishedWritingDictionary(bufferLength, metadataLength); } - private protected async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) + private protected virtual async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) { int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; for (int i = 0; i < fieldCount; i++) @@ -1319,9 +1319,8 @@ public virtual void Dispose() internal static class DictionaryCollector { - internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) + internal static void Collect(Schema schema, RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) { - Schema schema = recordBatch.Schema; for (int i = 0; i < schema.FieldsList.Count; i++) { Field field = schema.GetFieldByIndex(i); diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 350762c992769..1299009584a5b 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -21,6 +21,7 @@ using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Tests; +using Apache.Arrow.Types; using Google.Protobuf; using Grpc.Core; using Grpc.Core.Utils; @@ -54,6 +55,10 @@ private RecordBatch CreateTestBatch(int startValue, int length) builder.Append(startValue + i); } batchBuilder.Append("test", true, builder.Build()); + var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build(); + var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build(); + var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary); + batchBuilder.Append("dict", true, dictArray); return batchBuilder.Build(); } From 316f4e4aa02c8dcb738c57d9732770c983362eb4 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Tue, 15 Oct 2024 17:59:40 -0700 Subject: [PATCH 2/2] enable Flight tests requiring dictionaries --- dev/archery/archery/integration/datagen.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index bc862963405f2..8414930e97395 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1955,22 +1955,16 @@ def _temp_path(): generate_dictionary_case() # TODO(https://github.com/apache/arrow-nanoarrow/issues/622) - .skip_tester('nanoarrow') - # TODO(https://github.com/apache/arrow/issues/38045) - .skip_format(SKIP_FLIGHT, 'C#'), + .skip_tester('nanoarrow'), generate_dictionary_unsigned_case() .skip_tester('nanoarrow') - .skip_tester('Java') # TODO(ARROW-9377) - # TODO(https://github.com/apache/arrow/issues/38045) - .skip_format(SKIP_FLIGHT, 'C#'), + .skip_tester('Java'), # TODO(ARROW-9377) generate_nested_dictionary_case() # TODO(https://github.com/apache/arrow-nanoarrow/issues/622) .skip_tester('nanoarrow') - .skip_tester('Java') # TODO(ARROW-7779) - # TODO(https://github.com/apache/arrow/issues/38045) - .skip_format(SKIP_FLIGHT, 'C#'), + .skip_tester('Java'), # TODO(ARROW-7779) generate_run_end_encoded_case() .skip_tester('C#') @@ -1997,9 +1991,7 @@ def _temp_path(): .skip_tester('nanoarrow') # TODO: ensure the extension is registered in the C++ entrypoint .skip_format(SKIP_C_SCHEMA, 'C++') - .skip_format(SKIP_C_ARRAY, 'C++') - # TODO(https://github.com/apache/arrow/issues/38045) - .skip_format(SKIP_FLIGHT, 'C#'), + .skip_format(SKIP_C_ARRAY, 'C++'), ] generated_paths = []