Skip to content

Commit

Permalink
Handle partial stream reads in ManagedRandomAccessFile (#264)
Browse files Browse the repository at this point in the history
* Add test to reproduce errors on partial stream reads

* Handle partial reads from a managed stream and fix netstandard specific path

* Fix ManagedOutputStream TFM symbol and handling nbytes greater than max int

* Switch to _OR_GREATER TFM symbols

* Use correct array size limit
  • Loading branch information
adamreeve authored Apr 22, 2022
1 parent 7cecf04 commit bcd0243
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
46 changes: 45 additions & 1 deletion csharp.test/TestManagedRandomAccessFile.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.IO;
using System;
using System.IO;
using System.Linq;
using NUnit.Framework;
using ParquetSharp.IO;
Expand Down Expand Up @@ -118,6 +119,36 @@ public static void TestReadExeption()
Contains.Substring("this is an erroneous reader"));
}

[Test]
public static void TestPartialStreamRead()
{
var expected = Enumerable.Range(0, 1024 * 1024).ToArray();
using var buffer = new PartialReadStream();

// Write test data.
using (var output = new ManagedOutputStream(buffer, leaveOpen: true))
{
using var writer = new ParquetFileWriter(output, new Column[] {new Column<int>("ids")});
using var groupWriter = writer.AppendRowGroup();
using var columnWriter = groupWriter.NextColumn().LogicalWriter<int>();

columnWriter.WriteBatch(expected);

writer.Close();
}

// Seek back to start.
buffer.Seek(0, SeekOrigin.Begin);

// Read test data.
using var input = new ManagedRandomAccessFile(buffer, leaveOpen: true);
using var reader = new ParquetFileReader(input);
using var groupReader = reader.RowGroup(0);
using var columnReader = groupReader.Column(0).LogicalReader<int>();

Assert.AreEqual(expected, columnReader.ReadAll(expected.Length));
}

private sealed class ErroneousReaderStream : MemoryStream
{
public override int Read(byte[] buffer, int offset, int count)
Expand All @@ -133,5 +164,18 @@ public override void Write(byte[] buffer, int offset, int count)
throw new IOException("this is an erroneous writer");
}
}

/// <summary>
/// Simulate a stream that only partially fulfills reads sometimes,
/// eg. for data streamed from a cloud service (see https://github.com/G-Research/ParquetSharp/issues/263)
/// </summary>
private sealed class PartialReadStream : MemoryStream
{
public override int Read(byte[] buffer, int offset, int count)
{
count = Math.Min(count, 1024 * 1024);
return base.Read(buffer, offset, count);
}
}
}
}
12 changes: 8 additions & 4 deletions csharp/IO/ManagedOutputStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ private byte Write(IntPtr src, long nbytes, out string? exception)
{
try
{
#if !NETSTANDARD20
var buffer = new byte[(int) nbytes];
#if !NETSTANDARD2_1_OR_GREATER
var buffer = new byte[(int) Math.Min(nbytes, MaxArraySize)];
#endif

while (nbytes > 0)
{
var ibytes = (int) nbytes;
var ibytes = (int) Math.Min(nbytes, MaxArraySize);

#if NETSTANDARD20
#if NETSTANDARD2_1_OR_GREATER
unsafe
{
_stream.Write(new Span<byte>(src.ToPointer(), ibytes));
Expand Down Expand Up @@ -180,5 +180,9 @@ private static extern IntPtr ManagedOutputStream_Create(
// ReSharper disable NotAccessedField.Local
private string? _exceptionMessage;
// ReSharper restore NotAccessedField.Local

// Maximum size of a byte array,
// see https://docs.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/runtime/gcallowverylargeobjects-element#remarks
private const long MaxArraySize = 2_147_483_591;
}
}
36 changes: 27 additions & 9 deletions csharp/IO/ManagedRandomAccessFile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,32 @@ private byte Read(long nbytes, IntPtr bytesRead, IntPtr dest, out string? except
{
try
{
#if NETSTANDARD20
unsafe
#if !NETSTANDARD2_1_OR_GREATER
var buffer = new byte[(int) Math.Min(nbytes, MaxArraySize)];
#endif
var totalRead = 0L;
while (totalRead < nbytes)
{
var read = Stream.Read(new Span<byte>(dest.ToPointer(), (int)nbytes));
Marshal.WriteInt64(bytes_read, read);
}
var bytesToRead = (int) Math.Min(nbytes - totalRead, MaxArraySize);
int read;
#if NETSTANDARD2_1_OR_GREATER
unsafe
{
read = _stream.Read(new Span<byte>(dest.ToPointer(), bytesToRead));
}
#else
var buffer = new byte[(int) nbytes];
var read = _stream.Read(buffer, 0, (int) nbytes);
Marshal.Copy(buffer, 0, dest, read);
Marshal.WriteInt64(bytesRead, read);
read = _stream.Read(buffer, 0, bytesToRead);
Marshal.Copy(buffer, 0, dest, read);
#endif
if (read == 0)
{
break;
}
totalRead += read;
dest = IntPtr.Add(dest, read);
}

Marshal.WriteInt64(bytesRead, totalRead);
exception = null;
return 0;
}
Expand Down Expand Up @@ -188,5 +202,9 @@ private static extern IntPtr ManagedRandomAccessFile_Create(
// ReSharper disable NotAccessedField.Local
private string? _exceptionMessage;
// ReSharper restore NotAccessedField.Local

// Maximum size of a byte array,
// see https://docs.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/runtime/gcallowverylargeobjects-element#remarks
private const long MaxArraySize = 2_147_483_591;
}
}

0 comments on commit bcd0243

Please sign in to comment.