diff --git a/conn.go b/conn.go index 20b1fd09..3ee5a20b 100644 --- a/conn.go +++ b/conn.go @@ -44,7 +44,7 @@ const ( type watchType int const ( - watchTypeData = iota + watchTypeData watchType = iota watchTypeExist watchTypeChild ) @@ -530,6 +530,33 @@ func (c *Conn) flushRequests(err error) { c.requestsLock.Unlock() } +// Send event to all interested watchers +func (c *Conn) notifyWatches(ev Event) { + var wTypes []watchType + switch ev.Type { + case EventNodeCreated: + wTypes = []watchType{watchTypeExist} + case EventNodeDataChanged: + wTypes = []watchType{watchTypeExist, watchTypeData} + case EventNodeChildrenChanged: + wTypes = []watchType{watchTypeChild} + case EventNodeDeleted: + wTypes = []watchType{watchTypeExist, watchTypeData, watchTypeChild} + } + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, t := range wTypes { + wpt := watchPathType{ev.Path, t} + if watchers := c.watchers[wpt]; len(watchers) > 0 { + for _, ch := range watchers { + ch <- ev + close(ch) + } + delete(c.watchers, wpt) + } + } +} + // Send error to all watchers and clear watchers map func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() @@ -812,29 +839,7 @@ func (c *Conn) recvLoop(conn net.Conn) error { Err: nil, } c.sendEvent(ev) - wTypes := make([]watchType, 0, 2) - switch res.Type { - case EventNodeCreated: - wTypes = append(wTypes, watchTypeExist) - case EventNodeDataChanged: - wTypes = append(wTypes, watchTypeExist, watchTypeData) - case EventNodeChildrenChanged: - wTypes = append(wTypes, watchTypeChild) - case EventNodeDeleted: - wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) - } - c.watchersLock.Lock() - for _, t := range wTypes { - wpt := watchPathType{res.Path, t} - if watchers := c.watchers[wpt]; len(watchers) > 0 { - for _, ch := range watchers { - ch <- ev - close(ch) - } - delete(c.watchers, wpt) - } - } - c.watchersLock.Unlock() + c.notifyWatches(ev) } else if res.Xid == -2 { // Ping response. Ignore. } else if res.Xid < 0 { diff --git a/conn_test.go b/conn_test.go index 618efb19..630693a6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,6 +2,7 @@ package zk import ( "context" + "fmt" "io/ioutil" "sync" "testing" @@ -80,3 +81,88 @@ func TestDeadlockInClose(t *testing.T) { t.Fatal("apparent deadlock!") } } + +func TestNotifyWatches(t *testing.T) { + cases := []struct { + eType EventType + path string + watches map[watchPathType]bool + }{ + { + EventNodeCreated, "/", + map[watchPathType]bool{ + {"/", watchTypeExist}: true, + {"/", watchTypeChild}: false, + {"/", watchTypeData}: false, + }, + }, + { + EventNodeCreated, "/a", + map[watchPathType]bool{ + {"/b", watchTypeExist}: false, + }, + }, + { + EventNodeDataChanged, "/", + map[watchPathType]bool{ + {"/", watchTypeExist}: true, + {"/", watchTypeData}: true, + {"/", watchTypeChild}: false, + }, + }, + { + EventNodeChildrenChanged, "/", + map[watchPathType]bool{ + {"/", watchTypeExist}: false, + {"/", watchTypeData}: false, + {"/", watchTypeChild}: true, + }, + }, + { + EventNodeDeleted, "/", + map[watchPathType]bool{ + {"/", watchTypeExist}: true, + {"/", watchTypeData}: true, + {"/", watchTypeChild}: true, + }, + }, + } + + conn := &Conn{watchers: make(map[watchPathType][]chan Event)} + + for idx, c := range cases { + t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { + c := c + + notifications := make([]struct { + path string + notify bool + ch <-chan Event + }, len(c.watches)) + + var idx int + for wpt, expectEvent := range c.watches { + ch := conn.addWatcher(wpt.path, wpt.wType) + notifications[idx].path = wpt.path + notifications[idx].notify = expectEvent + notifications[idx].ch = ch + idx++ + } + ev := Event{Type: c.eType, Path: c.path} + conn.notifyWatches(ev) + + for _, res := range notifications { + select { + case e := <-res.ch: + if !res.notify || e.Path != res.path { + t.Fatal("unexpeted notification received") + } + default: + if res.notify { + t.Fatal("expected notification not received") + } + } + } + }) + } +}