diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/CustomMarshallerAttributeAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/CustomMarshallerAttributeAnalyzer.cs index 9e9f732c3973c6..b3a4f8e424e25a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/CustomMarshallerAttributeAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/CustomMarshallerAttributeAnalyzer.cs @@ -262,8 +262,8 @@ public void AnalyzeAttribute(SyntaxNodeAnalysisContext context) AttributeSyntax syntax = (AttributeSyntax)context.Node; ISymbol attributedSymbol = context.ContainingSymbol!; - AttributeData attr = GetAttributeData(syntax, attributedSymbol); - if (attr.AttributeClass?.ToDisplayString() == TypeNames.CustomMarshallerAttribute + AttributeData? attr = syntax.FindAttributeData(attributedSymbol); + if (attr?.AttributeClass?.ToDisplayString() == TypeNames.CustomMarshallerAttribute && attr.AttributeConstructor is not null) { DiagnosticReporter managedTypeReporter = DiagnosticReporter.CreateForLocation(syntax.FindArgumentWithNameOrArity("managedType", 0).FindTypeExpressionOrNullLocation(), context.ReportDiagnostic); @@ -313,20 +313,6 @@ private void AnalyzeMarshallerType(DiagnosticReporter diagnosticFactory, INamedT { // TODO: Implement for the V2 shapes } - - private static AttributeData GetAttributeData(AttributeSyntax syntax, ISymbol symbol) - { - if (syntax.FirstAncestorOrSelf().Target?.Identifier.IsKind(SyntaxKind.ReturnKeyword) == true) - { - return ((IMethodSymbol)symbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches); - } - return symbol.GetAttributes().First(attributeSyntaxLocationMatches); - - bool attributeSyntaxLocationMatches(AttributeData attrData) - { - return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span; - } - } } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/NativeMarshallingAttributeAnalyzer.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/NativeMarshallingAttributeAnalyzer.cs index b17e43a1b05ec6..bba752c7d0de55 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/NativeMarshallingAttributeAnalyzer.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/NativeMarshallingAttributeAnalyzer.cs @@ -86,8 +86,8 @@ public void AnalyzeAttribute(SyntaxNodeAnalysisContext context) AttributeSyntax syntax = (AttributeSyntax)context.Node; ISymbol attributedSymbol = context.ContainingSymbol!; - AttributeData attr = GetAttributeData(syntax, attributedSymbol); - if (attr.AttributeClass?.ToDisplayString() == TypeNames.NativeMarshallingAttribute + AttributeData? attr = syntax.FindAttributeData(attributedSymbol); + if (attr?.AttributeClass?.ToDisplayString() == TypeNames.NativeMarshallingAttribute && attr.AttributeConstructor is not null) { INamedTypeSymbol? entryType = (INamedTypeSymbol?)attr.ConstructorArguments[0].Value; @@ -163,20 +163,6 @@ private void AnalyzeManagedTypeMarshallingInfo( } } - private static AttributeData GetAttributeData(AttributeSyntax syntax, ISymbol symbol) - { - if (syntax.FirstAncestorOrSelf().Target?.Identifier.IsKind(SyntaxKind.ReturnKeyword) == true) - { - return ((IMethodSymbol)symbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches); - } - return symbol.GetAttributes().First(attributeSyntaxLocationMatches); - - bool attributeSyntaxLocationMatches(AttributeData attrData) - { - return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span; - } - } - private static ITypeSymbol GetSymbolType(ISymbol symbol) { return symbol switch diff --git a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/SyntaxExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/SyntaxExtensions.cs index 09de03c2aa1098..00f3d93797e280 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/SyntaxExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/SyntaxExtensions.cs @@ -25,6 +25,31 @@ public static Location FindTypeExpressionOrNullLocation(this AttributeArgumentSy return walker.TypeExpressionLocation; } + public static AttributeData? FindAttributeData(this AttributeSyntax syntax, ISymbol targetSymbol) + { + AttributeTargetSpecifierSyntax attributeTarget = syntax.FirstAncestorOrSelf().Target; + if (attributeTarget is not null) + { + switch (attributeTarget.Identifier.Kind()) + { + case SyntaxKind.ReturnKeyword: + return ((IMethodSymbol)targetSymbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches); + case SyntaxKind.AssemblyKeyword: + return targetSymbol.ContainingAssembly.GetAttributes().First(attributeSyntaxLocationMatches); + case SyntaxKind.ModuleKeyword: + return targetSymbol.ContainingModule.GetAttributes().First(attributeSyntaxLocationMatches); + default: + return null; + } + } + return targetSymbol.GetAttributes().First(attributeSyntaxLocationMatches); + + bool attributeSyntaxLocationMatches(AttributeData attrData) + { + return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span; + } + } + private sealed class FindTypeLocationWalker : CSharpSyntaxWalker { public Location? TypeExpressionLocation { get; private set; } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/NativeMarshallingAttributeAnalyzerTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/NativeMarshallingAttributeAnalyzerTests.cs index dd421e74513f9d..49587ad819a574 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/NativeMarshallingAttributeAnalyzerTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/NativeMarshallingAttributeAnalyzerTests.cs @@ -256,5 +256,19 @@ static unsafe class MarshallerType where W : unmanaged await VerifyCS.VerifyAnalyzerAsync(source, VerifyCS.Diagnostic(GenericEntryPointMarshallerTypeMustBeClosedOrMatchArityRule).WithLocation(0).WithArguments("MarshallerType", "ManagedType")); } + + [Fact] + public async Task UnrelatedAssemblyOrModuleTargetDiagnostic_DoesNotCauseException() + { + string source = """ + using System.Reflection; + using System.Runtime.CompilerServices; + + [assembly:AssemblyMetadata("MyKey", "MyValue")] + [module:SkipLocalsInit] + """; + + await VerifyCS.VerifyAnalyzerAsync(source); + } } }