Skip to content

Commit

Permalink
feat: cache attribute symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyMakkison committed Jun 6, 2023
1 parent b6836b2 commit beb7a53
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 197 deletions.
13 changes: 6 additions & 7 deletions src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Reflection;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Descriptors;

namespace Riok.Mapperly.Configuration;

Expand All @@ -8,27 +9,25 @@ namespace Riok.Mapperly.Configuration;
/// </summary>
internal static class AttributeDataAccessor
{
public static T? AccessFirstOrDefault<T>(Compilation compilation, ISymbol symbol)
where T : Attribute => Access<T, T>(compilation, symbol).FirstOrDefault();
public static T? AccessFirstOrDefault<T>(WellKnownTypes knownTypes, ISymbol symbol)
where T : Attribute => Access<T, T>(knownTypes, symbol).FirstOrDefault();

/// <summary>
/// Reads the attribute data and sets it on a newly created instance of <see cref="TData"/>.
/// If <see cref="TAttribute"/> has n type parameters,
/// <see cref="TData"/> needs to have an accessible ctor with the parameters 0 to n-1 to be of type <see cref="ITypeSymbol"/>.
/// </summary>
/// <param name="compilation">The compilation.</param>
/// <param name="knownTypes">The knownTypes used to get the type symbol.</param>
/// <param name="symbol">The symbol on which the attributes should be read.</param>
/// <typeparam name="TAttribute">The type of the attribute.</typeparam>
/// <typeparam name="TData">The type of the data class. If no type parameters are involved, this is usually the same as <see cref="TAttribute"/>.</typeparam>
/// <returns>The attribute data.</returns>
/// <exception cref="InvalidOperationException">If a property or ctor argument of <see cref="TData"/> could not be read on the attribute.</exception>
public static IEnumerable<TData> Access<TAttribute, TData>(Compilation compilation, ISymbol symbol)
public static IEnumerable<TData> Access<TAttribute, TData>(WellKnownTypes knownTypes, ISymbol symbol)
where TAttribute : Attribute
{
var attrType = typeof(TAttribute);
var attrSymbol = compilation.GetTypeByMetadataName($"{attrType.Namespace}.{attrType.Name}");
if (attrSymbol == null)
yield break;
var attrSymbol = knownTypes.Get($"{attrType.Namespace}.{attrType.Name}");

var attrDatas = symbol
.GetAttributes()
Expand Down
10 changes: 5 additions & 5 deletions src/Riok.Mapperly/Descriptors/Configuration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ public class Configuration
/// </summary>
private readonly Dictionary<Type, object> _defaultConfigurations = new();

private readonly Compilation _compilation;
private readonly WellKnownTypes _knownTypes;

public Configuration(Compilation compilation, INamedTypeSymbol mapperSymbol)
public Configuration(WellKnownTypes knownTypes, INamedTypeSymbol mapperSymbol)
{
_compilation = compilation;
Mapper = AttributeDataAccessor.AccessFirstOrDefault<MapperAttribute>(compilation, mapperSymbol) ?? new();
_knownTypes = knownTypes;
Mapper = AttributeDataAccessor.AccessFirstOrDefault<MapperAttribute>(knownTypes, mapperSymbol) ?? new();
InitDefaultConfigurations();
}

Expand All @@ -34,7 +34,7 @@ public T GetOrDefault<T>(IMethodSymbol? userSymbol)
public IEnumerable<TData> ListConfiguration<T, TData>(IMethodSymbol? userSymbol)
where T : Attribute
{
return userSymbol == null ? Enumerable.Empty<TData>() : AttributeDataAccessor.Access<T, TData>(_compilation, userSymbol);
return userSymbol == null ? Enumerable.Empty<TData>() : AttributeDataAccessor.Access<T, TData>(_knownTypes, userSymbol);
}

private void InitDefaultConfigurations()
Expand Down
10 changes: 6 additions & 4 deletions src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Abstractions.ReferenceHandling;
using Riok.Mapperly.Descriptors.MappingBodyBuilders;
using Riok.Mapperly.Descriptors.MappingBuilders;
using Riok.Mapperly.Descriptors.ObjectFactories;
Expand All @@ -22,15 +23,16 @@ public DescriptorBuilder(
SourceProductionContext sourceContext,
Compilation compilation,
ClassDeclarationSyntax mapperSyntax,
INamedTypeSymbol mapperSymbol
INamedTypeSymbol mapperSymbol,
WellKnownTypes wellKnownTypes
)
{
_mapperDescriptor = new MapperDescriptor(mapperSyntax, mapperSymbol, _methodNameBuilder);
_mappingBodyBuilder = new MappingBodyBuilder(_mappings);
_builderContext = new SimpleMappingBuilderContext(
compilation,
new Configuration(compilation, mapperSymbol),
new WellKnownTypes(compilation),
new Configuration(wellKnownTypes, mapperSymbol),
wellKnownTypes,
_mapperDescriptor,
sourceContext,
new MappingBuilder(_mappings),
Expand Down Expand Up @@ -95,7 +97,7 @@ private void BuildReferenceHandlingParameters()

foreach (var methodMapping in _mappings.MethodMappings)
{
methodMapping.EnableReferenceHandling(_builderContext.Types.IReferenceHandler);
methodMapping.EnableReferenceHandling(_builderContext.Types.Get<IReferenceHandler>());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ public static class EnsureCapacityBuilder
if (TryGetNonEnumeratedCount(sourceType, types, out var sourceSizeProperty))
return new EnsureCapacityMember(targetSizeProperty, sourceSizeProperty);

sourceType.ImplementsGeneric(types.IEnumerableT, out var iEnumerable);
sourceType.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var iEnumerable);

var nonEnumeratedCountMethod = types.Enumerable
var nonEnumeratedCountMethod = types
.Get(typeof(Enumerable))
.GetMembers(TryGetNonEnumeratedCountMethodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(
Expand All @@ -59,13 +60,19 @@ private static bool TryGetNonEnumeratedCount(ITypeSymbol value, WellKnownTypes t
return true;
}

if (value.ImplementsGeneric(types.ICollectionT, CountPropertyName, out _, out var hasCollectionCount) && !hasCollectionCount)
if (
value.ImplementsGeneric(types.Get(typeof(ICollection<>)), CountPropertyName, out _, out var hasCollectionCount)
&& !hasCollectionCount
)
{
expression = CountPropertyName;
return true;
}

if (value.ImplementsGeneric(types.IReadOnlyCollectionT, CountPropertyName, out _, out var hasReadOnlyCount) && !hasReadOnlyCount)
if (
value.ImplementsGeneric(types.Get(typeof(IReadOnlyCollection<>)), CountPropertyName, out _, out var hasReadOnlyCount)
&& !hasReadOnlyCount
)
{
expression = CountPropertyName;
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ private static void BuildConstructorMapping(INewInstanceBuilderContext<IMapping>
// ctors annotated with [Obsolete] are considered last unless they have a MapperConstructor attribute set
var ctorCandidates = namedTargetType.InstanceConstructors
.Where(ctor => ctor.IsAccessible())
.OrderByDescending(x => x.HasAttribute(ctx.BuilderContext.Types.MapperConstructorAttribute))
.ThenBy(x => x.HasAttribute(ctx.BuilderContext.Types.ObsoleteAttribute))
.OrderByDescending(x => x.HasAttribute(ctx.BuilderContext.Types.Get<MapperConstructorAttribute>()))
.ThenBy(x => x.HasAttribute(ctx.BuilderContext.Types.Get<ObsoleteAttribute>()))
.ThenByDescending(x => x.Parameters.Length == 0)
.ThenByDescending(x => x.Parameters.Length);
foreach (var ctorCandidate in ctorCandidates)
{
if (!TryBuildConstructorMapping(ctx, ctorCandidate, out var mappedTargetMemberNames, out var constructorParameterMappings))
{
if (ctorCandidate.HasAttribute(ctx.BuilderContext.Types.MapperConstructorAttribute))
if (ctorCandidate.HasAttribute(ctx.BuilderContext.Types.Get<MapperConstructorAttribute>()))
{
ctx.BuilderContext.ReportDiagnostic(
DiagnosticDescriptors.CannotMapToConfiguredConstructor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static class DictionaryMappingBuilder
.GetAllProperties(CountPropertyName)
.Any(x => !x.IsStatic && !x.IsIndexer && !x.IsWriteOnly && x.Type.SpecialType == SpecialType.System_Int32);

var targetDictionarySymbol = ctx.Types.DictionaryT.Construct(keyMapping.TargetType, valueMapping.TargetType);
var targetDictionarySymbol = ctx.Types.Get(typeof(Dictionary<,>)).Construct(keyMapping.TargetType, valueMapping.TargetType);
ctx.ObjectFactories.TryFindObjectFactory(ctx.Source, ctx.Target, out var dictionaryObjectFactory);
return new ForEachSetDictionaryMapping(
ctx.Source,
Expand Down Expand Up @@ -62,7 +62,7 @@ public static class DictionaryMappingBuilder
return null;
}

if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, out _))
if (!ctx.Target.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out _))
return null;

var ensureCapacityStatement = EnsureCapacityBuilder.TryBuildEnsureCapacity(ctx.Source, ctx.Target, ctx.Types);
Expand All @@ -84,14 +84,14 @@ public static class DictionaryMappingBuilder
if (!ctx.IsConversionEnabled(MappingConversionType.Dictionary))
return null;

if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, out _))
if (!ctx.Target.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out _))
return null;

if (BuildKeyValueMapping(ctx) is not var (keyMapping, valueMapping))
return null;

// if target is an immutable dictionary then don't create a foreach loop
if (ctx.Target.OriginalDefinition.ImplementsGeneric(ctx.Types.IImmutableDictionaryT, out _))
if (ctx.Target.OriginalDefinition.ImplementsGeneric(ctx.Types.Get(typeof(IImmutableDictionary<,>)), out _))
{
ctx.ReportDiagnostic(DiagnosticDescriptors.CannotMapToReadOnlyMember);
return null;
Expand Down Expand Up @@ -134,19 +134,19 @@ private static bool IsDictionaryType(MappingBuilderContext ctx, ITypeSymbol symb
if (symbol is not INamedTypeSymbol namedSymbol)
return false;

return SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.DictionaryT)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IDictionaryT)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IReadOnlyDictionaryT);
return SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.Get(typeof(Dictionary<,>)))
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.Get(typeof(IDictionary<,>)))
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.Get(typeof(IReadOnlyDictionary<,>)));
}

private static (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (t.ImplementsGeneric(ctx.Types.IDictionaryT, out var dictionaryImpl))
if (t.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (t.ImplementsGeneric(ctx.Types.IReadOnlyDictionaryT, out var readOnlyDictionaryImpl))
if (t.ImplementsGeneric(ctx.Types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}
Expand All @@ -156,13 +156,13 @@ private static (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBui

private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (!t.ImplementsGeneric(ctx.Types.IEnumerableT, out var enumerableImpl))
if (!t.ImplementsGeneric(ctx.Types.Get(typeof(IEnumerable<>)), out var enumerableImpl))
return null;

if (enumerableImpl.TypeArguments[0] is not INamedTypeSymbol enumeratedType)
return null;

if (!SymbolEqualityComparer.Default.Equals(enumeratedType.ConstructedFrom, ctx.Types.KeyValuePairT))
if (!SymbolEqualityComparer.Default.Equals(enumeratedType.ConstructedFrom, ctx.Types.Get(typeof(KeyValuePair<,>))))
return null;

return (enumeratedType.TypeArguments[0], enumeratedType.TypeArguments[1]);
Expand All @@ -171,8 +171,12 @@ private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBui
private static INamedTypeSymbol? GetExplicitIndexer(MappingBuilderContext ctx)
{
if (
ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, SetterIndexerPropertyName, out var typedInter, out var isExplicit)
&& !isExplicit
ctx.Target.ImplementsGeneric(
ctx.Types.Get(typeof(IDictionary<,>)),
SetterIndexerPropertyName,
out var typedInter,
out var isExplicit
) && !isExplicit
)
return null;

Expand All @@ -185,24 +189,24 @@ private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBui
ITypeMapping valueMapping
)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableSortedDictionaryT))
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableSortedDictionary<,>))))
return new LinqDicitonaryMapping(
ctx.Source,
ctx.Target,
ctx.Types.ImmutableSortedDictionary.GetStaticGenericMethod(ToImmutableSortedDictionaryMethodName)!,
ctx.Types.Get(typeof(ImmutableSortedDictionary)).GetStaticGenericMethod(ToImmutableSortedDictionaryMethodName)!,
keyMapping,
valueMapping
);

// if target is an ImmutableDictionary or IImmutableDictionary
if (
SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IImmutableDictionaryT)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableDictionaryT)
SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(IImmutableDictionary<,>)))
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.Get(typeof(ImmutableDictionary<,>)))
)
return new LinqDicitonaryMapping(
ctx.Source,
ctx.Target,
ctx.Types.ImmutableDictionary.GetStaticGenericMethod(ToImmutableDictionaryMethodName)!,
ctx.Types.Get(typeof(ImmutableDictionary)).GetStaticGenericMethod(ToImmutableDictionaryMethodName)!,
keyMapping,
valueMapping
);
Expand Down
Loading

0 comments on commit beb7a53

Please sign in to comment.