diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/IImmutableArray.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/IImmutableArray.cs index aee7bf152465bf..8e721e7d49b7fd 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/IImmutableArray.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/IImmutableArray.cs @@ -22,7 +22,5 @@ internal interface IImmutableArray /// Gets an untyped reference to the array. /// Array Array { get; } - - void ThrowInvalidOperationIfNotInitialized(); } } diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray.cs index 85c3c2530b57de..2b6c3f5aef2b62 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray.cs @@ -106,14 +106,18 @@ public static ImmutableArray CreateRange(IEnumerable items) var immutableArray = items as IImmutableArray; if (immutableArray != null) { - immutableArray.ThrowInvalidOperationIfNotInitialized(); + Array array = immutableArray.Array; + if (array == null) + { + throw new InvalidOperationException(SR.InvalidOperationOnDefaultArray); + } - // immutableArray.Array must not be null at this point, and we know it's an + // `array` must not be null at this point, and we know it's an // ImmutableArray or ImmutableArray as they are // the only types that could be both IEnumerable and IImmutableArray. // As such, we know that items is either an ImmutableArray or // ImmutableArray, and we can cast the array to T[]. - return new ImmutableArray((T[])immutableArray.Array); + return new ImmutableArray((T[])array); } // We don't recognize the source as an array that is safe to use. diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.Builder.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.Builder.cs index 23034f594a6eae..796c9a24963153 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.Builder.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.Builder.cs @@ -245,6 +245,12 @@ public void AddRange(IEnumerable items) if (items.TryGetCount(out count)) { this.EnsureCapacity(this.Count + count); + + if (items.TryCopyTo(_elements, _count)) + { + _count += count; + return; + } } foreach (var item in items) diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.cs index 6e8ab3208862d6..c6585b17da3d8e 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableArray_1.cs @@ -613,7 +613,7 @@ public ImmutableArray InsertRange(int index, ImmutableArray items) { var self = this; self.ThrowNullRefIfNotInitialized(); - ThrowNullRefIfNotInitialized(items); + items.ThrowNullRefIfNotInitialized(); Requires.Range(index >= 0 && index <= self.Length, nameof(index)); if (self.IsEmpty) @@ -692,6 +692,7 @@ public ImmutableArray AddRange(ImmutableArray items) public ImmutableArray SetItem(int index, T item) { var self = this; + self.ThrowNullRefIfNotInitialized(); Requires.Range(index >= 0 && index < self.Length, nameof(index)); T[] tmp = new T[self.Length]; @@ -728,7 +729,7 @@ public ImmutableArray Replace(T oldValue, T newValue) public ImmutableArray Replace(T oldValue, T newValue, IEqualityComparer equalityComparer) { var self = this; - int index = self.IndexOf(oldValue, equalityComparer); + int index = self.IndexOf(oldValue, 0, self.Length, equalityComparer); if (index < 0) { throw new ArgumentException(SR.CannotFindOldValue, nameof(oldValue)); @@ -764,7 +765,7 @@ public ImmutableArray Remove(T item, IEqualityComparer equalityComparer) { var self = this; self.ThrowNullRefIfNotInitialized(); - int index = self.IndexOf(item, equalityComparer); + int index = self.IndexOf(item, 0, self.Length, equalityComparer); return index < 0 ? self : self.RemoveAt(index); @@ -791,6 +792,7 @@ public ImmutableArray RemoveAt(int index) public ImmutableArray RemoveRange(int index, int length) { var self = this; + self.ThrowNullRefIfNotInitialized(); Requires.Range(index >= 0 && index <= self.Length, nameof(index)); Requires.Range(length >= 0 && index + length <= self.Length, nameof(length)); @@ -836,18 +838,18 @@ public ImmutableArray RemoveRange(IEnumerable items, IEqualityComparer self.ThrowNullRefIfNotInitialized(); Requires.NotNull(items, nameof(items)); - var indexesToRemove = new SortedSet(); + var indicesToRemove = new SortedSet(); foreach (var item in items) { - int index = self.IndexOf(item, equalityComparer); - while (index >= 0 && !indexesToRemove.Add(index) && index + 1 < self.Length) + int index = self.IndexOf(item, 0, self.Length, equalityComparer); + while (index >= 0 && !indicesToRemove.Add(index) && index + 1 < self.Length) { // This is a duplicate of one we've found. Try hard to find another instance in the list to remove. index = self.IndexOf(item, index + 1, equalityComparer); } } - return self.RemoveAtRange(indexesToRemove); + return self.RemoveAtRange(indicesToRemove); } /// @@ -917,22 +919,22 @@ public ImmutableArray RemoveAll(Predicate match) return self; } - List removeIndexes = null; + List removeIndices = null; for (int i = 0; i < self.array.Length; i++) { if (match(self.array[i])) { - if (removeIndexes == null) + if (removeIndices == null) { - removeIndexes = new List(); + removeIndices = new List(); } - removeIndexes.Add(i); + removeIndices.Add(i); } } - return removeIndexes != null ? - self.RemoveAtRange(removeIndexes) : + return removeIndices != null ? + self.RemoveAtRange(removeIndices) : self; } @@ -1566,7 +1568,9 @@ bool IStructuralEquatable.Equals(object other, IEqualityComparer comparer) var theirs = other as IImmutableArray; if (theirs != null) { - if (self.array == null && theirs.Array == null) + otherArray = theirs.Array; + + if (self.array == null && otherArray == null) { return true; } @@ -1574,8 +1578,6 @@ bool IStructuralEquatable.Equals(object other, IEqualityComparer comparer) { return false; } - - otherArray = theirs.Array; } } @@ -1617,16 +1619,16 @@ int IStructuralComparable.CompareTo(object other, IComparer comparer) var theirs = other as IImmutableArray; if (theirs != null) { - if (self.array == null && theirs.Array == null) + otherArray = theirs.Array; + + if (self.array == null && otherArray == null) { return 0; } - else if (self.array == null ^ theirs.Array == null) + else if (self.array == null ^ otherArray == null) { throw new ArgumentException(SR.ArrayInitializedStateNotEqual, nameof(other)); } - - otherArray = theirs.Array; } } @@ -1673,33 +1675,28 @@ private void ThrowInvalidOperationIfNotInitialized() } } - void IImmutableArray.ThrowInvalidOperationIfNotInitialized() - { - this.ThrowInvalidOperationIfNotInitialized(); - } - /// - /// Returns an array with items at the specified indexes removed. + /// Returns an array with items at the specified indices removed. /// - /// A **sorted set** of indexes to elements that should be omitted from the returned array. + /// A **sorted set** of indices to elements that should be omitted from the returned array. /// The new array. - private ImmutableArray RemoveAtRange(ICollection indexesToRemove) + private ImmutableArray RemoveAtRange(ICollection indicesToRemove) { var self = this; self.ThrowNullRefIfNotInitialized(); - Requires.NotNull(indexesToRemove, nameof(indexesToRemove)); + Requires.NotNull(indicesToRemove, nameof(indicesToRemove)); - if (indexesToRemove.Count == 0) + if (indicesToRemove.Count == 0) { // Be sure to return a !IsDefault instance. return self; } - var newArray = new T[self.Length - indexesToRemove.Count]; + var newArray = new T[self.Length - indicesToRemove.Count]; int copied = 0; int removed = 0; int lastIndexRemoved = -1; - foreach (var indexToRemove in indexesToRemove) + foreach (var indexToRemove in indicesToRemove) { int copyLength = lastIndexRemoved == -1 ? indexToRemove : (indexToRemove - lastIndexRemoved - 1); Debug.Assert(indexToRemove > lastIndexRemoved); // We require that the input be a sorted set. @@ -1713,13 +1710,5 @@ private ImmutableArray RemoveAtRange(ICollection indexesToRemove) return new ImmutableArray(newArray); } - - /// - /// Throws a if the specified array is uninitialized. - /// - private static void ThrowNullRefIfNotInitialized(ImmutableArray array) - { - array.ThrowNullRefIfNotInitialized(); - } } } diff --git a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableExtensions.cs b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableExtensions.cs index b3f9b392924146..783192bbd5e0b7 100644 --- a/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableExtensions.cs +++ b/src/libraries/System.Collections.Immutable/src/System/Collections/Immutable/ImmutableExtensions.cs @@ -106,7 +106,11 @@ internal static int GetCount(ref IEnumerable sequence) /// internal static bool TryCopyTo(this IEnumerable sequence, T[] array, int arrayIndex) { - // IList is the GCD of what the following 2 types implement. + Debug.Assert(sequence != null); + Debug.Assert(array != null); + Debug.Assert(arrayIndex >= 0 && arrayIndex <= array.Length); + + // IList is the GCD of what the following types implement. var listInterface = sequence as IList; if (listInterface != null) { @@ -117,6 +121,16 @@ internal static bool TryCopyTo(this IEnumerable sequence, T[] array, int a return true; } + // Array.Copy can throw an ArrayTypeMismatchException if the underlying type of + // the destination array is not typeof(T[]), but is assignment-compatible with T[]. + // See https://github.com/dotnet/corefx/issues/2241 for more info. + if (sequence.GetType() == typeof(T[])) + { + var sourceArray = (T[])sequence; + Array.Copy(sourceArray, 0, array, arrayIndex, sourceArray.Length); + return true; + } + if (sequence is ImmutableArray) { var immutable = (ImmutableArray)sequence; diff --git a/src/libraries/System.Collections.Immutable/tests/ImmutableArrayBuilderTest.cs b/src/libraries/System.Collections.Immutable/tests/ImmutableArrayBuilderTest.cs index 8b9fec0efa8153..2bf6bcc7ce7c21 100644 --- a/src/libraries/System.Collections.Immutable/tests/ImmutableArrayBuilderTest.cs +++ b/src/libraries/System.Collections.Immutable/tests/ImmutableArrayBuilderTest.cs @@ -75,6 +75,9 @@ public void AddRangeIEnumerable() builder.AddRange((IEnumerable)new[] { 2 }); Assert.Equal(2, builder.Count); + builder.AddRange((IEnumerable)new int[0]); + Assert.Equal(2, builder.Count); + // Exceed capacity builder.AddRange(Enumerable.Range(3, 2)); // use an enumerable without a breakable Count Assert.Equal(4, builder.Count); diff --git a/src/libraries/System.Collections.Immutable/tests/ImmutableArrayTest.cs b/src/libraries/System.Collections.Immutable/tests/ImmutableArrayTest.cs index 4895474e34aeb7..d28e00006595fd 100644 --- a/src/libraries/System.Collections.Immutable/tests/ImmutableArrayTest.cs +++ b/src/libraries/System.Collections.Immutable/tests/ImmutableArrayTest.cs @@ -1127,7 +1127,7 @@ public void RemoveAtInvalid(IEnumerable source) } [Theory] - [InlineData(-1, Skip = "#14961")] + [InlineData(-1)] [InlineData(0)] [InlineData(1)] public void RemoveAtDefaultInvalid(int index) @@ -1217,7 +1217,7 @@ public void RemoveRangeIndexLengthInvalid(IEnumerable source) } [Theory] - [InlineData(-1, 0, Skip = "#14961")] + [InlineData(-1, 0)] [InlineData(0, -1)] [InlineData(0, 0)] [InlineData(1, -1)] @@ -1446,9 +1446,8 @@ public void ReplaceDefaultInvalid() { Assert.All(SharedEqualityComparers(), comparer => { - // Uncomment when #14961 is fixed. - // Assert.Throws(() => s_emptyDefault.Replace(123, 123)); - // Assert.Throws(() => s_emptyDefault.Replace(123, 123, comparer)); + Assert.Throws(() => s_emptyDefault.Replace(123, 123)); + Assert.Throws(() => s_emptyDefault.Replace(123, 123, comparer)); Assert.Throws(() => ((IImmutableList)s_emptyDefault).Replace(123, 123)); Assert.Throws(() => ((IImmutableList)s_emptyDefault).Replace(123, 123, comparer)); @@ -1485,7 +1484,7 @@ public void SetItemInvalid(IEnumerable source) } [Theory] - [InlineData(-1, Skip = "#14961")] + [InlineData(-1)] [InlineData(0)] [InlineData(1)] public void SetItemDefaultInvalid(int index)