diff --git a/src/vanilla.ts b/src/vanilla.ts
index dc19c875..da800e03 100644
--- a/src/vanilla.ts
+++ b/src/vanilla.ts
@@ -86,6 +86,9 @@ const buildProxyFunction = (
!(x instanceof RegExp) &&
!(x instanceof ArrayBuffer),
+ shouldTrapDefineProperty = (desc: PropertyDescriptor) =>
+ desc.configurable && desc.enumerable && desc.writable,
+
defaultHandlePromise =
>(
promise: P & {
status?: 'pending' | 'fulfilled' | 'rejected'
@@ -244,6 +247,50 @@ const buildProxyFunction = (
const baseObject = Array.isArray(initialObject)
? []
: Object.create(Object.getPrototypeOf(initialObject))
+ const trapSet = (
+ hasPrevValue: boolean,
+ prevValue: any,
+ prop: string | symbol,
+ value: any,
+ setValue: (nextValue: any) => void
+ ) => {
+ if (
+ hasPrevValue &&
+ (objectIs(prevValue, value) ||
+ (proxyCache.has(value) && objectIs(prevValue, proxyCache.get(value))))
+ ) {
+ return
+ }
+ removePropListener(prop)
+ if (isObject(value)) {
+ value = getUntracked(value) || value
+ }
+ let nextValue = value
+ if (value instanceof Promise) {
+ value
+ .then((v) => {
+ value.status = 'fulfilled'
+ value.value = v
+ notifyUpdate(['resolve', [prop], v])
+ })
+ .catch((e) => {
+ value.status = 'rejected'
+ value.reason = e
+ notifyUpdate(['reject', [prop], e])
+ })
+ } else {
+ if (!proxyStateMap.has(value) && canProxy(value)) {
+ nextValue = proxyFunction(value)
+ }
+ const childProxyState =
+ !refSet.has(nextValue) && proxyStateMap.get(nextValue)
+ if (childProxyState) {
+ addPropListener(prop, childProxyState)
+ }
+ }
+ setValue(nextValue)
+ notifyUpdate(['set', [prop], value, prevValue])
+ }
const handler: ProxyHandler = {
deleteProperty(target: T, prop: string | symbol) {
const prevValue = Reflect.get(target, prop)
@@ -257,44 +304,35 @@ const buildProxyFunction = (
set(target: T, prop: string | symbol, value: any, receiver: object) {
const hasPrevValue = Reflect.has(target, prop)
const prevValue = Reflect.get(target, prop, receiver)
- if (
- hasPrevValue &&
- (objectIs(prevValue, value) ||
- (proxyCache.has(value) &&
- objectIs(prevValue, proxyCache.get(value))))
- ) {
- return true
- }
- removePropListener(prop)
- if (isObject(value)) {
- value = getUntracked(value) || value
- }
- let nextValue = value
- if (value instanceof Promise) {
- value
- .then((v) => {
- value.status = 'fulfilled'
- value.value = v
- notifyUpdate(['resolve', [prop], v])
- })
- .catch((e) => {
- value.status = 'rejected'
- value.reason = e
- notifyUpdate(['reject', [prop], e])
- })
- } else {
- if (!proxyStateMap.has(value) && canProxy(value)) {
- nextValue = proxyFunction(value)
- }
- const childProxyState =
- !refSet.has(nextValue) && proxyStateMap.get(nextValue)
- if (childProxyState) {
- addPropListener(prop, childProxyState)
+ trapSet(hasPrevValue, prevValue, prop, value, (nextValue) => {
+ Reflect.set(target, prop, nextValue, receiver)
+ })
+ return true
+ },
+ defineProperty(
+ target: T,
+ prop: string | symbol,
+ desc: PropertyDescriptor
+ ) {
+ if (shouldTrapDefineProperty(desc)) {
+ const prevDesc = Reflect.getOwnPropertyDescriptor(target, prop)
+ if (!prevDesc || shouldTrapDefineProperty(prevDesc)) {
+ trapSet(
+ !!prevDesc && 'value' in prevDesc,
+ prevDesc?.value,
+ prop,
+ desc.value,
+ (nextValue) => {
+ Reflect.defineProperty(target, prop, {
+ ...desc,
+ value: nextValue,
+ })
+ }
+ )
+ return true
}
}
- Reflect.set(target, prop, nextValue, receiver)
- notifyUpdate(['set', [prop], value, prevValue])
- return true
+ return Reflect.defineProperty(target, prop, desc)
},
}
const proxyObject = newProxy(baseObject, handler)
@@ -333,6 +371,7 @@ const buildProxyFunction = (
objectIs,
newProxy,
canProxy,
+ shouldTrapDefineProperty,
defaultHandlePromise,
snapCache,
createSnapshot,
diff --git a/tests/class.test.tsx b/tests/class.test.tsx
index f40f8b00..31b0dae6 100644
--- a/tests/class.test.tsx
+++ b/tests/class.test.tsx
@@ -355,3 +355,31 @@ it('no extra re-renders with getters', async () => {
getByText('sum: 2 (2)')
})
})
+
+it('support class fields (defineProperty semantics)', async () => {
+ class Base {
+ constructor() {
+ return proxy(this)
+ }
+ }
+ class CountClass extends Base {
+ counter = { count: 0 }
+ }
+ const obj = new CountClass()
+
+ const Counter = () => {
+ const snap = useSnapshot(obj)
+ return count: {snap.counter.count}
+ }
+
+ const { findByText } = render(
+
+
+
+ )
+
+ await findByText('count: 0')
+
+ obj.counter.count = 1
+ await findByText('count: 1')
+})