diff --git a/src/when.js b/src/when.js index 39f1742..230ba48 100644 --- a/src/when.js +++ b/src/when.js @@ -88,8 +88,9 @@ class WhenMock { return mockImplementation(...args) } } - - return defaultImplementation ? defaultImplementation(...args) : undefined + return defaultImplementation ? defaultImplementation(...args) + : (typeof fn.__whenMock__._origMock === 'function' + ? fn.__whenMock__._origMock(...args) : undefined) }) return { diff --git a/src/when.test.js b/src/when.test.js index 506be97..0b81bb1 100644 --- a/src/when.test.js +++ b/src/when.test.js @@ -819,5 +819,30 @@ describe('When', () => { const returnValue = theInstance.theMethod(1) expect(returnValue).toBe('mock') }) + + it('keeps default function implementation when not matched', () => { + class TheClass { + fn () { + return 'real' + } + } + const instance = new TheClass() + const spy = jest.spyOn(instance, 'fn') + when(spy) + .calledWith(1) + .mockReturnValue('mock') + expect(instance.fn(2)).toBe('real') + }) + + it('keeps default mock implementation when not matched', () => { + const fn = jest.fn(() => { + return 'real' + }) + when(fn) + .calledWith(1) + .mockReturnValue('mock') + expect(fn(1)).toBe('mock') + expect(fn(2)).toBe('real') + }) }) })