Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38045: [C#] Add flight dictionary support #38046

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions csharp/examples/Examples.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FluentBuilderExample", "Flu
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow", "..\src\Apache.Arrow\Apache.Arrow.csproj", "{1FE1DE95-FF6E-4895-82E7-909713C53524}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FlightAspServerExample", "FlightAspServerExample\FlightAspServerExample.csproj", "{51701AC8-5C3C-47EA-B481-56F46B8C5673}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FlightClientExample", "FlightClientExample\FlightClientExample.csproj", "{9F54DCD2-68C2-47A9-ABE2-816068176328}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand All @@ -21,6 +25,14 @@ Global
{1FE1DE95-FF6E-4895-82E7-909713C53524}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.Build.0 = Release|Any CPU
{51701AC8-5C3C-47EA-B481-56F46B8C5673}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{51701AC8-5C3C-47EA-B481-56F46B8C5673}.Debug|Any CPU.Build.0 = Debug|Any CPU
{51701AC8-5C3C-47EA-B481-56F46B8C5673}.Release|Any CPU.ActiveCfg = Release|Any CPU
{51701AC8-5C3C-47EA-B481-56F46B8C5673}.Release|Any CPU.Build.0 = Release|Any CPU
{9F54DCD2-68C2-47A9-ABE2-816068176328}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9F54DCD2-68C2-47A9-ABE2-816068176328}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9F54DCD2-68C2-47A9-ABE2-816068176328}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9F54DCD2-68C2-47A9-ABE2-816068176328}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ public abstract class FlightRecordBatchStreamReader : IAsyncStreamReader<RecordB
//Temporary until .NET 5.0 upgrade
private static ValueTask CompletedValueTask = new ValueTask();

private readonly RecordBatcReaderImplementation _arrowReaderImplementation;
private readonly RecordBatchReaderImplementation _arrowReaderImplementation;

private protected FlightRecordBatchStreamReader(IAsyncStreamReader<Protocol.FlightData> flightDataStream)
{
_arrowReaderImplementation = new RecordBatcReaderImplementation(flightDataStream);
_arrowReaderImplementation = new RecordBatchReaderImplementation(flightDataStream);
}

public ValueTask<Schema> Schema => _arrowReaderImplementation.ReadSchema();
Expand Down
61 changes: 47 additions & 14 deletions csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ internal class FlightDataStream : ArrowStreamWriter
private readonly FlightDescriptor _flightDescriptor;
private readonly IAsyncStreamWriter<Protocol.FlightData> _clientStreamWriter;
private Protocol.FlightData _currentFlightData;
private ByteString _currentAppMetadata;

public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema)
: base(new MemoryStream(), schema)
{
_clientStreamWriter = clientStreamWriter;
_flightDescriptor = flightDescriptor;
AlwaysWriteDictionaries = true;
}

private async Task SendSchema()
Expand All @@ -66,29 +68,42 @@ private void ResetStream()
this.BaseStream.SetLength(0);
}

public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata)
private void ResetFlightData()
{
if (!HasWrittenSchema)
_currentFlightData = new Protocol.FlightData();
}

private void AddMetadata()
{
if (_currentAppMetadata != null)
{
await SendSchema().ConfigureAwait(false);
_currentFlightData.AppMetadata = _currentAppMetadata;
}
ResetStream();
}

_currentFlightData = new Protocol.FlightData();
private async Task SetFlightDataBodyFromBaseStreamAsync()
{
BaseStream.Position = 0;
var body = await ByteString.FromStreamAsync(BaseStream).ConfigureAwait(false);
_currentFlightData.DataBody = body;
}

private async Task WriteFlightDataAsync()
{
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
}

if(applicationMetadata != null)
public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata)
{
_currentAppMetadata = applicationMetadata;
if (!HasWrittenSchema)
{
_currentFlightData.AppMetadata = applicationMetadata;
await SendSchema().ConfigureAwait(false);
}
ResetStream();
ResetFlightData();

await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false);

//Reset stream position
this.BaseStream.Position = 0;
var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false);

_currentFlightData.DataBody = bodyData;
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
}

private protected override ValueTask<long> WriteMessageAsync<T>(MessageHeader headerType, Offset<T> headerOffset, int bodyLength, CancellationToken cancellationToken)
Expand All @@ -105,5 +120,23 @@ private protected override ValueTask<long> WriteMessageAsync<T>(MessageHeader he

return new ValueTask<long>(0);
}

