diff --git a/src/bunit.core/ComponentFactories/ComponentFactoryCollectionExtensions.cs b/src/bunit.core/ComponentFactories/ComponentFactoryCollectionExtensions.cs
index 2e12c5a00..391b09420 100644
--- a/src/bunit.core/ComponentFactories/ComponentFactoryCollectionExtensions.cs
+++ b/src/bunit.core/ComponentFactories/ComponentFactoryCollectionExtensions.cs
@@ -80,6 +80,26 @@ public static ComponentFactoryCollection UseStubFor(this ComponentFactoryCollect
factories.Add(new StubComponentFactory(componentTypePredicate, options ?? StubOptions.Default));
return factories;
}
+
+ ///
+ /// Configures bUnit to replace all components of type with a component
+ /// of type .
+ ///
+ /// Type of component to replace.
+ /// Type of component to replace with.
+ /// The bUnit to configure.
+ /// A .
+ public static ComponentFactoryCollection UseFor(this ComponentFactoryCollection factories)
+ where TComponent : IComponent
+ where TReplacementComponent : IComponent
+ {
+ if (factories is null)
+ throw new ArgumentNullException(nameof(factories));
+
+ factories.Add(new GenericComponentFactory());
+
+ return factories;
+ }
}
}
#endif
diff --git a/src/bunit.core/ComponentFactories/GenericComponentFactory{TComponent,TReplacementComponent}.cs b/src/bunit.core/ComponentFactories/GenericComponentFactory{TComponent,TReplacementComponent}.cs
new file mode 100644
index 000000000..4ffe7ed4f
--- /dev/null
+++ b/src/bunit.core/ComponentFactories/GenericComponentFactory{TComponent,TReplacementComponent}.cs
@@ -0,0 +1,17 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Components;
+
+namespace Bunit.ComponentFactories
+{
+ internal sealed class GenericComponentFactory : IComponentFactory
+ where TComponent : IComponent
+ where TReplacementComponent : IComponent
+ {
+ public bool CanCreate(Type componentType) => componentType == typeof(TComponent);
+ public IComponent Create(Type componentType) => (IComponent)Activator.CreateInstance(typeof(TReplacementComponent))!;
+ }
+}
diff --git a/tests/bunit.core.tests/ComponentFactories/GenericComponentFactoryTest.cs b/tests/bunit.core.tests/ComponentFactories/GenericComponentFactoryTest.cs
new file mode 100644
index 000000000..a0dcaf9c8
--- /dev/null
+++ b/tests/bunit.core.tests/ComponentFactories/GenericComponentFactoryTest.cs
@@ -0,0 +1,42 @@
+#if NET5_0_OR_GREATER
+using System;
+using System.Threading.Tasks;
+using Bunit.TestAssets.SampleComponents;
+using Microsoft.AspNetCore.Components;
+using Microsoft.AspNetCore.Components.Rendering;
+using Shouldly;
+using Xunit;
+
+namespace Bunit.ComponentFactories
+{
+ public class GenericComponentFactoryTest : TestContext
+ {
+ [Fact(DisplayName = "UseFor throws when factories is null")]
+ public void Test001()
+ => Should.Throw(() => ComponentFactoryCollectionExtensions.UseFor(null));
+
+ [Fact(DisplayName = "UseFor replaces components of type TComponent with TReplacementComponent")]
+ public void Test002()
+ {
+ ComponentFactories.UseFor();
+
+ var cut = RenderComponent();
+
+ cut.MarkupMatches(@"Has ref = True
");
+ }
+
+ private class FakeSimple1 : Simple1
+ {
+ protected override void OnInitialized() { }
+ protected override Task OnInitializedAsync() => Task.CompletedTask;
+ public override Task SetParametersAsync(ParameterView parameters) => Task.CompletedTask;
+ protected override void OnParametersSet() { }
+ protected override Task OnParametersSetAsync() => Task.CompletedTask;
+ protected override void BuildRenderTree(RenderTreeBuilder builder) { }
+ protected override bool ShouldRender() => false;
+ protected override void OnAfterRender(bool firstRender) { }
+ protected override Task OnAfterRenderAsync(bool firstRender) => Task.CompletedTask;
+ }
+ }
+}
+#endif