diff --git a/packages/react-dom/src/__tests__/ReactComponentLifeCycle-test.js b/packages/react-dom/src/__tests__/ReactComponentLifeCycle-test.js index 8879bca9987e3c..2bbd3c57c258f5 100644 --- a/packages/react-dom/src/__tests__/ReactComponentLifeCycle-test.js +++ b/packages/react-dom/src/__tests__/ReactComponentLifeCycle-test.js @@ -1201,6 +1201,40 @@ describe('ReactComponentLifeCycle', () => { expect(log).toEqual([]); }); + it('should pass previous state to shouldComponentUpdate even with getDerivedStateFromProps', () => { + const divRef = React.createRef(); + class SimpleComponent extends React.Component { + constructor(props) { + super(props); + this.state = { + value: props.value, + }; + } + + static getDerivedStateFromProps(nextProps, prevState) { + if (nextProps.value === prevState.value) { + return null; + } + return {value: nextProps.value}; + } + + shouldComponentUpdate(nextProps, nextState) { + return nextState.value !== this.state.value; + } + + render() { + return
value: {this.state.value}
; + } + } + + const div = document.createElement('div'); + + ReactDOM.render(, div); + expect(divRef.current.textContent).toBe('value: initial'); + ReactDOM.render(, div); + expect(divRef.current.textContent).toBe('value: updated'); + }); + it('should call getSnapshotBeforeUpdate before mutations are committed', () => { const log = []; diff --git a/packages/react-test-renderer/src/ReactShallowRenderer.js b/packages/react-test-renderer/src/ReactShallowRenderer.js index a8ec56f014aee1..778e5943ec9c1b 100644 --- a/packages/react-test-renderer/src/ReactShallowRenderer.js +++ b/packages/react-test-renderer/src/ReactShallowRenderer.js @@ -529,7 +529,20 @@ class ReactShallowRenderer { this._updater, ); - this._updateStateFromStaticLifecycle(element.props); + if (typeof element.type.getDerivedStateFromProps === 'function') { + const partialState = element.type.getDerivedStateFromProps.call( + null, + element.props, + this._instance.state, + ); + if (partialState != null) { + this._instance.state = Object.assign( + {}, + this._instance.state, + partialState, + ); + } + } if (element.type.hasOwnProperty('contextTypes')) { currentlyValidatingElement = element; @@ -653,10 +666,19 @@ class ReactShallowRenderer { } } } - this._updateStateFromStaticLifecycle(props); // Read state after cWRP in case it calls setState - const state = this._newState || oldState; + let state = this._newState || oldState; + if (typeof type.getDerivedStateFromProps === 'function') { + const partialState = type.getDerivedStateFromProps.call( + null, + props, + state, + ); + if (partialState != null) { + state = Object.assign({}, state, partialState); + } + } let shouldUpdate = true; if (this._forcedUpdate) { @@ -692,6 +714,7 @@ class ReactShallowRenderer { this._instance.context = context; this._instance.props = props; this._instance.state = state; + this._newState = null; if (shouldUpdate) { this._rendered = this._instance.render(); @@ -699,27 +722,6 @@ class ReactShallowRenderer { // Intentionally do not call componentDidUpdate() // because DOM refs are not available. } - - _updateStateFromStaticLifecycle(props: Object) { - if (this._element === null) { - return; - } - const {type} = this._element; - - if (typeof type.getDerivedStateFromProps === 'function') { - const oldState = this._newState || this._instance.state; - const partialState = type.getDerivedStateFromProps.call( - null, - props, - oldState, - ); - - if (partialState != null) { - const newState = Object.assign({}, oldState, partialState); - this._instance.state = this._newState = newState; - } - } - } } let currentlyValidatingElement = null; diff --git a/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js b/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js index b6c4259af9f407..d83dada160f909 100644 --- a/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js +++ b/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js @@ -942,6 +942,42 @@ describe('ReactShallowRenderer', () => { expect(result).toEqual(
value:1
); }); + it('should pass previous state to shouldComponentUpdate even with getDerivedStateFromProps', () => { + class SimpleComponent extends React.Component { + constructor(props) { + super(props); + this.state = { + value: props.value, + }; + } + + static getDerivedStateFromProps(nextProps, prevState) { + if (nextProps.value === prevState.value) { + return null; + } + return {value: nextProps.value}; + } + + shouldComponentUpdate(nextProps, nextState) { + return nextState.value !== this.state.value; + } + + render() { + return
{`value:${this.state.value}`}
; + } + } + + const shallowRenderer = createRenderer(); + const initialResult = shallowRenderer.render( + , + ); + expect(initialResult).toEqual(
value:initial
); + const updatedResult = shallowRenderer.render( + , + ); + expect(updatedResult).toEqual(
value:updated
); + }); + it('can setState with an updater function', () => { let instance;