Skip to content

Commit

Permalink
Merge pull request #34 from SteveDunn/33-disallow-new-default-from-la…
Browse files Browse the repository at this point in the history
…mbdas

Fixes #33 disallow new and default from lambda returns
  • Loading branch information
SteveDunn authored Dec 31, 2021
2 parents 6d996a4 + 3c94e4f commit 1f57dfa
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 43 deletions.
67 changes: 55 additions & 12 deletions src/Vogen/CreationUsingDefaultLiteralAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ namespace Vogen;
[Generator]
public class CreationUsingDefaultLiteralAnalyzer : IIncrementalGenerator
{
public record struct FoundItem(Location Location, INamedTypeSymbol VoClass);
private record struct FoundItem(Location Location, INamedTypeSymbol VoClass);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider<FoundItem?> targets = GetTargets(context);

IncrementalValueProvider<(Compilation, ImmutableArray<FoundItem?>)> compilationAndTypes
= context.CompilationProvider.Combine(targets.Collect());

context.RegisterSourceOutput(compilationAndTypes,
static (spc, source) => Execute(source.Item2, spc));
}
Expand All @@ -42,22 +42,37 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return null;
}

var typeSyntax = GetTypeFromVariableOrParameter(literalExpressionSyntax);
if (typeSyntax is null)
var typeSyntax = TryGetTypeFromVariableOrParameter(ctx, literalExpressionSyntax);

if (typeSyntax is not null)
{
return null;
var classFromSyntax = VoFilter.TryGetValueObjectClass(ctx, typeSyntax);

if (classFromSyntax is not null)
{
return new FoundItem(typeSyntax.GetLocation(), classFromSyntax);
}
}

INamedTypeSymbol? voClass = VoFilter.TryGetValueObjectClass(ctx, typeSyntax);

return voClass is null ? null : new FoundItem(typeSyntax.GetLocation(), voClass);
INamedTypeSymbol? classFromModel = TryGetTypeFromModel(ctx, literalExpressionSyntax);

return classFromModel is null ? null : new FoundItem(literalExpressionSyntax.GetLocation(), classFromModel);
}

private static INamedTypeSymbol? TryGetTypeFromModel(GeneratorSyntaxContext ctx, LiteralExpressionSyntax literalExpressionSyntax)
{
// for lambdas, we need the semantic model...
var voClass = TryGetFromLambda(ctx, literalExpressionSyntax);
return voClass;

}

// A default literal expression can be for a variable (CustomerId id = default), or
// a parameter (void DoSomething(CustomerId id = default)).
// We need to try to find the 'Type' from either one of those type.
private static TypeSyntax? GetTypeFromVariableOrParameter(LiteralExpressionSyntax literalExpressionSyntax)
private static TypeSyntax? TryGetTypeFromVariableOrParameter(
GeneratorSyntaxContext ctx,
LiteralExpressionSyntax literalExpressionSyntax)
{
// first, see if it's an array
var ancestor = literalExpressionSyntax.Ancestors(false)
Expand All @@ -84,7 +99,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return methodSyntax.ReturnType;
}


ancestor = literalExpressionSyntax.Ancestors(false)
.FirstOrDefault(a => a.IsKind(SyntaxKind.VariableDeclaration));

Expand All @@ -104,6 +118,35 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return null;
}

private static INamedTypeSymbol? TryGetFromLambda(
GeneratorSyntaxContext ctx,
SyntaxNode literalExpressionSyntax)
{
var ancestor = literalExpressionSyntax.Ancestors(false)
.FirstOrDefault(a => a.IsKind(SyntaxKind.ParenthesizedLambdaExpression));

if (ancestor is not ParenthesizedLambdaExpressionSyntax lambdaExpressionSyntax)
{
return null;
}

var info = ctx.SemanticModel.GetSymbolInfo(lambdaExpressionSyntax);

if (info.Symbol is not IMethodSymbol ms)
{
return null;
}

var returnTypeSymbol = ms.ReturnType as INamedTypeSymbol;

if (VoFilter.TryGetValueObjectClass(ctx, returnTypeSymbol))
{
return returnTypeSymbol;
}

return null;
}

