Skip to content

Commit

Permalink
Fix trim warnings for base method implementing interface (#104753)
Browse files Browse the repository at this point in the history
When a type provides an interface implementation via a base method,
there can be new unexpected trim analysis warnings pointing to the
base method. For example:

```csharp
interface I {
    void M();
}

class Base {
    [RequiresUnreferencedCode("Message")]
    public void M() { }
}
```

This code is fine on its own, and produces no trim analysis
warnings. Now if another piece of code has:

```csharp
class Derived : Base, I {}
```

this causes a new warning to appear at Base.M:
```csharp
interface I {
    void M();
}

class Base {
    [RequiresUnreferencedCode("Message")]
    public void M() { } // warning IL2046: Member 'Base.M()' with
                        // 'RequiresUnreferencedCodeAttribute' implements
			// interface member 'I.M()' without
			// 'RequiresUnreferencedCodeAttribute'.
			// 'RequiresUnreferencedCodeAttribute' annotations must
			// match across all interface implementations or
			// overrides.
}
```

In general, the derived class could be defined in another assembly,
leading to trim warnings that "blame" a correctly annotated
assembly. The warning should instead point to `Derived`, similar to
what happens if there's a mismatch in the base/interface signatures:

```csharp
interface I {
    void M();
}

class Base {
    public int M() { }
}

class Derived : Base, I { } // error CS0738: 'Derived' does not implement
                            // interface member 'I.M()'. 'Base.M()' cannot
			    // implement 'I.M()' because it does not have the
			    // matching return type of 'void'.
```

This fixes the problematic behavior by producing the warning from the
derived type instead, in ILLink, Native AOT, and the ILLink Roslyn
analyzer.
  • Loading branch information
sbomer committed Jul 15, 2024
1 parent e48c2b2 commit 961dfb5
Show file tree
Hide file tree
Showing 11 changed files with 445 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ private static bool ScanMethodBodyForFieldAccess(MethodIL body, bool write, out
}
}

internal void ValidateMethodAnnotationsAreSame(MethodDesc method, MethodDesc baseMethod)
internal void ValidateMethodAnnotationsAreSame(MethodDesc method, MethodDesc baseMethod, TypeSystemEntity origin)
{
method = method.GetTypicalMethodDefinition();
baseMethod = baseMethod.GetTypicalMethodDefinition();
Expand All @@ -715,14 +715,14 @@ internal void ValidateMethodAnnotationsAreSame(MethodDesc method, MethodDesc bas
GetAnnotations(baseMethod.OwningType).TryGetAnnotation(baseMethod, out var baseMethodAnnotations);

if (methodAnnotations.ReturnParameterAnnotation != baseMethodAnnotations.ReturnParameterAnnotation)
LogValidationWarning(method.Signature.ReturnType, baseMethod, method);
LogValidationWarning((method.Signature.ReturnType, method), baseMethod, origin);

if (methodAnnotations.ParameterAnnotations != null || baseMethodAnnotations.ParameterAnnotations != null)
{
if (methodAnnotations.ParameterAnnotations == null)
ValidateMethodParametersHaveNoAnnotations(baseMethodAnnotations.ParameterAnnotations!, method, baseMethod, method);
ValidateMethodParametersHaveNoAnnotations(baseMethodAnnotations.ParameterAnnotations!, method, baseMethod, origin);
else if (baseMethodAnnotations.ParameterAnnotations == null)
ValidateMethodParametersHaveNoAnnotations(methodAnnotations.ParameterAnnotations, method, baseMethod, method);
ValidateMethodParametersHaveNoAnnotations(methodAnnotations.ParameterAnnotations, method, baseMethod, origin);
else
{
if (methodAnnotations.ParameterAnnotations.Length != baseMethodAnnotations.ParameterAnnotations.Length)
Expand All @@ -734,17 +734,17 @@ internal void ValidateMethodAnnotationsAreSame(MethodDesc method, MethodDesc bas
LogValidationWarning(
(new MethodProxy(method)).GetParameter((ParameterIndex)parameterIndex),
(new MethodProxy(baseMethod)).GetParameter((ParameterIndex)parameterIndex),
method);
origin);
}
}
}

if (methodAnnotations.GenericParameterAnnotations != null || baseMethodAnnotations.GenericParameterAnnotations != null)
{
if (methodAnnotations.GenericParameterAnnotations == null)
ValidateMethodGenericParametersHaveNoAnnotations(baseMethodAnnotations.GenericParameterAnnotations!, method, baseMethod, method);
ValidateMethodGenericParametersHaveNoAnnotations(baseMethodAnnotations.GenericParameterAnnotations!, method, baseMethod, origin);
else if (baseMethodAnnotations.GenericParameterAnnotations == null)
ValidateMethodGenericParametersHaveNoAnnotations(methodAnnotations.GenericParameterAnnotations, method, baseMethod, method);
ValidateMethodGenericParametersHaveNoAnnotations(methodAnnotations.GenericParameterAnnotations, method, baseMethod, origin);
else
{
if (methodAnnotations.GenericParameterAnnotations.Length != baseMethodAnnotations.GenericParameterAnnotations.Length)
Expand All @@ -757,14 +757,14 @@ internal void ValidateMethodAnnotationsAreSame(MethodDesc method, MethodDesc bas
LogValidationWarning(
method.Instantiation[genericParameterIndex],
baseMethod.Instantiation[genericParameterIndex],
method);
origin);
}
}
}
}
}