private protected override async Task PostRecordBatchAsync()
{
// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync().ConfigureAwait(false);
AddMetadata();
await WriteFlightDataAsync().ConfigureAwait(false);
}

private protected override async Task PostDictionaryAsync()
{
// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync().ConfigureAwait(false);
await WriteFlightDataAsync().ConfigureAwait(false);
// Reset the stream for the next dictionary or record batch
ResetStream();
ResetFlightData();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ public static Schema DecodeSchema(ReadOnlyMemory<byte> buffer)
return schema;
}

internal static Schema DecodeSchema(ByteBuffer schemaBuffer)
internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo)
{
//DictionaryBatch not supported for now
DictionaryMemo dictionaryMemo = null;
var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage<Flatbuf.Schema>(schemaBuffer), ref dictionaryMemo);
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

namespace Apache.Arrow.Flight.Internal
{
internal class RecordBatcReaderImplementation : ArrowReaderImplementation
internal class RecordBatchReaderImplementation : ArrowReaderImplementation
{
private readonly IAsyncStreamReader<Protocol.FlightData> _flightDataStream;
private FlightDescriptor _flightDescriptor;
private readonly List<ByteString> _applicationMetadatas;

public RecordBatcReaderImplementation(IAsyncStreamReader<Protocol.FlightData> streamReader)
public RecordBatchReaderImplementation(IAsyncStreamReader<Protocol.FlightData> streamReader)
{
_flightDataStream = streamReader;
_applicationMetadatas = new List<ByteString>();
Expand Down Expand Up @@ -87,7 +87,7 @@ public async ValueTask<Schema> ReadSchema()
switch (message.HeaderType)
{
case MessageHeader.Schema:
Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo);
break;
default:
throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}");
Expand All @@ -103,8 +103,10 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
{
await ReadSchema().ConfigureAwait(false);
}
var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (moveNextResult)

// Keep reading dictionary batches until we get a record batch
var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
while (keepGoing)
{
//AppMetadata will never be null, but length 0 if empty
//Those are skipped
Expand All @@ -121,8 +123,17 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
case MessageHeader.RecordBatch:
var body = _flightDataStream.Current.DataBody.Memory;
return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
case MessageHeader.DictionaryBatch:
var dictionaryBody = _flightDataStream.Current.DataBody.Memory;
CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null);
keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (!keepGoing)
{
throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries");
}
break;
default:
throw new NotImplementedException();
throw new NotImplementedException($"Message type {message.HeaderType} is not implemented.");
}
}
return null;
Expand Down
33 changes: 30 additions & 3 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ public void Visit(IArrowArray array)

private bool HasWrittenDictionaryBatch { get; set; }

protected bool AlwaysWriteDictionaries { get; set; }

private bool HasWrittenStart { get; set; }

private bool HasWrittenEnd { get; set; }
Expand Down Expand Up @@ -314,7 +316,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)
HasWrittenSchema = true;
}

if (!HasWrittenDictionaryBatch)
if (!HasWrittenDictionaryBatch || AlwaysWriteDictionaries)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
WriteDictionaries(recordBatch);
Expand All @@ -340,6 +342,8 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)
long bufferLength = WriteBufferData(recordBatchBuilder.Buffers);

FinishedWritingRecordBatch(bufferLength, metadataLength);

PostRecordBatch();
}

private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
Expand All @@ -353,7 +357,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat
HasWrittenSchema = true;
}

if (!HasWrittenDictionaryBatch)
if (!HasWrittenDictionaryBatch || AlwaysWriteDictionaries)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false);
Expand All @@ -380,6 +384,8 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat
long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false);

FinishedWritingRecordBatch(bufferLength, metadataLength);

await PostRecordBatchAsync().ConfigureAwait(false);
}

private long WriteBufferData(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers)
Expand Down Expand Up @@ -490,7 +496,6 @@ private Tuple<ArrowRecordBatchFlatBufferBuilder, VectorOffset> PreparingWritingR
return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset);
}


private protected void WriteDictionaries(RecordBatch recordBatch)
{
foreach (Field field in recordBatch.Schema.FieldsList)
Expand Down Expand Up @@ -520,6 +525,8 @@ private protected void WriteDictionary(Field field)
dictionaryBatchOffset, recordBatchBuilder.TotalLength);

WriteBufferData(recordBatchBuilder.Buffers);

PostDictionary();
}

private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken)
Expand Down Expand Up @@ -551,6 +558,8 @@ await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch,
dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false);

