diff --git a/integration/testing-module-override/e2e/circular-dependency/a.module.ts b/integration/testing-module-override/e2e/circular-dependency/a.module.ts new file mode 100644 index 00000000000..695a9edbdf0 --- /dev/null +++ b/integration/testing-module-override/e2e/circular-dependency/a.module.ts @@ -0,0 +1,12 @@ +import { Injectable, Module, forwardRef } from '@nestjs/common'; +import { BModule, BProvider } from './b.module'; + +@Injectable() +export class AProvider {} + +@Module({ + imports: [forwardRef(() => BModule)], + providers: [AProvider], + exports: [AProvider], +}) +export class AModule {} diff --git a/integration/testing-module-override/e2e/circular-dependency/b.module.ts b/integration/testing-module-override/e2e/circular-dependency/b.module.ts new file mode 100644 index 00000000000..6671d82a6af --- /dev/null +++ b/integration/testing-module-override/e2e/circular-dependency/b.module.ts @@ -0,0 +1,12 @@ +import { Injectable, Module, forwardRef } from '@nestjs/common'; +import { AModule, AProvider } from './a.module'; + +@Injectable() +export class BProvider {} + +@Module({ + imports: [forwardRef(() => AModule)], + providers: [BProvider], + exports: [BProvider], +}) +export class BModule {} diff --git a/integration/testing-module-override/e2e/modules-override.spec.ts b/integration/testing-module-override/e2e/modules-override.spec.ts new file mode 100644 index 00000000000..991d9ae6bb6 --- /dev/null +++ b/integration/testing-module-override/e2e/modules-override.spec.ts @@ -0,0 +1,309 @@ +import { + Controller, + Module, + DynamicModule, + forwardRef, + Injectable, + Global, +} from '@nestjs/common'; +import { LazyModuleLoader } from '@nestjs/core'; +import { Test, TestingModule } from '@nestjs/testing'; +import { expect } from 'chai'; + +import { AModule, AProvider } from './circular-dependency/a.module'; +import { BModule, BProvider } from './circular-dependency/b.module'; + +describe('modules override', () => { + describe('top-level module', () => { + @Controller() + class ControllerOverwritten {} + + @Module({ + controllers: [ControllerOverwritten], + }) + class ModuleToBeOverwritten {} + + @Controller() + class ControllerOverride {} + + @Module({ + controllers: [ControllerOverride], + }) + class ModuleOverride {} + + let testingModule: TestingModule; + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [ModuleToBeOverwritten], + }) + .overrideModule(ModuleToBeOverwritten) + .useModule(ModuleOverride) + .compile(); + }); + + it('should be possible to override top-level modules using testing module builder', () => { + expect(() => + testingModule.get(ControllerOverwritten), + ).to.throw(); + expect( + testingModule.get(ControllerOverride), + ).to.be.an.instanceof(ControllerOverride); + }); + }); + + describe('dynamic module', () => { + @Controller() + class ControllerOverwritten {} + + @Module({}) + class DynamicModuleToBeOverwritten {} + + const dynamicModuleOverwritten: DynamicModule = { + module: DynamicModuleToBeOverwritten, + controllers: [ControllerOverwritten], + }; + + @Controller() + class ControllerOverride {} + + @Module({}) + class DynamicModuleOverride {} + + const dynamicModuleOverride: DynamicModule = { + module: DynamicModuleOverride, + controllers: [ControllerOverride], + }; + + let testingModule: TestingModule; + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [dynamicModuleOverwritten], + }) + .overrideModule(dynamicModuleOverwritten) + .useModule(dynamicModuleOverride) + .compile(); + }); + + it('should be possible to override dynamic modules using testing module builder', () => { + expect(() => + testingModule.get(ControllerOverwritten), + ).to.throw(); + expect( + testingModule.get(ControllerOverride), + ).to.be.an.instanceof(ControllerOverride); + }); + }); + + describe('circular dependency module', () => { + let testingModule: TestingModule; + + @Injectable() + class CProvider {} + + @Module({ + providers: [CProvider], + }) + class CModule {} + + @Injectable() + class BProviderOverride {} + + @Module({ + imports: [forwardRef(() => AModule), forwardRef(() => CModule)], + providers: [BProviderOverride], + exports: [BProviderOverride], + }) + class BModuleOverride {} + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [AModule], + }) + .overrideModule(BModule) + .useModule(BModuleOverride) + .compile(); + }); + + it('should be possible to override top-level modules using testing module builder', () => { + expect(testingModule.get(AProvider)).to.be.an.instanceof( + AProvider, + ); + expect(() => testingModule.get(BProvider)).to.throw(); + expect(testingModule.get(CProvider)).to.be.an.instanceof( + CProvider, + ); + expect( + testingModule.get(BProviderOverride), + ).to.be.an.instanceof(BProviderOverride); + }); + }); + + describe('nested module', () => { + let testingModule: TestingModule; + + @Controller() + class OverwrittenNestedModuleController {} + + @Module({ + controllers: [OverwrittenNestedModuleController], + }) + class OverwrittenNestedModule {} + + @Controller() + class OverrideNestedModuleController {} + + @Module({ + controllers: [OverrideNestedModuleController], + }) + class OverrideNestedModule {} + + @Module({ + imports: [OverwrittenNestedModule], + }) + class AppModule {} + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [AppModule], + }) + .overrideModule(OverwrittenNestedModule) + .useModule(OverrideNestedModule) + .compile(); + }); + + it('should be possible to override nested modules using testing module builder', () => { + expect( + testingModule.get( + OverrideNestedModuleController, + ), + ).to.be.an.instanceof(OverrideNestedModuleController); + expect(() => + testingModule.get( + OverwrittenNestedModuleController, + ), + ).to.throw(); + }); + }); + + describe('lazy-loaded module', () => { + let testingModule: TestingModule; + + @Injectable() + class OverwrittenLazyProvider { + value() { + return 'overwritten lazy'; + } + } + + @Module({ + providers: [ + { + provide: 'LAZY_PROVIDER', + useClass: OverwrittenLazyProvider, + }, + ], + }) + class OverwrittenLazyModule {} + + @Injectable() + class OverrideLazyProvider { + value() { + return 'override lazy'; + } + } + + @Module({ + providers: [ + { + provide: 'LAZY_PROVIDER', + useClass: OverrideLazyProvider, + }, + ], + }) + class OverrideLazyModule {} + + @Injectable() + class AppService { + constructor(private lazyModuleLoader: LazyModuleLoader) {} + + async value() { + const moduleRef = await this.lazyModuleLoader.load( + () => OverwrittenLazyModule, + ); + return moduleRef.get('LAZY_PROVIDER').value(); + } + } + + @Module({ + imports: [], + providers: [AppService], + }) + class AppModule {} + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [AppModule], + }) + .overrideModule(OverwrittenLazyModule) + .useModule(OverrideLazyModule) + .compile(); + }); + + it('should be possible to override lazy loaded modules using testing module builder', async () => { + const result = await testingModule.get(AppService).value(); + expect(result).to.be.equal('override lazy'); + }); + }); + + describe('global module', () => { + let testingModule: TestingModule; + + @Injectable() + class OverwrittenProvider { + value() { + return 'overwritten lazy'; + } + } + + @Global() + @Module({ + providers: [OverwrittenProvider], + exports: [OverwrittenProvider], + }) + class OverwrittenModule {} + + @Injectable() + class OverrideProvider { + value() { + return 'override lazy'; + } + } + + @Global() + @Module({ + providers: [OverrideProvider], + exports: [OverrideProvider], + }) + class OverrideModule {} + + beforeEach(async () => { + testingModule = await Test.createTestingModule({ + imports: [OverwrittenModule], + }) + .overrideModule(OverwrittenModule) + .useModule(OverrideModule) + .compile(); + }); + + it('should be possible to override global modules using testing module builder', () => { + expect( + testingModule.get(OverrideProvider), + ).to.be.an.instanceof(OverrideProvider); + expect(() => + testingModule.get(OverwrittenProvider), + ).to.throw(); + }); + }); +}); diff --git a/integration/testing-module-override/tsconfig.json b/integration/testing-module-override/tsconfig.json new file mode 100644 index 00000000000..e268e37aa3d --- /dev/null +++ b/integration/testing-module-override/tsconfig.json @@ -0,0 +1,17 @@ +{ + "compilerOptions": { + "module": "commonjs", + "declaration": true, + "removeComments": true, + "emitDecoratorMetadata": true, + "experimentalDecorators": true, + "allowSyntheticDefaultImports": true, + "target": "es2017", + "sourceMap": true, + "outDir": "./dist", + "baseUrl": "./", + "incremental": true, + "skipLibCheck": true + }, + "include": ["src/**/*"] +} diff --git a/packages/core/injector/container.ts b/packages/core/injector/container.ts index b50957d384d..d5b7fcbbf0b 100644 --- a/packages/core/injector/container.ts +++ b/packages/core/injector/container.ts @@ -7,7 +7,7 @@ import { CircularDependencyException } from '../errors/exceptions/circular-depen import { UndefinedForwardRefException } from '../errors/exceptions/undefined-forwardref.exception'; import { UnknownModuleException } from '../errors/exceptions/unknown-module.exception'; import { REQUEST } from '../router/request/request-constants'; -import { ModuleCompiler } from './compiler'; +import { ModuleCompiler, ModuleFactory } from './compiler'; import { ContextId } from './instance-wrapper'; import { InternalCoreModule } from './internal-core-module'; import { InternalProvidersStorage } from './internal-providers-storage'; @@ -15,6 +15,9 @@ import { Module } from './module'; import { ModuleTokenFactory } from './module-token-factory'; import { ModulesContainer } from './modules-container'; +type ModuleMetatype = Type | DynamicModule | Promise; +type ModuleScope = Type[]; + export class NestContainer { private readonly globalModules = new Set(); private readonly moduleTokenFactory = new ModuleTokenFactory(); @@ -54,8 +57,8 @@ export class NestContainer { } public async addModule( - metatype: Type | DynamicModule | Promise, - scope: Type[], + metatype: ModuleMetatype, + scope: ModuleScope, ): Promise { // In DependenciesScanner#scanForModules we already check for undefined or invalid modules // We still need to catch the edge-case of `forwardRef(() => undefined)` @@ -68,8 +71,50 @@ export class NestContainer { if (this.modules.has(token)) { return this.modules.get(token); } + + return this.setModule( + { + token, + type, + dynamicMetadata, + }, + scope, + ); + } + + public async replaceModule( + metatypeToReplace: ModuleMetatype, + newMetatype: ModuleMetatype, + scope: ModuleScope, + ): Promise { + // In DependenciesScanner#scanForModules we already check for undefined or invalid modules + // We still need to catch the edge-case of `forwardRef(() => undefined)` + if (!metatypeToReplace || !newMetatype) { + throw new UndefinedForwardRefException(scope); + } + + const { token } = await this.moduleCompiler.compile(metatypeToReplace); + const { type, dynamicMetadata } = await this.moduleCompiler.compile( + newMetatype, + ); + + return this.setModule( + { + token, + type, + dynamicMetadata, + }, + scope, + ); + } + + private async setModule( + { token, dynamicMetadata, type }: ModuleFactory, + scope: ModuleScope, + ): Promise { const moduleRef = new Module(type, this); moduleRef.token = token; + this.modules.set(token, moduleRef); await this.addDynamicMetadata( @@ -81,6 +126,7 @@ export class NestContainer { if (this.isGlobalModule(type, dynamicMetadata)) { this.addGlobalModule(moduleRef); } + return moduleRef; } diff --git a/packages/core/injector/internal-core-module-factory.ts b/packages/core/injector/internal-core-module-factory.ts index 99ecd589a35..daac183697b 100644 --- a/packages/core/injector/internal-core-module-factory.ts +++ b/packages/core/injector/internal-core-module-factory.ts @@ -1,7 +1,7 @@ import { Logger } from '@nestjs/common'; import { ExternalContextCreator } from '../helpers/external-context-creator'; import { HttpAdapterHost } from '../helpers/http-adapter-host'; -import { DependenciesScanner } from '../scanner'; +import { DependenciesScanner, ModuleToOverride } from '../scanner'; import { ModuleCompiler } from './compiler'; import { NestContainer } from './container'; import { InstanceLoader } from './instance-loader'; @@ -15,6 +15,7 @@ export class InternalCoreModuleFactory { scanner: DependenciesScanner, moduleCompiler: ModuleCompiler, httpAdapterHost: HttpAdapterHost, + modulesToOverride?: ModuleToOverride[], ) { return InternalCoreModule.register([ { @@ -45,6 +46,7 @@ export class InternalCoreModuleFactory { instanceLoader, moduleCompiler, container.getModules(), + modulesToOverride, ); }, }, diff --git a/packages/core/injector/lazy-module-loader.ts b/packages/core/injector/lazy-module-loader.ts index caba141163a..3dd11dd1505 100644 --- a/packages/core/injector/lazy-module-loader.ts +++ b/packages/core/injector/lazy-module-loader.ts @@ -1,5 +1,5 @@ import { DynamicModule, Type } from '@nestjs/common'; -import { DependenciesScanner } from '../scanner'; +import { DependenciesScanner, ModuleToOverride } from '../scanner'; import { ModuleCompiler } from './compiler'; import { InstanceLoader } from './instance-loader'; import { Module } from './module'; @@ -12,6 +12,7 @@ export class LazyModuleLoader { private readonly instanceLoader: InstanceLoader, private readonly moduleCompiler: ModuleCompiler, private readonly modulesContainer: ModulesContainer, + private readonly modulesToOverride?: ModuleToOverride[], ) {} public async load( @@ -21,9 +22,10 @@ export class LazyModuleLoader { | DynamicModule, ): Promise { const moduleClassOrDynamicDefinition = await loaderFn(); - const moduleInstances = await this.dependenciesScanner.scanForModules( - moduleClassOrDynamicDefinition, - ); + const moduleInstances = await this.dependenciesScanner.scanForModules({ + moduleDefinition: moduleClassOrDynamicDefinition, + modulesToOverride: this.modulesToOverride, + }); if (moduleInstances.length === 0) { // The module has been loaded already. In this case, we must // retrieve a module reference from the exising container. diff --git a/packages/core/scanner.ts b/packages/core/scanner.ts index ef4ebecc81b..53300858dde 100644 --- a/packages/core/scanner.ts +++ b/packages/core/scanner.ts @@ -53,6 +53,24 @@ interface ApplicationProviderWrapper { scope?: Scope; } +export type ModuleDefinition = + | ForwardReference + | Type + | DynamicModule + | Promise; + +export interface ModuleToOverride { + moduleToReplace: ModuleDefinition; + newModule: ModuleDefinition; +} + +interface ModulesScanParamaters { + moduleDefinition: ModuleDefinition; + scope?: Type[]; + ctxRegistry?: (ForwardReference | DynamicModule | Type)[]; + modulesToOverride?: ModuleToOverride[]; +} + export class DependenciesScanner { private readonly applicationProvidersApplyMap: ApplicationProviderWrapper[] = []; @@ -63,9 +81,9 @@ export class DependenciesScanner { private readonly applicationConfig = new ApplicationConfig(), ) {} - public async scan(module: Type) { - await this.registerCoreModule(); - await this.scanForModules(module); + public async scan(module: Type, modulesToOverride?: ModuleToOverride[]) { + await this.registerCoreModule(modulesToOverride); + await this.scanForModules({ moduleDefinition: module, modulesToOverride }); await this.scanModulesForDependencies(); this.calculateModulesDistance(); @@ -73,16 +91,20 @@ export class DependenciesScanner { this.container.bindGlobalScope(); } - public async scanForModules( - moduleDefinition: - | ForwardReference - | Type - | DynamicModule - | Promise, - scope: Type[] = [], - ctxRegistry: (ForwardReference | DynamicModule | Type)[] = [], - ): Promise { - const moduleInstance = await this.insertModule(moduleDefinition, scope); + public async scanForModules({ + moduleDefinition, + scope = [], + ctxRegistry = [], + modulesToOverride = [], + }: ModulesScanParamaters): Promise { + const moduleInstance = await this.putModule( + moduleDefinition, + modulesToOverride, + scope, + ); + moduleDefinition = + this.getOverrideModuleByModule(moduleDefinition, modulesToOverride) + ?.newModule ?? moduleDefinition; moduleDefinition = moduleDefinition instanceof Promise ? await moduleDefinition @@ -119,11 +141,12 @@ export class DependenciesScanner { if (ctxRegistry.includes(innerModule)) { continue; } - const moduleRefs = await this.scanForModules( - innerModule, - [].concat(scope, moduleDefinition), + const moduleRefs = await this.scanForModules({ + moduleDefinition: innerModule, + scope: [].concat(scope, moduleDefinition), ctxRegistry, - ); + modulesToOverride, + }); registeredModuleRefs = registeredModuleRefs.concat(moduleRefs); } if (!moduleInstance) { @@ -402,14 +425,18 @@ export class DependenciesScanner { return Reflect.getMetadata(metadataKey, metatype) || []; } - public async registerCoreModule() { + public async registerCoreModule(modulesToOverride?: ModuleToOverride[]) { const moduleDefinition = InternalCoreModuleFactory.create( this.container, this, this.container.getModuleCompiler(), this.container.getHttpAdapterHostRef(), + modulesToOverride, ); - const [instance] = await this.scanForModules(moduleDefinition); + const [instance] = await this.scanForModules({ + moduleDefinition, + modulesToOverride, + }); this.container.registerCoreModuleRef(instance); } @@ -515,4 +542,58 @@ export class DependenciesScanner { private isRequestOrTransient(scope: Scope): boolean { return scope === Scope.REQUEST || scope === Scope.TRANSIENT; } + + private putModule( + moduleDefinition: ModuleDefinition, + modulesToOverride: ModuleToOverride[], + scope: Type[], + ): Promise { + const overrideModule = this.getOverrideModuleByModule( + moduleDefinition, + modulesToOverride, + ); + if (overrideModule !== undefined) { + return this.overrideModule( + moduleDefinition, + overrideModule.newModule, + scope, + ); + } + + return this.insertModule(moduleDefinition, scope); + } + + // The 'any' definition and castings is related to forward reference, there is a better way do this? + private getOverrideModuleByModule( + module: ModuleDefinition | any, + modulesToOverride: ModuleToOverride[], + ): ModuleToOverride | undefined { + if (this.isForwardReference(module)) { + return modulesToOverride.find(moduleToOverride => { + return ( + moduleToOverride.moduleToReplace === module.forwardRef() || + (moduleToOverride.moduleToReplace as any).forwardRef?.() === + module.forwardRef() + ); + }); + } + + return modulesToOverride.find( + moduleToOverride => moduleToOverride.moduleToReplace === module, + ); + } + + private async overrideModule( + moduleToOverride: any, + newModule: any, + scope: Type[], + ): Promise { + return this.container.replaceModule( + this.isForwardReference(moduleToOverride) + ? moduleToOverride.forwardRef() + : moduleToOverride, + this.isForwardReference(newModule) ? newModule.forwardRef() : newModule, + scope, + ); + } } diff --git a/packages/core/test/injector/container.spec.ts b/packages/core/test/injector/container.spec.ts index 0d7f181c5e7..eb435142a6b 100644 --- a/packages/core/test/injector/container.spec.ts +++ b/packages/core/test/injector/container.spec.ts @@ -81,6 +81,37 @@ describe('NestContainer', () => { expect(addGlobalModuleSpy.calledOnce).to.be.true; }); }); + + describe('replaceModule', () => { + it('should replace module if already exists in collection', async () => { + @Module({}) + class ReplaceTestModule {} + + const modules = new Map(); + const setSpy = sinon.spy(modules, 'set'); + (container as any).modules = modules; + + await container.addModule(TestModule as any, []); + await container.replaceModule( + TestModule as any, + ReplaceTestModule as any, + [], + ); + + expect(setSpy.calledTwice).to.be.true; + }); + + it('should throws an exception when metatype is not defined', () => { + expect(container.addModule(undefined, [])).to.eventually.throws(); + }); + + it('should add global module when module is global', async () => { + const addGlobalModuleSpy = sinon.spy(container, 'addGlobalModule'); + await container.addModule(GlobalTestModule as any, []); + expect(addGlobalModuleSpy.calledOnce).to.be.true; + }); + }); + describe('isGlobalModule', () => { describe('when module is not globally scoped', () => { it('should return false', () => { diff --git a/packages/core/test/scanner.spec.ts b/packages/core/test/scanner.spec.ts index 8b9364ab8c8..9fe8e5e6aa3 100644 --- a/packages/core/test/scanner.spec.ts +++ b/packages/core/test/scanner.spec.ts @@ -13,7 +13,7 @@ import { UndefinedModuleException } from '../errors/exceptions/undefined-module. import { NestContainer } from '../injector/container'; import { InstanceWrapper } from '../injector/instance-wrapper'; import { MetadataScanner } from '../metadata-scanner'; -import { DependenciesScanner } from '../scanner'; +import { DependenciesScanner, ModuleToOverride } from '../scanner'; describe('DependenciesScanner', () => { class Guard {} @@ -68,11 +68,17 @@ describe('DependenciesScanner', () => { mockContainer.restore(); }); - it('should "insertModule" call twice (2 modules) container method "addModule"', async () => { - const expectation = mockContainer.expects('addModule').twice(); + it('should "putModule" call twice (2 modules) container method "addModule"', async () => { + const expectationCountAddModule = mockContainer + .expects('addModule') + .twice(); + const expectationCountReplaceModule = mockContainer + .expects('replaceModule') + .never(); await scanner.scan(TestModule as any); - expectation.verify(); + expectationCountAddModule.verify(); + expectationCountReplaceModule.verify(); }); it('should "insertProvider" call twice (2 components) container method "addProvider"', async () => { @@ -96,6 +102,134 @@ describe('DependenciesScanner', () => { expectation.verify(); }); + describe('when there is modules overrides', () => { + @Injectable() + class OverwrittenTestComponent {} + + @Controller('') + class OverwrittenControlerOne {} + + @Controller('') + class OverwrittenControllerTwo {} + + @Module({ + controllers: [OverwrittenControlerOne], + providers: [OverwrittenTestComponent], + }) + class OverwrittenModuleOne {} + + @Module({ + controllers: [OverwrittenControllerTwo], + }) + class OverwrittenModuleTwo {} + + @Module({ + imports: [OverwrittenModuleOne, OverwrittenModuleTwo], + }) + class OverrideTestModule {} + + @Injectable() + class OverrideTestComponent {} + + @Controller('') + class OverrideControllerOne {} + + @Controller('') + class OverrideControllerTwo {} + + @Module({ + controllers: [OverwrittenControlerOne], + providers: [OverrideTestComponent], + }) + class OverrideModuleOne {} + + @Module({ + controllers: [OverrideControllerTwo], + }) + class OverrideModuleTwo {} + + const modulesToOverride: ModuleToOverride[] = [ + { moduleToReplace: OverwrittenModuleOne, newModule: OverrideModuleOne }, + { moduleToReplace: OverwrittenModuleTwo, newModule: OverrideModuleTwo }, + ]; + + it('should "putModule" call twice (2 modules) container method "replaceModule"', async () => { + const expectationReplaceModuleFirst = mockContainer + .expects('replaceModule') + .once() + .withArgs(OverwrittenModuleOne, OverrideModuleOne, sinon.match.array); + const expectationReplaceModuleSecond = mockContainer + .expects('replaceModule') + .once() + .withArgs(OverwrittenModuleTwo, OverrideModuleTwo, sinon.match.array); + const expectationCountAddModule = mockContainer + .expects('addModule') + .once(); + + await scanner.scan(OverrideTestModule as any, modulesToOverride); + + expectationReplaceModuleFirst.verify(); + expectationReplaceModuleSecond.verify(); + expectationCountAddModule.verify(); + }); + + it('should "insertProvider" call once container method "addProvider"', async () => { + const expectation = mockContainer.expects('addProvider').once(); + + await scanner.scan(OverrideTestModule as any); + expectation.verify(); + }); + + it('should "insertController" call twice (2 components) container method "addController"', async () => { + const expectation = mockContainer.expects('addController').twice(); + await scanner.scan(OverrideTestModule as any); + expectation.verify(); + }); + + it('should "putModule" call container method "replaceModule" with forwardRef() when forwardRef property exists', async () => { + const overwrittenForwardRefSpy = sinon.spy(); + + @Module({}) + class OverwrittenForwardRef {} + + @Module({}) + class Overwritten { + public static forwardRef() { + overwrittenForwardRefSpy(); + return OverwrittenForwardRef; + } + } + + const overrideForwardRefSpy = sinon.spy(); + + @Module({}) + class OverrideForwardRef {} + + @Module({}) + class Override { + public static forwardRef() { + overrideForwardRefSpy(); + return OverrideForwardRef; + } + } + + @Module({ + imports: [Overwritten], + }) + class OverrideForwardRefTestModule {} + + await scanner.scan(OverrideForwardRefTestModule as any, [ + { + moduleToReplace: Overwritten, + newModule: Override, + }, + ]); + + expect(overwrittenForwardRefSpy.called).to.be.true; + expect(overrideForwardRefSpy.called).to.be.true; + }); + }); + describe('reflectDynamicMetadata', () => { describe('when param has prototype', () => { it('should call "reflectParamInjectables" and "reflectInjectables"', () => { @@ -447,14 +581,20 @@ describe('DependenciesScanner', () => { describe('scanForModules', () => { it('should throw an exception when the imports array includes undefined', () => { try { - scanner.scanForModules(UndefinedModule, [UndefinedModule]); + scanner.scanForModules({ + moduleDefinition: UndefinedModule, + scope: [UndefinedModule], + }); } catch (exception) { expect(exception instanceof UndefinedModuleException).to.be.true; } }); it('should throw an exception when the imports array includes an invalid value', () => { try { - scanner.scanForModules(InvalidModule, [InvalidModule]); + scanner.scanForModules({ + moduleDefinition: InvalidModule, + scope: [InvalidModule], + }); } catch (exception) { expect(exception instanceof InvalidModuleException).to.be.true; } diff --git a/packages/testing/testing-module.builder.ts b/packages/testing/testing-module.builder.ts index e72f6778597..3389a9361bb 100644 --- a/packages/testing/testing-module.builder.ts +++ b/packages/testing/testing-module.builder.ts @@ -3,7 +3,11 @@ import { ModuleMetadata } from '@nestjs/common/interfaces'; import { ApplicationConfig } from '@nestjs/core/application-config'; import { NestContainer } from '@nestjs/core/injector/container'; import { MetadataScanner } from '@nestjs/core/metadata-scanner'; -import { DependenciesScanner } from '@nestjs/core/scanner'; +import { + DependenciesScanner, + ModuleToOverride, + ModuleDefinition, +} from '@nestjs/core/scanner'; import { MockFactory, OverrideBy, @@ -17,6 +21,10 @@ export class TestingModuleBuilder { private readonly applicationConfig = new ApplicationConfig(); private readonly container = new NestContainer(this.applicationConfig); private readonly overloadsMap = new Map(); + private readonly overloadsModuleMap = new Map< + ModuleDefinition, + ModuleDefinition + >(); private readonly scanner: DependenciesScanner; private readonly instanceLoader = new TestingInstanceLoader(this.container); private readonly module: any; @@ -62,9 +70,20 @@ export class TestingModuleBuilder { return this.override(typeOrToken, true); } + public overrideModule(moduleToOverride: ModuleDefinition): { + useModule: (newModule: ModuleDefinition) => TestingModuleBuilder; + } { + return { + useModule: newModule => { + this.overloadsModuleMap.set(moduleToOverride, newModule); + return this; + }, + }; + } + public async compile(): Promise { this.applyLogger(); - await this.scanner.scan(this.module); + await this.scanner.scan(this.module, this.getModuleOverloads()); this.applyOverloadsMap(); await this.instanceLoader.createInstancesOfDependencies( @@ -105,6 +124,15 @@ export class TestingModuleBuilder { }); } + private getModuleOverloads(): ModuleToOverride[] { + return [...this.overloadsModuleMap.entries()].map( + ([moduleToReplace, newModule]) => ({ + moduleToReplace, + newModule, + }), + ); + } + private getRootModule() { const modules = this.container.getModules().values(); return modules.next().value;