diff --git a/src/Microsoft.ML.Core/Utilities/MathUtils.cs b/src/Microsoft.ML.Core/Utilities/MathUtils.cs index e2848ea25d..8106ff5a2c 100644 --- a/src/Microsoft.ML.Core/Utilities/MathUtils.cs +++ b/src/Microsoft.ML.Core/Utilities/MathUtils.cs @@ -871,5 +871,16 @@ public static double Cos(double a) var res = Math.Cos(a); return Math.Abs(res) > 1 ? double.NaN : res; } + + /// + /// Returns the smallest integral value that is greater than or equal to the result of the division. + /// + /// Number to be divided. + /// Number with which to divide the numerator. + /// + public static long DivisionCeiling(long numerator, long denomenator) + { + return (checked(numerator + denomenator) - 1) / denomenator; + } } } diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index f0ecde34dc..21271f6e5c 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -94,7 +94,7 @@ public sealed class Arguments private readonly int _columnChunkReadSize; private readonly Column[] _columnsLoaded; private readonly DataSet _schemaDataSet; - private const int _defaultColumnChunkReadSize = 100; // Should ideally be close to Rowgroup size + private const int _defaultColumnChunkReadSize = 1000000; private bool _disposed; @@ -368,8 +368,8 @@ private sealed class Cursor : RootCursorBase, IRowCursor private readonly Delegate[] _getters; private readonly ReaderOptions _readerOptions; private int _curDataSetRow; - private IEnumerator _dataSetEnumerator; - private IEnumerator _blockEnumerator; + private IEnumerator _dataSetEnumerator; + private IEnumerator _blockEnumerator; private IList[] _columnValues; private IRandom _rand; @@ -390,11 +390,18 @@ public Cursor(ParquetLoader parent, Func predicate, IRandom rand) Columns = _loader._columnsLoaded.Select(i => i.Name).ToArray() }; - int numBlocks = (int)Math.Ceiling(((decimal)parent.GetRowCount() / _readerOptions.Count)); - int[] blockOrder = _rand == null ? Utils.GetIdentityPermutation(numBlocks) : Utils.GetRandomPermutation(rand, numBlocks); + // The number of blocks is calculated based on the specified rows in a block (defaults to 1M). + // Since we want to shuffle the blocks in addition to shuffling the rows in each block, checks + // are put in place to ensure we can produce a shuffle order for the blocks. + var numBlocks = MathUtils.DivisionCeiling((long)parent.GetRowCount(), _readerOptions.Count); + if (numBlocks > int.MaxValue) + { + throw _loader._host.ExceptParam(nameof(Arguments.ColumnChunkReadSize), "Error due to too many blocks. Try increasing block size."); + } + var blockOrder = CreateOrderSequence((int)numBlocks); _blockEnumerator = blockOrder.GetEnumerator(); - _dataSetEnumerator = new int[0].GetEnumerator(); // Initialize an empty enumerator to get started + _dataSetEnumerator = Enumerable.Empty().GetEnumerator(); _columnValues = new IList[_actives.Length]; _getters = new Delegate[_actives.Length]; for (int i = 0; i < _actives.Length; ++i) @@ -472,12 +479,12 @@ protected override bool MoveNextCore() { if (_dataSetEnumerator.MoveNext()) { - _curDataSetRow = (int)_dataSetEnumerator.Current; + _curDataSetRow = _dataSetEnumerator.Current; return true; } else if (_blockEnumerator.MoveNext()) { - _readerOptions.Offset = (int)_blockEnumerator.Current * _readerOptions.Count; + _readerOptions.Offset = (long)_blockEnumerator.Current * _readerOptions.Count; // When current dataset runs out, read the next portion of the parquet file. DataSet ds; @@ -486,9 +493,9 @@ protected override bool MoveNextCore() ds = ParquetReader.Read(_loader._parquetStream, _loader._parquetOptions, _readerOptions); } - int[] dataSetOrder = _rand == null ? Utils.GetIdentityPermutation(ds.RowCount) : Utils.GetRandomPermutation(_rand, ds.RowCount); + var dataSetOrder = CreateOrderSequence(ds.RowCount); _dataSetEnumerator = dataSetOrder.GetEnumerator(); - _curDataSetRow = dataSetOrder[0]; + _curDataSetRow = dataSetOrder.ElementAt(0); // Cache list for each active column for (int i = 0; i < _actives.Length; i++) @@ -533,6 +540,26 @@ public bool IsColumnActive(int col) Ch.CheckParam(0 <= col && col < _colToActivesIndex.Length, nameof(col)); return _colToActivesIndex[col] >= 0; } + + /// + /// Creates a in-order or shuffled sequence, based on whether _rand is specified. + /// If unable to create a shuffle sequence, will default to sequential. + /// + /// Number of elements in the sequence. + /// + private IEnumerable CreateOrderSequence(int size) + { + IEnumerable order; + try + { + order = _rand == null ? Enumerable.Range(0, size) : Utils.GetRandomPermutation(_rand, size); + } + catch (OutOfMemoryException) + { + order = Enumerable.Range(0, size); + } + return order; + } } #region Dispose @@ -671,4 +698,4 @@ private string ConvertListToString(IList list) } } } -} +} \ No newline at end of file