Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Func<> resolution #166

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using Xunit;
using Jab;

Expand Down
109 changes: 109 additions & 0 deletions src/Jab.FunctionalTests.Common/ContainerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,115 @@ public void CanGetMultipleOpenGenericScoped()
partial class CanGetMultipleOpenGenericScopedContainer
{
}

[Fact]
public void SupportsImplicitFunc()
{
SupportsImplicitFuncFactoryContainer c = new();
var transientFunc = c.GetService<Func<IService>>();
var transientFunc2 = c.GetService<Func<IService>>();
var transientService1 = transientFunc();
var transientService2 = transientFunc();

var scope1 = c.CreateScope();
var scopedFunc = scope1.GetService<Func<IService1>>();
var scopedFunc2 = scope1.GetService<Func<IService1>>();
var scopedService1 = scopedFunc();
var scopedService2 = scopedFunc();

var scope2 = c.CreateScope();
var scopedFunc3 = scope2.GetService<Func<IService1>>();
var scopedService3 = scopedFunc3();

var singletonFunc = c.GetService<Func<IService2>>();
var singletonFunc2 = c.GetService<Func<IService2>>();

var singletonService1 = singletonFunc();
var singletonService2 = singletonFunc2();

Assert.Equal(2, c.TransientCount);
Assert.Equal(2, c.ScopedCount);
Assert.Equal(1, c.SingletonCount);

Assert.Same(singletonFunc, singletonFunc2);
Assert.Same(transientFunc, transientFunc2);
Assert.Same(scopedFunc, scopedFunc2);
Assert.NotSame(scopedFunc2, scopedFunc3);

Assert.Same(singletonService1, singletonService2);
Assert.Same(scopedService1, scopedService2);
Assert.NotSame(scopedService1, scopedService3);

Assert.NotSame(transientService1, transientService2);
}

[Fact]
public void SupportsImplicitNamedFunc()
{
SupportsImplicitFuncFactoryContainer c = new();
var transientFunc = c.GetService<Func<IService>>("named");
var transientFunc2 = c.GetService<Func<IService>>("named");
var transientService1 = transientFunc();
var transientService2 = transientFunc();

var singletonFunc = c.GetService<Func<IService2>>("named");
var singletonFunc2 = c.GetService<Func<IService2>>("named");

var singletonService1 = singletonFunc();
var singletonService2 = singletonFunc2();

Assert.Equal(2, c.TransientNamedCount);
Assert.Equal(1, c.SingletonNamedCount);

Assert.Same(singletonFunc, singletonFunc2);
Assert.Same(transientFunc, transientFunc2);

Assert.Same(singletonService1, singletonService2);
Assert.NotSame(transientService1, transientService2);
}

[ServiceProvider(RootServices = new [] { typeof(Func<IService1>) })]
[Transient(typeof(IService), Factory=nameof(TransientNamedFactory), Name = "named")]
[Singleton(typeof(IService2), Factory=nameof(SingletonNamedFactory), Name = "named")]
[Transient(typeof(IService), Factory=nameof(TransientFactory))]
[Scoped(typeof(IService1), Factory=nameof(ScopedFactory))]
[Singleton(typeof(IService2), Factory=nameof(SingletonFactory))]
internal partial class SupportsImplicitFuncFactoryContainer
{
internal int TransientCount = 0;
internal int ScopedCount = 0;
internal int SingletonCount = 0;

internal int TransientNamedCount = 0;
internal int SingletonNamedCount = 0;

internal ServiceImplementation TransientFactory()
{
TransientCount++;
return new();
}
internal ServiceImplementation ScopedFactory()
{
ScopedCount++;
return new();
}
internal ServiceImplementation SingletonFactory()
{
SingletonCount++;
return new();
}

internal ServiceImplementation TransientNamedFactory()
{
TransientNamedCount++;
return new();
}
internal ServiceImplementation SingletonNamedFactory()
{
SingletonNamedCount++;
return new();
}
}

#region Non-generic member factory with parameters
[Fact]
Expand Down
2 changes: 1 addition & 1 deletion src/Jab/ConstructorCallSite.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

