Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38045: [C#] Support Dictionaries in Arrow Flight #44426

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
// limitations under the License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flatbuf;
using Apache.Arrow.Flight.Protocol;
using Apache.Arrow.Ipc;
using Google.FlatBuffers;
using Google.Protobuf;
Expand All @@ -36,6 +33,7 @@ internal class FlightDataStream : ArrowStreamWriter
private readonly FlightDescriptor _flightDescriptor;
private readonly IAsyncStreamWriter<Protocol.FlightData> _clientStreamWriter;
private Protocol.FlightData _currentFlightData;
private ByteString _currentAppMetadata;

public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema)
: base(new MemoryStream(), schema)
Expand Down Expand Up @@ -66,29 +64,66 @@ private void ResetStream()
this.BaseStream.SetLength(0);
}

private void ResetFlightData()
{
_currentFlightData = new Protocol.FlightData();
}

private void AddMetadata()
{
if (_currentAppMetadata != null)
{
_currentFlightData.AppMetadata = _currentAppMetadata;
}
}

private async Task SetFlightDataBodyFromBaseStreamAsync(CancellationToken cancellationToken)
{
BaseStream.Position = 0;
var body = await ByteString.FromStreamAsync(BaseStream, cancellationToken).ConfigureAwait(false);
_currentFlightData.DataBody = body;
}

private async Task WriteFlightDataAsync()
{
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
}

public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata)
{
_currentAppMetadata = applicationMetadata;
if (!HasWrittenSchema)
{
await SendSchema().ConfigureAwait(false);
}
ResetStream();
ResetFlightData();

_currentFlightData = new Protocol.FlightData();
await WriteRecordBatchAsync(recordBatch).ConfigureAwait(false);
}

if(applicationMetadata != null)
{
_currentFlightData.AppMetadata = applicationMetadata;
}
public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
{
await WriteRecordBatchInternalAsync(recordBatch, cancellationToken);

await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false);
// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false);
AddMetadata();
await WriteFlightDataAsync().ConfigureAwait(false);

//Reset stream position
this.BaseStream.Position = 0;
var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false);
HasWrittenDictionaryBatch = false; // force the dictionary to be sent again with the next batch
}

_currentFlightData.DataBody = bodyData;
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
private protected override async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
{
await base.WriteDictionariesAsync(dictionaryMemo, cancellationToken).ConfigureAwait(false);

// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false);
await WriteFlightDataAsync().ConfigureAwait(false);
// Reset the stream for the next dictionary or record batch
ResetStream();
ResetFlightData();
}

private protected override ValueTask<long> WriteMessageAsync<T>(MessageHeader headerType, Offset<T> headerOffset, int bodyLength, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

using System;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Apache.Arrow.Ipc;
using Google.FlatBuffers;

Expand Down Expand Up @@ -51,10 +49,8 @@ public static Schema DecodeSchema(ReadOnlyMemory<byte> buffer)
return schema;
}

internal static Schema DecodeSchema(ByteBuffer schemaBuffer)
internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo)
{
//DictionaryBatch not supported for now
DictionaryMemo dictionaryMemo = null;
var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage<Flatbuf.Schema>(schemaBuffer), ref dictionaryMemo);
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flatbuf;
Expand Down Expand Up @@ -84,7 +83,7 @@ public override async ValueTask ReadSchemaAsync(CancellationToken cancellationTo

// AppMetadata will never be null, but length 0 if empty
// Those are skipped
if(_flightDataStream.Current.AppMetadata.Length > 0)
if (_flightDataStream.Current.AppMetadata.Length > 0)
{
_applicationMetadatas.Add(_flightDataStream.Current.AppMetadata);
}
Expand All @@ -101,7 +100,7 @@ public override async ValueTask ReadSchemaAsync(CancellationToken cancellationTo
switch (message.HeaderType)
{
case MessageHeader.Schema:
_schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
_schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo);
break;
default:
throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}");
Expand All @@ -117,8 +116,10 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
{
await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
}
var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (moveNextResult)

// Keep reading dictionary batches until we get a record batch
var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
while (keepGoing)
{
//AppMetadata will never be null, but length 0 if empty
//Those are skipped
Expand All @@ -135,8 +136,17 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
case MessageHeader.RecordBatch:
var body = _flightDataStream.Current.DataBody.Memory;
return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
case MessageHeader.DictionaryBatch:
var dictionaryBody = _flightDataStream.Current.DataBody.Memory;
CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null);
keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (!keepGoing)
{
throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries");
}
break;
default:
throw new NotImplementedException();
throw new NotImplementedException($"Message type {message.HeaderType} is not implemented.");
}
}
return null;
Expand Down
13 changes: 6 additions & 7 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ public void Visit(IArrowArray array)

