diff --git a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs index a6d3de2254..ea442a9793 100644 --- a/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs +++ b/src/Riok.Mapperly/Descriptors/MappingBuilders/EnumerableMappingBuilder.cs @@ -87,10 +87,16 @@ public static class EnumerableMappingBuilder return CreateForEach(nameof(Queue.Enqueue)); // create a foreach loop with add calls if source is not an array - // and ICollection.Add(T): void is implemented and not explicit + // and void ICollection.Add(T) or bool ISet.Add(T) is implemented and not explicit // ensures add is not called and immutable types - if (!ctx.Target.IsArrayType() && ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ICollection<>)), AddMethodName)) + // ISet.Add(T) is explicitly needed as sets implement the ICollection.Add(T) explicit, + // and override the add method with new + var hasImplicitCollectionAdd = ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ICollection<>)), AddMethodName); + var hasImplicitSetAdd = ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ISet<>)), AddMethodName); + if (!ctx.Target.IsArrayType() && (hasImplicitCollectionAdd || hasImplicitSetAdd)) + { return CreateForEach(AddMethodName); + } // if a mapping could be created for an immutable collection // we diagnostic when it is an existing target mapping @@ -167,22 +173,22 @@ private static LinqConstructorMapping BuildLinqConstructorMapping(MappingBuilder } // create a foreach loop with add calls if source is not an array - // and ICollection.Add(T): void is implemented and not explicit - // ensures add is not called and immutable types - if (!ctx.Target.IsArrayType() && ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ICollection<>)), AddMethodName)) - { - var ensureCapacityStatement = EnsureCapacityBuilder.TryBuildEnsureCapacity(ctx.Source, ctx.Target, ctx.Types); - return new ForEachAddEnumerableMapping( - ctx.Source, - ctx.Target, - elementMapping, - objectFactory, - AddMethodName, - ensureCapacityStatement - ); - } + // and void ICollection.Add(T) or bool ISet.Add(T) is implemented and not explicit + // ensures .Add() is not called on immutable types + var hasImplicitCollectionAdd = ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ICollection<>)), AddMethodName); + var hasImplicitSetAdd = ctx.Target.HasImplicitGenericImplementation(ctx.Types.Get(typeof(ISet<>)), AddMethodName); + if (ctx.Target.IsArrayType() || (!hasImplicitCollectionAdd && !hasImplicitSetAdd)) + return null; - return null; + var ensureCapacityStatement = EnsureCapacityBuilder.TryBuildEnsureCapacity(ctx.Source, ctx.Target, ctx.Types); + return new ForEachAddEnumerableMapping( + ctx.Source, + ctx.Target, + elementMapping, + objectFactory, + AddMethodName, + ensureCapacityStatement + ); } private static (bool CanMapWithLinq, string? CollectMethod) ResolveCollectMethodName( diff --git a/src/Riok.Mapperly/Helpers/SymbolExtensions.cs b/src/Riok.Mapperly/Helpers/SymbolExtensions.cs index 2986807c28..2d5153ed1a 100644 --- a/src/Riok.Mapperly/Helpers/SymbolExtensions.cs +++ b/src/Riok.Mapperly/Helpers/SymbolExtensions.cs @@ -153,17 +153,20 @@ out bool isExplicit } var interfaceSymbol = typedInterface.GetMembers(symbolName).First(); - - var symbolImplementaton = t.FindImplementationForInterfaceMember(interfaceSymbol); + var symbolImplementation = t.FindImplementationForInterfaceMember(interfaceSymbol); // if null then the method is unimplemented // symbol implements genericInterface but has not implemented the corresponding methods - // this can only occur in unit tests - if (symbolImplementaton == null) - throw new NotSupportedException("Symbol implementation cannot be null for objects implementing interface."); + // this is for example the case for arrays (arrays implement several interfaces at runtime) + // and unit tests for which not the full interface is implemented + if (symbolImplementation == null) + { + isExplicit = false; + return false; + } // check if symbol is explicit - isExplicit = symbolImplementaton switch + isExplicit = symbolImplementation switch { IMethodSymbol methodSymbol => methodSymbol.ExplicitInterfaceImplementations.Any(), IPropertySymbol propertySymbol => propertySymbol.ExplicitInterfaceImplementations.Any(), diff --git a/test/Riok.Mapperly.Tests/Mapping/EnumerableSetTest.cs b/test/Riok.Mapperly.Tests/Mapping/EnumerableSetTest.cs new file mode 100644 index 0000000000..7a6277ec3a --- /dev/null +++ b/test/Riok.Mapperly.Tests/Mapping/EnumerableSetTest.cs @@ -0,0 +1,59 @@ +namespace Riok.Mapperly.Tests.Mapping; + +public class EnumerableSetTest +{ + [Fact] + public void ExistingEnumerableToExistingSet() + { + var source = TestSourceBuilder.Mapping( + "A", + "B", + "class A { public IEnumerable Values { get; } }", + "class B { public ISet Values { get; } }" + ); + TestHelper + .GenerateMapper(source) + .Should() + .HaveSingleMethodBody( + """ + var target = new global::B(); + foreach (var item in source.Values) + { + target.Values.Add(item); + } + + return target; + """ + ); + } + + [Fact] + public void ExistingEnumerableToExistingHashSet() + { + var source = TestSourceBuilder.Mapping( + "A", + "B", + "class A { public IEnumerable Values { get; } }", + "class B { public HashSet Values { get; } }" + ); + TestHelper + .GenerateMapper(source) + .Should() + .HaveSingleMethodBody( + """ + var target = new global::B(); + if (global::System.Linq.Enumerable.TryGetNonEnumeratedCount(source.Values, out var sourceCount)) + { + target.Values.EnsureCapacity(sourceCount + target.Values.Count); + } + + foreach (var item in source.Values) + { + target.Values.Add(item); + } + + return target; + """ + ); + } +} diff --git a/test/Riok.Mapperly.Tests/Mapping/EnumerableTest.cs b/test/Riok.Mapperly.Tests/Mapping/EnumerableTest.cs index ed9e85e8c3..1318095c26 100644 --- a/test/Riok.Mapperly.Tests/Mapping/EnumerableTest.cs +++ b/test/Riok.Mapperly.Tests/Mapping/EnumerableTest.cs @@ -671,9 +671,9 @@ public void EnumerableToReadOnlyArrayPropertyShouldDiagnostic() .Should() .HaveSingleMethodBody( """ - var target = new global::B(); - return target; - """ + var target = new global::B(); + return target; + """ ) .HaveDiagnostic(DiagnosticDescriptors.CannotMapToReadOnlyMember); }