private void ValidateMethodParametersHaveNoAnnotations(DynamicallyAccessedMemberTypes[] parameterAnnotations, MethodDesc method, MethodDesc baseMethod, MethodDesc origin)
private void ValidateMethodParametersHaveNoAnnotations(DynamicallyAccessedMemberTypes[] parameterAnnotations, MethodDesc method, MethodDesc baseMethod, TypeSystemEntity origin)
{
for (int parameterIndex = 0; parameterIndex < parameterAnnotations.Length; parameterIndex++)
{
Expand All @@ -777,7 +777,7 @@ private void ValidateMethodParametersHaveNoAnnotations(DynamicallyAccessedMember
}
}

private void ValidateMethodGenericParametersHaveNoAnnotations(DynamicallyAccessedMemberTypes[] genericParameterAnnotations, MethodDesc method, MethodDesc baseMethod, MethodDesc origin)
private void ValidateMethodGenericParametersHaveNoAnnotations(DynamicallyAccessedMemberTypes[] genericParameterAnnotations, MethodDesc method, MethodDesc baseMethod, TypeSystemEntity origin)
{
for (int genericParameterIndex = 0; genericParameterIndex < genericParameterAnnotations.Length; genericParameterIndex++)
{
Expand All @@ -791,7 +791,7 @@ private void ValidateMethodGenericParametersHaveNoAnnotations(DynamicallyAccesse
}
}

private void LogValidationWarning(object provider, object baseProvider, MethodDesc origin)
private void LogValidationWarning(object provider, object baseProvider, TypeSystemEntity origin)
{
switch (provider)
{
Expand All @@ -810,9 +810,9 @@ private void LogValidationWarning(object provider, object baseProvider, MethodDe
genericParameterOverride.Name, DiagnosticUtilities.GetGenericParameterDeclaringMemberDisplayName(genericParameterOverride),
((GenericParameterDesc)baseProvider).Name, DiagnosticUtilities.GetGenericParameterDeclaringMemberDisplayName((GenericParameterDesc)baseProvider));
break;
case TypeDesc:
case (TypeDesc, MethodDesc method):
_logger.LogWarning(origin, DiagnosticId.DynamicallyAccessedMembersMismatchOnMethodReturnValueBetweenOverrides,
DiagnosticUtilities.GetMethodSignatureDisplayName(origin), DiagnosticUtilities.GetMethodSignatureDisplayName((MethodDesc)baseProvider));
DiagnosticUtilities.GetMethodSignatureDisplayName(method), DiagnosticUtilities.GetMethodSignatureDisplayName((MethodDesc)baseProvider));
break;
// No fields - it's not possible to have a virtual field and override it
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ public sealed override IEnumerable<CombinedDependencyListEntry> GetConditionalSt
result.Add(new CombinedDependencyListEntry(factory.VirtualMethodUse(interfaceMethod), factory.VariantInterfaceMethodUse(typicalInterfaceMethod), "Interface method"));
}

factory.MetadataManager.NoteOverridingMethod(interfaceMethod, implMethod);
TypeSystemEntity origin = (implMethod.OwningType != defType) ? defType : null;
factory.MetadataManager.NoteOverridingMethod(interfaceMethod, implMethod, origin);

factory.MetadataManager.GetDependenciesForOverridingMethod(ref result, factory, interfaceMethod, implMethod);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ public override IEnumerable<CombinedDependencyListEntry> SearchDynamicDependenci
else
dynamicDependencies.Add(new CombinedDependencyListEntry(factory.GVMDependencies(implementingMethodInstantiation.GetCanonMethodTarget(CanonicalFormKind.Specific)), null, "ImplementingMethodInstantiation"));

factory.MetadataManager.NoteOverridingMethod(_method, implementingMethodInstantiation);
TypeSystemEntity origin = (implementingMethodInstantiation.OwningType != potentialOverrideType) ? potentialOverrideType : null;
factory.MetadataManager.NoteOverridingMethod(_method, implementingMethodInstantiation, origin);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ public virtual DependencyList GetDependenciesForCustomAttribute(NodeFactory fact
return null;
}

public virtual void NoteOverridingMethod(MethodDesc baseMethod, MethodDesc overridingMethod)
public virtual void NoteOverridingMethod(MethodDesc baseMethod, MethodDesc overridingMethod, TypeSystemEntity origin = null)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,14 +784,16 @@ public bool GeneratesAttributeMetadata(TypeDesc attributeType)
return true;
}

public override void NoteOverridingMethod(MethodDesc baseMethod, MethodDesc overridingMethod)
public override void NoteOverridingMethod(MethodDesc baseMethod, MethodDesc overridingMethod, TypeSystemEntity origin)
{
baseMethod = baseMethod.GetTypicalMethodDefinition();
overridingMethod = overridingMethod.GetTypicalMethodDefinition();

if (baseMethod == overridingMethod)
return;

origin ??= overridingMethod;

bool baseMethodTypeIsInterface = baseMethod.OwningType.IsInterface;
foreach (var requiresAttribute in _requiresAttributeMismatchNameAndId)
{
Expand All @@ -803,15 +805,15 @@ public override void NoteOverridingMethod(MethodDesc baseMethod, MethodDesc over
string message = MessageFormat.FormatRequiresAttributeMismatch(overridingMethod.DoesMethodRequire(requiresAttribute.AttributeName, out _),
baseMethodTypeIsInterface, requiresAttribute.AttributeName, overridingMethodName, baseMethodName);

Logger.LogWarning(overridingMethod, requiresAttribute.Id, message);
Logger.LogWarning(origin, requiresAttribute.Id, message);
}
}

bool baseMethodRequiresDataflow = FlowAnnotations.RequiresVirtualMethodDataflowAnalysis(baseMethod);
bool overridingMethodRequiresDataflow = FlowAnnotations.RequiresVirtualMethodDataflowAnalysis(overridingMethod);
if (baseMethodRequiresDataflow || overridingMethodRequiresDataflow)
{
FlowAnnotations.ValidateMethodAnnotationsAreSame(overridingMethod, baseMethod);
FlowAnnotations.ValidateMethodAnnotationsAreSame(overridingMethod, baseMethod, origin);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ void AddRange (DiagnosticId first, DiagnosticId last)

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => GetSupportedDiagnostics ();

static Location GetPrimaryLocation (ImmutableArray<Location> locations) => locations.Length > 0 ? locations[0] : Location.None;
static Location GetPrimaryLocation (ImmutableArray<Location>? locations) {
if (locations is null)
return Location.None;

return locations.Value.Length > 0 ? locations.Value[0] : Location.None;
}

public override void Initialize (AnalysisContext context)
{
Expand Down Expand Up @@ -167,7 +172,7 @@ static void VerifyDamOnDerivedAndBaseMethodsMatch (SymbolAnalysisContext context
}
}

static void VerifyDamOnMethodsMatch (SymbolAnalysisContext context, IMethodSymbol overrideMethod, IMethodSymbol baseMethod)
static void VerifyDamOnMethodsMatch (SymbolAnalysisContext context, IMethodSymbol overrideMethod, IMethodSymbol baseMethod, ISymbol? origin = null)
{
var overrideMethodReturnAnnotation = FlowAnnotations.GetMethodReturnValueAnnotation (overrideMethod);
var baseMethodReturnAnnotation = FlowAnnotations.GetMethodReturnValueAnnotation (baseMethod);
Expand All @@ -184,9 +189,10 @@ static void VerifyDamOnMethodsMatch (SymbolAnalysisContext context, IMethodSymbo
&& baseMethod.TryGetReturnAttribute (DynamicallyAccessedMembersAnalyzer.DynamicallyAccessedMembersAttribute, out var _))
) ? (null, null) : CreateArguments (attributableSymbolLocation, missingAttribute);

var returnOrigin = origin ??= overrideMethod;
context.ReportDiagnostic (Diagnostic.Create (
DiagnosticDescriptors.GetDiagnosticDescriptor (DiagnosticId.DynamicallyAccessedMembersMismatchOnMethodReturnValueBetweenOverrides),
GetPrimaryLocation (overrideMethod.Locations), sourceLocation, DAMArgs?.ToImmutableDictionary (), overrideMethod.GetDisplayName (), baseMethod.GetDisplayName ()));
GetPrimaryLocation (returnOrigin.Locations), sourceLocation, DAMArgs?.ToImmutableDictionary (), overrideMethod.GetDisplayName (), baseMethod.GetDisplayName ()));
}