static void Execute(ImmutableArray<FoundItem?> typeDeclarations, SourceProductionContext context)
{
foreach (FoundItem? eachFoundItem in typeDeclarations)
Expand Down
65 changes: 59 additions & 6 deletions src/Vogen/CreationUsingImplicitNewAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Vogen;
[Generator]
public class CreationUsingImplicitNewAnalyzer : IIncrementalGenerator
{
public record struct FoundItem(Location Location, INamedTypeSymbol VoClass);
private record struct FoundItem(Location Location, INamedTypeSymbol VoClass);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
Expand All @@ -24,7 +24,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
= context.CompilationProvider.Combine(targets.Collect());

context.RegisterSourceOutput(compilationAndTypes,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
static (spc, source) => Execute(source.Item2, spc));
}

private static IncrementalValuesProvider<FoundItem?> GetTargets(IncrementalGeneratorInitializationContext context) =>
Expand All @@ -36,13 +36,67 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
private static FoundItem? TryGetTarget(GeneratorSyntaxContext ctx)
{
var syntax = (ImplicitObjectCreationExpressionSyntax) ctx.Node;

var foundItem = TryGetTargetFromSyntax(ctx, syntax);

if (foundItem is not null)
{
return foundItem;
}

INamedTypeSymbol? voClass = TryGetTypeFromModel(ctx, syntax);


return voClass is null ? null : new FoundItem(syntax.GetLocation(), voClass);
}

private static INamedTypeSymbol? TryGetTypeFromModel(GeneratorSyntaxContext ctx, ImplicitObjectCreationExpressionSyntax implicitNewSyntax)
{
// for lambdas, we need the semantic model...
var voClass = TryGetFromLambda(ctx, implicitNewSyntax);
return voClass;
}

private static INamedTypeSymbol? TryGetFromLambda(
GeneratorSyntaxContext ctx,
ImplicitObjectCreationExpressionSyntax implicitNewSyntax)
{
var ancestor = implicitNewSyntax.Ancestors(false)
.FirstOrDefault(a => a.IsKind(SyntaxKind.ParenthesizedLambdaExpression));

if (ancestor is not ParenthesizedLambdaExpressionSyntax lambdaExpressionSyntax)
{
return null;
}

var info = ctx.SemanticModel.GetSymbolInfo(lambdaExpressionSyntax);

if (info.Symbol is not IMethodSymbol ms)
{
return null;
}

var returnTypeSymbol = ms.ReturnType as INamedTypeSymbol;

if(VoFilter.TryGetValueObjectClass(ctx, returnTypeSymbol))
{
return returnTypeSymbol;
}

return null;
}

private static FoundItem? TryGetTargetFromSyntax(
GeneratorSyntaxContext ctx,
SyntaxNode syntax)
{
var ancestor = syntax.Ancestors(false)
.FirstOrDefault(a => a.IsKind(SyntaxKind.VariableDeclaration));

if (ancestor is VariableDeclarationSyntax variableDeclarationSyntax)
{
TypeSyntax t = variableDeclarationSyntax.Type;

INamedTypeSymbol? voClass = VoFilter.TryGetValueObjectClass(ctx, t);

return voClass == null ? null : new FoundItem
Expand All @@ -58,6 +112,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
if (ancestor is MethodDeclarationSyntax methodSyntax)
{
TypeSyntax t = methodSyntax.ReturnType;

INamedTypeSymbol? voClass = VoFilter.TryGetValueObjectClass(ctx, t);

return voClass == null ? null : new FoundItem
Expand All @@ -73,6 +128,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
if (ancestor is LocalFunctionStatementSyntax localFunctionStatementSyntax)
{
TypeSyntax t = localFunctionStatementSyntax.ReturnType;

INamedTypeSymbol? voClass = VoFilter.TryGetValueObjectClass(ctx, t);

return voClass == null ? null : new FoundItem
Expand All @@ -85,10 +141,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return null;
}

static void Execute(
Compilation _,
ImmutableArray<FoundItem?> typeDeclarations,
SourceProductionContext context)
static void Execute(ImmutableArray<FoundItem?> typeDeclarations, SourceProductionContext context)
{
foreach (FoundItem? eachFoundItem in typeDeclarations)
{
Expand Down
19 changes: 12 additions & 7 deletions src/Vogen/VoFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,32 @@ public static bool IsTarget(SyntaxNode syntaxNode) =>
return null;
}

public static INamedTypeSymbol? TryGetValueObjectClass(GeneratorSyntaxContext context, SyntaxNode node)
public static INamedTypeSymbol? TryGetValueObjectClass(GeneratorSyntaxContext context, SyntaxNode syntaxNode)
{
SymbolInfo typeSymbolInfo = context.SemanticModel.GetSymbolInfo(node);
SymbolInfo typeSymbolInfo = context.SemanticModel.GetSymbolInfo(syntaxNode);

var voClass = typeSymbolInfo.Symbol as INamedTypeSymbol;
var symbol = typeSymbolInfo.Symbol as INamedTypeSymbol;

return TryGetValueObjectClass(context, symbol) ? symbol : null;
}

public static bool TryGetValueObjectClass(GeneratorSyntaxContext context, INamedTypeSymbol? voClass)
{
if (voClass == null)
{
return null;
return false;
}

var attributes = voClass.GetAttributes();

if (attributes.Length == 0)
{
return null;
return false;
}

AttributeData? voAttribute =
attributes.SingleOrDefault(a => a.AttributeClass?.FullName() is "Vogen.ValueObjectAttribute");

return voAttribute is null ? null : voClass;
return voAttribute is not null;
}
}
1 change: 0 additions & 1 deletion src/Vogen/VoWorkItem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ public class VoWorkItem
public List<InstanceProperties> InstanceProperties { get; set; } = new();

public string FullNamespace { get; set; } = string.Empty;
public INamedTypeSymbol? ContainingType { get; set; }
}
10 changes: 0 additions & 10 deletions src/Vogen/WriteWorkItems.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,4 @@ private static IGenerateSourceCode GetGenerator(VoWorkItem item) =>
StructDeclarationSyntax => _structGeneratorForValueAndReferenceTypes,
_ => throw new InvalidOperationException("Don't know how to get the generator!")
};

private static void ReportErrors(SourceProductionContext context,
DiagnosticCollection syntaxReceiverDiagnosticMessages)
{
foreach (var eachDiag in syntaxReceiverDiagnosticMessages)
{
context.ReportDiagnostic(eachDiag);
}
}

}
5 changes: 2 additions & 3 deletions tests/Testbench/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
* to analyze and generate types for.
*/

using System;
using Vogen;

var c = GetCustomer();
CustomerId GetCustomer() => CustomerId.From(123);
// Task<CustomerId> t = Task.FromResult<CustomerId>(new());

Console.ReadLine();

Expand All @@ -19,4 +19,3 @@ public partial struct CustomerId { }




Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Linq;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.CodeAnalysis;
using VerifyXunit;
Expand All @@ -9,11 +8,11 @@
namespace Vogen.IntegrationTests.SnapshotTests;

[UsesVerify]
public class DefaultingTests
public class DisallowDefaultingTests
{
private readonly ITestOutputHelper _output;

public DefaultingTests(ITestOutputHelper output) => _output = output;
public DisallowDefaultingTests(ITestOutputHelper output) => _output = output;

[Fact]
public void Disallows_default_parameters()
Expand Down Expand Up @@ -140,4 +139,49 @@ class Foo {
diagnostic.ToString().Should().Be("(8,19): error VOG009: Type 'CustomerId' cannot be constructed with default as it is prohibited.");
}

[Fact]
public void Disallows_default_literal_from_func()
{
// The source code to test
var source = @"using System;
using Vogen;
Func<CustomerId> f = () => default;
[ValueObject(typeof(int))]
public partial struct CustomerId { }
";

var (diagnostics, output) = TestHelper.GetGeneratedOutput<CreationUsingDefaultLiteralAnalyzer>(source);

_output.WriteLine(output);

diagnostics.Should().HaveCount(1);
Diagnostic diagnostic = diagnostics.Single();

diagnostic.Id.Should().Be("VOG009");
diagnostic.ToString().Should().Be("(3,28): error VOG009: Type 'CustomerId' cannot be constructed with default as it is prohibited.");
}

[Fact]
public void Disallows_default_from_func()
{
// The source code to test
var source = @"using Vogen;
Func<CustomerId> f = () => default(CustomerId);
[ValueObject(typeof(int))]
public partial struct CustomerId { }
";

var (diagnostics, output) = TestHelper.GetGeneratedOutput<CreationUsingDefaultAnalyzer>(source);

_output.WriteLine(output);

diagnostics.Should().HaveCount(1);
Diagnostic diagnostic = diagnostics.Single();

diagnostic.Id.Should().Be("VOG009");
diagnostic.ToString().Should().Be("(2,36): error VOG009: Type 'CustomerId' cannot be constructed with default as it is prohibited.");
}

}
Loading

0 comments on commit 1f57dfa

Please sign in to comment.