diff --git a/packages/toolkit/src/combineSlices.ts b/packages/toolkit/src/combineSlices.ts index 26ff173e7d..a5353af937 100644 --- a/packages/toolkit/src/combineSlices.ts +++ b/packages/toolkit/src/combineSlices.ts @@ -8,7 +8,7 @@ import type { UnionToIntersection, WithOptionalProp, } from './tsHelpers' -import { emplace } from './utils' +import { getOrInsertComputed } from './utils' type SliceLike = { reducerPath: ReducerPath @@ -324,8 +324,10 @@ const createStateProxy = ( state: State, reducerMap: Partial>, ) => - emplace(stateProxyMap, state, { - insert: () => + getOrInsertComputed( + stateProxyMap, + state, + () => new Proxy(state, { get: (target, prop, receiver) => { if (prop === ORIGINAL_STATE) return target @@ -350,7 +352,7 @@ const createStateProxy = ( return result }, }), - }) as State + ) as State const original = (state: any) => { if (!isStateProxy(state)) { diff --git a/packages/toolkit/src/createSlice.ts b/packages/toolkit/src/createSlice.ts index e3d2c25c56..1d4f3e3712 100644 --- a/packages/toolkit/src/createSlice.ts +++ b/packages/toolkit/src/createSlice.ts @@ -26,7 +26,7 @@ import { createReducer } from './createReducer' import type { ActionReducerMapBuilder, TypedActionCreator } from './mapBuilders' import { executeReducerBuilderCallback } from './mapBuilders' import type { Id, TypeGuard } from './tsHelpers' -import { emplace } from './utils' +import { getOrInsertComputed } from './utils' const asyncThunkSymbol = /* @__PURE__ */ Symbol.for( 'rtk-slice-createasyncthunk', @@ -769,25 +769,25 @@ export function buildCreateSlice({ creators }: BuildCreateSliceConfig = {}) { function getSelectors( selectState: (rootState: any) => State = selectSelf, ) { - const selectorCache = emplace(injectedSelectorCache, injected, { - insert: () => new WeakMap(), - }) - - return emplace(selectorCache, selectState, { - insert: () => { - const map: Record> = {} - for (const [name, selector] of Object.entries( - options.selectors ?? {}, - )) { - map[name] = wrapSelector( - selector, - selectState, - getInitialState, - injected, - ) - } - return map - }, + const selectorCache = getOrInsertComputed( + injectedSelectorCache, + injected, + () => new WeakMap(), + ) + + return getOrInsertComputed(selectorCache, selectState, () => { + const map: Record> = {} + for (const [name, selector] of Object.entries( + options.selectors ?? {}, + )) { + map[name] = wrapSelector( + selector, + selectState, + getInitialState, + injected, + ) + } + return map }) as any } return { diff --git a/packages/toolkit/src/dynamicMiddleware/index.ts b/packages/toolkit/src/dynamicMiddleware/index.ts index ed151b2979..8e61d6769b 100644 --- a/packages/toolkit/src/dynamicMiddleware/index.ts +++ b/packages/toolkit/src/dynamicMiddleware/index.ts @@ -3,7 +3,7 @@ import { compose } from 'redux' import { createAction } from '../createAction' import { isAllOf } from '../matchers' import { nanoid } from '../nanoid' -import { emplace, find } from '../utils' +import { getOrInsertComputed } from '../utils' import type { AddMiddleware, DynamicMiddleware, @@ -23,7 +23,6 @@ const createMiddlewareEntry = < >( middleware: Middleware, ): MiddlewareEntry => ({ - id: nanoid(), middleware, applied: new Map(), }) @@ -38,7 +37,10 @@ export const createDynamicMiddleware = < DispatchType extends Dispatch = Dispatch, >(): DynamicMiddlewareInstance => { const instanceId = nanoid() - const middlewareMap = new Map>() + const middlewareMap = new Map< + Middleware, + MiddlewareEntry + >() const withMiddleware = Object.assign( createAction( @@ -58,14 +60,7 @@ export const createDynamicMiddleware = < ...middlewares: Middleware[] ) { middlewares.forEach((middleware) => { - let entry = find( - Array.from(middlewareMap.values()), - (entry) => entry.middleware === middleware, - ) - if (!entry) { - entry = createMiddlewareEntry(middleware) - } - middlewareMap.set(entry.id, entry) + getOrInsertComputed(middlewareMap, middleware, createMiddlewareEntry) }) }, { withTypes: () => addMiddleware }, @@ -73,7 +68,7 @@ export const createDynamicMiddleware = < const getFinalMiddleware: Middleware<{}, State, DispatchType> = (api) => { const appliedMiddleware = Array.from(middlewareMap.values()).map((entry) => - emplace(entry.applied, api, { insert: () => entry.middleware(api) }), + getOrInsertComputed(entry.applied, api, entry.middleware), ) return compose(...appliedMiddleware) } diff --git a/packages/toolkit/src/dynamicMiddleware/types.ts b/packages/toolkit/src/dynamicMiddleware/types.ts index ee8c37a21b..989c7ffcc0 100644 --- a/packages/toolkit/src/dynamicMiddleware/types.ts +++ b/packages/toolkit/src/dynamicMiddleware/types.ts @@ -59,7 +59,6 @@ export type MiddlewareEntry< State = unknown, DispatchType extends Dispatch = Dispatch, > = { - id: string middleware: Middleware applied: Map< MiddlewareAPI, diff --git a/packages/toolkit/src/listenerMiddleware/index.ts b/packages/toolkit/src/listenerMiddleware/index.ts index efa2912ad3..cfefa17e09 100644 --- a/packages/toolkit/src/listenerMiddleware/index.ts +++ b/packages/toolkit/src/listenerMiddleware/index.ts @@ -4,7 +4,6 @@ import type { ThunkDispatch } from 'redux-thunk' import { createAction } from '../createAction' import { nanoid } from '../nanoid' -import { find } from '../utils' import { TaskAbortError, listenerCancelled, @@ -221,9 +220,8 @@ export const createListenerEntry: TypedCreateListenerEntry = (options: FallbackAddListenerOptions) => { const { type, predicate, effect } = getListenerEntryPropsFrom(options) - const id = nanoid() const entry: ListenerEntry = { - id, + id: nanoid(), effect, type, predicate, @@ -238,6 +236,22 @@ export const createListenerEntry: TypedCreateListenerEntry = { withTypes: () => createListenerEntry }, ) as unknown as TypedCreateListenerEntry +const findListenerEntry = ( + listenerMap: Map, + options: FallbackAddListenerOptions, +) => { + const { type, effect, predicate } = getListenerEntryPropsFrom(options) + + return Array.from(listenerMap.values()).find((entry) => { + const matchPredicateOrType = + typeof type === 'string' + ? entry.type === type + : entry.predicate === predicate + + return matchPredicateOrType && entry.effect === effect + }) +} + const cancelActiveListeners = ( entry: ListenerEntry>, ) => { @@ -330,7 +344,7 @@ export const createListenerMiddleware = < assertFunction(onError, 'onError') const insertEntry = (entry: ListenerEntry) => { - entry.unsubscribe = () => listenerMap.delete(entry!.id) + entry.unsubscribe = () => listenerMap.delete(entry.id) listenerMap.set(entry.id, entry) return (cancelOptions?: UnsubscribeListenerOptions) => { @@ -342,14 +356,9 @@ export const createListenerMiddleware = < } const startListening = ((options: FallbackAddListenerOptions) => { - let entry = find( - Array.from(listenerMap.values()), - (existingEntry) => existingEntry.effect === options.effect, - ) - - if (!entry) { - entry = createListenerEntry(options as any) - } + const entry = + findListenerEntry(listenerMap, options) ?? + createListenerEntry(options as any) return insertEntry(entry) }) as AddListenerOverloads @@ -361,16 +370,7 @@ export const createListenerMiddleware = < const stopListening = ( options: FallbackAddListenerOptions & UnsubscribeListenerOptions, ): boolean => { - const { type, effect, predicate } = getListenerEntryPropsFrom(options) - - const entry = find(Array.from(listenerMap.values()), (entry) => { - const matchPredicateOrType = - typeof type === 'string' - ? entry.type === type - : entry.predicate === predicate - - return matchPredicateOrType && entry.effect === effect - }) + const entry = findListenerEntry(listenerMap, options) if (entry) { entry.unsubscribe() diff --git a/packages/toolkit/src/listenerMiddleware/tests/listenerMiddleware.test.ts b/packages/toolkit/src/listenerMiddleware/tests/listenerMiddleware.test.ts index 56939af639..ad657508e2 100644 --- a/packages/toolkit/src/listenerMiddleware/tests/listenerMiddleware.test.ts +++ b/packages/toolkit/src/listenerMiddleware/tests/listenerMiddleware.test.ts @@ -117,6 +117,7 @@ describe('createListenerMiddleware', () => { const testAction1 = createAction('testAction1') type TestAction1 = ReturnType const testAction2 = createAction('testAction2') + type TestAction2 = ReturnType const testAction3 = createAction('testAction3') beforeAll(() => { @@ -339,6 +340,27 @@ describe('createListenerMiddleware', () => { ]) }) + test('subscribing with the same effect but different predicate is allowed', () => { + const effect = vi.fn((_: TestAction1 | TestAction2) => {}) + + startListening({ + actionCreator: testAction1, + effect, + }) + startListening({ + actionCreator: testAction2, + effect, + }) + + store.dispatch(testAction1('a')) + store.dispatch(testAction2('b')) + + expect(effect.mock.calls).toEqual([ + [testAction1('a'), middlewareApi], + [testAction2('b'), middlewareApi], + ]) + }) + test('unsubscribing via callback', () => { const effect = vi.fn((_: TestAction1) => {}) diff --git a/packages/toolkit/src/listenerMiddleware/types.ts b/packages/toolkit/src/listenerMiddleware/types.ts index b5980e1085..7e6f6c2783 100644 --- a/packages/toolkit/src/listenerMiddleware/types.ts +++ b/packages/toolkit/src/listenerMiddleware/types.ts @@ -578,9 +578,13 @@ export type TypedAddListener< OverrideStateType, unknown, UnknownAction - >, - OverrideExtraArgument = unknown, - >() => TypedAddListener + >, + OverrideExtraArgument = unknown, + >() => TypedAddListener< + OverrideStateType, + OverrideDispatchType, + OverrideExtraArgument + > } /** @@ -641,7 +645,11 @@ export type TypedRemoveListener< UnknownAction >, OverrideExtraArgument = unknown, - >() => TypedRemoveListener + >() => TypedRemoveListener< + OverrideStateType, + OverrideDispatchType, + OverrideExtraArgument + > } /** @@ -701,7 +709,11 @@ export type TypedStartListening< UnknownAction >, OverrideExtraArgument = unknown, - >() => TypedStartListening + >() => TypedStartListening< + OverrideStateType, + OverrideDispatchType, + OverrideExtraArgument + > } /** @@ -756,7 +768,11 @@ export type TypedStopListening< UnknownAction >, OverrideExtraArgument = unknown, - >() => TypedStopListening + >() => TypedStopListening< + OverrideStateType, + OverrideDispatchType, + OverrideExtraArgument + > } /** @@ -813,7 +829,11 @@ export type TypedCreateListenerEntry< UnknownAction >, OverrideExtraArgument = unknown, - >() => TypedStopListening + >() => TypedStopListening< + OverrideStateType, + OverrideDispatchType, + OverrideExtraArgument + > } /** diff --git a/packages/toolkit/src/query/core/buildInitiate.ts b/packages/toolkit/src/query/core/buildInitiate.ts index ff2ab45456..ca2d6854e0 100644 --- a/packages/toolkit/src/query/core/buildInitiate.ts +++ b/packages/toolkit/src/query/core/buildInitiate.ts @@ -16,7 +16,7 @@ import type { QueryDefinition, ResultTypeFrom, } from '../endpointDefinitions' -import { countObjectKeys, isNotNullish } from '../utils' +import { countObjectKeys, getOrInsert, isNotNullish } from '../utils' import type { SubscriptionOptions } from './apiState' import type { QueryResultSelectorResult } from './buildSelectors' import type { MutationThunk, QueryThunk, QueryThunkArg } from './buildThunks' @@ -391,9 +391,8 @@ You must add the middleware for RTK-Query to function correctly!`, ) if (!runningQuery && !skippedSynchronously && !forceQueryFn) { - const running = runningQueries.get(dispatch) || {} + const running = getOrInsert(runningQueries, dispatch, {}) running[queryCacheKey] = statePromise - runningQueries.set(dispatch, running) statePromise.then(() => { delete running[queryCacheKey] diff --git a/packages/toolkit/src/query/utils/getOrInsert.ts b/packages/toolkit/src/query/utils/getOrInsert.ts new file mode 100644 index 0000000000..124da032ea --- /dev/null +++ b/packages/toolkit/src/query/utils/getOrInsert.ts @@ -0,0 +1,15 @@ +export function getOrInsert( + map: WeakMap, + key: K, + value: V, +): V +export function getOrInsert(map: Map, key: K, value: V): V +export function getOrInsert( + map: Map | WeakMap, + key: K, + value: V, +): V { + if (map.has(key)) return map.get(key) as V + + return map.set(key, value).get(key) as V +} diff --git a/packages/toolkit/src/query/utils/index.ts b/packages/toolkit/src/query/utils/index.ts index 0eb7c62ce9..916b32fd60 100644 --- a/packages/toolkit/src/query/utils/index.ts +++ b/packages/toolkit/src/query/utils/index.ts @@ -8,3 +8,4 @@ export * from './isNotNullish' export * from './isOnline' export * from './isValidUrl' export * from './joinUrls' +export * from './getOrInsert' diff --git a/packages/toolkit/src/utils.ts b/packages/toolkit/src/utils.ts index 1f8445bb0f..6607f4b339 100644 --- a/packages/toolkit/src/utils.ts +++ b/packages/toolkit/src/utils.ts @@ -26,19 +26,6 @@ export function delay(ms: number) { return new Promise((resolve) => setTimeout(resolve, ms)) } -export function find( - iterable: Iterable, - comparator: (item: T) => boolean, -): T | undefined { - for (const entry of iterable) { - if (comparator(entry)) { - return entry - } - } - - return undefined -} - export class Tuple = []> extends Array< Items[number] > { @@ -87,81 +74,38 @@ export function freezeDraftable(val: T) { return isDraftable(val) ? createNextState(val, () => {}) : val } -interface WeakMapEmplaceHandler { - /** - * Will be called to get value, if no value is currently in map. - */ - insert?(key: K, map: WeakMap): V - /** - * Will be called to update a value, if one exists already. - */ - update?(previous: V, key: K, map: WeakMap): V -} +export function getOrInsert( + map: WeakMap, + key: K, + value: V, +): V +export function getOrInsert(map: Map, key: K, value: V): V +export function getOrInsert( + map: Map | WeakMap, + key: K, + value: V, +): V { + if (map.has(key)) return map.get(key) as V -interface MapEmplaceHandler { - /** - * Will be called to get value, if no value is currently in map. - */ - insert?(key: K, map: Map): V - /** - * Will be called to update a value, if one exists already. - */ - update?(previous: V, key: K, map: Map): V + return map.set(key, value).get(key) as V } -export function emplace( - map: Map, +export function getOrInsertComputed( + map: WeakMap, key: K, - handler: MapEmplaceHandler, + compute: (key: K) => V, ): V -export function emplace( - map: WeakMap, +export function getOrInsertComputed( + map: Map, key: K, - handler: WeakMapEmplaceHandler, + compute: (key: K) => V, ): V -/** - * Allow inserting a new value, or updating an existing one - * @throws if called for a key with no current value and no `insert` handler is provided - * @returns current value in map (after insertion/updating) - * ```ts - * // return current value if already in map, otherwise initialise to 0 and return that - * const num = emplace(map, key, { - * insert: () => 0 - * }) - * - * // increase current value by one if already in map, otherwise initialise to 0 - * const num = emplace(map, key, { - * update: (n) => n + 1, - * insert: () => 0, - * }) - * - * // only update if value's already in the map - and increase it by one - * if (map.has(key)) { - * const num = emplace(map, key, { - * update: (n) => n + 1, - * }) - * } - * ``` - * - * @remarks - * Based on https://github.com/tc39/proposal-upsert currently in Stage 2 - maybe in a few years we'll be able to replace this with direct method calls - */ -export function emplace( - map: WeakMap, +export function getOrInsertComputed( + map: Map | WeakMap, key: K, - handler: WeakMapEmplaceHandler, + compute: (key: K) => V, ): V { - if (map.has(key)) { - let value = map.get(key) as V - if (handler.update) { - value = handler.update(value, key, map) - map.set(key, value) - } - return value - } - if (!handler.insert) - throw new Error('No insert provided for key not already in map') - const inserted = handler.insert(key, map) - map.set(key, inserted) - return inserted + if (map.has(key)) return map.get(key) as V + + return map.set(key, compute(key)).get(key) as V }