diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
index e2848ea25d..8106ff5a2c 100644
--- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs
+++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs
@@ -871,5 +871,16 @@ public static double Cos(double a)
var res = Math.Cos(a);
return Math.Abs(res) > 1 ? double.NaN : res;
}
+
+ ///
+ /// Returns the smallest integral value that is greater than or equal to the result of the division.
+ ///
+ /// Number to be divided.
+ /// Number with which to divide the numerator.
+ ///
+ public static long DivisionCeiling(long numerator, long denomenator)
+ {
+ return (checked(numerator + denomenator) - 1) / denomenator;
+ }
}
}
diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs
index f0ecde34dc..21271f6e5c 100644
--- a/src/Microsoft.ML.Parquet/ParquetLoader.cs
+++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs
@@ -94,7 +94,7 @@ public sealed class Arguments
private readonly int _columnChunkReadSize;
private readonly Column[] _columnsLoaded;
private readonly DataSet _schemaDataSet;
- private const int _defaultColumnChunkReadSize = 100; // Should ideally be close to Rowgroup size
+ private const int _defaultColumnChunkReadSize = 1000000;
private bool _disposed;
@@ -368,8 +368,8 @@ private sealed class Cursor : RootCursorBase, IRowCursor
private readonly Delegate[] _getters;
private readonly ReaderOptions _readerOptions;
private int _curDataSetRow;
- private IEnumerator _dataSetEnumerator;
- private IEnumerator _blockEnumerator;
+ private IEnumerator _dataSetEnumerator;
+ private IEnumerator _blockEnumerator;
private IList[] _columnValues;
private IRandom _rand;
@@ -390,11 +390,18 @@ public Cursor(ParquetLoader parent, Func predicate, IRandom rand)
Columns = _loader._columnsLoaded.Select(i => i.Name).ToArray()
};
- int numBlocks = (int)Math.Ceiling(((decimal)parent.GetRowCount() / _readerOptions.Count));
- int[] blockOrder = _rand == null ? Utils.GetIdentityPermutation(numBlocks) : Utils.GetRandomPermutation(rand, numBlocks);
+ // The number of blocks is calculated based on the specified rows in a block (defaults to 1M).
+ // Since we want to shuffle the blocks in addition to shuffling the rows in each block, checks
+ // are put in place to ensure we can produce a shuffle order for the blocks.
+ var numBlocks = MathUtils.DivisionCeiling((long)parent.GetRowCount(), _readerOptions.Count);
+ if (numBlocks > int.MaxValue)
+ {
+ throw _loader._host.ExceptParam(nameof(Arguments.ColumnChunkReadSize), "Error due to too many blocks. Try increasing block size.");
+ }
+ var blockOrder = CreateOrderSequence((int)numBlocks);
_blockEnumerator = blockOrder.GetEnumerator();
- _dataSetEnumerator = new int[0].GetEnumerator(); // Initialize an empty enumerator to get started
+ _dataSetEnumerator = Enumerable.Empty().GetEnumerator();
_columnValues = new IList[_actives.Length];
_getters = new Delegate[_actives.Length];
for (int i = 0; i < _actives.Length; ++i)
@@ -472,12 +479,12 @@ protected override bool MoveNextCore()
{
if (_dataSetEnumerator.MoveNext())
{
- _curDataSetRow = (int)_dataSetEnumerator.Current;
+ _curDataSetRow = _dataSetEnumerator.Current;
return true;
}
else if (_blockEnumerator.MoveNext())
{
- _readerOptions.Offset = (int)_blockEnumerator.Current * _readerOptions.Count;
+ _readerOptions.Offset = (long)_blockEnumerator.Current * _readerOptions.Count;
// When current dataset runs out, read the next portion of the parquet file.
DataSet ds;
@@ -486,9 +493,9 @@ protected override bool MoveNextCore()
ds = ParquetReader.Read(_loader._parquetStream, _loader._parquetOptions, _readerOptions);
}
- int[] dataSetOrder = _rand == null ? Utils.GetIdentityPermutation(ds.RowCount) : Utils.GetRandomPermutation(_rand, ds.RowCount);
+ var dataSetOrder = CreateOrderSequence(ds.RowCount);
_dataSetEnumerator = dataSetOrder.GetEnumerator();
- _curDataSetRow = dataSetOrder[0];
+ _curDataSetRow = dataSetOrder.ElementAt(0);
// Cache list for each active column
for (int i = 0; i < _actives.Length; i++)
@@ -533,6 +540,26 @@ public bool IsColumnActive(int col)
Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col));
return _colToActivesIndex[col] >= 0;
}
+
+ ///
+ /// Creates a in-order or shuffled sequence, based on whether _rand is specified.
+ /// If unable to create a shuffle sequence, will default to sequential.
+ ///
+ /// Number of elements in the sequence.
+ ///
+ private IEnumerable CreateOrderSequence(int size)
+ {
+ IEnumerable order;
+ try
+ {
+ order = _rand == null ? Enumerable.Range(0, size) : Utils.GetRandomPermutation(_rand, size);
+ }
+ catch (OutOfMemoryException)
+ {
+ order = Enumerable.Range(0, size);
+ }
+ return order;
+ }
}
#region Dispose
@@ -671,4 +698,4 @@ private string ConvertListToString(IList list)
}
}
}
-}
+}
\ No newline at end of file