Skip to content

Commit

Permalink
Merge pull request #651 from aryaemami59/typed-result-equality-check
Browse files Browse the repository at this point in the history
  • Loading branch information
markerikson authored Dec 1, 2023
2 parents 6a03653 + d7632a6 commit 5af3050
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
14 changes: 7 additions & 7 deletions src/defaultMemoize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ export function createCacheKeyComparator(equalityCheck: EqualityFn) {
/**
* @public
*/
export interface DefaultMemoizeOptions {
export interface DefaultMemoizeOptions<T = any> {
/**
* Used to compare the individual arguments of the provided calculation function.
*
Expand All @@ -142,7 +142,7 @@ export interface DefaultMemoizeOptions {
* use case, where an update to another field in the original data causes a recalculation
* due to changed references, but the output is still effectively the same.
*/
resultEqualityCheck?: EqualityFn
resultEqualityCheck?: EqualityFn<T>
/**
* The cache size for the selector. If greater than 1, the selector will use an LRU cache internally.
*
Expand All @@ -167,7 +167,7 @@ export interface DefaultMemoizeOptions {
*/
export function defaultMemoize<Func extends AnyFunction>(
func: Func,
equalityCheckOrOptions?: EqualityFn | DefaultMemoizeOptions
equalityCheckOrOptions?: EqualityFn | DefaultMemoizeOptions<ReturnType<Func>>
) {
const providedOptions =
typeof equalityCheckOrOptions === 'object'
Expand All @@ -191,20 +191,20 @@ export function defaultMemoize<Func extends AnyFunction>(

// we reference arguments instead of spreading them for performance reasons
function memoized() {
let value = cache.get(arguments)
let value = cache.get(arguments) as ReturnType<Func>
if (value === NOT_FOUND) {
// @ts-ignore
value = func.apply(null, arguments)
value = func.apply(null, arguments) as ReturnType<Func>
resultsCount++

if (resultEqualityCheck) {
const entries = cache.getEntries()
const matchingEntry = entries.find(entry =>
resultEqualityCheck(entry.value, value)
resultEqualityCheck(entry.value as ReturnType<Func>, value)
)

if (matchingEntry) {
value = matchingEntry.value
value = matchingEntry.value as ReturnType<Func>
resultsCount--
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ export type Combiner<InputSelectors extends SelectorArray, Result> = Distribute<
*
* @public
*/
export type EqualityFn = (a: any, b: any) => boolean
export type EqualityFn<T = any> = (a: T, b: T) => boolean

/**
* The frequency of input stability checks.
Expand Down
11 changes: 7 additions & 4 deletions src/weakMapMemoize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function createCacheNode<T>(): CacheNode<T> {
/**
* @public
*/
export interface WeakMapMemoizeOptions {
export interface WeakMapMemoizeOptions<T = any> {
/**
* If provided, used to compare a newly generated output value against previous values in the cache.
* If a match is found, the old value is returned. This addresses the common
Expand All @@ -82,7 +82,7 @@ export interface WeakMapMemoizeOptions {
* use case, where an update to another field in the original data causes a recalculation
* due to changed references, but the output is still effectively the same.
*/
resultEqualityCheck?: EqualityFn
resultEqualityCheck?: EqualityFn<T>
}

/**
Expand Down Expand Up @@ -160,7 +160,7 @@ export interface WeakMapMemoizeOptions {
*/
export function weakMapMemoize<Func extends AnyFunction>(
func: Func,
options: WeakMapMemoizeOptions = {}
options: WeakMapMemoizeOptions<ReturnType<Func>> = {}
) {
let fnNode = createCacheNode()
const { resultEqualityCheck } = options
Expand Down Expand Up @@ -222,7 +222,10 @@ export function weakMapMemoize<Func extends AnyFunction>(

if (resultEqualityCheck) {
const lastResultValue = lastResult?.deref() ?? lastResult
if (lastResultValue != null && resultEqualityCheck(lastResultValue, result)) {
if (
lastResultValue != null &&
resultEqualityCheck(lastResultValue as ReturnType<Func>, result)
) {
result = lastResultValue
resultsCount !== 0 && resultsCount--
}
Expand Down

0 comments on commit 5af3050

Please sign in to comment.