From 9ed2ad1df422c29335dedf09ea9db9a155384e05 Mon Sep 17 00:00:00 2001 From: Mark Erikson Date: Sat, 27 Aug 2022 14:30:32 -0400 Subject: [PATCH] Consolidate RTKQ middleware to simplify stack size Previously, the RTKQ middleware was made up of 7 "sub-middleware", each encapsulating a different responsibility (polling, caching, etc). However, that meant that each middleware was called on _every_ action, even if it wasn't RTKQ-related. That adds to call stack size. Almost all the logic runs _after_ the action is handled by the reducers. So, I've reworked the files to create handlers (conceptually similar to the `(action) => {}` part of a middleware, with a couple extra args), and rewritten the top-level middleware to run those in a loop. --- .../core/buildMiddleware/batchActions.ts | 54 ++-- .../core/buildMiddleware/cacheCollection.ts | 153 +++++----- .../core/buildMiddleware/cacheLifecycle.ts | 264 +++++++++--------- .../core/buildMiddleware/devMiddleware.ts | 53 ++-- .../src/query/core/buildMiddleware/index.ts | 112 +++++--- .../buildMiddleware/invalidationByTags.ts | 75 +++-- .../src/query/core/buildMiddleware/polling.ts | 212 +++++++------- .../core/buildMiddleware/queryLifecycle.ts | 147 +++++----- .../src/query/core/buildMiddleware/types.ts | 11 +- .../buildMiddleware/windowEventHandling.ts | 30 +- 10 files changed, 561 insertions(+), 550 deletions(-) diff --git a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts index a002f68cb3..f0bb6bacd9 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/batchActions.ts @@ -1,5 +1,5 @@ import type { QueryThunk, RejectedAction } from '../buildThunks' -import type { SubMiddlewareBuilder } from './types' +import type { InternalHandlerBuilder } from './types' // Copied from https://github.com/feross/queue-microtask let promise: Promise @@ -14,44 +14,38 @@ const queueMicrotaskShim = }, 0) ) -export const build: SubMiddlewareBuilder = ({ +export const buildBatchedActionsHandler: InternalHandlerBuilder = ({ api, - context: { apiUid }, queryThunk, - reducerPath, }) => { - return (mwApi) => { - let abortedQueryActionsQueue: RejectedAction[] = [] - let dispatchQueued = false + let abortedQueryActionsQueue: RejectedAction[] = [] + let dispatchQueued = false - return (next) => (action) => { - if (queryThunk.rejected.match(action)) { - const { condition, arg } = action.meta + return (action, mwApi) => { + if (queryThunk.rejected.match(action)) { + const { condition, arg } = action.meta - if (condition && arg.subscribe) { - // request was aborted due to condition (another query already running) - // _Don't_ dispatch right away - queue it for a debounced grouped dispatch - abortedQueryActionsQueue.push(action) + if (condition && arg.subscribe) { + // request was aborted due to condition (another query already running) + // _Don't_ dispatch right away - queue it for a debounced grouped dispatch + abortedQueryActionsQueue.push(action) - if (!dispatchQueued) { - queueMicrotaskShim(() => { - mwApi.dispatch( - api.internalActions.subscriptionRequestsRejected( - abortedQueryActionsQueue - ) + if (!dispatchQueued) { + queueMicrotaskShim(() => { + mwApi.dispatch( + api.internalActions.subscriptionRequestsRejected( + abortedQueryActionsQueue ) - abortedQueryActionsQueue = [] - }) - dispatchQueued = true - } - // _Don't_ let the action reach the reducers now! - return + ) + abortedQueryActionsQueue = [] + }) + dispatchQueued = true } + // _Don't_ let the action reach the reducers now! + return false } - - const result = next(action) - - return result } + + return true } } diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts index 825c2747c9..4f17d37e75 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts @@ -1,12 +1,12 @@ import type { BaseQueryFn } from '../../baseQueryTypes' import type { QueryDefinition } from '../../endpointDefinitions' import type { ConfigState, QueryCacheKey } from '../apiState' -import { QuerySubstateIdentifier } from '../apiState' import type { QueryStateMeta, SubMiddlewareApi, - SubMiddlewareBuilder, TimeoutId, + InternalHandlerBuilder, + ApiMiddlewareInternalHandler, } from './types' export type ReferenceCacheCollection = never @@ -45,7 +45,11 @@ declare module '../../endpointDefinitions' { export const THIRTY_TWO_BIT_MAX_INT = 2_147_483_647 export const THIRTY_TWO_BIT_MAX_TIMER_SECONDS = 2_147_483_647 / 1_000 - 1 -export const build: SubMiddlewareBuilder = ({ reducerPath, api, context }) => { +export const buildCacheCollectionHandler: InternalHandlerBuilder = ({ + reducerPath, + api, + context, +}) => { const { removeQueryResult, unsubscribeQueryResult } = api.internalActions function anySubscriptionsRemainingForKey( @@ -57,88 +61,83 @@ export const build: SubMiddlewareBuilder = ({ reducerPath, api, context }) => { return !!subscriptions && !isObjectEmpty(subscriptions) } - return (mwApi) => { - const currentRemovalTimeouts: QueryStateMeta = {} + const currentRemovalTimeouts: QueryStateMeta = {} - return (next) => - (action): any => { - const result = next(action) + const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { + if (unsubscribeQueryResult.match(action)) { + const state = mwApi.getState()[reducerPath] + const { queryCacheKey } = action.payload - if (unsubscribeQueryResult.match(action)) { - const state = mwApi.getState()[reducerPath] - const { queryCacheKey } = action.payload - - handleUnsubscribe( - queryCacheKey, - state.queries[queryCacheKey]?.endpointName, - mwApi, - state.config - ) - } - - if (api.util.resetApiState.match(action)) { - for (const [key, timeout] of Object.entries(currentRemovalTimeouts)) { - if (timeout) clearTimeout(timeout) - delete currentRemovalTimeouts[key] - } - } - - if (context.hasRehydrationInfo(action)) { - const state = mwApi.getState()[reducerPath] - const { queries } = context.extractRehydrationInfo(action)! - for (const [queryCacheKey, queryState] of Object.entries(queries)) { - // Gotcha: - // If rehydrating before the endpoint has been injected,the global `keepUnusedDataFor` - // will be used instead of the endpoint-specific one. - handleUnsubscribe( - queryCacheKey as QueryCacheKey, - queryState?.endpointName, - mwApi, - state.config - ) - } - } + handleUnsubscribe( + queryCacheKey, + state.queries[queryCacheKey]?.endpointName, + mwApi, + state.config + ) + } - return result + if (api.util.resetApiState.match(action)) { + for (const [key, timeout] of Object.entries(currentRemovalTimeouts)) { + if (timeout) clearTimeout(timeout) + delete currentRemovalTimeouts[key] } + } - function handleUnsubscribe( - queryCacheKey: QueryCacheKey, - endpointName: string | undefined, - api: SubMiddlewareApi, - config: ConfigState - ) { - const endpointDefinition = context.endpointDefinitions[ - endpointName! - ] as QueryDefinition - const keepUnusedDataFor = - endpointDefinition?.keepUnusedDataFor ?? config.keepUnusedDataFor - - if (keepUnusedDataFor === Infinity) { - // Hey, user said keep this forever! - return + if (context.hasRehydrationInfo(action)) { + const state = mwApi.getState()[reducerPath] + const { queries } = context.extractRehydrationInfo(action)! + for (const [queryCacheKey, queryState] of Object.entries(queries)) { + // Gotcha: + // If rehydrating before the endpoint has been injected,the global `keepUnusedDataFor` + // will be used instead of the endpoint-specific one. + handleUnsubscribe( + queryCacheKey as QueryCacheKey, + queryState?.endpointName, + mwApi, + state.config + ) } - // Prevent `setTimeout` timers from overflowing a 32-bit internal int, by - // clamping the max value to be at most 1000ms less than the 32-bit max. - // Look, a 24.8-day keepalive ought to be enough for anybody, right? :) - // Also avoid negative values too. - const finalKeepUnusedDataFor = Math.max( - 0, - Math.min(keepUnusedDataFor, THIRTY_TWO_BIT_MAX_TIMER_SECONDS) - ) + } + } - if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) { - const currentTimeout = currentRemovalTimeouts[queryCacheKey] - if (currentTimeout) { - clearTimeout(currentTimeout) - } - currentRemovalTimeouts[queryCacheKey] = setTimeout(() => { - if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) { - api.dispatch(removeQueryResult({ queryCacheKey })) - } - delete currentRemovalTimeouts![queryCacheKey] - }, finalKeepUnusedDataFor * 1000) + function handleUnsubscribe( + queryCacheKey: QueryCacheKey, + endpointName: string | undefined, + api: SubMiddlewareApi, + config: ConfigState + ) { + const endpointDefinition = context.endpointDefinitions[ + endpointName! + ] as QueryDefinition + const keepUnusedDataFor = + endpointDefinition?.keepUnusedDataFor ?? config.keepUnusedDataFor + + if (keepUnusedDataFor === Infinity) { + // Hey, user said keep this forever! + return + } + // Prevent `setTimeout` timers from overflowing a 32-bit internal int, by + // clamping the max value to be at most 1000ms less than the 32-bit max. + // Look, a 24.8-day keepalive ought to be enough for anybody, right? :) + // Also avoid negative values too. + const finalKeepUnusedDataFor = Math.max( + 0, + Math.min(keepUnusedDataFor, THIRTY_TWO_BIT_MAX_TIMER_SECONDS) + ) + + if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) { + const currentTimeout = currentRemovalTimeouts[queryCacheKey] + if (currentTimeout) { + clearTimeout(currentTimeout) } + currentRemovalTimeouts[queryCacheKey] = setTimeout(() => { + if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) { + api.dispatch(removeQueryResult({ queryCacheKey })) + } + delete currentRemovalTimeouts![queryCacheKey] + }, finalKeepUnusedDataFor * 1000) } } + + return handler } diff --git a/packages/toolkit/src/query/core/buildMiddleware/cacheLifecycle.ts b/packages/toolkit/src/query/core/buildMiddleware/cacheLifecycle.ts index 5fa44cac8d..21a0e4af77 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/cacheLifecycle.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/cacheLifecycle.ts @@ -11,9 +11,10 @@ import type { import { getMutationCacheKey } from '../buildSlice' import type { PatchCollection, Recipe } from '../buildThunks' import type { + ApiMiddlewareInternalHandler, + InternalHandlerBuilder, PromiseWithKnownReason, SubMiddlewareApi, - SubMiddlewareBuilder, } from './types' export type ReferenceCacheLifecycle = never @@ -176,7 +177,7 @@ const neverResolvedError = new Error( message: 'Promise never resolved before cacheEntryRemoved.' } -export const build: SubMiddlewareBuilder = ({ +export const buildCacheLifecycleHandler: InternalHandlerBuilder = ({ api, reducerPath, context, @@ -185,148 +186,145 @@ export const build: SubMiddlewareBuilder = ({ }) => { const isQueryThunk = isAsyncThunkAction(queryThunk) const isMutationThunk = isAsyncThunkAction(mutationThunk) - const isFullfilledThunk = isFulfilled(queryThunk, mutationThunk) + const isFulfilledThunk = isFulfilled(queryThunk, mutationThunk) - return (mwApi) => { - type CacheLifecycle = { - valueResolved?(value: { data: unknown; meta: unknown }): unknown - cacheEntryRemoved(): void - } - const lifecycleMap: Record = {} - - return (next) => - (action): any => { - const stateBefore = mwApi.getState() - - const result = next(action) - - const cacheKey = getCacheKey(action) + type CacheLifecycle = { + valueResolved?(value: { data: unknown; meta: unknown }): unknown + cacheEntryRemoved(): void + } + const lifecycleMap: Record = {} - if (queryThunk.pending.match(action)) { - const oldState = stateBefore[reducerPath].queries[cacheKey] - const state = mwApi.getState()[reducerPath].queries[cacheKey] - if (!oldState && state) { - handleNewKey( - action.meta.arg.endpointName, - action.meta.arg.originalArgs, - cacheKey, - mwApi, - action.meta.requestId - ) - } - } else if (mutationThunk.pending.match(action)) { - const state = mwApi.getState()[reducerPath].mutations[cacheKey] - if (state) { - handleNewKey( - action.meta.arg.endpointName, - action.meta.arg.originalArgs, - cacheKey, - mwApi, - action.meta.requestId - ) - } - } else if (isFullfilledThunk(action)) { - const lifecycle = lifecycleMap[cacheKey] - if (lifecycle?.valueResolved) { - lifecycle.valueResolved({ - data: action.payload, - meta: action.meta.baseQueryMeta, - }) - delete lifecycle.valueResolved - } - } else if ( - api.internalActions.removeQueryResult.match(action) || - api.internalActions.removeMutationResult.match(action) - ) { - const lifecycle = lifecycleMap[cacheKey] - if (lifecycle) { - delete lifecycleMap[cacheKey] - lifecycle.cacheEntryRemoved() - } - } else if (api.util.resetApiState.match(action)) { - for (const [cacheKey, lifecycle] of Object.entries(lifecycleMap)) { - delete lifecycleMap[cacheKey] - lifecycle.cacheEntryRemoved() - } - } + const handler: ApiMiddlewareInternalHandler = ( + action, + mwApi, + stateBefore + ) => { + const cacheKey = getCacheKey(action) - return result + if (queryThunk.pending.match(action)) { + const oldState = stateBefore[reducerPath].queries[cacheKey] + const state = mwApi.getState()[reducerPath].queries[cacheKey] + if (!oldState && state) { + handleNewKey( + action.meta.arg.endpointName, + action.meta.arg.originalArgs, + cacheKey, + mwApi, + action.meta.requestId + ) + } + } else if (mutationThunk.pending.match(action)) { + const state = mwApi.getState()[reducerPath].mutations[cacheKey] + if (state) { + handleNewKey( + action.meta.arg.endpointName, + action.meta.arg.originalArgs, + cacheKey, + mwApi, + action.meta.requestId + ) + } + } else if (isFulfilledThunk(action)) { + const lifecycle = lifecycleMap[cacheKey] + if (lifecycle?.valueResolved) { + lifecycle.valueResolved({ + data: action.payload, + meta: action.meta.baseQueryMeta, + }) + delete lifecycle.valueResolved + } + } else if ( + api.internalActions.removeQueryResult.match(action) || + api.internalActions.removeMutationResult.match(action) + ) { + const lifecycle = lifecycleMap[cacheKey] + if (lifecycle) { + delete lifecycleMap[cacheKey] + lifecycle.cacheEntryRemoved() + } + } else if (api.util.resetApiState.match(action)) { + for (const [cacheKey, lifecycle] of Object.entries(lifecycleMap)) { + delete lifecycleMap[cacheKey] + lifecycle.cacheEntryRemoved() } - - function getCacheKey(action: any) { - if (isQueryThunk(action)) return action.meta.arg.queryCacheKey - if (isMutationThunk(action)) return action.meta.requestId - if (api.internalActions.removeQueryResult.match(action)) - return action.payload.queryCacheKey - if (api.internalActions.removeMutationResult.match(action)) - return getMutationCacheKey(action.payload) - return '' } + } - function handleNewKey( - endpointName: string, - originalArgs: any, - queryCacheKey: string, - mwApi: SubMiddlewareApi, - requestId: string - ) { - const endpointDefinition = context.endpointDefinitions[endpointName] - const onCacheEntryAdded = endpointDefinition?.onCacheEntryAdded - if (!onCacheEntryAdded) return + function getCacheKey(action: any) { + if (isQueryThunk(action)) return action.meta.arg.queryCacheKey + if (isMutationThunk(action)) return action.meta.requestId + if (api.internalActions.removeQueryResult.match(action)) + return action.payload.queryCacheKey + if (api.internalActions.removeMutationResult.match(action)) + return getMutationCacheKey(action.payload) + return '' + } - let lifecycle = {} as CacheLifecycle + function handleNewKey( + endpointName: string, + originalArgs: any, + queryCacheKey: string, + mwApi: SubMiddlewareApi, + requestId: string + ) { + const endpointDefinition = context.endpointDefinitions[endpointName] + const onCacheEntryAdded = endpointDefinition?.onCacheEntryAdded + if (!onCacheEntryAdded) return - const cacheEntryRemoved = new Promise((resolve) => { - lifecycle.cacheEntryRemoved = resolve - }) - const cacheDataLoaded: PromiseWithKnownReason< - { data: unknown; meta: unknown }, - typeof neverResolvedError - > = Promise.race([ - new Promise<{ data: unknown; meta: unknown }>((resolve) => { - lifecycle.valueResolved = resolve - }), - cacheEntryRemoved.then(() => { - throw neverResolvedError - }), - ]) - // prevent uncaught promise rejections from happening. - // if the original promise is used in any way, that will create a new promise that will throw again - cacheDataLoaded.catch(() => {}) - lifecycleMap[queryCacheKey] = lifecycle - const selector = (api.endpoints[endpointName] as any).select( - endpointDefinition.type === DefinitionType.query - ? originalArgs - : queryCacheKey - ) + let lifecycle = {} as CacheLifecycle - const extra = mwApi.dispatch((_, __, extra) => extra) - const lifecycleApi = { - ...mwApi, - getCacheEntry: () => selector(mwApi.getState()), - requestId, - extra, - updateCachedData: (endpointDefinition.type === DefinitionType.query - ? (updateRecipe: Recipe) => - mwApi.dispatch( - api.util.updateQueryData( - endpointName as never, - originalArgs, - updateRecipe - ) - ) - : undefined) as any, + const cacheEntryRemoved = new Promise((resolve) => { + lifecycle.cacheEntryRemoved = resolve + }) + const cacheDataLoaded: PromiseWithKnownReason< + { data: unknown; meta: unknown }, + typeof neverResolvedError + > = Promise.race([ + new Promise<{ data: unknown; meta: unknown }>((resolve) => { + lifecycle.valueResolved = resolve + }), + cacheEntryRemoved.then(() => { + throw neverResolvedError + }), + ]) + // prevent uncaught promise rejections from happening. + // if the original promise is used in any way, that will create a new promise that will throw again + cacheDataLoaded.catch(() => {}) + lifecycleMap[queryCacheKey] = lifecycle + const selector = (api.endpoints[endpointName] as any).select( + endpointDefinition.type === DefinitionType.query + ? originalArgs + : queryCacheKey + ) - cacheDataLoaded, - cacheEntryRemoved, - } + const extra = mwApi.dispatch((_, __, extra) => extra) + const lifecycleApi = { + ...mwApi, + getCacheEntry: () => selector(mwApi.getState()), + requestId, + extra, + updateCachedData: (endpointDefinition.type === DefinitionType.query + ? (updateRecipe: Recipe) => + mwApi.dispatch( + api.util.updateQueryData( + endpointName as never, + originalArgs, + updateRecipe + ) + ) + : undefined) as any, - const runningHandler = onCacheEntryAdded(originalArgs, lifecycleApi) - // if a `neverResolvedError` was thrown, but not handled in the running handler, do not let it leak out further - Promise.resolve(runningHandler).catch((e) => { - if (e === neverResolvedError) return - throw e - }) + cacheDataLoaded, + cacheEntryRemoved, } + + const runningHandler = onCacheEntryAdded(originalArgs, lifecycleApi) + // if a `neverResolvedError` was thrown, but not handled in the running handler, do not let it leak out further + Promise.resolve(runningHandler).catch((e) => { + if (e === neverResolvedError) return + throw e + }) } + + return handler } diff --git a/packages/toolkit/src/query/core/buildMiddleware/devMiddleware.ts b/packages/toolkit/src/query/core/buildMiddleware/devMiddleware.ts index 3149d174a7..caec68a216 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/devMiddleware.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/devMiddleware.ts @@ -1,47 +1,34 @@ -import type { SubMiddlewareBuilder } from './types' +import type { InternalHandlerBuilder } from './types' -export const build: SubMiddlewareBuilder = ({ +export const buildDevCheckHandler: InternalHandlerBuilder = ({ api, context: { apiUid }, reducerPath, }) => { - return (mwApi) => { - let initialized = false - return (next) => (action) => { - if (!initialized) { - initialized = true - // dispatch before any other action - mwApi.dispatch(api.internalActions.middlewareRegistered(apiUid)) - } - - const result = next(action) - - if (api.util.resetApiState.match(action)) { - // dispatch after api reset - mwApi.dispatch(api.internalActions.middlewareRegistered(apiUid)) - } + return (action, mwApi) => { + if (api.util.resetApiState.match(action)) { + // dispatch after api reset + mwApi.dispatch(api.internalActions.middlewareRegistered(apiUid)) + } + if ( + typeof process !== 'undefined' && + process.env.NODE_ENV === 'development' + ) { if ( - typeof process !== 'undefined' && - process.env.NODE_ENV === 'development' + api.internalActions.middlewareRegistered.match(action) && + action.payload === apiUid && + mwApi.getState()[reducerPath]?.config?.middlewareRegistered === + 'conflict' ) { - if ( - api.internalActions.middlewareRegistered.match(action) && - action.payload === apiUid && - mwApi.getState()[reducerPath]?.config?.middlewareRegistered === - 'conflict' - ) { - console.warn(`There is a mismatch between slice and middleware for the reducerPath "${reducerPath}". + console.warn(`There is a mismatch between slice and middleware for the reducerPath "${reducerPath}". You can only have one api per reducer path, this will lead to crashes in various situations!${ - reducerPath === 'api' - ? ` + reducerPath === 'api' + ? ` If you have multiple apis, you *have* to specify the reducerPath option when using createApi!` - : '' - }`) - } + : '' + }`) } - - return result } } } diff --git a/packages/toolkit/src/query/core/buildMiddleware/index.ts b/packages/toolkit/src/query/core/buildMiddleware/index.ts index 7732daac61..840d7a03bb 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/index.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/index.ts @@ -1,5 +1,3 @@ -import { compose } from 'redux' - import type { AnyAction, Middleware, ThunkDispatch } from '@reduxjs/toolkit' import { createAction } from '@reduxjs/toolkit' @@ -9,60 +7,100 @@ import type { } from '../../endpointDefinitions' import type { QueryStatus, QuerySubState, RootState } from '../apiState' import type { QueryThunkArg } from '../buildThunks' -import { build as buildCacheCollection } from './cacheCollection' -import { build as buildInvalidationByTags } from './invalidationByTags' -import { build as buildPolling } from './polling' -import type { BuildMiddlewareInput } from './types' -import { build as buildWindowEventHandling } from './windowEventHandling' -import { build as buildCacheLifecycle } from './cacheLifecycle' -import { build as buildQueryLifecycle } from './queryLifecycle' -import { build as buildDevMiddleware } from './devMiddleware' -import { build as buildBatchActions } from './batchActions' +import { buildCacheCollectionHandler } from './cacheCollection' +import { buildInvalidationByTagsHandler } from './invalidationByTags' +import { buildPollingHandler } from './polling' +import type { BuildMiddlewareInput, InternalHandlerBuilder } from './types' +import { buildWindowEventHandler } from './windowEventHandling' +import { buildCacheLifecycleHandler } from './cacheLifecycle' +import { buildQueryLifecycleHandler } from './queryLifecycle' +import { buildDevCheckHandler } from './devMiddleware' +import { buildBatchedActionsHandler } from './batchActions' export function buildMiddleware< Definitions extends EndpointDefinitions, ReducerPath extends string, TagTypes extends string >(input: BuildMiddlewareInput) { - const { reducerPath, queryThunk } = input + const { reducerPath, queryThunk, api, context } = input + const { apiUid } = context + const actions = { invalidateTags: createAction< Array> >(`${reducerPath}/invalidateTags`), } - const middlewares = [ - buildDevMiddleware, - buildCacheCollection, - buildInvalidationByTags, - buildPolling, - buildWindowEventHandling, - buildCacheLifecycle, - buildQueryLifecycle, - buildBatchActions, - ].map((build) => - build({ + const isThisApiSliceAction = (action: AnyAction) => { + return ( + !!action && + typeof action.type === 'string' && + action.type.startsWith(`${reducerPath}/`) + ) + } + + const handlerBuilders: InternalHandlerBuilder[] = [ + buildDevCheckHandler, + buildCacheCollectionHandler, + buildInvalidationByTagsHandler, + buildPollingHandler, + buildCacheLifecycleHandler, + buildQueryLifecycleHandler, + ] + + const middleware: Middleware< + {}, + RootState, + ThunkDispatch + > = (mwApi) => { + let initialized = false + + const builderArgs = { ...(input as any as BuildMiddlewareInput< EndpointDefinitions, string, string >), refetchQuery, - }) - ) - const middleware: Middleware< - {}, - RootState, - ThunkDispatch - > = (mwApi) => (next) => { - const applied = compose( - ...middlewares.map((middleware) => middleware(mwApi)) - )(next) - return (action) => { - if (mwApi.getState()[reducerPath]) { - return applied(action) + } + + const handlers = handlerBuilders.map((build) => build(builderArgs)) + + const batchedActionsHandler = buildBatchedActionsHandler(builderArgs) + const windowEventsHandler = buildWindowEventHandler(builderArgs) + + return (next) => { + return (action) => { + if (!initialized) { + initialized = true + // dispatch before any other action + mwApi.dispatch(api.internalActions.middlewareRegistered(apiUid)) + } + + const stateBefore = mwApi.getState() + + if (!batchedActionsHandler(action, mwApi, stateBefore)) { + return + } + + const res = next(action) + + if (!!mwApi.getState()[reducerPath]) { + // Only run these checks if the middleware is registered okay + + // This looks for actions that aren't specific to the API slice + windowEventsHandler(action, mwApi, stateBefore) + + if (isThisApiSliceAction(action)) { + // Only run these additional checks if the actions are part of the API slice + for (let handler of handlers) { + handler(action, mwApi, stateBefore) + } + } + } + + return res } - return next(action) } } diff --git a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts index ba1dc6f77b..cfaf179679 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/invalidationByTags.ts @@ -5,9 +5,13 @@ import { calculateProvidedBy } from '../../endpointDefinitions' import type { QueryCacheKey } from '../apiState' import { QueryStatus } from '../apiState' import { calculateProvidedByThunk } from '../buildThunks' -import type { SubMiddlewareApi, SubMiddlewareBuilder } from './types' +import type { + SubMiddlewareApi, + InternalHandlerBuilder, + ApiMiddlewareInternalHandler, +} from './types' -export const build: SubMiddlewareBuilder = ({ +export const buildInvalidationByTagsHandler: InternalHandlerBuilder = ({ reducerPath, context, context: { endpointDefinitions }, @@ -17,45 +21,38 @@ export const build: SubMiddlewareBuilder = ({ refetchQuery, }) => { const { removeQueryResult } = api.internalActions + const isThunkActionWithTags = isAnyOf( + isFulfilled(mutationThunk), + isRejectedWithValue(mutationThunk) + ) - return (mwApi) => - (next) => - (action): any => { - const result = next(action) - - if ( - isAnyOf( - isFulfilled(mutationThunk), - isRejectedWithValue(mutationThunk) - )(action) - ) { - invalidateTags( - calculateProvidedByThunk( - action, - 'invalidatesTags', - endpointDefinitions, - assertTagType - ), - mwApi - ) - } - - if (api.util.invalidateTags.match(action)) { - invalidateTags( - calculateProvidedBy( - action.payload, - undefined, - undefined, - undefined, - undefined, - assertTagType - ), - mwApi - ) - } + const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { + if (isThunkActionWithTags(action)) { + invalidateTags( + calculateProvidedByThunk( + action, + 'invalidatesTags', + endpointDefinitions, + assertTagType + ), + mwApi + ) + } - return result + if (api.util.invalidateTags.match(action)) { + invalidateTags( + calculateProvidedBy( + action.payload, + undefined, + undefined, + undefined, + undefined, + assertTagType + ), + mwApi + ) } + } function invalidateTags( tags: readonly FullTagDescription[], @@ -86,4 +83,6 @@ export const build: SubMiddlewareBuilder = ({ } }) } + + return handler } diff --git a/packages/toolkit/src/query/core/buildMiddleware/polling.ts b/packages/toolkit/src/query/core/buildMiddleware/polling.ts index a30f3f4fe9..8a5ded5271 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/polling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/polling.ts @@ -3,132 +3,123 @@ import { QueryStatus } from '../apiState' import type { QueryStateMeta, SubMiddlewareApi, - SubMiddlewareBuilder, TimeoutId, + InternalHandlerBuilder, + ApiMiddlewareInternalHandler, } from './types' -export const build: SubMiddlewareBuilder = ({ +export const buildPollingHandler: InternalHandlerBuilder = ({ reducerPath, queryThunk, api, refetchQuery, }) => { - return (mwApi) => { - const currentPolls: QueryStateMeta<{ - nextPollTimestamp: number - timeout?: TimeoutId - pollingInterval: number - }> = {} - - return (next) => - (action): any => { - const result = next(action) - - if ( - api.internalActions.updateSubscriptionOptions.match(action) || - api.internalActions.unsubscribeQueryResult.match(action) - ) { - updatePollingInterval(action.payload, mwApi) - } - - if ( - queryThunk.pending.match(action) || - (queryThunk.rejected.match(action) && action.meta.condition) - ) { - updatePollingInterval(action.meta.arg, mwApi) - } - - if ( - queryThunk.fulfilled.match(action) || - (queryThunk.rejected.match(action) && !action.meta.condition) - ) { - startNextPoll(action.meta.arg, mwApi) - } - - if (api.util.resetApiState.match(action)) { - clearPolls() - } - - return result - } - - function startNextPoll( - { queryCacheKey }: QuerySubstateIdentifier, - api: SubMiddlewareApi + const currentPolls: QueryStateMeta<{ + nextPollTimestamp: number + timeout?: TimeoutId + pollingInterval: number + }> = {} + + const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { + if ( + api.internalActions.updateSubscriptionOptions.match(action) || + api.internalActions.unsubscribeQueryResult.match(action) ) { - const state = api.getState()[reducerPath] - const querySubState = state.queries[queryCacheKey] - const subscriptions = state.subscriptions[queryCacheKey] - - if (!querySubState || querySubState.status === QueryStatus.uninitialized) - return - - const lowestPollingInterval = findLowestPollingInterval(subscriptions) - if (!Number.isFinite(lowestPollingInterval)) return - - const currentPoll = currentPolls[queryCacheKey] - - if (currentPoll?.timeout) { - clearTimeout(currentPoll.timeout) - currentPoll.timeout = undefined - } - - const nextPollTimestamp = Date.now() + lowestPollingInterval - - const currentInterval: typeof currentPolls[number] = (currentPolls[ - queryCacheKey - ] = { - nextPollTimestamp, - pollingInterval: lowestPollingInterval, - timeout: setTimeout(() => { - currentInterval!.timeout = undefined - api.dispatch(refetchQuery(querySubState, queryCacheKey)) - }, lowestPollingInterval), - }) + updatePollingInterval(action.payload, mwApi) } - function updatePollingInterval( - { queryCacheKey }: QuerySubstateIdentifier, - api: SubMiddlewareApi + if ( + queryThunk.pending.match(action) || + (queryThunk.rejected.match(action) && action.meta.condition) ) { - const state = api.getState()[reducerPath] - const querySubState = state.queries[queryCacheKey] - const subscriptions = state.subscriptions[queryCacheKey] - - if ( - !querySubState || - querySubState.status === QueryStatus.uninitialized - ) { - return - } - - const lowestPollingInterval = findLowestPollingInterval(subscriptions) - - if (!Number.isFinite(lowestPollingInterval)) { - cleanupPollForKey(queryCacheKey) - return - } - - const currentPoll = currentPolls[queryCacheKey] - const nextPollTimestamp = Date.now() + lowestPollingInterval - - if (!currentPoll || nextPollTimestamp < currentPoll.nextPollTimestamp) { - startNextPoll({ queryCacheKey }, api) - } + updatePollingInterval(action.meta.arg, mwApi) } - function cleanupPollForKey(key: string) { - const existingPoll = currentPolls[key] - if (existingPoll?.timeout) { - clearTimeout(existingPoll.timeout) - } - delete currentPolls[key] + if ( + queryThunk.fulfilled.match(action) || + (queryThunk.rejected.match(action) && !action.meta.condition) + ) { + startNextPoll(action.meta.arg, mwApi) + } + + if (api.util.resetApiState.match(action)) { + clearPolls() } + } + + function startNextPoll( + { queryCacheKey }: QuerySubstateIdentifier, + api: SubMiddlewareApi + ) { + const state = api.getState()[reducerPath] + const querySubState = state.queries[queryCacheKey] + const subscriptions = state.subscriptions[queryCacheKey] + + if (!querySubState || querySubState.status === QueryStatus.uninitialized) + return + + const lowestPollingInterval = findLowestPollingInterval(subscriptions) + if (!Number.isFinite(lowestPollingInterval)) return + + const currentPoll = currentPolls[queryCacheKey] + + if (currentPoll?.timeout) { + clearTimeout(currentPoll.timeout) + currentPoll.timeout = undefined + } + + const nextPollTimestamp = Date.now() + lowestPollingInterval + + const currentInterval: typeof currentPolls[number] = (currentPolls[ + queryCacheKey + ] = { + nextPollTimestamp, + pollingInterval: lowestPollingInterval, + timeout: setTimeout(() => { + currentInterval!.timeout = undefined + api.dispatch(refetchQuery(querySubState, queryCacheKey)) + }, lowestPollingInterval), + }) + } + + function updatePollingInterval( + { queryCacheKey }: QuerySubstateIdentifier, + api: SubMiddlewareApi + ) { + const state = api.getState()[reducerPath] + const querySubState = state.queries[queryCacheKey] + const subscriptions = state.subscriptions[queryCacheKey] + + if (!querySubState || querySubState.status === QueryStatus.uninitialized) { + return + } + + const lowestPollingInterval = findLowestPollingInterval(subscriptions) + + if (!Number.isFinite(lowestPollingInterval)) { + cleanupPollForKey(queryCacheKey) + return + } + + const currentPoll = currentPolls[queryCacheKey] + const nextPollTimestamp = Date.now() + lowestPollingInterval + + if (!currentPoll || nextPollTimestamp < currentPoll.nextPollTimestamp) { + startNextPoll({ queryCacheKey }, api) + } + } + + function cleanupPollForKey(key: string) { + const existingPoll = currentPolls[key] + if (existingPoll?.timeout) { + clearTimeout(existingPoll.timeout) + } + delete currentPolls[key] + } - function clearPolls() { - for (const key of Object.keys(currentPolls)) { - cleanupPollForKey(key) - } + function clearPolls() { + for (const key of Object.keys(currentPolls)) { + cleanupPollForKey(key) } } @@ -143,4 +134,5 @@ export const build: SubMiddlewareBuilder = ({ } return lowestPollingInterval } + return handler } diff --git a/packages/toolkit/src/query/core/buildMiddleware/queryLifecycle.ts b/packages/toolkit/src/query/core/buildMiddleware/queryLifecycle.ts index c89a29f8d8..0df42c159e 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/queryLifecycle.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/queryLifecycle.ts @@ -8,9 +8,10 @@ import { DefinitionType } from '../../endpointDefinitions' import type { QueryFulfilledRejectionReason } from '../../endpointDefinitions' import type { Recipe } from '../buildThunks' import type { - SubMiddlewareBuilder, PromiseWithKnownReason, PromiseConstructorWithKnownReason, + InternalHandlerBuilder, + ApiMiddlewareInternalHandler, } from './types' export type ReferenceQueryLifecycle = never @@ -200,7 +201,7 @@ declare module '../../endpointDefinitions' { QueryLifecyclePromises {} } -export const build: SubMiddlewareBuilder = ({ +export const buildQueryLifecycleHandler: InternalHandlerBuilder = ({ api, context, queryThunk, @@ -210,83 +211,77 @@ export const build: SubMiddlewareBuilder = ({ const isRejectedThunk = isRejected(queryThunk, mutationThunk) const isFullfilledThunk = isFulfilled(queryThunk, mutationThunk) - return (mwApi) => { - type CacheLifecycle = { - resolve(value: { data: unknown; meta: unknown }): unknown - reject(value: QueryFulfilledRejectionReason): unknown - } - const lifecycleMap: Record = {} - - return (next) => - (action): any => { - const result = next(action) - - if (isPendingThunk(action)) { - const { - requestId, - arg: { endpointName, originalArgs }, - } = action.meta - const endpointDefinition = context.endpointDefinitions[endpointName] - const onQueryStarted = endpointDefinition?.onQueryStarted - if (onQueryStarted) { - const lifecycle = {} as CacheLifecycle - const queryFulfilled = - new (Promise as PromiseConstructorWithKnownReason)< - { data: unknown; meta: unknown }, - QueryFulfilledRejectionReason - >((resolve, reject) => { - lifecycle.resolve = resolve - lifecycle.reject = reject - }) - // prevent uncaught promise rejections from happening. - // if the original promise is used in any way, that will create a new promise that will throw again - queryFulfilled.catch(() => {}) - lifecycleMap[requestId] = lifecycle - const selector = (api.endpoints[endpointName] as any).select( - endpointDefinition.type === DefinitionType.query - ? originalArgs - : requestId - ) + type CacheLifecycle = { + resolve(value: { data: unknown; meta: unknown }): unknown + reject(value: QueryFulfilledRejectionReason): unknown + } + const lifecycleMap: Record = {} - const extra = mwApi.dispatch((_, __, extra) => extra) - const lifecycleApi = { - ...mwApi, - getCacheEntry: () => selector(mwApi.getState()), - requestId, - extra, - updateCachedData: (endpointDefinition.type === - DefinitionType.query - ? (updateRecipe: Recipe) => - mwApi.dispatch( - api.util.updateQueryData( - endpointName as never, - originalArgs, - updateRecipe - ) - ) - : undefined) as any, - queryFulfilled, - } - onQueryStarted(originalArgs, lifecycleApi) - } - } else if (isFullfilledThunk(action)) { - const { requestId, baseQueryMeta } = action.meta - lifecycleMap[requestId]?.resolve({ - data: action.payload, - meta: baseQueryMeta, - }) - delete lifecycleMap[requestId] - } else if (isRejectedThunk(action)) { - const { requestId, rejectedWithValue, baseQueryMeta } = action.meta - lifecycleMap[requestId]?.reject({ - error: action.payload ?? action.error, - isUnhandledError: !rejectedWithValue, - meta: baseQueryMeta as any, + const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { + if (isPendingThunk(action)) { + const { + requestId, + arg: { endpointName, originalArgs }, + } = action.meta + const endpointDefinition = context.endpointDefinitions[endpointName] + const onQueryStarted = endpointDefinition?.onQueryStarted + if (onQueryStarted) { + const lifecycle = {} as CacheLifecycle + const queryFulfilled = + new (Promise as PromiseConstructorWithKnownReason)< + { data: unknown; meta: unknown }, + QueryFulfilledRejectionReason + >((resolve, reject) => { + lifecycle.resolve = resolve + lifecycle.reject = reject }) - delete lifecycleMap[requestId] - } + // prevent uncaught promise rejections from happening. + // if the original promise is used in any way, that will create a new promise that will throw again + queryFulfilled.catch(() => {}) + lifecycleMap[requestId] = lifecycle + const selector = (api.endpoints[endpointName] as any).select( + endpointDefinition.type === DefinitionType.query + ? originalArgs + : requestId + ) - return result + const extra = mwApi.dispatch((_, __, extra) => extra) + const lifecycleApi = { + ...mwApi, + getCacheEntry: () => selector(mwApi.getState()), + requestId, + extra, + updateCachedData: (endpointDefinition.type === DefinitionType.query + ? (updateRecipe: Recipe) => + mwApi.dispatch( + api.util.updateQueryData( + endpointName as never, + originalArgs, + updateRecipe + ) + ) + : undefined) as any, + queryFulfilled, + } + onQueryStarted(originalArgs, lifecycleApi) } + } else if (isFullfilledThunk(action)) { + const { requestId, baseQueryMeta } = action.meta + lifecycleMap[requestId]?.resolve({ + data: action.payload, + meta: baseQueryMeta, + }) + delete lifecycleMap[requestId] + } else if (isRejectedThunk(action)) { + const { requestId, rejectedWithValue, baseQueryMeta } = action.meta + lifecycleMap[requestId]?.reject({ + error: action.payload ?? action.error, + isUnhandledError: !rejectedWithValue, + meta: baseQueryMeta as any, + }) + delete lifecycleMap[requestId] + } } + + return handler } diff --git a/packages/toolkit/src/query/core/buildMiddleware/types.ts b/packages/toolkit/src/query/core/buildMiddleware/types.ts index 8465883629..5d79bd9b29 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/types.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/types.ts @@ -1,6 +1,5 @@ import type { AnyAction, - AsyncThunk, AsyncThunkAction, Middleware, MiddlewareAPI, @@ -61,6 +60,16 @@ export type SubMiddlewareBuilder = ( ThunkDispatch > +export type ApiMiddlewareInternalHandler = ( + action: AnyAction, + mwApi: SubMiddlewareApi, + prevState: RootState +) => ReturnType + +export type InternalHandlerBuilder = ( + input: BuildSubMiddlewareInput +) => ApiMiddlewareInternalHandler + export interface PromiseConstructorWithKnownReason { /** * Creates a new Promise with a known rejection reason. diff --git a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts index 088cb47125..0345c2e14c 100644 --- a/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts +++ b/packages/toolkit/src/query/core/buildMiddleware/windowEventHandling.ts @@ -1,9 +1,13 @@ import { QueryStatus } from '../apiState' import type { QueryCacheKey } from '../apiState' import { onFocus, onOnline } from '../setupListeners' -import type { SubMiddlewareApi, SubMiddlewareBuilder } from './types' +import type { + ApiMiddlewareInternalHandler, + InternalHandlerBuilder, + SubMiddlewareApi, +} from './types' -export const build: SubMiddlewareBuilder = ({ +export const buildWindowEventHandler: InternalHandlerBuilder = ({ reducerPath, context, api, @@ -11,20 +15,14 @@ export const build: SubMiddlewareBuilder = ({ }) => { const { removeQueryResult } = api.internalActions - return (mwApi) => - (next) => - (action): any => { - const result = next(action) - - if (onFocus.match(action)) { - refetchValidQueries(mwApi, 'refetchOnFocus') - } - if (onOnline.match(action)) { - refetchValidQueries(mwApi, 'refetchOnReconnect') - } - - return result + const handler: ApiMiddlewareInternalHandler = (action, mwApi) => { + if (onFocus.match(action)) { + refetchValidQueries(mwApi, 'refetchOnFocus') } + if (onOnline.match(action)) { + refetchValidQueries(mwApi, 'refetchOnReconnect') + } + } function refetchValidQueries( api: SubMiddlewareApi, @@ -64,4 +62,6 @@ export const build: SubMiddlewareBuilder = ({ } }) } + + return handler }