diff --git a/src/vanilla.ts b/src/vanilla.ts index dc19c875..da800e03 100644 --- a/src/vanilla.ts +++ b/src/vanilla.ts @@ -86,6 +86,9 @@ const buildProxyFunction = ( !(x instanceof RegExp) && !(x instanceof ArrayBuffer), + shouldTrapDefineProperty = (desc: PropertyDescriptor) => + desc.configurable && desc.enumerable && desc.writable, + defaultHandlePromise =

>( promise: P & { status?: 'pending' | 'fulfilled' | 'rejected' @@ -244,6 +247,50 @@ const buildProxyFunction = ( const baseObject = Array.isArray(initialObject) ? [] : Object.create(Object.getPrototypeOf(initialObject)) + const trapSet = ( + hasPrevValue: boolean, + prevValue: any, + prop: string | symbol, + value: any, + setValue: (nextValue: any) => void + ) => { + if ( + hasPrevValue && + (objectIs(prevValue, value) || + (proxyCache.has(value) && objectIs(prevValue, proxyCache.get(value)))) + ) { + return + } + removePropListener(prop) + if (isObject(value)) { + value = getUntracked(value) || value + } + let nextValue = value + if (value instanceof Promise) { + value + .then((v) => { + value.status = 'fulfilled' + value.value = v + notifyUpdate(['resolve', [prop], v]) + }) + .catch((e) => { + value.status = 'rejected' + value.reason = e + notifyUpdate(['reject', [prop], e]) + }) + } else { + if (!proxyStateMap.has(value) && canProxy(value)) { + nextValue = proxyFunction(value) + } + const childProxyState = + !refSet.has(nextValue) && proxyStateMap.get(nextValue) + if (childProxyState) { + addPropListener(prop, childProxyState) + } + } + setValue(nextValue) + notifyUpdate(['set', [prop], value, prevValue]) + } const handler: ProxyHandler = { deleteProperty(target: T, prop: string | symbol) { const prevValue = Reflect.get(target, prop) @@ -257,44 +304,35 @@ const buildProxyFunction = ( set(target: T, prop: string | symbol, value: any, receiver: object) { const hasPrevValue = Reflect.has(target, prop) const prevValue = Reflect.get(target, prop, receiver) - if ( - hasPrevValue && - (objectIs(prevValue, value) || - (proxyCache.has(value) && - objectIs(prevValue, proxyCache.get(value)))) - ) { - return true - } - removePropListener(prop) - if (isObject(value)) { - value = getUntracked(value) || value - } - let nextValue = value - if (value instanceof Promise) { - value - .then((v) => { - value.status = 'fulfilled' - value.value = v - notifyUpdate(['resolve', [prop], v]) - }) - .catch((e) => { - value.status = 'rejected' - value.reason = e - notifyUpdate(['reject', [prop], e]) - }) - } else { - if (!proxyStateMap.has(value) && canProxy(value)) { - nextValue = proxyFunction(value) - } - const childProxyState = - !refSet.has(nextValue) && proxyStateMap.get(nextValue) - if (childProxyState) { - addPropListener(prop, childProxyState) + trapSet(hasPrevValue, prevValue, prop, value, (nextValue) => { + Reflect.set(target, prop, nextValue, receiver) + }) + return true + }, + defineProperty( + target: T, + prop: string | symbol, + desc: PropertyDescriptor + ) { + if (shouldTrapDefineProperty(desc)) { + const prevDesc = Reflect.getOwnPropertyDescriptor(target, prop) + if (!prevDesc || shouldTrapDefineProperty(prevDesc)) { + trapSet( + !!prevDesc && 'value' in prevDesc, + prevDesc?.value, + prop, + desc.value, + (nextValue) => { + Reflect.defineProperty(target, prop, { + ...desc, + value: nextValue, + }) + } + ) + return true } } - Reflect.set(target, prop, nextValue, receiver) - notifyUpdate(['set', [prop], value, prevValue]) - return true + return Reflect.defineProperty(target, prop, desc) }, } const proxyObject = newProxy(baseObject, handler) @@ -333,6 +371,7 @@ const buildProxyFunction = ( objectIs, newProxy, canProxy, + shouldTrapDefineProperty, defaultHandlePromise, snapCache, createSnapshot, diff --git a/tests/class.test.tsx b/tests/class.test.tsx index f40f8b00..31b0dae6 100644 --- a/tests/class.test.tsx +++ b/tests/class.test.tsx @@ -355,3 +355,31 @@ it('no extra re-renders with getters', async () => { getByText('sum: 2 (2)') }) }) + +it('support class fields (defineProperty semantics)', async () => { + class Base { + constructor() { + return proxy(this) + } + } + class CountClass extends Base { + counter = { count: 0 } + } + const obj = new CountClass() + + const Counter = () => { + const snap = useSnapshot(obj) + return

count: {snap.counter.count}
+ } + + const { findByText } = render( + + + + ) + + await findByText('count: 0') + + obj.counter.count = 1 + await findByText('count: 1') +})