foreach (var overrideParam in overrideMethod.GetMetadataParameters ()) {
Expand All @@ -205,9 +211,10 @@ static void VerifyDamOnMethodsMatch (SymbolAnalysisContext context, IMethodSymbo
&& baseParam.ParameterSymbol!.TryGetAttribute (DynamicallyAccessedMembersAnalyzer.DynamicallyAccessedMembersAttribute, out var _))
) ? (null, null) : CreateArguments (attributableSymbolLocation, missingAttribute);

var parameterOrigin = origin ?? overrideParam.ParameterSymbol;
context.ReportDiagnostic (Diagnostic.Create (
DiagnosticDescriptors.GetDiagnosticDescriptor (DiagnosticId.DynamicallyAccessedMembersMismatchOnMethodParameterBetweenOverrides),
overrideParam.Location, sourceLocation, DAMArgs?.ToImmutableDictionary (),
GetPrimaryLocation (parameterOrigin?.Locations), sourceLocation, DAMArgs?.ToImmutableDictionary (),
overrideParam.GetDisplayName (), overrideMethod.GetDisplayName (), baseParam.GetDisplayName (), baseMethod.GetDisplayName ()));
}
}
Expand All @@ -228,27 +235,38 @@ static void VerifyDamOnMethodsMatch (SymbolAnalysisContext context, IMethodSymbo
&& baseMethod.TypeParameters[i].TryGetAttribute (DynamicallyAccessedMembersAnalyzer.DynamicallyAccessedMembersAttribute, out var _))
) ? (null, null) : CreateArguments (attributableSymbolLocation, missingAttribute);