await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false);

await PostDictionaryAsync().ConfigureAwait(false);
}

private Tuple<ArrowRecordBatchFlatBufferBuilder, Offset<Flatbuf.DictionaryBatch>> CreateDictionaryBatchOffset(Field field)
Expand Down Expand Up @@ -614,6 +623,24 @@ private protected virtual void FinishedWritingRecordBatch(long bodyLength, long
{
}

private protected virtual void PostRecordBatch()
{
}

private protected virtual void PostDictionary()
{
}

private protected virtual async Task PostRecordBatchAsync()
{
await Task.CompletedTask;
}

private protected virtual async Task PostDictionaryAsync()
{
await Task.CompletedTask;
}

public virtual void WriteRecordBatch(RecordBatch recordBatch)
{
WriteRecordBatchInternal(recordBatch);
Expand Down
12 changes: 6 additions & 6 deletions csharp/src/Apache.Arrow/Ipc/DictionaryMemo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ class DictionaryMemo
{
private readonly Dictionary<long, IArrowArray> _idToDictionary;
private readonly Dictionary<long, IArrowType> _idToValueType;
private readonly Dictionary<Field, long> _fieldToId;
private readonly Dictionary<string, long> _fieldToId;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this change is technically correct because it ignores the fact that e.g. two nested structs could have dictionary-encoded fields with the same name and these would now collide. Obviously the existing implementation isn't great either in that it implicitly relies on reference equality (which I assume is what causes problems for Flight). I don't have any immediate suggestions though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. When I have time to get back to this I will write a test for this and look into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check out my attempt at an implementation at https://github.com/CurtHagenlocher/arrow/tree/dev/curth/FlightDictionaries. It changes the write side of flight server so that the same schema is used to collect the dictionaries for each batch. This continues to let reference equality work for calculating dictionary IDs. I also refactored the "Post" changes to be (what I think is) a little cleaner.

Disclaimer: I am not very familiar with the Flight code or protocol, but your changed tests are passing.


public DictionaryMemo()
{
_idToDictionary = new Dictionary<long, IArrowArray>();
_idToValueType = new Dictionary<long, IArrowType>();
_fieldToId = new Dictionary<Field, long>();
_fieldToId = new Dictionary<string, long>();
}

public IArrowType GetDictionaryType(long id)
Expand All @@ -53,7 +53,7 @@ public IArrowArray GetDictionary(long id)

public void AddField(long id, Field field)
{
if (_fieldToId.ContainsKey(field))
if (_fieldToId.ContainsKey(field.Name))
{
throw new ArgumentException($"Field {field.Name} is already in Memo");
}
Expand All @@ -73,13 +73,13 @@ public void AddField(long id, Field field)
}
}

_fieldToId.Add(field, id);
_fieldToId.Add(field.Name, id);
_idToValueType.Add(id, valueType);
}

public long GetId(Field field)
{
if (!_fieldToId.TryGetValue(field, out long id))
if (!_fieldToId.TryGetValue(field.Name, out long id))
{
throw new ArgumentException($"Field with name {field.Name} not found");
}
Expand All @@ -88,7 +88,7 @@ public long GetId(Field field)

public long GetOrAssignId(Field field)
{
if (!_fieldToId.TryGetValue(field, out long id))
if (!_fieldToId.TryGetValue(field.Name, out long id))
{
id = _fieldToId.Count + 1;
AddField(id, field);
Expand Down
9 changes: 7 additions & 2 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Apache.Arrow.Types;
using Google.Protobuf;
using Grpc.Core.Utils;
using Xunit;
Expand Down Expand Up @@ -52,6 +53,10 @@ private RecordBatch CreateTestBatch(int startValue, int length)
builder.Append(startValue + i);
}
batchBuilder.Append("test", true, builder.Build());
var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build();
var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build();
var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary);
jduo marked this conversation as resolved.
Show resolved Hide resolved
batchBuilder.Append("dict", true, dictArray);
return batchBuilder.Build();
}

Expand Down Expand Up @@ -187,8 +192,8 @@ public async Task TestGetFlightMetadata()

var getStream = _flightClient.GetStream(endpoint.Ticket);

List<ByteString> actualMetadata = new List<ByteString>();
while(await getStream.ResponseStream.MoveNext(default))
List<ByteString> actualMetadata = new List<ByteString>();
while (await getStream.ResponseStream.MoveNext(default))
{
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
}
Expand Down