diff --git a/README.md b/README.md index 4b1c61507..8199c1595 100644 --- a/README.md +++ b/README.md @@ -100,8 +100,6 @@ There are 2 parts to the locator design: * **Locator.Current** The property to use to **retrieve** services. Locator.Current is a static variable that can be set on startup, to adapt Splat to other DI/IoC frameworks. We're currently working from v7 onward to make it easier to use your DI/IoC framework of choice. (see below) * **Locator.CurrentMutable** The property to use to **register** services -**Note:** Currently these properties point to the same object and you can use CurrentMutable to also GetServices, but this is not the intention and the interfaces may be adjusted in future to lock this down (and make it more obvious what the use cases are). - To get a service: ```cs @@ -132,6 +130,10 @@ Locator.CurrentMutable.RegisterLazySingleton(() => new LazyToaster(), typeof(ITo ### Dependency Resolver Packages For each of the provided dependency resolver adapters, there is a specific package that allows the service locator to be implemented by another ioc container. +Please note: If you are adjusting behaviours of Splat by working with your custom container directly. Please read the relevant projects documentation on +REPLACING the registration. If the container supports appending\ multiple registrations you may get undesired behaviours, such as the wrong logger factory +being used. + | Container | NuGet | Read Me |---------|-------|-------| | [Splat.Autofac][SplatAutofacNuGet] | [![SplatAutofacBadge]][SplatAutofacNuGet] | [Setup Autofac][SplatAutofacReadme] diff --git a/src/Splat.Autofac.Tests/DependencyResolverTests.cs b/src/Splat.Autofac.Tests/DependencyResolverTests.cs index e94975471..ff379fd6f 100644 --- a/src/Splat.Autofac.Tests/DependencyResolverTests.cs +++ b/src/Splat.Autofac.Tests/DependencyResolverTests.cs @@ -10,6 +10,7 @@ using Autofac; using Shouldly; using Splat.Common.Test; +using Splat.Tests.ServiceLocation; using Xunit; namespace Splat.Autofac.Tests @@ -17,7 +18,7 @@ namespace Splat.Autofac.Tests /// /// Tests to show the works correctly. /// - public class DependencyResolverTests + public class DependencyResolverTests : BaseDependencyResolverTests { /// /// Shoulds the resolve views. @@ -93,45 +94,61 @@ public void AutofacDependencyResolver_Should_Resolve_Screen() /// Should throw an exception if service registration call back called. /// [Fact] - public void AutofacDependencyResolver_Should_Throw_If_UnregisterCurrent_Called() + public void AutofacDependencyResolver_Should_Throw_If_ServiceRegistionCallback_Called() { var container = new ContainerBuilder(); container.UseAutofacDependencyResolver(); var result = Record.Exception(() => - Locator.CurrentMutable.UnregisterCurrent(typeof(IScreen))); + Locator.CurrentMutable.ServiceRegistrationCallback(typeof(IScreen), disposable => { })); result.ShouldBeOfType(); } /// - /// Should unregister all. + /// Check to ensure the correct logger is returned. /// + /// + /// Introduced for Splat #331. + /// [Fact] - public void AutofacDependencyResolver_Should_UnregisterAll_Called() + public void AutofacDependencyResolver_Should_ReturnRegisteredLogger() { var container = new ContainerBuilder(); container.UseAutofacDependencyResolver(); - var result = Record.Exception(() => - Locator.CurrentMutable.UnregisterCurrent(typeof(IScreen))); + Locator.CurrentMutable.RegisterConstant( + new FuncLogManager(type => new WrappingFullLogger(new ConsoleLogger())), + typeof(ILogManager)); - result.ShouldBeOfType(); + var d = Splat.Locator.Current.GetService(); + Assert.IsType(d); } /// - /// Should throw an exception if service registration call back called. + /// Test that a pre-init logger isn't overriden. /// + /// + /// Introduced for Splat #331. + /// [Fact] - public void AutofacDependencyResolver_Should_Throw_If_ServiceRegistionCallback_Called() + public void AutofacDependencyResolver_PreInit_Should_ReturnRegisteredLogger() { - var container = new ContainerBuilder(); - container.UseAutofacDependencyResolver(); + var builder = new ContainerBuilder(); + builder.Register(_ => new FuncLogManager(type => new WrappingFullLogger(new ConsoleLogger()))).As(typeof(ILogManager)) + .AsImplementedInterfaces(); - var result = Record.Exception(() => - Locator.CurrentMutable.ServiceRegistrationCallback(typeof(IScreen), disposable => { })); + builder.UseAutofacDependencyResolver(); - result.ShouldBeOfType(); + var d = Splat.Locator.Current.GetService(); + Assert.IsType(d); + } + + /// + protected override AutofacDependencyResolver GetDependencyResolver() + { + var container = new ContainerBuilder(); + return new AutofacDependencyResolver(container.Build()); } } } diff --git a/src/Splat.Autofac.Tests/Splat.Autofac.Tests.csproj b/src/Splat.Autofac.Tests/Splat.Autofac.Tests.csproj index f52b515e5..755368780 100644 --- a/src/Splat.Autofac.Tests/Splat.Autofac.Tests.csproj +++ b/src/Splat.Autofac.Tests/Splat.Autofac.Tests.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Splat.Autofac/AutofacDependencyResolver.cs b/src/Splat.Autofac/AutofacDependencyResolver.cs index 82229ebd1..77d983578 100644 --- a/src/Splat.Autofac/AutofacDependencyResolver.cs +++ b/src/Splat.Autofac/AutofacDependencyResolver.cs @@ -10,6 +10,7 @@ using System.Linq; using Autofac; using Autofac.Core; +using Autofac.Core.Registration; #pragma warning disable CS0618 // Obsolete values. @@ -20,7 +21,8 @@ namespace Splat.Autofac /// public class AutofacDependencyResolver : IDependencyResolver { - private readonly IComponentContext _componentContext; + private readonly object _lockObject = new object(); + private IComponentContext _componentContext; /// /// Initializes a new instance of the class. @@ -34,32 +36,47 @@ public AutofacDependencyResolver(IComponentContext componentContext) /// public virtual object GetService(Type serviceType, string contract = null) { - try + lock (_lockObject) { - return string.IsNullOrEmpty(contract) - ? _componentContext.Resolve(serviceType) - : _componentContext.ResolveNamed(contract, serviceType); - } - catch (DependencyResolutionException) - { - return null; + try + { + return string.IsNullOrEmpty(contract) + ? _componentContext.Resolve(serviceType) + : _componentContext.ResolveNamed(contract, serviceType); + } + catch (DependencyResolutionException) + { + return null; + } } } /// public virtual IEnumerable GetServices(Type serviceType, string contract = null) { - try + lock (_lockObject) { - var enumerableType = typeof(IEnumerable<>).MakeGenericType(serviceType); - object instance = string.IsNullOrEmpty(contract) - ? _componentContext.Resolve(enumerableType) - : _componentContext.ResolveNamed(contract, enumerableType); - return ((IEnumerable)instance).Cast(); + try + { + var enumerableType = typeof(IEnumerable<>).MakeGenericType(serviceType); + object instance = string.IsNullOrEmpty(contract) + ? _componentContext.Resolve(enumerableType) + : _componentContext.ResolveNamed(contract, enumerableType); + return ((IEnumerable)instance).Cast(); + } + catch (DependencyResolutionException) + { + return null; + } } - catch (DependencyResolutionException) + } + + /// + public bool HasRegistration(Type serviceType) + { + lock (_lockObject) { - return null; + return _componentContext.IsRegistered(serviceType); } } @@ -75,17 +92,20 @@ public virtual IEnumerable GetServices(Type serviceType, string contract /// A optional contract value which will indicates to only generate the value if this contract is specified. public virtual void Register(Func factory, Type serviceType, string contract = null) { - var builder = new ContainerBuilder(); - if (string.IsNullOrEmpty(contract)) - { - builder.Register(x => factory()).As(serviceType).AsImplementedInterfaces(); - } - else + lock (_lockObject) { - builder.Register(x => factory()).Named(contract, serviceType).AsImplementedInterfaces(); - } + var builder = new ContainerBuilder(); + if (string.IsNullOrEmpty(contract)) + { + builder.Register(x => factory()).As(serviceType).AsImplementedInterfaces(); + } + else + { + builder.Register(x => factory()).Named(contract, serviceType).AsImplementedInterfaces(); + } - builder.Update(_componentContext.ComponentRegistry); + builder.Update(_componentContext.ComponentRegistry); + } } /// @@ -98,7 +118,78 @@ public virtual void Register(Func factory, Type serviceType, string cont /// public virtual void UnregisterCurrent(Type serviceType, string contract = null) { - throw new NotImplementedException(); + lock (_lockObject) + { + var registrations = _componentContext.ComponentRegistry.Registrations.ToList(); + var registrationCount = registrations.Count; + if (registrationCount < 1) + { + return; + } + + var candidatesForRemoval = new List(registrationCount); + var registrationIndex = 0; + while (registrationIndex < registrationCount) + { + var componentRegistration = registrations[registrationIndex]; + + var isCandidateForRemoval = GetWhetherServiceIsCandidateForRemoval( + componentRegistration.Services, + serviceType, + contract); + if (isCandidateForRemoval) + { + registrations.RemoveAt(registrationIndex); + candidatesForRemoval.Add(componentRegistration); + registrationCount--; + } + else + { + registrationIndex++; + } + } + + if (candidatesForRemoval.Count == 0) + { + // nothing to remove + return; + } + + if (candidatesForRemoval.Count > 1) + { + // need to re-add some registrations + var reAdd = candidatesForRemoval.Take(candidatesForRemoval.Count - 1); + registrations.AddRange(reAdd); + + /* + // check for multi service registration + // in future might want to just remove a single service from a component + // rather than do the whole component. + var lastCandidate = candidatesForRemoval.Last(); + var lastCandidateRegisteredServices = lastCandidate.Services.ToArray(); + if (lastCandidateRegisteredServices.Length > 1) + { + // + // builder.RegisterType() + // .AsSelf() + // .As() + // .As(); + var survivingServices = lastCandidateRegisteredServices.Where(s => s.GetType() != serviceType); + var newRegistration = new ComponentRegistration( + lastCandidate.Id, + lastCandidate.Activator, + lastCandidate.Lifetime, + lastCandidate.Sharing, + lastCandidate.Ownership, + survivingServices, + lastCandidate.Metadata); + registrations.Add(newRegistration); + } + */ + } + + RemoveAndRebuild(registrations); + } } /// @@ -111,7 +202,46 @@ public virtual void UnregisterCurrent(Type serviceType, string contract = null) /// public virtual void UnregisterAll(Type serviceType, string contract = null) { - throw new NotImplementedException(); + lock (_lockObject) + { + // prevent multiple enumerations + var registrations = _componentContext.ComponentRegistry.Registrations.ToList(); + var registrationCount = registrations.Count; + if (registrationCount < 1) + { + return; + } + + if (!string.IsNullOrEmpty(contract)) + { + RemoveAndRebuild( + registrationCount, + registrations, + x => x.Services.All(s => + { + if (!(s is TypedService typedService)) + { + return false; + } + + return typedService.ServiceType != serviceType || !HasMatchingContract(s, contract); + })); + return; + } + + RemoveAndRebuild( + registrationCount, + registrations, + x => x.Services.All(s => + { + if (!(s is TypedService typedService)) + { + return false; + } + + return typedService.ServiceType != serviceType; + })); + } } /// @@ -134,10 +264,98 @@ public void Dispose() /// Whether or not the instance is disposing. protected virtual void Dispose(bool disposing) { - if (disposing) + lock (_lockObject) { - _componentContext.ComponentRegistry?.Dispose(); + if (disposing) + { + _componentContext.ComponentRegistry?.Dispose(); + } } } + + private static bool GetWhetherServiceIsCandidateForRemoval( + IEnumerable componentRegistrationServices, + Type serviceType, + string contract) + { + foreach (var componentRegistrationService in componentRegistrationServices) + { + if (!(componentRegistrationService is TypedService typedService)) + { + continue; + } + + if (typedService.ServiceType != serviceType) + { + continue; + } + + // right type + if (string.IsNullOrEmpty(contract)) + { + if (!HasNoContract(componentRegistrationService)) + { + continue; + } + + // candidate for removal + return true; + } + + if (!HasMatchingContract(typedService, contract)) + { + continue; + } + + // candidate for removal + return true; + } + + return false; + } + + private static bool HasMatchingContract(Service service, string contract) + { + // you can't directly access the name key. shame. + return service.Description.StartsWith($"({contract})", StringComparison.Ordinal); + } + + private static bool HasNoContract(Service service) + { + return !service.Description.StartsWith("(", StringComparison.Ordinal); + } + + private void RemoveAndRebuild( + int registrationCount, + IList registrations, + Func predicate) + { + var survivingComponents = registrations.Where(predicate).ToArray(); + + if (survivingComponents.Length == registrationCount) + { + // not removing anything + // drop out + return; + } + + RemoveAndRebuild(survivingComponents); + } + + private void RemoveAndRebuild(IEnumerable survivingComponents) + { + var builder = new ContainerBuilder(); + foreach (var c in survivingComponents) + { + builder.RegisterComponent(c); + } + + foreach (var source in _componentContext.ComponentRegistry.Sources) + { + builder.RegisterSource(source); + } + + _componentContext = builder.Build(); + } } } diff --git a/src/Splat.DryIoc.Tests/DependencyResolverTests.cs b/src/Splat.DryIoc.Tests/DependencyResolverTests.cs index 10a72d380..92193f635 100644 --- a/src/Splat.DryIoc.Tests/DependencyResolverTests.cs +++ b/src/Splat.DryIoc.Tests/DependencyResolverTests.cs @@ -170,5 +170,42 @@ public void DryIocDependencyResolver_Should_Throw_If_ServiceRegistionCallback_Ca result.ShouldBeOfType(); } + + /// + /// Check to ensure the correct logger is returned. + /// + /// + /// Introduced for Splat #331. + /// + [Fact] + public void DryIocDependencyResolver_Should_ReturnRegisteredLogger() + { + var c = new Container(); + c.UseDryIocDependencyResolver(); + c.Register(ifAlreadyRegistered: IfAlreadyRegistered.Replace); + Locator.CurrentMutable.RegisterConstant( + new FuncLogManager(type => new WrappingFullLogger(new ConsoleLogger())), + typeof(ILogManager)); + + var d = Splat.Locator.Current.GetService(); + Assert.IsType(d); + } + + /// + /// Test that a pre-init logger isn't overriden. + /// + /// + /// Introduced for Splat #331. + /// + [Fact] + public void DryIocDependencyResolver_PreInit_Should_ReturnRegisteredLogger() + { + var c = new Container(); + c.UseInstance(typeof(ILogManager), new FuncLogManager(type => new WrappingFullLogger(new ConsoleLogger()))); + c.UseDryIocDependencyResolver(); + + var d = Splat.Locator.Current.GetService(); + Assert.IsType(d); + } } } diff --git a/src/Splat.DryIoc/DryIocDependencyResolver.cs b/src/Splat.DryIoc/DryIocDependencyResolver.cs index 83ebc3200..cf9f37bac 100644 --- a/src/Splat.DryIoc/DryIocDependencyResolver.cs +++ b/src/Splat.DryIoc/DryIocDependencyResolver.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; +using System.Linq; using DryIoc; namespace Splat.DryIoc @@ -30,8 +31,8 @@ public DryIocDependencyResolver(IContainer container = null) /// public virtual object GetService(Type serviceType, string contract = null) => string.IsNullOrEmpty(contract) - ? _container.Resolve(serviceType, IfUnresolved.ReturnDefault) - : _container.Resolve(serviceType, contract, IfUnresolved.ReturnDefault); + ? _container.ResolveMany(serviceType).LastOrDefault() + : _container.ResolveMany(serviceType, serviceKey: contract).LastOrDefault(); /// public virtual IEnumerable GetServices(Type serviceType, string contract = null) => @@ -39,6 +40,12 @@ public virtual IEnumerable GetServices(Type serviceType, string contract ? _container.ResolveMany(serviceType) : _container.ResolveMany(serviceType, serviceKey: contract); + /// + public bool HasRegistration(Type serviceType) + { + return _container.GetServiceRegistrations().Any(x => x.ServiceType == serviceType); + } + /// public virtual void Register(Func factory, Type serviceType, string contract = null) { diff --git a/src/Splat.Ninject/NinjectDependencyResolver.cs b/src/Splat.Ninject/NinjectDependencyResolver.cs index 92e5f0c75..a070bf8db 100644 --- a/src/Splat.Ninject/NinjectDependencyResolver.cs +++ b/src/Splat.Ninject/NinjectDependencyResolver.cs @@ -38,6 +38,12 @@ public virtual IEnumerable GetServices(Type serviceType, string contract ? _kernel.GetAll(serviceType) : _kernel.GetAll(serviceType, contract); + /// + public bool HasRegistration(Type serviceType) + { + return _kernel.CanResolve(serviceType); + } + /// public virtual void Register(Func factory, Type serviceType, string contract = null) => _kernel.Bind(serviceType).ToMethod(_ => factory()); diff --git a/src/Splat.SimpleInjector/SimpleInjectorDependencyResolver.cs b/src/Splat.SimpleInjector/SimpleInjectorDependencyResolver.cs index 498d473dc..eff26b8b1 100644 --- a/src/Splat.SimpleInjector/SimpleInjectorDependencyResolver.cs +++ b/src/Splat.SimpleInjector/SimpleInjectorDependencyResolver.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; +using System.Linq; using SimpleInjector; namespace Splat.SimpleInjector @@ -27,12 +28,21 @@ public SimpleInjectorDependencyResolver(Container container) } /// - public object GetService(Type serviceType, string contract = null) => _container.GetInstance(serviceType); + public object GetService(Type serviceType, string contract = null) + { + return _container.GetInstance(serviceType); + } /// public IEnumerable GetServices(Type serviceType, string contract = null) => _container.GetAllInstances(serviceType); + /// + public bool HasRegistration(Type serviceType) + { + return _container.GetCurrentRegistrations().Any(x => x.ServiceType == serviceType); + } + /// public void Register(Func factory, Type serviceType, string contract = null) => _container.Register(serviceType, factory); diff --git a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.net472.approved.txt b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.net472.approved.txt index 83786430f..83fddc9ba 100644 --- a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.net472.approved.txt +++ b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.net472.approved.txt @@ -181,7 +181,6 @@ namespace Splat { public static T GetService(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } public static System.Collections.Generic.IEnumerable GetServices(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } - public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } public static void Register(this Splat.IMutableDependencyResolver resolver, System.Func factory, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, object value, System.Type serviceType, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, T value, string contract = null) { } @@ -217,6 +216,7 @@ namespace Splat protected virtual void Dispose(bool isDisposing) { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -440,6 +440,7 @@ namespace Splat } public interface IMutableDependencyResolver { + bool HasRegistration(System.Type serviceType); void Register(System.Func factory, System.Type serviceType, string contract = null); System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback); void UnregisterAll(System.Type serviceType, string contract = null); @@ -683,12 +684,15 @@ namespace Splat public class ModernDependencyResolver : Splat.IDependencyResolver, Splat.IMutableDependencyResolver, Splat.IReadonlyDependencyResolver, System.IDisposable { public ModernDependencyResolver() { } - protected ModernDependencyResolver(System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } + protected ModernDependencyResolver([System.Runtime.CompilerServices.TupleElementNamesAttribute(new string[] { + "serviceType", + "contract"})] System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } public void Dispose() { } protected virtual void Dispose(bool isDisposing) { } public Splat.ModernDependencyResolver Duplicate() { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -756,6 +760,10 @@ namespace Splat public static System.Tuple DivideWithPadding(this System.Drawing.RectangleF value, float sliceAmount, float padding, Splat.RectEdge fromEdge) { } public static System.Drawing.RectangleF InvertWithin(this System.Drawing.RectangleF value, System.Drawing.RectangleF containingRect) { } } + public class static ServiceLocationInitialization + { + public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } + } public class static SizeExtensions { public static System.Drawing.SizeF FromNative(this System.Windows.Size value) { } diff --git a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp2.1.approved.txt b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp2.1.approved.txt index 52ac5b8e3..46b5583d0 100644 --- a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp2.1.approved.txt +++ b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp2.1.approved.txt @@ -170,7 +170,6 @@ namespace Splat { public static T GetService(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } public static System.Collections.Generic.IEnumerable GetServices(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } - public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } public static void Register(this Splat.IMutableDependencyResolver resolver, System.Func factory, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, object value, System.Type serviceType, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, T value, string contract = null) { } @@ -206,6 +205,7 @@ namespace Splat protected virtual void Dispose(bool isDisposing) { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -429,6 +429,7 @@ namespace Splat } public interface IMutableDependencyResolver { + bool HasRegistration(System.Type serviceType); void Register(System.Func factory, System.Type serviceType, string contract = null); System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback); void UnregisterAll(System.Type serviceType, string contract = null); @@ -672,12 +673,15 @@ namespace Splat public class ModernDependencyResolver : Splat.IDependencyResolver, Splat.IMutableDependencyResolver, Splat.IReadonlyDependencyResolver, System.IDisposable { public ModernDependencyResolver() { } - protected ModernDependencyResolver(System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } + protected ModernDependencyResolver([System.Runtime.CompilerServices.TupleElementNamesAttribute(new string[] { + "serviceType", + "contract"})] System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } public void Dispose() { } protected virtual void Dispose(bool isDisposing) { } public Splat.ModernDependencyResolver Duplicate() { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -720,6 +724,10 @@ namespace Splat public static System.Tuple DivideWithPadding(this System.Drawing.RectangleF value, float sliceAmount, float padding, Splat.RectEdge fromEdge) { } public static System.Drawing.RectangleF InvertWithin(this System.Drawing.RectangleF value, System.Drawing.RectangleF containingRect) { } } + public class static ServiceLocationInitialization + { + public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } + } public class static SizeMathExtensions { public static System.Drawing.SizeF ScaledBy(this System.Drawing.SizeF value, float factor) { } diff --git a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp3.0.approved.txt b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp3.0.approved.txt index 438fbb056..763cf31ad 100644 --- a/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp3.0.approved.txt +++ b/src/Splat.Tests/API/ApiApprovalTests.SplatProject.netcoreapp3.0.approved.txt @@ -181,7 +181,6 @@ namespace Splat { public static T GetService(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } public static System.Collections.Generic.IEnumerable GetServices(this Splat.IReadonlyDependencyResolver resolver, string contract = null) { } - public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } public static void Register(this Splat.IMutableDependencyResolver resolver, System.Func factory, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, object value, System.Type serviceType, string contract = null) { } public static void RegisterConstant(this Splat.IMutableDependencyResolver resolver, T value, string contract = null) { } @@ -217,6 +216,7 @@ namespace Splat protected virtual void Dispose(bool isDisposing) { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -440,6 +440,7 @@ namespace Splat } public interface IMutableDependencyResolver { + bool HasRegistration(System.Type serviceType); void Register(System.Func factory, System.Type serviceType, string contract = null); System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback); void UnregisterAll(System.Type serviceType, string contract = null); @@ -683,12 +684,15 @@ namespace Splat public class ModernDependencyResolver : Splat.IDependencyResolver, Splat.IMutableDependencyResolver, Splat.IReadonlyDependencyResolver, System.IDisposable { public ModernDependencyResolver() { } - protected ModernDependencyResolver(System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } + protected ModernDependencyResolver([System.Runtime.CompilerServices.TupleElementNamesAttribute(new string[] { + "serviceType", + "contract"})] System.Collections.Generic.Dictionary, System.Collections.Generic.List>> registry) { } public void Dispose() { } protected virtual void Dispose(bool isDisposing) { } public Splat.ModernDependencyResolver Duplicate() { } public object GetService(System.Type serviceType, string contract = null) { } public System.Collections.Generic.IEnumerable GetServices(System.Type serviceType, string contract = null) { } + public bool HasRegistration(System.Type serviceType) { } public void Register(System.Func factory, System.Type serviceType, string contract = null) { } public System.IDisposable ServiceRegistrationCallback(System.Type serviceType, string contract, System.Action callback) { } public void UnregisterAll(System.Type serviceType, string contract = null) { } @@ -756,6 +760,10 @@ namespace Splat public static System.Tuple DivideWithPadding(this System.Drawing.RectangleF value, float sliceAmount, float padding, Splat.RectEdge fromEdge) { } public static System.Drawing.RectangleF InvertWithin(this System.Drawing.RectangleF value, System.Drawing.RectangleF containingRect) { } } + public class static ServiceLocationInitialization + { + public static void InitializeSplat(this Splat.IMutableDependencyResolver resolver) { } + } public class static SizeExtensions { public static System.Drawing.SizeF FromNative(this System.Windows.Size value) { } diff --git a/src/Splat.Tests/ServiceLocation/BaseDependencyResolverTests.cs b/src/Splat.Tests/ServiceLocation/BaseDependencyResolverTests.cs new file mode 100644 index 000000000..78c455aa0 --- /dev/null +++ b/src/Splat.Tests/ServiceLocation/BaseDependencyResolverTests.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace Splat.Tests.ServiceLocation +{ + /// + /// Common tests for Dependency Resolver interaction with Splat. + /// + /// The dependency resolver to test. + public abstract class BaseDependencyResolverTests + where T : IDependencyResolver + { + /// + /// Test to ensure Unregister doesn't cause an IndexOutOfRangeException. + /// + [Fact] + public void UnregisterCurrent_Doesnt_Throw_When_List_Empty() + { + var resolver = GetDependencyResolver(); + var type = typeof(ILogManager); + resolver.Register(() => new DefaultLogManager(), type); + resolver.Register(() => new DefaultLogManager(), type, "named"); + resolver.UnregisterCurrent(type); + resolver.UnregisterCurrent(type); + } + + /// + /// Test to ensure UnregisterCurrent removes last entry. + /// + [Fact] + public void UnregisterCurrent_Remove_Last() + { + var resolver = GetDependencyResolver(); + var type = typeof(ILogManager); + resolver.Register(() => new DefaultLogManager(), type); + resolver.Register(() => new FuncLogManager(_ => new WrappingFullLogger(new DebugLogger())), type); + resolver.Register(() => new DefaultLogManager(), type, "named"); + + var service = resolver.GetService(type); + Assert.IsType(service); + + resolver.UnregisterCurrent(type); + + service = resolver.GetService(type); + Assert.IsType(service); + } + + /// + /// Test to ensure Unregister doesn't cause an IndexOutOfRangeException. + /// + [Fact] + public void UnregisterCurrentByName_Doesnt_Throw_When_List_Empty() + { + var resolver = GetDependencyResolver(); + var type = typeof(ILogManager); + var contract = "named"; + resolver.Register(() => new DefaultLogManager(), type); + resolver.Register(() => new DefaultLogManager(), type, contract); + resolver.UnregisterCurrent(type, contract); + resolver.UnregisterCurrent(type, contract); + } + + /// + /// Test to ensure Unregister doesn't cause an IndexOutOfRangeException. + /// + [Fact] + public void UnregisterAll_UnregisterCurrent_Doesnt_Throw_When_List_Empty() + { + var resolver = GetDependencyResolver(); + var type = typeof(ILogManager); + resolver.Register(() => new DefaultLogManager(), type); + resolver.Register(() => new DefaultLogManager(), type, "named"); + resolver.UnregisterAll(type); + resolver.UnregisterCurrent(type); + } + + /// + /// Test to ensure Unregister doesn't cause an IndexOutOfRangeException. + /// + [Fact] + public void UnregisterAllByContract_UnregisterCurrent_Doesnt_Throw_When_List_Empty() + { + var resolver = GetDependencyResolver(); + var type = typeof(ILogManager); + var contract = "named"; + resolver.Register(() => new DefaultLogManager(), type); + resolver.Register(() => new DefaultLogManager(), type, contract); + resolver.UnregisterAll(type, contract); + resolver.UnregisterCurrent(type, contract); + } + + /// + /// Gets an instance of a dependency resolver to test. + /// + /// Dependency Resolver. + protected abstract T GetDependencyResolver(); + } +} diff --git a/src/Splat.Tests/ServiceLocation/ModernDependencyResolverTests.cs b/src/Splat.Tests/ServiceLocation/ModernDependencyResolverTests.cs new file mode 100644 index 000000000..20138c411 --- /dev/null +++ b/src/Splat.Tests/ServiceLocation/ModernDependencyResolverTests.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace Splat.Tests.ServiceLocation +{ + /// + /// Unit Tests for the Modern Dependency Resolver. + /// + public sealed class ModernDependencyResolverTests : BaseDependencyResolverTests + { + /// + protected override ModernDependencyResolver GetDependencyResolver() => new ModernDependencyResolver(); + } +} diff --git a/src/Splat/ServiceLocation/DependencyResolverMixins.cs b/src/Splat/ServiceLocation/DependencyResolverMixins.cs index a5c6d5f61..69a95b9e8 100644 --- a/src/Splat/ServiceLocation/DependencyResolverMixins.cs +++ b/src/Splat/ServiceLocation/DependencyResolverMixins.cs @@ -155,20 +155,5 @@ public static void UnregisterAll(this IMutableDependencyResolver resolver, st { resolver.UnregisterAll(typeof(T), contract); } - - /// - /// Registers all the default registrations that are needed by the Splat module. - /// - /// The resolver to register the needed service types against. - public static void InitializeSplat(this IMutableDependencyResolver resolver) - { - resolver.Register(() => new DefaultLogManager(), typeof(ILogManager)); - resolver.RegisterConstant(new DebugLogger(), typeof(ILogger)); - -#if !NETSTANDARD - // not supported in netstandard2.0 - resolver.RegisterLazySingleton(() => new PlatformBitmapLoader(), typeof(IBitmapLoader)); -#endif - } } } diff --git a/src/Splat/ServiceLocation/FuncDependencyResolver.cs b/src/Splat/ServiceLocation/FuncDependencyResolver.cs index 130d7adba..a079e19ad 100644 --- a/src/Splat/ServiceLocation/FuncDependencyResolver.cs +++ b/src/Splat/ServiceLocation/FuncDependencyResolver.cs @@ -60,6 +60,12 @@ public IEnumerable GetServices(Type serviceType, string contract = null) return _innerGetServices(serviceType, contract); } + /// + public bool HasRegistration(Type serviceType) + { + return _innerGetServices(serviceType, null) != null; + } + /// public void Register(Func factory, Type serviceType, string contract = null) { diff --git a/src/Splat/ServiceLocation/IMutableDependencyResolver.cs b/src/Splat/ServiceLocation/IMutableDependencyResolver.cs index 673533f5c..fe87f1d3f 100644 --- a/src/Splat/ServiceLocation/IMutableDependencyResolver.cs +++ b/src/Splat/ServiceLocation/IMutableDependencyResolver.cs @@ -12,6 +12,13 @@ namespace Splat /// public interface IMutableDependencyResolver { + /// + /// Check to see if a resolvers has a registration for a type. + /// + /// The type to check for registration. + /// Whether there is a registration for the type. + bool HasRegistration(Type serviceType); + /// /// Register a function with the resolver which will generate a object /// for the specified service type. diff --git a/src/Splat/ServiceLocation/ModernDependencyResolver.cs b/src/Splat/ServiceLocation/ModernDependencyResolver.cs index 3f7df3f4f..13b944801 100644 --- a/src/Splat/ServiceLocation/ModernDependencyResolver.cs +++ b/src/Splat/ServiceLocation/ModernDependencyResolver.cs @@ -23,8 +23,8 @@ namespace Splat /// public class ModernDependencyResolver : IDependencyResolver { - private Dictionary, List>> _registry; - private Dictionary, List>> _callbackRegistry; + private Dictionary<(Type serviceType, string contract), List>> _registry; + private Dictionary<(Type serviceType, string contract), List>> _callbackRegistry; private bool _isDisposed; @@ -40,19 +40,26 @@ public ModernDependencyResolver() /// Initializes a new instance of the class. /// /// A registry of services. - protected ModernDependencyResolver(Dictionary, List>> registry) + protected ModernDependencyResolver(Dictionary<(Type serviceType, string contract), List>> registry) { _registry = registry != null ? registry.ToDictionary(k => k.Key, v => v.Value.ToList()) : - new Dictionary, List>>(); + new Dictionary<(Type serviceType, string contract), List>>(); - _callbackRegistry = new Dictionary, List>>(); + _callbackRegistry = new Dictionary<(Type serviceType, string contract), List>>(); + } + + /// + public bool HasRegistration(Type serviceType) + { + var pair = GetKey(serviceType); + return _registry.ContainsKey(pair); } /// public void Register(Func factory, Type serviceType, string contract = null) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); if (!_registry.ContainsKey(pair)) { _registry[pair] = new List>(); @@ -94,7 +101,7 @@ public void Register(Func factory, Type serviceType, string contract = n /// public object GetService(Type serviceType, string contract = null) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); if (!_registry.ContainsKey(pair)) { return default(object); @@ -107,7 +114,7 @@ public object GetService(Type serviceType, string contract = null) /// public IEnumerable GetServices(Type serviceType, string contract = null) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); if (!_registry.ContainsKey(pair)) { return Enumerable.Empty(); @@ -119,20 +126,26 @@ public IEnumerable GetServices(Type serviceType, string contract = null) /// public void UnregisterCurrent(Type serviceType, string contract = null) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); if (!_registry.TryGetValue(pair, out var list)) { return; } - list.RemoveAt(list.Count - 1); + var position = list.Count - 1; + if (position < 0) + { + return; + } + + list.RemoveAt(position); } /// public void UnregisterAll(Type serviceType, string contract = null) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); _registry[pair] = new List>(); } @@ -140,7 +153,7 @@ public void UnregisterAll(Type serviceType, string contract = null) /// public IDisposable ServiceRegistrationCallback(Type serviceType, string contract, Action callback) { - var pair = Tuple.Create(serviceType, contract ?? string.Empty); + var pair = GetKey(serviceType, contract); if (!_callbackRegistry.ContainsKey(pair)) { @@ -197,5 +210,10 @@ protected virtual void Dispose(bool isDisposing) _isDisposed = true; } + + private static (Type, string) GetKey( + Type serviceType, + string contract = null) => + (serviceType, contract ?? string.Empty); } } diff --git a/src/Splat/ServiceLocation/ServiceLocationInitialization.cs b/src/Splat/ServiceLocation/ServiceLocationInitialization.cs new file mode 100644 index 000000000..ca124cc4e --- /dev/null +++ b/src/Splat/ServiceLocation/ServiceLocationInitialization.cs @@ -0,0 +1,53 @@ +// Copyright (c) 2019 .NET Foundation and Contributors. All rights reserved. +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for full license information. + +namespace Splat +{ + /// + /// Initialization logic for Splat interacting with Dependency Resolvers. + /// + public static class ServiceLocationInitialization + { + /// + /// Registers all the default registrations that are needed by the Splat module. + /// + /// The resolver to register the needed service types against. + public static void InitializeSplat(this IMutableDependencyResolver resolver) + { + RegisterDefaultLogManager(resolver); + RegisterLogger(resolver); +#if !NETSTANDARD + RegisterPlatformBitmapLoader(resolver); +#endif + } + + private static void RegisterDefaultLogManager(IMutableDependencyResolver resolver) + { + if (!resolver.HasRegistration(typeof(ILogManager))) + { + resolver.Register(() => new DefaultLogManager(), typeof(ILogManager)); + } + } + + private static void RegisterLogger(IMutableDependencyResolver resolver) + { + if (!resolver.HasRegistration(typeof(ILogger))) + { + resolver.RegisterConstant(new DebugLogger(), typeof(ILogger)); + } + } + +#if !NETSTANDARD + private static void RegisterPlatformBitmapLoader(IMutableDependencyResolver resolver) + { + // not supported in netstandard2.0 + if (!resolver.HasRegistration(typeof(IBitmapLoader))) + { + resolver.RegisterLazySingleton(() => new PlatformBitmapLoader(), typeof(IBitmapLoader)); + } + } +#endif + } +} diff --git a/version.json b/version.json index 406a71c11..60ebdd7ee 100644 --- a/version.json +++ b/version.json @@ -1,5 +1,5 @@ { - "version": "8.0", + "version": "8.1", "publicReleaseRefSpec": [ "^refs/heads/master$", // we release out of master "^refs/heads/preview/.*", // we release previews