internal record ConstructorCallSite : ServiceCallSite
{
public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair<IParameterSymbol, ServiceCallSite>[] optionalParameters, ServiceLifetime lifetime, int? reverseIndex, bool? isDisposable)
public ConstructorCallSite(ServiceIdentity identity, INamedTypeSymbol implementationType, ServiceCallSite[] parameters, KeyValuePair<IParameterSymbol, ServiceCallSite>[] optionalParameters, ServiceLifetime lifetime, bool? isDisposable)
: base(identity, implementationType, lifetime, isDisposable)
{
Parameters = parameters;
Expand Down
31 changes: 20 additions & 11 deletions src/Jab/ContainerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ private void GenerateCallSiteWithCache(CodeWriter codeWriter, string rootReferen
if (serviceCallSite.Lifetime != ServiceLifetime.Transient)
{
var cacheLocation = GetCacheLocation(serviceCallSite.Identity);
codeWriter.Line($"if ({cacheLocation} == null)");
codeWriter.Line($"lock (this)");
using (codeWriter.Scope($"if ({cacheLocation} == null)"))
var locking = serviceCallSite is not FuncCallSite;
if (locking)
{
GenerateCallSite(
codeWriter,
rootReference,
serviceCallSite,
(w, v) =>
{
w.Line($"{cacheLocation} = {v};");
});
codeWriter.Line($"if ({cacheLocation} == null)");
codeWriter.Line($"lock (this)");
}
GenerateCallSite(
codeWriter,
rootReference,
serviceCallSite,
(w, v) =>
{
w.Line($"{cacheLocation} ??= {v};");
});

if (serviceCallSite.ImplementationType.IsValueType)
{
Expand Down Expand Up @@ -146,6 +147,14 @@ private void GenerateCallSite(CodeWriter codeWriter, string rootReference, Servi
w.Append($")");
});
break;

case FuncCallSite funcCallSite:
valueCallback(codeWriter, w =>
{
w.Append($"() => ");
WriteResolutionCall(codeWriter, funcCallSite.Inner.Identity, "this");
});
break;
case MemberCallSite memberCallSite:
valueCallback(codeWriter, w =>
{
Expand Down
18 changes: 18 additions & 0 deletions src/Jab/FuncCallSite.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace Jab;

internal record FuncCallSite : ServiceCallSite
{
public FuncCallSite(ServiceIdentity identity, ServiceCallSite inner)
: base(identity, identity.Type, GetFuncLifetime(inner.Lifetime), false)
{
Inner = inner;
}

public ServiceCallSite Inner { get; }

private static ServiceLifetime GetFuncLifetime(ServiceLifetime innerLifetime) => innerLifetime switch
{
ServiceLifetime.Scoped => ServiceLifetime.Scoped,
_ => ServiceLifetime.Singleton
};
}
3 changes: 3 additions & 0 deletions src/Jab/KnownTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ internal class KnownTypes

private const string IAsyncDisposableMetadataName = "System.IAsyncDisposable";
private const string IEnumerableMetadataName = "System.Collections.Generic.IEnumerable`1";
private const string FuncMetadataName = "System.Func`1";
private const string IServiceProviderMetadataName = "System.IServiceProvider";
private const string IServiceScopeMetadataName = "Microsoft.Extensions.DependencyInjection.IServiceScope";
private const string IKeyedServiceProviderMetadataName = "Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider";
Expand All @@ -59,6 +60,7 @@ internal class KnownTypes
"Microsoft.Extensions.DependencyInjection.IServiceProviderIsService";

public INamedTypeSymbol IEnumerableType { get; }
public INamedTypeSymbol FuncType { get; }
public INamedTypeSymbol IServiceProviderType { get; }
public INamedTypeSymbol CompositionRootAttributeType { get; }
public INamedTypeSymbol TransientAttributeType { get; }
Expand Down Expand Up @@ -102,6 +104,7 @@ static INamedTypeSymbol GetTypeFromCompilationByMetadataNameOrThrow(Compilation
?? throw new InvalidOperationException($"Type with metadata '{fullyQualifiedMetadataName}' not found");

IEnumerableType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IEnumerableMetadataName);
FuncType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, FuncMetadataName);
IServiceProviderType = GetTypeFromCompilationByMetadataNameOrThrow(compilation, IServiceProviderMetadataName);
IServiceScopeType = compilation.GetTypeByMetadataName(IServiceScopeMetadataName);
IAsyncDisposableType = compilation.GetTypeByMetadataName(IAsyncDisposableMetadataName);
Expand Down
33 changes: 32 additions & 1 deletion src/Jab/ServiceProviderBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ private void EmitTypeDiagnostics(ITypeSymbol typeSymbol)
return TryCreateSpecial(serviceType, name, context) ??
TryCreateExact(serviceType, name, null, context) ??
TryCreateEnumerable(serviceType, name, context) ??
TryCreateFunc(serviceType, name, context) ??
TryCreateGeneric(serviceType, name, context);
}
finally
Expand Down Expand Up @@ -432,6 +433,37 @@ static ServiceLifetime GetCommonLifetime(IEnumerable<ServiceCallSite> callSites)
return null;
}


private ServiceCallSite? TryCreateFunc(ITypeSymbol serviceType, string? name, ServiceResolutionContext context)
{
if (serviceType is INamedTypeSymbol { IsGenericType: true } genericType &&
SymbolEqualityComparer.Default.Equals(genericType.ConstructedFrom, _knownTypes.FuncType))
{
var identity = new ServiceIdentity(genericType, name, null);

if (context.CallSiteCache.TryGet(identity, out var callSite))
{
return callSite;
}

var innerType = genericType.TypeArguments[0];
var inner = GetCallSite(innerType, name, context);

if (inner == null)
{
return null;
}

callSite = new FuncCallSite(identity, inner);

context.CallSiteCache.Add(callSite);

return callSite;
}

return null;
}

private ServiceCallSite? TryCreateExact(
ITypeSymbol serviceType,
string? name,
Expand Down Expand Up @@ -612,7 +644,6 @@ private ServiceCallSite CreateConstructorCallSite(
parameters.ToArray(),
namedParameters.ToArray(),
registration.Lifetime,
identity.ReverseIndex,
// TODO: this can be optimized to avoid check for all the types
isDisposable: null
);
Expand Down
Loading