var typeParameterOrigin = origin ?? overrideMethod.TypeParameters[i];
context.ReportDiagnostic (Diagnostic.Create (
DiagnosticDescriptors.GetDiagnosticDescriptor (DiagnosticId.DynamicallyAccessedMembersMismatchOnGenericParameterBetweenOverrides),
GetPrimaryLocation (overrideMethod.TypeParameters[i].Locations), sourceLocation, DAMArgs?.ToImmutableDictionary (),
GetPrimaryLocation (typeParameterOrigin.Locations), sourceLocation, DAMArgs?.ToImmutableDictionary (),
overrideMethod.TypeParameters[i].GetDisplayName (), overrideMethod.GetDisplayName (),
baseMethod.TypeParameters[i].GetDisplayName (), baseMethod.GetDisplayName ()));
}
}

if (!overrideMethod.IsStatic && overrideMethod.GetDynamicallyAccessedMemberTypes () != baseMethod.GetDynamicallyAccessedMemberTypes ())
if (!overrideMethod.IsStatic && overrideMethod.GetDynamicallyAccessedMemberTypes () != baseMethod.GetDynamicallyAccessedMemberTypes ()) {
var methodOrigin = origin ?? overrideMethod;
context.ReportDiagnostic (Diagnostic.Create (
DiagnosticDescriptors.GetDiagnosticDescriptor (DiagnosticId.DynamicallyAccessedMembersMismatchOnImplicitThisBetweenOverrides),
GetPrimaryLocation (overrideMethod.Locations),
GetPrimaryLocation (methodOrigin.Locations),
overrideMethod.GetDisplayName (), baseMethod.GetDisplayName ()));
}
}

