diff --git a/Source/SuperLinq/Batch.Buffered.cs b/Source/SuperLinq/Batch.Buffered.cs index 634048fe..494c2443 100644 --- a/Source/SuperLinq/Batch.Buffered.cs +++ b/Source/SuperLinq/Batch.Buffered.cs @@ -155,7 +155,7 @@ public static IEnumerable Batch( ArgumentNullException.ThrowIfNull(array); ArgumentNullException.ThrowIfNull(resultSelector); ArgumentOutOfRangeException.ThrowIfLessThan(size, 1); - ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(size, array.Length); + ArgumentOutOfRangeException.ThrowIfGreaterThan(size, array.Length); return BatchImpl(source, array, size, resultSelector); } @@ -166,12 +166,17 @@ private static IEnumerable BatchImpl( int size, Func, TResult> resultSelector) { - if (source is ICollection coll - && coll.Count <= size) + if (source is ICollection coll) { - coll.CopyTo(array, 0); - yield return resultSelector(new ArraySegment(array, 0, coll.Count)); - yield break; + if (coll.Count == 0) + yield break; + + if (coll.Count <= size) + { + coll.CopyTo(array, 0); + yield return resultSelector(new ArraySegment(array, 0, coll.Count)); + yield break; + } } var n = 0; diff --git a/Source/SuperLinq/Batch.cs b/Source/SuperLinq/Batch.cs index b3972b07..818a82e0 100644 --- a/Source/SuperLinq/Batch.cs +++ b/Source/SuperLinq/Batch.cs @@ -65,10 +65,6 @@ static IEnumerable> Core(IEnumerable source, int size) yield break; } } - else if (source.TryGetCollectionCount() == 0) - { - yield break; - } var n = 0; foreach (var item in source) @@ -101,7 +97,7 @@ public BatchIterator(IList source, int size) _size = size; } - public override int Count => ((_source.Count - 1) / _size) + 1; + public override int Count => _source.Count == 0 ? 0 : ((_source.Count - 1) / _size) + 1; protected override IEnumerable> GetEnumerable() { diff --git a/Tests/SuperLinq.Test/BatchTest.cs b/Tests/SuperLinq.Test/BatchTest.cs index f2fe7f0c..d7b1fbbe 100644 --- a/Tests/SuperLinq.Test/BatchTest.cs +++ b/Tests/SuperLinq.Test/BatchTest.cs @@ -2,12 +2,18 @@ public class BatchTest { - #region Regular [Fact] public void BatchIsLazy() { _ = new BreakingSequence().Batch(1); _ = new BreakingSequence().Buffer(1); + + _ = new BreakingSequence() + .Batch(1, BreakingFunc.Of, int>()); + _ = new BreakingSequence() + .Batch(new int[2], BreakingFunc.Of, int>()); + _ = new BreakingSequence() + .Batch(new int[2], 1, BreakingFunc.Of, int>()); } [Fact] @@ -16,6 +22,21 @@ public void BatchValidatesSize() _ = Assert.Throws("size", () => new BreakingSequence() .Batch(0)); + + _ = Assert.Throws("size", + () => new BreakingSequence() + .Batch(0, BreakingFunc.Of, int>())); + + _ = Assert.Throws("size", + () => new BreakingSequence() + .Batch([], 0, BreakingFunc.Of, int>())); + + _ = Assert.Throws("size", + () => new BreakingSequence() + .Batch(new int[5], 6, BreakingFunc.Of, int>())); + + _ = new BreakingSequence() + .Batch(new int[5], 5, BreakingFunc.Of, int>()); } public static IEnumerable GetFourElementSequences() => @@ -77,98 +98,147 @@ public void BatchModifiedDoesNotAffectPreviousBatch(IDisposableEnumerable s } } + public enum BatchMethod + { + Traditional, + BufferSize, + BufferArray, + BufferSizeArray, + } + + private static IEnumerable GetBatchTestSequences(IEnumerable source) + { + foreach (var seq in source.GetListSequences()) + yield return new object[] { seq, BatchMethod.Traditional, }; + yield return new object[] { source.AsTestingSequence(maxEnumerations: 2), BatchMethod.BufferSize, }; + yield return new object[] { source.AsTestingSequence(maxEnumerations: 2), BatchMethod.BufferArray, }; + yield return new object[] { source.AsTestingSequence(maxEnumerations: 2), BatchMethod.BufferSizeArray, }; + } + + private static IEnumerable> GetBatches( + IEnumerable seq, + BatchMethod method, + int size) => + method switch + { + BatchMethod.Traditional => seq.Batch(size), + BatchMethod.BufferSize => seq.Batch(size, arr => arr.ToList()), + BatchMethod.BufferArray => seq.Batch(new T[size], arr => arr.ToList()), + BatchMethod.BufferSizeArray => seq.Batch(new T[size + 10], size, arr => arr.ToList()), + _ => throw new NotSupportedException(), + }; + public static IEnumerable GetEmptySequences() => - Array.Empty() - .GetAllSequences() - .Select(x => new object[] { x }); + GetBatchTestSequences(Enumerable.Empty()); [Theory] [MemberData(nameof(GetEmptySequences))] - public void BatchWithEmptySource(IDisposableEnumerable seq) + public void BatchWithEmptySource(IDisposableEnumerable seq, BatchMethod bm) { using (seq) - Assert.Empty(seq.Batch(1)); - } - - [Fact] - // branch not able to run with `BreakingList<>` - public void BatchWithEmptyIListProvider() - { - Enumerable.Range(0, 0) - .Batch(1) - .AssertSequenceEqual(); + { + var result = GetBatches(seq, bm, 5); + result.AssertSequenceEqual(); + } } public static IEnumerable GetSequences() => - Enumerable.Range(1, 9) - .GetListSequences() - .Select(x => new object[] { x }); + GetBatchTestSequences(Enumerable.Range(1, 9)); [Theory] [MemberData(nameof(GetSequences))] - public void BatchEvenlyDivisibleSequence(IDisposableEnumerable seq) + public void BatchEvenlyDivisibleSequence(IDisposableEnumerable seq, BatchMethod bm) { using (seq) { - var result = seq.Batch(3); - + var result = GetBatches(seq, bm, 3); result.AssertSequenceEqual( - Seq(1, 2, 3), - Seq(4, 5, 6), - Seq(7, 8, 9)); + [1, 2, 3], + [4, 5, 6], + [7, 8, 9]); } } [Theory] [MemberData(nameof(GetSequences))] - public void BatchUnevenlyDivisibleSequence(IDisposableEnumerable seq) + public void BatchUnevenlyDivisibleSequence(IDisposableEnumerable seq, BatchMethod bm) { using (seq) { - var result = seq.Batch(4); - + var result = GetBatches(seq, bm, 4); result.AssertSequenceEqual( - Seq(1, 2, 3, 4), - Seq(5, 6, 7, 8), - Seq(9)); + [1, 2, 3, 4], + [5, 6, 7, 8], + [9]); } } [Theory] [MemberData(nameof(GetSequences))] - public void BatchSmallSequence(IDisposableEnumerable seq) + public void BatchSmallSequence(IDisposableEnumerable seq, BatchMethod bm) { using (seq) { - var result = seq.Batch(10); - + var result = GetBatches(seq, bm, 10); result.AssertSequenceEqual( - Seq(1, 2, 3, 4, 5, 6, 7, 8, 9)); + [1, 2, 3, 4, 5, 6, 7, 8, 9]); } } - [Fact] - public void BatchWithCollectionSmallerThanBatchSize() + public static IEnumerable GetBreakingCollections(IEnumerable source) { - using var seq = new BreakingCollection(Enumerable.Range(1, 9)); - seq.Batch(10).Consume(); + yield return new object[] { source.AsBreakingCollection(), BatchMethod.Traditional, }; + yield return new object[] { source.AsBreakingCollection(), BatchMethod.BufferSize, }; + yield return new object[] { source.AsBreakingCollection(), BatchMethod.BufferArray, }; + yield return new object[] { source.AsBreakingCollection(), BatchMethod.BufferSizeArray, }; } - [Fact] - public void BatchCollectionSizeNotEvaluatedEarly() + [Theory] + [MemberData(nameof(GetBreakingCollections), new int[] { })] + public void BatchWithEmptyCollection(IDisposableEnumerable seq, BatchMethod bm) { - var list = new List(Enumerable.Range(1, 3)); - var result = list.Batch(3); + using (seq) + { + var result = GetBatches(seq, bm, 10); + result.AssertSequenceEqual(); + } + } + + [Theory] + [MemberData(nameof(GetBreakingCollections), new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 })] + public void BatchWithCollectionSmallerThanBatchSize(IDisposableEnumerable seq, BatchMethod bm) + { + using (seq) + { + var result = GetBatches(seq, bm, 10); + result.AssertSequenceEqual( + [1, 2, 3, 4, 5, 6, 7, 8, 9]); + } + } + + [Theory] + [InlineData(BatchMethod.Traditional)] + [InlineData(BatchMethod.BufferSize)] + [InlineData(BatchMethod.BufferArray)] + [InlineData(BatchMethod.BufferSizeArray)] + public void BatchCollectionSizeNotEvaluatedEarly(BatchMethod bm) + { + var list = new List() { 1, 2, 3, }; + var result = GetBatches(list, bm, 3); list.Add(4); result.AssertCount(2).Consume(); } - [Fact] - public void BatchUsesCollectionCountAtIterationTime() + [Theory] + [InlineData(BatchMethod.Traditional)] + [InlineData(BatchMethod.BufferSize)] + [InlineData(BatchMethod.BufferArray)] + [InlineData(BatchMethod.BufferSizeArray)] + public void BatchUsesCollectionCountAtIterationTime(BatchMethod bm) { var list = new List(Enumerable.Range(1, 3)); using var ts = new BreakingCollection(list); - var result = ts.Batch(3); + var result = GetBatches(ts, bm, 3); // should use `CopyTo` result.AssertCount(1).Consume(); @@ -208,89 +278,4 @@ public void BatchListUnevenlyDivisibleBehavior() Assert.Equal(Enumerable.Range(9_980, 20), result.ElementAt(^2)); Assert.Equal(Enumerable.Range(10_000, 2), result.ElementAt(^1)); } - #endregion - - #region Buffered - [Fact] - public void BatchBufferedIsLazy() - { - _ = new BreakingSequence() - .Batch(1, BreakingFunc.Of, int>()); - _ = new BreakingSequence() - .Batch(new int[2], BreakingFunc.Of, int>()); - _ = new BreakingSequence() - .Batch(new int[2], 1, BreakingFunc.Of, int>()); - } - - [Fact] - public void BatchBufferedValidatesSize() - { - _ = Assert.Throws("size", - () => new BreakingSequence() - .Batch(0, BreakingFunc.Of, int>())); - _ = Assert.Throws("size", - () => new BreakingSequence() - .Batch(new int[2], 0, BreakingFunc.Of, int>())); - _ = Assert.Throws("size", - () => new BreakingSequence() - .Batch(new int[2], 3, BreakingFunc.Of, int>())); - } - - [Fact] - public void BatchBufferedWithEmptySource() - { - using var xs = TestingSequence.Of(); - Assert.Empty(xs.Batch(1, BreakingFunc.Of, int>())); - } - - [Fact] - public void BatchBufferedEvenlyDivisibleSequence() - { - using var seq = Enumerable.Range(1, 9).AsTestingSequence(); - - var result = seq.Batch(3, l => string.Join(",", l)); - using var reader = result.Read(); - Assert.Equal("1,2,3", reader.Read()); - Assert.Equal("4,5,6", reader.Read()); - Assert.Equal("7,8,9", reader.Read()); - reader.ReadEnd(); - } - - [Fact] - public void BatchBufferedUnevenlyDivisibleSequence() - { - using var seq = Enumerable.Range(1, 9).AsTestingSequence(); - - var result = seq.Batch(4, l => string.Join(",", l)); - using var reader = result.Read(); - Assert.Equal("1,2,3,4", reader.Read()); - Assert.Equal("5,6,7,8", reader.Read()); - Assert.Equal("9", reader.Read()); - reader.ReadEnd(); - } - - [Fact] - public void BatchBufferedWithCollectionSmallerThanBatchSize() - { - using var seq = new BreakingCollection(Enumerable.Range(1, 9)); - seq.Batch(10, i => i.Sum()).Consume(); - } - - [Fact] - public void BatchBufferedUsesCollectionCountAtIterationTime() - { - var list = new List(Enumerable.Range(1, 3)); - using var ts = new BreakingCollection(list); - var result = ts.Batch(3, w => w[0]); - - // should use `CopyTo` - result.AssertCount(1).Consume(); - - list.Add(4); - - // should fail trying to enumerate - _ = Assert.Throws( - () => result.AssertCount(2).Consume()); - } - #endregion }