-
Notifications
You must be signed in to change notification settings - Fork 471
/
Copy pathDoNotCreateTasksWithoutPassingATaskScheduler.cs
97 lines (82 loc) · 5.3 KB
/
DoNotCreateTasksWithoutPassingATaskScheduler.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Immutable;
using System.Linq;
using Analyzer.Utilities;
using Analyzer.Utilities.Extensions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
namespace Microsoft.NetCore.Analyzers.Tasks
{
/// <summary>
/// CA2008: Do not create tasks without passing a TaskScheduler
/// </summary>
[DiagnosticAnalyzer(LanguageNames.CSharp, LanguageNames.VisualBasic)]
public sealed class DoNotCreateTasksWithoutPassingATaskSchedulerAnalyzer : DiagnosticAnalyzer
{
internal const string RuleId = "CA2008";
private static readonly LocalizableString s_localizableTitle = new LocalizableResourceString(nameof(SystemThreadingTasksAnalyzersResources.DoNotCreateTasksWithoutPassingATaskSchedulerTitle), SystemThreadingTasksAnalyzersResources.ResourceManager, typeof(SystemThreadingTasksAnalyzersResources));
private static readonly LocalizableString s_localizableMessage = new LocalizableResourceString(nameof(SystemThreadingTasksAnalyzersResources.DoNotCreateTasksWithoutPassingATaskSchedulerMessage), SystemThreadingTasksAnalyzersResources.ResourceManager, typeof(SystemThreadingTasksAnalyzersResources));
private static readonly LocalizableString s_localizableDescription = new LocalizableResourceString(nameof(SystemThreadingTasksAnalyzersResources.DoNotCreateTasksWithoutPassingATaskSchedulerDescription), SystemThreadingTasksAnalyzersResources.ResourceManager, typeof(SystemThreadingTasksAnalyzersResources));
internal static DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId,
s_localizableTitle,
s_localizableMessage,
DiagnosticCategory.Reliability,
DiagnosticHelpers.DefaultDiagnosticSeverity,
isEnabledByDefault: DiagnosticHelpers.EnabledByDefaultIfNotBuildingVSIX,
description: s_localizableDescription,
helpLinkUri: null,
customTags: WellKnownDiagnosticTags.Telemetry);
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => ImmutableArray.Create(Rule);
public sealed override void Initialize(AnalysisContext context)
{
context.EnableConcurrentExecution();
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.RegisterCompilationStartAction(compilationContext =>
{
// Check if TPL is available before actually doing the searches
var taskType = compilationContext.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
var taskFactoryType = compilationContext.Compilation.GetTypeByMetadataName("System.Threading.Tasks.TaskFactory");
var taskSchedulerType = compilationContext.Compilation.GetTypeByMetadataName("System.Threading.Tasks.TaskScheduler");
if (taskType == null || taskFactoryType == null || taskSchedulerType == null)
{
return;
}
compilationContext.RegisterOperationAction(operationContext =>
{
var invocation = (IInvocationOperation)operationContext.Operation;
if (invocation.TargetMethod == null)
{
return;
}
if (!IsMethodOfInterest(invocation.TargetMethod, taskType, taskFactoryType))
{
return;
}
// We want to ensure that all overloads called are explicitly taking a task scheduler
if (invocation.TargetMethod.Parameters.Any(p => p.Type.Equals(taskSchedulerType)))
{
return;
}
operationContext.ReportDiagnostic(Diagnostic.Create(Rule, invocation.Syntax.GetLocation(), invocation.TargetMethod.Name));
}, OperationKind.Invocation);
});
}
private static bool IsMethodOfInterest(IMethodSymbol methodSymbol, INamedTypeSymbol taskType, INamedTypeSymbol taskFactoryType)
{
// Check if it's a method of Task or a derived type (for Task<T>)
if ((taskType.Equals(methodSymbol.ContainingType) ||
taskType.Equals(methodSymbol.ContainingType.BaseType)) &&
methodSymbol.Name == "ContinueWith")
{
return true;
}
if (methodSymbol.ContainingType.Equals(taskFactoryType) &&
methodSymbol.Name == "StartNew")
{
return true;
}
return false;
}
}
}