protected bool HasWrittenSchema { get; set; }

private bool HasWrittenDictionaryBatch { get; set; }
protected bool HasWrittenDictionaryBatch { get; set; }

private bool HasWrittenStart { get; set; }

Expand Down Expand Up @@ -668,7 +668,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)

if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo);
WriteDictionaries(_dictionaryMemo);
HasWrittenDictionaryBatch = true;
}
Expand Down Expand Up @@ -707,7 +707,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat

if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo);
await WriteDictionariesAsync(_dictionaryMemo, cancellationToken).ConfigureAwait(false);
HasWrittenDictionaryBatch = true;
}
Expand Down Expand Up @@ -862,7 +862,7 @@ private protected virtual void FinishedWritingDictionary(long bodyLength, long m
{
}

private protected void WriteDictionaries(DictionaryMemo dictionaryMemo)
private protected virtual void WriteDictionaries(DictionaryMemo dictionaryMemo)
{
int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
for (int i = 0; i < fieldCount; i++)
Expand All @@ -886,7 +886,7 @@ private protected void WriteDictionary(long id, IArrowType valueType, IArrowArra
FinishedWritingDictionary(bufferLength, metadataLength);
}

private protected async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
private protected virtual async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
{
int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
for (int i = 0; i < fieldCount; i++)
Expand Down Expand Up @@ -1319,9 +1319,8 @@ public virtual void Dispose()

internal static class DictionaryCollector
{
internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)
internal static void Collect(Schema schema, RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)
{
Schema schema = recordBatch.Schema;
for (int i = 0; i < schema.FieldsList.Count; i++)
{
Field field = schema.GetFieldByIndex(i);
Expand Down
5 changes: 5 additions & 0 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Apache.Arrow.Types;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Utils;
Expand Down Expand Up @@ -54,6 +55,10 @@ private RecordBatch CreateTestBatch(int startValue, int length)
builder.Append(startValue + i);
}
batchBuilder.Append("test", true, builder.Build());
var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build();
var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build();
var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary);
batchBuilder.Append("dict", true, dictArray);
return batchBuilder.Build();
}

Expand Down
16 changes: 4 additions & 12 deletions dev/archery/archery/integration/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,22 +1955,16 @@ def _temp_path():

generate_dictionary_case()
# TODO(https://github.com/apache/arrow-nanoarrow/issues/622)
.skip_tester('nanoarrow')
# TODO(https://github.com/apache/arrow/issues/38045)
.skip_format(SKIP_FLIGHT, 'C#'),
.skip_tester('nanoarrow'),

generate_dictionary_unsigned_case()
.skip_tester('nanoarrow')
.skip_tester('Java') # TODO(ARROW-9377)
# TODO(https://github.com/apache/arrow/issues/38045)
.skip_format(SKIP_FLIGHT, 'C#'),
.skip_tester('Java'), # TODO(ARROW-9377)

generate_nested_dictionary_case()
# TODO(https://github.com/apache/arrow-nanoarrow/issues/622)
.skip_tester('nanoarrow')
.skip_tester('Java') # TODO(ARROW-7779)
# TODO(https://github.com/apache/arrow/issues/38045)
.skip_format(SKIP_FLIGHT, 'C#'),
.skip_tester('Java'), # TODO(ARROW-7779)

generate_run_end_encoded_case()
.skip_tester('C#')
Expand All @@ -1997,9 +1991,7 @@ def _temp_path():
.skip_tester('nanoarrow')
# TODO: ensure the extension is registered in the C++ entrypoint
.skip_format(SKIP_C_SCHEMA, 'C++')
.skip_format(SKIP_C_ARRAY, 'C++')
# TODO(https://github.com/apache/arrow/issues/38045)
.skip_format(SKIP_FLIGHT, 'C#'),
.skip_format(SKIP_C_ARRAY, 'C++'),
]

generated_paths = []
Expand Down
Loading