static void VerifyDamOnInterfaceAndImplementationMethodsMatch (SymbolAnalysisContext context, INamedTypeSymbol type)
{
foreach (var (interfaceMember, implementationMember) in type.GetMemberInterfaceImplementationPairs ()) {
if (implementationMember is IMethodSymbol implementationMethod
&& interfaceMember is IMethodSymbol interfaceMethod)
VerifyDamOnMethodsMatch (context, implementationMethod, interfaceMethod);
if (implementationMember is IMethodSymbol implementationMethod && interfaceMember is IMethodSymbol interfaceMethod) {
ISymbol origin = implementationMethod;
INamedTypeSymbol implementationType = implementationMethod.ContainingType;

// If this type implements an interface method through a base class, the origin of the warning is this type,
// not the member on the base class.
if (!implementationType.IsInterface () && !SymbolEqualityComparer.Default.Equals (implementationType, type))
origin = type;

VerifyDamOnMethodsMatch (context, implementationMethod, interfaceMethod, origin);
}
}
}

Expand Down
20 changes: 17 additions & 3 deletions src/tools/illink/src/ILLink.RoslynAnalyzer/RequiresAnalyzerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,21 @@ void CheckMatchingAttributesInInterfaces (
INamedTypeSymbol type)
{
foreach (var memberpair in type.GetMemberInterfaceImplementationPairs ()) {
var implementationType = memberpair.ImplementationMember switch {
IMethodSymbol method => method.ContainingType,
IPropertySymbol property => property.ContainingType,
IEventSymbol @event => @event.ContainingType,
_ => throw new NotSupportedException ()
};
ISymbol origin = memberpair.ImplementationMember;
// If this type implements an interface method through a base class, the origin of the warning is this type,
// not the member on the base class.
if (!implementationType.IsInterface () && !SymbolEqualityComparer.Default.Equals (implementationType, type))
origin = type;
if (HasMismatchingAttributes (memberpair.InterfaceMember, memberpair.ImplementationMember)) {
ReportMismatchInAttributesDiagnostic (symbolAnalysisContext, memberpair.ImplementationMember, memberpair.InterfaceMember, isInterface: true);
ReportMismatchInAttributesDiagnostic (symbolAnalysisContext, memberpair.ImplementationMember, memberpair.InterfaceMember, isInterface: true, origin);
}
}
}
Expand Down Expand Up @@ -230,12 +243,13 @@ private void ReportRequiresOnStaticCtorDiagnostic (SymbolAnalysisContext symbolA
ctor.GetDisplayName ()));
}

private void ReportMismatchInAttributesDiagnostic (SymbolAnalysisContext symbolAnalysisContext, ISymbol member, ISymbol baseMember, bool isInterface = false)
private void ReportMismatchInAttributesDiagnostic (SymbolAnalysisContext symbolAnalysisContext, ISymbol member, ISymbol baseMember, bool isInterface = false, ISymbol? origin = null)
{
origin ??= member;
string message = MessageFormat.FormatRequiresAttributeMismatch (member.HasAttribute (RequiresAttributeName), isInterface, RequiresAttributeName, member.GetDisplayName (), baseMember.GetDisplayName ());
symbolAnalysisContext.ReportDiagnostic (Diagnostic.Create (
RequiresAttributeMismatch,
member.Locations[0],
origin.Locations[0],
message));
}

Expand Down
Loading

0 comments on commit 961dfb5

Please sign in to comment.