Skip to content

Commit

Permalink
Fix handling reference types that don't inherit directly from object (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve authored May 7, 2024
1 parent cac1a76 commit 56e7ed2
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 5 deletions.
128 changes: 125 additions & 3 deletions csharp.test/TestLogicalTypeFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,44 @@ public static void TestRoundTripCustomDecimal()
Assert.AreEqual(readValues.Select(v => v.Value).ToArray(), values.Select(v => v.Value).ToArray());
}

[Test]
public static void TestRoundTripDerivedValueType()
{
// Test using a custom type that is a reference type but doesn't inherit directly from System.Object
var values = Enumerable.Range(0, 100)
.Select(i => i % 10 == 5 ? null : new VolumeInDollarsDerivedType(i))
.ToArray();
var columns = new Column[]
{
new Column<VolumeInDollarsDerivedType?>("Values")
};

using var buffer = new ResizableBuffer();
using (var outStream = new BufferOutputStream(buffer))
{
using var fileWriter = new ParquetFileWriter(outStream, columns, new WriteTypeFactory())
{
LogicalWriteConverterFactory = new WriteConverterFactory()
};
using var rowGroupWriter = fileWriter.AppendRowGroup();
using var columnWriter = rowGroupWriter.NextColumn().LogicalWriter<VolumeInDollarsDerivedType?>();
columnWriter.WriteBatch(values);
fileWriter.Close();
}

using var input = new BufferReader(buffer);
using var fileReader = new ParquetFileReader(input)
{
LogicalReadConverterFactory = new ReadConverterFactory()
};
using var groupReader = fileReader.RowGroup(0);
using var columnReader = groupReader.Column(0).LogicalReaderOverride<VolumeInDollarsDerivedType?>();

var readValues = columnReader.ReadAll(checked((int) groupReader.MetaData.NumRows));

Assert.AreEqual(readValues, values);
}

private static GroupNode GetNestedSchema()
{
using var noneType = LogicalType.None();
Expand Down Expand Up @@ -470,6 +508,30 @@ public override string ToString()
}
}

private class BaseValue
{
}

private class VolumeInDollarsDerivedType : BaseValue, IEquatable<VolumeInDollarsDerivedType>
{
public VolumeInDollarsDerivedType(float value)
{
Value = value;
}

public readonly float Value;

public bool Equals(VolumeInDollarsDerivedType? other)
{
return other != null && Value.Equals(other.Value);
}

public override string ToString()
{
return $"VolumeInDollars({Value})";
}
}

/// <summary>
/// A logical type factory that supports our user custom type (for the read tests only). Ignore overrides (used by unit tests that cannot provide a columnLogicalTypeOverride).
/// </summary>
Expand Down Expand Up @@ -500,9 +562,30 @@ private sealed class ReadConverterFactory : LogicalReadConverterFactory
public override Delegate GetConverter<TLogical, TPhysical>(ColumnDescriptor columnDescriptor, ColumnChunkMetaData columnChunkMetaData)
{
// VolumeInDollars is bitwise identical to float, so we can reuse the native converter.
if (typeof(TLogical) == typeof(VolumeInDollars)) return LogicalRead.GetNativeConverter<VolumeInDollars, float>();
if (typeof(TLogical) == typeof(VolumeInDollars))
{
return LogicalRead.GetNativeConverter<VolumeInDollars, float>();
}

if (typeof(TLogical) == typeof(VolumeInDollarsDerivedType))
{
return (LogicalRead<VolumeInDollarsDerivedType?, float>.Converter) ConvertVolumeInDollarsDerived;
}

return base.GetConverter<TLogical, TPhysical>(columnDescriptor, columnChunkMetaData);
}

private static void ConvertVolumeInDollarsDerived(
ReadOnlySpan<float> source,
ReadOnlySpan<short> defLevels,
Span<VolumeInDollarsDerivedType?> destination,
short definedLevel)
{
for (int i = 0, src = 0; i < destination.Length; ++i)
{
destination[i] = defLevels.IsEmpty || defLevels[i] == definedLevel ? new VolumeInDollarsDerivedType(source[src++]) : null;
}
}
}

/// <summary>
Expand All @@ -512,7 +595,16 @@ private sealed class WriteTypeFactory : LogicalTypeFactory
{
public override bool TryGetParquetTypes(Type logicalSystemType, out (LogicalType? logicalType, Repetition repetition, PhysicalType physicalType) entry)
{
if (logicalSystemType == typeof(VolumeInDollars)) return base.TryGetParquetTypes(typeof(float), out entry);
if (logicalSystemType == typeof(VolumeInDollars))
{
return base.TryGetParquetTypes(typeof(float), out entry);
}

if (logicalSystemType == typeof(VolumeInDollarsDerivedType))
{
return base.TryGetParquetTypes(typeof(float?), out entry);
}

return base.TryGetParquetTypes(logicalSystemType, out entry);
}
}
Expand Down Expand Up @@ -544,9 +636,39 @@ private sealed class WriteConverterFactory : LogicalWriteConverterFactory
{
public override Delegate GetConverter<TLogical, TPhysical>(ColumnDescriptor columnDescriptor, ByteBuffer? byteBuffer)
{
if (typeof(TLogical) == typeof(VolumeInDollars)) return LogicalWrite.GetNativeConverter<VolumeInDollars, float>();
if (typeof(TLogical) == typeof(VolumeInDollars))
{
return LogicalWrite.GetNativeConverter<VolumeInDollars, float>();
}

if (typeof(TLogical) == typeof(VolumeInDollarsDerivedType) && columnDescriptor.MaxDefinitionLevel > 0)
{
return (LogicalWrite<VolumeInDollarsDerivedType?, float>.Converter) ConvertVolumeInDollarsDerived;
}

return base.GetConverter<TLogical, TPhysical>(columnDescriptor, byteBuffer);
}

private static void ConvertVolumeInDollarsDerived(
ReadOnlySpan<VolumeInDollarsDerivedType?> source,
Span<short> defLevels,
Span<float> destination,
short nullLevel)
{
for (int i = 0, dst = 0; i < source.Length; ++i)
{
var value = source[i];
if (value == null)
{
defLevels[i] = nullLevel;
}
else
{
destination[dst++] = value.Value;
defLevels[i] = (short) (nullLevel + 1);
}
}
}
}

private static readonly float[] Values = {1f, 2f, 3f};
Expand Down
3 changes: 1 addition & 2 deletions csharp/ColumnDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ public TReturn Apply<TReturn>(LogicalTypeFactory typeFactory, Type? columnLogica
}

if (node.Repetition == Repetition.Optional &&
elementType.BaseType != typeof(object) &&
elementType.BaseType != typeof(Array) &&
elementType.IsValueType &&
!TypeUtils.IsNullable(elementType, out _))
{
// Node is optional and the element type is not already a nullable type
Expand Down

0 comments on commit 56e7ed2

Please sign in to comment.