diff --git a/src/components/connect.js b/src/components/connect.js index 02510299b..20753133b 100644 --- a/src/components/connect.js +++ b/src/components/connect.js @@ -18,51 +18,33 @@ function getDisplayName(WrappedComponent) { return WrappedComponent.displayName || WrappedComponent.name || 'Component' } +function checkStateShape(stateProps, dispatch) { + invariant( + isPlainObject(stateProps), + '`%sToProps` must return an object. Instead received %s.', + dispatch ? 'mapDispatch' : 'mapState', + stateProps + ) + return stateProps +} + // Helps track hot reloading. let nextVersion = 0 export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, options = {}) { const shouldSubscribe = Boolean(mapStateToProps) - const finalMapStateToProps = mapStateToProps || defaultMapStateToProps - const finalMapDispatchToProps = isPlainObject(mapDispatchToProps) ? + const mapState = mapStateToProps || defaultMapStateToProps + const mapDispatch = isPlainObject(mapDispatchToProps) ? wrapActionCreators(mapDispatchToProps) : mapDispatchToProps || defaultMapDispatchToProps + const finalMergeProps = mergeProps || defaultMergeProps - const doStatePropsDependOnOwnProps = finalMapStateToProps.length !== 1 - const doDispatchPropsDependOnOwnProps = finalMapDispatchToProps.length !== 1 + const checkMergedEquals = finalMergeProps !== defaultMergeProps const { pure = true, withRef = false } = options // Helps track hot reloading. const version = nextVersion++ - function computeStateProps(store, props) { - const state = store.getState() - const stateProps = doStatePropsDependOnOwnProps ? - finalMapStateToProps(state, props) : - finalMapStateToProps(state) - - invariant( - isPlainObject(stateProps), - '`mapStateToProps` must return an object. Instead received %s.', - stateProps - ) - return stateProps - } - - function computeDispatchProps(store, props) { - const { dispatch } = store - const dispatchProps = doDispatchPropsDependOnOwnProps ? - finalMapDispatchToProps(dispatch, props) : - finalMapDispatchToProps(dispatch) - - invariant( - isPlainObject(dispatchProps), - '`mapDispatchToProps` must return an object. Instead received %s.', - dispatchProps - ) - return dispatchProps - } - function computeMergedProps(stateProps, dispatchProps, parentProps) { const mergedProps = finalMergeProps(stateProps, dispatchProps, parentProps) invariant( @@ -96,8 +78,47 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, this.clearCache() } + computeStateProps(store, props) { + if (!this.finalMapStateToProps) { + return this.configureFinalMapState(store, props) + } + const state = store.getState() + const stateProps = this.doStatePropsDependOnOwnProps ? + this.finalMapStateToProps(state, props) : + this.finalMapStateToProps(state) + + return checkStateShape(stateProps) + } + + configureFinalMapState(store, props) { + const mappedState = mapState(store.getState(), props) + const isFactory = typeof mappedState === 'function' + this.finalMapStateToProps = isFactory ? mappedState : mapState + this.doStatePropsDependOnOwnProps = this.finalMapStateToProps.length !== 1 + return isFactory ? this.computeStateProps(store, props) : checkStateShape(mappedState) + } + + computeDispatchProps(store, props) { + if (!this.finalMapDispatchToProps) { + return this.configureFinalMapDispatch(store, props) + } + const { dispatch } = store + const dispatchProps = this.doDispatchPropsDependOnOwnProps ? + this.finalMapDispatchToProps(dispatch, props) : + this.finalMapDispatchToProps(dispatch) + return checkStateShape(dispatchProps, true) + } + + configureFinalMapDispatch(store, props) { + const mappedDispatch = mapDispatch(store.dispatch, props) + const isFactory = typeof mappedDispatch === 'function' + this.finalMapDispatchToProps = isFactory ? mappedDispatch : mapDispatch + this.doDispatchPropsDependOnOwnProps = this.finalMapDispatchToProps.length !== 1 + return isFactory ? this.computeDispatchProps(store, props) : checkStateShape(mappedDispatch, true) + } + updateStatePropsIfNeeded() { - const nextStateProps = computeStateProps(this.store, this.props) + const nextStateProps = this.computeStateProps(this.store, this.props) if (this.stateProps && shallowEqual(nextStateProps, this.stateProps)) { return false } @@ -107,7 +128,7 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, } updateDispatchPropsIfNeeded() { - const nextDispatchProps = computeDispatchProps(this.store, this.props) + const nextDispatchProps = this.computeDispatchProps(this.store, this.props) if (this.dispatchProps && shallowEqual(nextDispatchProps, this.dispatchProps)) { return false } @@ -116,12 +137,14 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, return true } - updateMergedProps() { - this.mergedProps = computeMergedProps( - this.stateProps, - this.dispatchProps, - this.props - ) + updateMergedPropsIfNeeded() { + const nextMergedProps = computeMergedProps(this.stateProps, this.dispatchProps, this.props) + if (this.mergedProps && checkMergedEquals && shallowEqual(nextMergedProps, this.mergedProps)) { + return false + } + + this.mergedProps = nextMergedProps + return true } isSubscribed() { @@ -164,6 +187,8 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, this.haveOwnPropsChanged = true this.hasStoreStateChanged = true this.renderedElement = null + this.finalMapDispatchToProps = null + this.finalMapStateToProps = null } handleChange() { @@ -203,10 +228,10 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, let shouldUpdateDispatchProps = true if (pure && renderedElement) { shouldUpdateStateProps = hasStoreStateChanged || ( - haveOwnPropsChanged && doStatePropsDependOnOwnProps + haveOwnPropsChanged && this.doStatePropsDependOnOwnProps ) shouldUpdateDispatchProps = - haveOwnPropsChanged && doDispatchPropsDependOnOwnProps + haveOwnPropsChanged && this.doDispatchPropsDependOnOwnProps } let haveStatePropsChanged = false @@ -224,7 +249,7 @@ export default function connect(mapStateToProps, mapDispatchToProps, mergeProps, haveDispatchPropsChanged || haveOwnPropsChanged ) { - this.updateMergedProps() + haveMergedPropsChanged = this.updateMergedPropsIfNeeded() } else { haveMergedPropsChanged = false } diff --git a/test/components/connect.spec.js b/test/components/connect.spec.js index 5fefb76da..765b2422e 100644 --- a/test/components/connect.spec.js +++ b/test/components/connect.spec.js @@ -1509,5 +1509,108 @@ describe('React', () => { // But render is not because it did not make any actual changes expect(renderCalls).toBe(1) }) + + it('should allow providing a factory function to mapStateToProps', () => { + let updatedCount = 0 + let memoizedReturnCount = 0 + const store = createStore(() => ({ value: 1 })) + + const mapStateFactory = () => { + let lastProp, lastVal, lastResult + return (state, props) => { + if (props.name === lastProp && lastVal === state.value) { + memoizedReturnCount++ + return lastResult + } + lastProp = props.name + lastVal = state.value + return lastResult = { someObject: { prop: props.name, stateVal: state.value } } + } + } + + @connect(mapStateFactory) + class Container extends Component { + componentWillUpdate() { + updatedCount++ + } + render() { + return
+ } + } + + TestUtils.renderIntoDocument( + +
+ + +
+
+ ) + + store.dispatch({ type: 'test' }) + expect(updatedCount).toBe(0) + expect(memoizedReturnCount).toBe(2) + }) + + it('should allow providing a factory function to mapDispatchToProps', () => { + let updatedCount = 0 + let memoizedReturnCount = 0 + const store = createStore(() => ({ value: 1 })) + + const mapDispatchFactory = () => { + let lastProp, lastResult + return (dispatch, props) => { + if (props.name === lastProp) { + memoizedReturnCount++ + return lastResult + } + lastProp = props.name + return lastResult = { someObject: { dispatchFn: dispatch } } + } + } + function mergeParentDispatch(stateProps, dispatchProps, parentProps) { + return { ...stateProps, ...dispatchProps, name: parentProps.name } + } + + @connect(null, mapDispatchFactory, mergeParentDispatch) + class Passthrough extends Component { + componentWillUpdate() { + updatedCount++ + } + render() { + return
+ } + } + + class Container extends Component { + constructor(props) { + super(props) + this.state = { count: 0 } + } + componentDidMount() { + this.setState({ count: 1 }) + } + render() { + const { count } = this.state + return ( +
+ + +
+ ) + } + } + + TestUtils.renderIntoDocument( + + + + ) + + store.dispatch({ type: 'test' }) + expect(updatedCount).toBe(0) + expect(memoizedReturnCount).toBe(2) + }) + }) })