From 73b390faca5d1d053dbf8c564c51ae0bdeae0b27 Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Wed, 24 Aug 2022 13:34:30 -0700 Subject: [PATCH] Support persistent watches Implements the new persistent watch types introduced in 3.6, along with corresponding TreeCache and NodeCache utilities that try to keep in sync with the remote state. --- Makefile | 2 + cache_utils.go | 146 +++++++++++++ conn.go | 204 ++++++++++++++++--- conn_test.go | 108 ++++++---- constants.go | 42 ++++ nodecache.go | 109 ++++++++++ nodecache_test.go | 114 +++++++++++ server_help_test.go | 50 ++++- structs.go | 33 ++- treecache.go | 417 ++++++++++++++++++++++++++++++++++++++ treecache_test.go | 316 +++++++++++++++++++++++++++++ unlimited_channel.go | 88 ++++++++ unlimited_channel_test.go | 54 +++++ zk_test.go | 159 +++++++++++++-- 14 files changed, 1740 insertions(+), 102 deletions(-) create mode 100644 cache_utils.go create mode 100644 nodecache.go create mode 100644 nodecache_test.go create mode 100644 treecache.go create mode 100644 treecache_test.go create mode 100644 unlimited_channel.go create mode 100644 unlimited_channel_test.go diff --git a/Makefile b/Makefile index f0b7965c..a8361dc0 100644 --- a/Makefile +++ b/Makefile @@ -20,9 +20,11 @@ $(ZK): tar -zxf $(ZK).tar.gz rm $(ZK).tar.gz +.PHONY: zookeeper zookeeper: $(ZK) # we link to a standard directory path so then the tests dont need to find based on version # in the test code. this allows backward compatable testing. + rm -f zookeeper ln -s $(ZK) zookeeper .PHONY: setup diff --git a/cache_utils.go b/cache_utils.go new file mode 100644 index 00000000..661f6bf3 --- /dev/null +++ b/cache_utils.go @@ -0,0 +1,146 @@ +package zk + +import ( + "math" + "math/rand" + "strings" + "time" +) + +type RetryPolicy interface { + // ShouldRetry checks whether a given failed call should be retried based on how many times it was attempted and the + // last error encountered. See ExecuteWithRetries for more details. + ShouldRetry(attempt int, lastError error) (backoff time.Duration) +} + +// ExecuteWithRetries simply retries the given function as many times as the given policy will allow, waiting in between +// invocations according to the backoff given by the policy. If the policy returns a negative backoff or stopChan is +// closed, the last encountered error is returned. +func ExecuteWithRetries(policy RetryPolicy, stopChan chan struct{}, f func() (err error)) (err error) { + for attempt := 0; ; attempt++ { + err = f() + if err == nil { + return nil + } + backoff := policy.ShouldRetry(attempt, err) + if backoff < 0 { + return err + } + + select { + case <-stopChan: + return err + case <-time.After(backoff): + continue + } + } +} + +// The DefaultWatcherRetryPolicy is an ExponentialBackoffPolicy with infinite retries on all but three error types: +// +// - zk.ErrNoNode: Retrying fetches on a node that doesn't exist isn't going to yield very interesting results, +// especially in the context of a watch where an eventual zk.EventNodeCreated will notify the watcher of the node's +// reappearance. +// +// - zk.ErrConnectionClosed: This error is returned by any call made after Close() is called on a zk.Conn. This call +// will never succeed. +// +// - zk.ErrNoAuth: If a zk.Conn does not have the required authentication to access a node, retrying the call will not +// succeed until authentication is added. It's best to report this as early as possible instead of blocking the process. +// +// The reasoning behind infinite retries by default is that if any network connectivity issues arise, the watcher itself +// will likely be impacted or stop receiving events altogether. Retrying forever is the best bet to keep everything in +// sync. +var DefaultWatcherRetryPolicy RetryPolicy = &ExponentialBackoffPolicy{ + InitialBackoff: 100 * time.Millisecond, + MaxBackoff: 5 * time.Second, + MaxAttempts: math.MaxInt64, + IsErrorRetryable: func(err error) bool { + return err != ErrNoNode && err != ErrConnectionClosed && err != ErrNoAuth + }, +} + +type RetryPolicyFunc func(attempt int, lastError error) time.Duration + +func (r RetryPolicyFunc) ShouldRetry(attempt int, lastError error) (backoff time.Duration) { + return r(attempt, lastError) +} + +// ExponentialBackoffPolicy is a RetryPolicy that implements exponential backoff and jitter (see "Full Jitter" in +// https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/). It gives an option to dynamically decide +// whether to retry specific error types. +type ExponentialBackoffPolicy struct { + // The initial backoff duration and the value that will be multiplied when calculating the backoff for a specific + // attempt. + InitialBackoff time.Duration + // The maximum duration to backoff. + MaxBackoff time.Duration + // How many times to retry a given call before bailing. + MaxAttempts int + // If non-nil, this function is called to check if an error can be retried + IsErrorRetryable func(err error) bool + // If non-nil, this rand.Rand will be used to generate the jitter. Otherwise, the global rand is used. + Rand *rand.Rand +} + +func (e *ExponentialBackoffPolicy) ShouldRetry(retryCount int, err error) (backoff time.Duration) { + if (e.IsErrorRetryable != nil && !e.IsErrorRetryable(err)) || retryCount > e.MaxAttempts { + return -1 + } + + backoff = e.InitialBackoff << retryCount + if backoff < e.InitialBackoff /* check for overflow from left shift */ || backoff > e.MaxBackoff { + backoff = e.MaxBackoff + } + + if e.Rand != nil { + backoff = time.Duration(e.Rand.Int63n(int64(backoff))) + } else { + backoff = time.Duration(rand.Int63n(int64(backoff))) + + } + + return backoff +} + +func getNodeData(policy RetryPolicy, stopChan chan struct{}, nodePath string, conn *Conn) (data []byte, err error) { + err = ExecuteWithRetries(policy, stopChan, func() (err error) { + data, _, err = conn.Get(nodePath) + return err + }) + return data, err +} + +func getNodeDataAndChildren(policy RetryPolicy, stopChan chan struct{}, nodePath string, conn *Conn) (data []byte, children []string, err error) { + err = ExecuteWithRetries(policy, stopChan, func() (err error) { + // Execute both calls in the same attempt so the data and children are as in-sync as possible + data, _, err = conn.Get(nodePath) + if err != nil { + return err + } + + children, _, err = conn.Children(nodePath) + return err + }) + return data, children, err +} + +func JoinPath(parent, child string) string { + if !strings.HasSuffix(parent, "/") { + parent += "/" + } + if strings.HasPrefix(child, "/") { + child = child[1:] + } + return parent + child +} + +func SplitPath(path string) (dir, name string) { + i := strings.LastIndex(path, "/") + if i == 0 { + dir, name = "/", path[1:] + } else { + dir, name = path[:i], path[i+1:] + } + return dir, name +} diff --git a/conn.go b/conn.go index 9afd2d27..1f996066 100644 --- a/conn.go +++ b/conn.go @@ -47,8 +47,14 @@ const ( watchTypeData watchType = iota watchTypeExist watchTypeChild + watchTypePersistent + watchTypePersistentRecursive ) +func (w watchType) isPersistent() bool { + return w == watchTypePersistent || w == watchTypePersistentRecursive +} + type watchPathType struct { path string wType watchType @@ -530,29 +536,47 @@ func (c *Conn) flushRequests(err error) { c.requestsLock.Unlock() } +var eventWatchTypes = map[EventType][]watchType{ + EventNodeCreated: {watchTypeExist, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeDataChanged: {watchTypeExist, watchTypeData, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeChildrenChanged: {watchTypeChild, watchTypePersistent}, + EventNodeDeleted: {watchTypeExist, watchTypeData, watchTypeChild, watchTypePersistent, watchTypePersistentRecursive}, +} +var persistentWatchTypes = []watchType{watchTypePersistent, watchTypePersistentRecursive} + // Send event to all interested watchers func (c *Conn) notifyWatches(ev Event) { - var wTypes []watchType - switch ev.Type { - case EventNodeCreated: - wTypes = []watchType{watchTypeExist} - case EventNodeDataChanged: - wTypes = []watchType{watchTypeExist, watchTypeData} - case EventNodeChildrenChanged: - wTypes = []watchType{watchTypeChild} - case EventNodeDeleted: - wTypes = []watchType{watchTypeExist, watchTypeData, watchTypeChild} + wTypes := eventWatchTypes[ev.Type] + if len(wTypes) == 0 { + return } + c.watchersLock.Lock() defer c.watchersLock.Unlock() - for _, t := range wTypes { - wpt := watchPathType{ev.Path, t} - if watchers := c.watchers[wpt]; len(watchers) > 0 { - for _, ch := range watchers { - ch <- ev + + broadcast := func(wpt watchPathType) { + for _, ch := range c.watchers[wpt] { + ch <- ev + if !wpt.wType.isPersistent() { close(ch) } - delete(c.watchers, wpt) + } + } + + for _, t := range wTypes { + if t == watchTypePersistentRecursive { + for p := ev.Path; ; p, _ = SplitPath(p) { + broadcast(watchPathType{p, t}) + if p == "/" { + break + } + } + } else { + wpt := watchPathType{ev.Path, t} + broadcast(wpt) + if !t.isPersistent() { + delete(c.watchers, wpt) + } } } } @@ -562,8 +586,15 @@ func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - if len(c.watchers) >= 0 { + if len(c.watchers) > 0 { for pathType, watchers := range c.watchers { + if err == ErrSessionExpired && pathType.wType.isPersistent() { + // Ignore ErrSessionExpired for persistent watchers as the client will either automatically reconnect, + // or this is a shutdown-worthy error in which case there will be a followup invocation of this method + // with ErrClosing + continue + } + ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} c.sendEvent(ev) // also publish globally for _, ch := range watchers { @@ -610,12 +641,7 @@ func (c *Conn) sendSetWatches() { reqs = append(reqs, req) } sizeSoFar = 28 // fixed overhead of a set-watches packet - req = &setWatchesRequest{ - RelativeZxid: c.lastZxid, - DataWatches: make([]string, 0), - ExistWatches: make([]string, 0), - ChildWatches: make([]string, 0), - } + req = &setWatchesRequest{RelativeZxid: c.lastZxid} } sizeSoFar += addlLen switch pathType.wType { @@ -625,6 +651,10 @@ func (c *Conn) sendSetWatches() { req.ExistWatches = append(req.ExistWatches, pathType.path) case watchTypeChild: req.ChildWatches = append(req.ChildWatches, pathType.path) + case watchTypePersistent: + req.PersistentWatches = append(req.PersistentWatches, pathType.path) + case watchTypePersistentRecursive: + req.PersistentRecursiveWatches = append(req.PersistentRecursiveWatches, pathType.path) } n++ } @@ -646,7 +676,37 @@ func (c *Conn) sendSetWatches() { // aren't failure modes where a blocking write to the channel of requests // could hang indefinitely and cause this goroutine to leak... for _, req := range reqs { - _, err := c.request(opSetWatches, req, res, nil) + var op int32 = opSetWatches + if len(req.PersistentWatches) > 0 || len(req.PersistentRecursiveWatches) > 0 { + // to maintain compatibility with older servers, only send opSetWatches2 if persistent watches are used + op = opSetWatches2 + } + + _, err := c.request(op, req, res, func(r *request, header *responseHeader, err error) { + if err == nil && op == opSetWatches2 { + // If the setWatches was successful, notify the persistent watchers they've been reconnected. + // Because we process responses in one routine, we know that the following will execute before + // subsequent responses are processed. This means we won't end up in a situation where events are + // sent to watchers before the reconnect event is sent. + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + var paths []string + if wt == watchTypePersistent { + paths = req.PersistentWatches + } else { + paths = req.PersistentRecursiveWatches + } + for _, p := range paths { + e := Event{Type: EventWatching, State: StateConnected, Path: p} + c.sendEvent(e) // also publish globally + for _, ch := range c.watchers[watchPathType{path: p, wType: wt}] { + ch <- e + } + } + } + } + }) if err != nil { c.logger.Printf("Failed to set previous watches: %v", err) break @@ -1050,7 +1110,7 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { return &res.Stat, err } -// Create creates a znode. +// Create creates a znode. If acl is nil, it uses the global WorldACL with PermAll // The returned path is the new path assigned by the server, it may not be the // same as the input, for example when creating a sequence znode the returned path // will be the input path with a sequence number appended. @@ -1059,6 +1119,10 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return "", err } + if acl == nil { + acl = WorldACL(PermAll) + } + res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) if err == ErrConnectionClosed { @@ -1331,6 +1395,96 @@ func (c *Conn) Server() string { return c.server } +func (c *Conn) AddPersistentWatch(path string, mode AddWatchMode) (ch <-chan Event, err error) { + if err = validatePath(path, false); err != nil { + return nil, err + } + + res := &addWatchResponse{} + _, err = c.request(opAddWatch, &addWatchRequest{Path: path, Mode: mode}, res, func(r *request, header *responseHeader, err error) { + if err == nil { + var wt watchType + if mode == AddWatchModePersistent { + wt = watchTypePersistent + } else { + wt = watchTypePersistentRecursive + } + ch = c.addWatcher(path, wt) + } + }) + if err == ErrConnectionClosed { + return nil, err + } + return ch, err +} + +func (c *Conn) RemovePersistentWatch(path string, ch <-chan Event) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + deleted := false + + res := &checkWatchesResponse{} + _, err = c.request(opCheckWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for i, w := range c.watchers[wpt] { + if w == ch { + deleted = true + c.watchers[wpt] = append(c.watchers[wpt][:i], c.watchers[wpt][i+1:]...) + w <- Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher} + close(w) + return + } + } + } + }) + + if err != nil { + return err + } + + if !deleted { + return ErrNoWatcher + } + + return nil +} + +func (c *Conn) RemoveAllPersistentWatches(path string) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + res := &checkWatchesResponse{} + _, err = c.request(opRemoveWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + e := Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: ErrNoWatcher} + for _, ch := range c.watchers[wpt] { + ch <- e + close(ch) + } + delete(c.watchers, wpt) + } + }) + return err +} + func resendZkAuth(ctx context.Context, c *Conn) error { shouldCancel := func() bool { select { diff --git a/conn_test.go b/conn_test.go index 96299280..75e8cd27 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "io/ioutil" + "log" + "strings" "sync" "testing" "time" @@ -119,77 +121,109 @@ func TestNotifyWatches(t *testing.T) { watches map[watchPathType]bool }{ { - EventNodeCreated, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeChild}: false, - {"/", watchTypeData}: false, - }, - }, - { - EventNodeCreated, "/a", - map[watchPathType]bool{ + eType: EventNodeCreated, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, {"/b", watchTypeExist}: false, + + {"/a", watchTypeChild}: false, + + {"/a", watchTypeData}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeDataChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: false, + eType: EventNodeDataChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeChildrenChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: false, - {"/", watchTypeData}: false, - {"/", watchTypeChild}: true, + eType: EventNodeChildrenChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: false, + {"/a", watchTypeData}: false, + {"/a", watchTypeChild}: true, + {"/a", watchTypePersistent}: true, + {"/a", watchTypePersistentRecursive}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: false, + {"/", watchTypePersistentRecursive}: false, }, }, { - EventNodeDeleted, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: true, + eType: EventNodeDeleted, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: true, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, } - conn := &Conn{watchers: make(map[watchPathType][]chan Event)} + log.SetFlags(log.Flags() | log.Lshortfile) for idx, c := range cases { + c := c t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { - c := c - notifications := make([]struct { - path string + watchPathType notify bool ch <-chan Event }, len(c.watches)) + conn := &Conn{watchers: make(map[watchPathType][]chan Event)} + var idx int for wpt, expectEvent := range c.watches { - ch := conn.addWatcher(wpt.path, wpt.wType) - notifications[idx].path = wpt.path + notifications[idx].watchPathType = wpt notifications[idx].notify = expectEvent - notifications[idx].ch = ch + notifications[idx].ch = conn.addWatcher(wpt.path, wpt.wType) idx++ } - ev := Event{Type: c.eType, Path: c.path} - conn.notifyWatches(ev) + + conn.notifyWatches(Event{Type: c.eType, Path: c.path}) for _, res := range notifications { select { case e := <-res.ch: - if !res.notify || e.Path != res.path { - t.Fatal("unexpeted notification received") + isPathCorrect := + (res.wType == watchTypePersistentRecursive && strings.HasPrefix(e.Path, res.path)) || + e.Path == res.path + if !res.notify || !isPathCorrect { + t.Logf("unexpeted notification received by %+v: %+v", res, e) + t.Fail() } default: if res.notify { - t.Fatal("expected notification not received") + t.Logf("expected notification not received for %+v", res) + t.Fail() } } } diff --git a/constants.go b/constants.go index 84455d2b..6df68088 100644 --- a/constants.go +++ b/constants.go @@ -27,11 +27,15 @@ const ( opCheck = 13 opMulti = 14 opReconfig = 16 + opCheckWatches = 17 + opRemoveWatches = 18 opCreateContainer = 19 opCreateTTL = 21 opClose = -11 opSetAuth = 100 opSetWatches = 101 + opSetWatches2 = 105 + opAddWatch = 106 opError = -1 // Not in protocol, used internally opWatcherEvent = -2 @@ -47,6 +51,7 @@ const ( // EventSession represents a session event. EventSession EventType = -1 EventNotWatching EventType = -2 + EventWatching EventType = -3 ) var ( @@ -57,6 +62,7 @@ var ( EventNodeChildrenChanged: "EventNodeChildrenChanged", EventSession: "EventSession", EventNotWatching: "EventNotWatching", + EventWatching: "EventWatching", } ) @@ -129,6 +135,8 @@ var ( ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") ErrBadArguments = errors.New("invalid arguments") + ErrNoWatcher = errors.New("zk: no such watcher") + ErrUnimplemented = errors.New("zk: Not implemented") // ErrInvalidCallback = errors.New("zk: invalid callback specified") errCodeToError = map[ErrCode]error{ @@ -147,8 +155,10 @@ var ( errClosing: ErrClosing, errNothing: ErrNothing, errSessionMoved: ErrSessionMoved, + errNoWatcher: ErrNoWatcher, errZReconfigDisabled: ErrReconfigDisabled, errBadArguments: ErrBadArguments, + errUnimplemented: ErrUnimplemented, } ) @@ -186,6 +196,7 @@ const ( errClosing ErrCode = -116 errNothing ErrCode = -117 errSessionMoved ErrCode = -118 + errNoWatcher ErrCode = -121 // Attempts to perform a reconfiguration operation when reconfiguration feature is disabled errZReconfigDisabled ErrCode = -123 ) @@ -224,6 +235,7 @@ var ( opClose: "close", opSetAuth: "setAuth", opSetWatches: "setWatches", + opAddWatch: "addWatch", opWatcherEvent: "watcherEvent", } @@ -263,3 +275,33 @@ var ( ModeStandalone: "standalone", } ) + +// AddWatchMode asd +type AddWatchMode int32 + +func (m AddWatchMode) String() string { + if name, ok := addWatchModeNames[m]; ok { + return name + } + return "unknown" +} + +const ( + AddWatchModePersistent AddWatchMode = 0 + AddWatchModePersistentRecursive AddWatchMode = 1 +) + +var ( + addWatchModeNames = map[AddWatchMode]string{ + AddWatchModePersistent: "persistent", + AddWatchModePersistentRecursive: "persistentRecursive", + } +) + +type WatcherType int32 + +const ( + WatcherTypeChildren = WatcherType(1) + WatcherTypeData = WatcherType(2) + WatcherTypeAny = WatcherType(3) +) diff --git a/nodecache.go b/nodecache.go new file mode 100644 index 00000000..2aed18eb --- /dev/null +++ b/nodecache.go @@ -0,0 +1,109 @@ +package zk + +import ( + "sync" +) + +type NodeCacheOpts struct { + // The retry policy to use when fetching nodes' data and children. If nil, uses DefaultWatcherRetryPolicy + RetryPolicy RetryPolicy + // used for debugging, receives events after they are processed + outChan chan Event +} + +// A NodeCache attempts to stay up-to-date with a given node's state using a persistent watch. +type NodeCache struct { + Path string + Conn *Conn + + retryPolicy RetryPolicy + lock sync.RWMutex + data []byte + events <-chan Event + err error + stopChan chan struct{} + outChan chan Event +} + +func NewNodeCache(conn *Conn, nodePath string) (nc *NodeCache, err error) { + return NewNodeCacheWithOpts(conn, nodePath, NodeCacheOpts{}) +} + +func NewNodeCacheWithOpts(conn *Conn, nodePath string, opts NodeCacheOpts) (nc *NodeCache, err error) { + nc = &NodeCache{ + Path: nodePath, + retryPolicy: opts.RetryPolicy, + Conn: conn, + stopChan: make(chan struct{}), + outChan: opts.outChan, + } + if nc.retryPolicy == nil { + nc.retryPolicy = DefaultWatcherRetryPolicy + } + + nc.events, err = conn.AddPersistentWatch(nodePath, AddWatchModePersistent) + if err != nil { + return nil, err + } + + nc.start() + + return nc, nil +} + +// Stop removes the persistent watch that was created for this node. +func (nc *NodeCache) Stop() (err error) { + err = nc.Conn.RemovePersistentWatch(nc.Path, nc.events) + if err != nil { + close(nc.stopChan) + } + return err +} + +// Get returns the most up-to-date data it has available. Returns zk.ErrNoWatcher if Stop has been called. +func (nc *NodeCache) Get() ([]byte, error) { + nc.lock.RLock() + defer nc.lock.RUnlock() + return nc.data, nc.err +} + +// Refresh forces a refresh of the node's data. +func (nc *NodeCache) Refresh() ([]byte, error) { + data, err := getNodeData(nc.retryPolicy, nc.stopChan, nc.Path, nc.Conn) + nc.lock.Lock() + defer nc.lock.Unlock() + nc.data, nc.err = data, err + return nc.data, nc.err +} + +func (nc *NodeCache) start() { + ch := toUnlimitedChannel(nc.events) + + // Initial data fetch blocks startup to ensure a sane first read + nc.Refresh() + + go func() { + for e := range ch { + switch e.Type { + case EventNodeCreated, EventNodeDataChanged, EventWatching: + nc.Refresh() + case EventNodeDeleted: + nc.lock.Lock() + nc.data, nc.err = nil, ErrNoNode + nc.lock.Unlock() + case EventNotWatching: + // EventNotWatching means that channel's about to close, and we can report the error that caused the + // closure. We don't zero out the data and stat to reflect the last known state of the node. + nc.lock.Lock() + nc.err = e.Err + nc.lock.Unlock() + } + if nc.outChan != nil { + nc.outChan <- e + } + } + if nc.outChan != nil { + close(nc.outChan) + } + }() +} diff --git a/nodecache_test.go b/nodecache_test.go new file mode 100644 index 00000000..4116190e --- /dev/null +++ b/nodecache_test.go @@ -0,0 +1,114 @@ +package zk + +import ( + "bytes" + "testing" + "time" +) + +func TestNodeCache(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + + outChan := make(chan Event) + nc, err := NewNodeCacheWithOpts(zk, "/foo", NodeCacheOpts{outChan: outChan}) + requireNoErrorf(t, err, "Failed to initialize node cache") + + _, err = nc.Get() + if err != ErrNoNode { + t.Fatalf("Get did not return zk.ErrNoNode: %+v", err) + } + + testData := []byte{1, 2, 3, 4} + _, err = zk.Create(nc.Path, testData, 0, nil) + requireNoErrorf(t, err, "Failed to set data for %q", nc.Path) + + select { + case e := <-outChan: + if e.Type != EventNodeCreated { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(1 * time.Second): + t.Fatalf("Did not get create event") + } + + data, err := nc.Get() + if err != nil { + t.Fatalf("Get returned an error: %+v", err) + } + if !bytes.Equal(data, testData) { + t.Fatalf("Get did not return the correct data, expected %+v, got %+v", testData, data) + } + + err = zk.Delete(nc.Path, -1) + requireNoErrorf(t, err, "Failed to delete %q", nc.Path) + + select { + case e := <-outChan: + if e.Type != EventNodeDeleted { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(1 * time.Second): + t.Fatalf("Did not get create event") + } + + _, err = nc.Get() + if err != ErrNoNode { + t.Fatalf("Get returned a node after a delete") + } + + zk.conn.Close() + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("create returned an error: %+v", err) + } + + _, err = zk2.Create(nc.Path, testData, 0, nil) + if err != nil { + t.Fatalf("create returned an error: %+v", err) + } + + close(zk.reconnectLatch) + + select { + case e := <-outChan: + if e.Type != EventWatching { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get reconnect event") + } + + data, err = nc.Get() + if err != nil { + t.Fatalf("Get returned an error: %+v", err) + } + if !bytes.Equal(data, testData) { + t.Fatalf("Get did not return the correct data, expected %+v, got %+v", testData, data) + } + }) +} + +func TestNodeCacheStartup(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + testData := []byte{1, 2, 3, 4} + + const path = "/foo" + _, err := zk.Create(path, testData, 0, nil) + requireNoErrorf(t, err, "Failed to create %q", path) + + nc, err := NewNodeCache(zk, path) + requireNoErrorf(t, err, "Failed to initialize node cache") + + data, err := nc.Get() + if err != nil { + t.Fatalf("Get returned an error: %+v", err) + } + if !bytes.Equal(data, testData) { + t.Fatalf("Get did not return the correct data, expected %+v, got %+v", testData, data) + } + }) +} diff --git a/server_help_test.go b/server_help_test.go index 067b9a75..a97b9611 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -35,6 +35,33 @@ type TestCluster struct { Servers []TestServer } +func WithTestCluster(t *testing.T, testTimeout time.Duration, f func(ts *TestCluster, zk *Conn)) { + ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + ts.Stop() + }) + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + t.Cleanup(func() { + zk.Close() + }) + doneChan := make(chan struct{}) + go func() { + defer close(doneChan) + f(ts, zk) + }() + select { + case <-doneChan: + case <-time.After(testTimeout): + t.Fatalf("Test did not complete within timeout") + } +} + // TODO: pull this into its own package to allow for better isolation of integration tests vs. unit // testing. This should be used on CI systems and local only when needed whereas unit tests should remain // fast and not rely on external dependencies. @@ -54,7 +81,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl } tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") success := false startPort := int(rand.Int31n(6000) + 10000) @@ -68,7 +95,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN+1)) - requireNoError(t, os.Mkdir(srvPath, 0700), "failed to make server path") + requireNoErrorf(t, os.Mkdir(srvPath, 0700), "failed to make server path") port := startPort + serverN*3 cfg := ServerConfig{ @@ -89,20 +116,20 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) srv, err := NewIntegrationTestServer(t, cfgPath, stdout, stderr) - requireNoError(t, err) + requireNoErrorf(t, err) if err := srv.Start(); err != nil { return nil, err @@ -252,14 +279,15 @@ func (tc *TestCluster) StopAllServers() error { return nil } -func requireNoError(t *testing.T, err error, msgAndArgs ...interface{}) { +func requireNoErrorf(t *testing.T, err error, msgAndArgs ...interface{}) { if err != nil { + t.Helper() t.Logf("received unexpected error: %v", err) - t.Fatal(msgAndArgs...) + t.Fatalf(msgAndArgs[0].(string), msgAndArgs[1:]...) } } -func requireMinimumZkVersion(t *testing.T, minimum string) { +func RequireMinimumZkVersion(t *testing.T, minimum string) { if val, ok := os.LookupEnv("ZK_VERSION"); ok { split := func(v string) (parts []int) { for _, s := range strings.Split(minimum, ".") { @@ -281,6 +309,6 @@ func requireMinimumZkVersion(t *testing.T, minimum string) { } } } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") + t.Skip("did not detect zk_version from env. skipping test") } } diff --git a/structs.go b/structs.go index 8eb41e39..3cbf1e32 100644 --- a/structs.go +++ b/structs.go @@ -3,6 +3,7 @@ package zk import ( "encoding/binary" "errors" + "fmt" "log" "reflect" "runtime" @@ -18,8 +19,8 @@ var ( type defaultLogger struct{} -func (defaultLogger) Printf(format string, a ...interface{}) { - log.Printf(format, a...) +func (defaultLogger) Printf(format string, v ...interface{}) { + log.Output(3, fmt.Sprintf(format, v...)) } type ACL struct { @@ -256,10 +257,12 @@ type setSaslResponse struct { } type setWatchesRequest struct { - RelativeZxid int64 - DataWatches []string - ExistWatches []string - ChildWatches []string + RelativeZxid int64 + DataWatches []string + ExistWatches []string + ChildWatches []string + PersistentWatches []string + PersistentRecursiveWatches []string } type setWatchesResponse struct{} @@ -301,6 +304,20 @@ type reconfigRequest struct { type reconfigReponse getDataResponse +type addWatchRequest struct { + Path string + Mode AddWatchMode +} + +type addWatchResponse struct{} + +type checkWatchesRequest struct { + Path string + Type WatcherType +} + +type checkWatchesResponse struct{} + func (r *multiRequest) Encode(buf []byte) (int, error) { total := 0 for _, op := range r.Ops { @@ -622,7 +639,7 @@ func requestStructForOp(op int32) interface{} { return &setAclRequest{} case opSetData: return &SetDataRequest{} - case opSetWatches: + case opSetWatches, opSetWatches2: return &setWatchesRequest{} case opSync: return &syncRequest{} @@ -634,6 +651,8 @@ func requestStructForOp(op int32) interface{} { return &multiRequest{} case opReconfig: return &reconfigRequest{} + case opAddWatch: + return &addWatchRequest{} } return nil } diff --git a/treecache.go b/treecache.go new file mode 100644 index 00000000..f71d2773 --- /dev/null +++ b/treecache.go @@ -0,0 +1,417 @@ +package zk + +import ( + "fmt" + "strings" + "sync" +) + +type TreeCacheOpts struct { + // The retry policy to use when fetching nodes' data and children. If nil, uses DefaultWatcherRetryPolicy + RetryPolicy RetryPolicy + // Provides a depth limit when bootstrapping and listening to watch events. The root is considered depth 1, so any + // negative or zero value disables this feature. + MaxDepth int + // used for debugging, receives events after they are processed + outChan chan Event +} + +type TreeCache struct { + RootPath string + Conn *Conn + + root *treeCacheNode + retryPolicy RetryPolicy + maxDepth int + + events <-chan Event + stopChan chan struct{} + outChan chan Event +} + +type treeCacheNode struct { + lock sync.RWMutex + data []byte + children map[string]*treeCacheNode + err error +} + +func NewTreeCache(conn *Conn, rootPath string) (tc *TreeCache, err error) { + return NewTreeCacheWithOpts(conn, rootPath, TreeCacheOpts{}) +} + +func NewTreeCacheWithOpts(conn *Conn, rootPath string, opts TreeCacheOpts) (tc *TreeCache, err error) { + tc = &TreeCache{ + RootPath: rootPath, + retryPolicy: opts.RetryPolicy, + maxDepth: opts.MaxDepth, + root: newTreeCacheNode(), + Conn: conn, + stopChan: make(chan struct{}), + outChan: opts.outChan, + } + if tc.retryPolicy == nil { + tc.retryPolicy = DefaultWatcherRetryPolicy + } + + tc.events, err = conn.AddPersistentWatch(rootPath, AddWatchModePersistentRecursive) + if err != nil { + return nil, err + } + tc.start() + + return tc, nil +} + +func (tc *TreeCache) start() { + ch := toUnlimitedChannel(tc.events) + + tc.bootstrapRoot() + + go func() { + for e := range ch { + switch e.Type { + case EventWatching: + tc.bootstrapRoot() + case EventNodeCreated: + tc.nodeCreated(e.Path) + case EventNodeDataChanged: + tc.nodeDataChanged(e.Path) + case EventNodeDeleted: + tc.nodeDeleted(e.Path) + case EventNotWatching: + // EventNotWatching means that channel's about to close, and we can report the error that caused the + // closure. We don't zero out the data and stat to reflect the last known state of all the nodes. + var notWatching func(*treeCacheNode) + notWatching = func(n *treeCacheNode) { + n.lock.Lock() + defer n.lock.Unlock() + n.err = e.Err + for _, c := range n.children { + notWatching(c) + } + } + notWatching(tc.root) + } + if tc.outChan != nil { + tc.outChan <- e + } + } + if tc.outChan != nil { + close(tc.outChan) + } + }() +} + +func (tc *TreeCache) isPastMaxDepth(nodePath string) bool { + if tc.maxDepth <= 0 || nodePath == tc.RootPath { + return false + } + + nodePath = strings.TrimPrefix(nodePath, tc.RootPath+"/") + return strings.Count(nodePath, "/")+1 >= tc.maxDepth +} + +func (tc *TreeCache) nodeCreated(nodePath string) { + if tc.isPastMaxDepth(nodePath) { + return + } + + var n *treeCacheNode + if nodePath == tc.RootPath { + n = tc.root + } else { + dir, name := SplitPath(nodePath) + parent := tc.get(dir) + if parent == nil { + // This can happen if the node was created then immediately deleted while a bootstrap was occurring. + return + } + + child, ok := parent.children[name] + if ok { + n = child + } else { + n = newTreeCacheNode() + defer func() { // after the new node's data is updated, add it to its parent's children + parent.lock.Lock() + defer parent.lock.Unlock() + parent.children[name] = n + }() + } + } + + data, err := getNodeData(tc.retryPolicy, tc.stopChan, nodePath, tc.Conn) + if err == ErrNoNode { + tc.nodeDeleted(nodePath) + return + } + + n.lock.Lock() + defer n.lock.Unlock() + if err != nil { + n.err = err + } else { + n.data, n.err = data, nil + } +} + +func (tc *TreeCache) nodeDataChanged(nodePath string) { + if tc.isPastMaxDepth(nodePath) { + return + } + + var n *treeCacheNode + if nodePath == tc.RootPath { + n = tc.root + } else { + n = tc.get(nodePath) + if n == nil { + // This can happen if a number of now redundant events were queued up during a .bootstrap() call + return + } + } + + data, err := getNodeData(tc.retryPolicy, tc.stopChan, nodePath, tc.Conn) + if err == ErrNoNode { + tc.nodeDeleted(nodePath) + return + } + + n.lock.Lock() + defer n.lock.Unlock() + if err != nil { + n.err = err + } else { + n.data, n.err = data, nil + } +} + +func (tc *TreeCache) nodeDeleted(nodePath string) { + if tc.isPastMaxDepth(nodePath) { + return + } + + if nodePath == tc.RootPath { + tc.root.lock.Lock() + tc.root.data = nil + tc.root.children = map[string]*treeCacheNode{} + tc.root.err = ErrNoNode + tc.root.lock.Unlock() + } else { + dir, name := SplitPath(nodePath) + parent := tc.get(dir) + if parent == nil { + // This can happen if a number of now redundant events were queued up during a .bootstrap() call + return + } + parent.lock.Lock() + parent.lock.Unlock() + _, ok := parent.children[name] + if ok { + delete(parent.children, name) + } + } +} + +var ErrNotInWatchedSubtree = fmt.Errorf("zk: node path is not in watched subtree") +var ErrPastMaxDepth = fmt.Errorf("zk: node path is past maximum configured depth") + +func (tc *TreeCache) cleanAndCheckPath(nodePath string) (string, error) { + if !strings.HasPrefix(nodePath, tc.RootPath) { + return "", ErrNotInWatchedSubtree + } + if strings.HasSuffix(nodePath, "/") && nodePath != "/" { + nodePath = nodePath[:len(nodePath)-1] + } + if tc.isPastMaxDepth(nodePath) { + return "", ErrPastMaxDepth + } + return nodePath, nil +} + +// Get returns the node's most up-to-date +func (tc *TreeCache) Get(nodePath string) (data []byte, children []string, err error) { + nodePath, err = tc.cleanAndCheckPath(nodePath) + if err != nil { + return nil, nil, err + } + + n := tc.get(nodePath) + if n == nil { + return nil, nil, ErrNoNode + } + + n.lock.RLock() + defer n.lock.RUnlock() + data = n.data + for k := range n.children { + children = append(children, k) + } + return n.data, children, n.err +} + +func (tc *TreeCache) get(nodePath string) *treeCacheNode { + if nodePath == tc.RootPath { + return tc.root + } + + var relativeNodePath string + if tc.RootPath == "/" { + relativeNodePath = nodePath[1:] + } else { + relativeNodePath = nodePath[len(tc.RootPath)+1:] + } + segments := strings.Split(relativeNodePath, "/") + + node := tc.root + for _, s := range segments { + node.lock.RLock() + newNode, ok := node.children[s] + node.lock.RUnlock() + if !ok { + return nil + } + node = newNode + } + + return node +} + +// Refresh forces the immediate refresh of a node's data (not its children). +func (tc *TreeCache) Refresh(nodePath string) (data []byte, err error) { + nodePath, err = tc.cleanAndCheckPath(nodePath) + if err != nil { + return nil, err + } + + n := tc.get(nodePath) + if n != nil { + return nil, ErrNoNode + } + + data, err = getNodeData(tc.retryPolicy, tc.stopChan, nodePath, tc.Conn) + + n.lock.Lock() + defer n.lock.Unlock() + n.data, n.err = data, err + return n.data, n.err +} + +type NodeData struct { + Err error + Data []byte +} + +// Children constructs a map of all the children of the given node. +func (tc *TreeCache) Children(root string) (m map[string]NodeData) { + root, err := tc.cleanAndCheckPath(root) + if err != nil { + return map[string]NodeData{root: {Err: err}} + } + n := tc.get(root) + if n == nil { + return map[string]NodeData{root: {Err: ErrNoNode}} + } + + m = map[string]NodeData{} + + n.lock.RLock() + defer n.lock.RUnlock() + + for k, v := range n.children { + v.lock.RLock() + m[k] = NodeData{ + Err: v.err, + Data: v.data, + } + v.lock.RUnlock() + } + + return m +} + +// Tree recursively constructs a map of all the nodes the cache is aware of, starting at the given root. +func (tc *TreeCache) Tree(root string) (m map[string]NodeData) { + root, err := tc.cleanAndCheckPath(root) + if err != nil { + return map[string]NodeData{root: {Err: err}} + } + n := tc.get(root) + if n == nil { + return map[string]NodeData{root: {Err: ErrNoNode}} + } + + m = map[string]NodeData{} + + var tree func(string, *treeCacheNode) + tree = func(nodePath string, n *treeCacheNode) { + n.lock.RLock() + defer n.lock.RUnlock() + + m[nodePath] = NodeData{ + Err: n.err, + Data: n.data, + } + + for k, v := range n.children { + tree(JoinPath(nodePath, k), v) + } + } + + tree(root, n) + + return m +} + +func newTreeCacheNode() *treeCacheNode { + return &treeCacheNode{children: map[string]*treeCacheNode{}} +} + +func (tc *TreeCache) bootstrapRoot() { + tc.bootstrap(tc.RootPath, tc.root, 1) +} + +func (tc *TreeCache) bootstrap(nodePath string, n *treeCacheNode, depth int) (deleted bool) { + n.lock.Lock() + defer n.lock.Unlock() + + var data []byte + var children []string + var err error + if tc.maxDepth > 0 && depth >= tc.maxDepth { + data, err = getNodeData(tc.retryPolicy, tc.stopChan, nodePath, tc.Conn) + } else { + data, children, err = getNodeDataAndChildren(tc.retryPolicy, tc.stopChan, nodePath, tc.Conn) + } + if err != nil { + n.err = err + return n.err == ErrNoNode + } + + var childrenMap map[string]*treeCacheNode + + childrenMap = make(map[string]*treeCacheNode, len(children)) + for _, c := range children { + childrenMap[c] = newTreeCacheNode() + } + + for k, v := range childrenMap { + if tc.bootstrap(JoinPath(nodePath, k), v, depth+1) { + delete(childrenMap, k) + } + } + + n.data, n.children, n.err = data, childrenMap, nil + + return false +} + +// Stop removes the persistent watch that was created for this path. Returns zk.ErrNoWatcher if called more than once. +func (tc *TreeCache) Stop() (err error) { + err = tc.Conn.RemovePersistentWatch(tc.RootPath, tc.events) + if err == nil { + close(tc.stopChan) + } + return err +} diff --git a/treecache_test.go b/treecache_test.go new file mode 100644 index 00000000..b3b6debd --- /dev/null +++ b/treecache_test.go @@ -0,0 +1,316 @@ +package zk + +import ( + "fmt" + "math" + "reflect" + "sync" + "testing" + "time" +) + +func TestTreeCache(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + nodes := []string{ + "/root", + "/root/foo", + "/root/foo/bar", + "/root/foo/bar/baz", + "/root/asd", + } + + for _, node := range nodes { + _, err := zk.Create(node, []byte(node), 0, nil) + if err != nil { + t.Fatalf("Failed to create node %q: %+v", node, err) + } + } + + outChan := make(chan Event) + tc, err := NewTreeCacheWithOpts(zk, "/root", TreeCacheOpts{outChan: outChan}) + if err != nil { + t.Fatalf("Failed to create TreeCache: %+v", err) + } + + var tree map[string]NodeData + checkTree := func(k string) { + t.Helper() + v, ok := tree[k] + if !ok { + t.Fatalf("Could not find %q in tree", k) + } + if v.Err != nil { + t.Fatalf("%q had an error: %+v", k, v.Err) + } + if k != string(v.Data) { + t.Fatalf("Unexpected data for %q: expected %v, got %v", k, []byte(k), v.Data) + } + } + + tree = tc.Tree("/root") + if len(nodes) != len(tree) { + t.Fatalf("Incorrect node count from tree: expected %d got %d (%v)", len(nodes), len(tree), tree) + } + for _, node := range nodes { + checkTree(node) + } + + tree = tc.Tree("/root/foo/bar") + if len(tree) != 2 { + t.Fatalf("Incorrect node count from tree: expected 2 got %d", len(tree)) + } + checkTree("/root/foo/bar") + checkTree("/root/foo/bar/baz") + + zk.conn.Close() + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("create returned an error: %+v", err) + } + + newNode := "/root/foo/bar/foo" + _, err = zk2.Create(newNode, []byte(newNode), 0, nil) + if err != nil { + t.Fatalf("create returned an error: %+v", err) + } + + close(zk.reconnectLatch) + // wait for reconnect + select { + case e := <-outChan: + if e.Type != EventWatching { + t.Fatalf("Unexpected event %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not reconnect!") + } + + tree = tc.Tree("/root/foo/bar") + if len(tree) != 3 { + t.Fatalf("Incorrect node count from tree: expected 3 got %d (%v)", len(tree), tree) + } + checkTree("/root/foo/bar") + checkTree("/root/foo/bar/baz") + checkTree(newNode) + + tree = tc.Children("/root/foo") + if len(tree) != 1 { + t.Fatalf("Incorrect node count from tree: expected %d got %d (%v)", len(nodes), len(tree), tree) + } + if err = tree["bar"].Err; err != nil { + t.Fatalf("Unexpected error in %q: %+v", "/root/foo/bar", err) + } + if data := string(tree["bar"].Data); data != "/root/foo/bar" { + t.Fatalf("Unexpected data in %q: %q", "/root/foo/bar", data) + } + + err = tc.Stop() + if err != nil { + t.Fatalf("Stop returned an error: %+v", err) + } + + err = tc.Stop() + if err != ErrNoWatcher { + t.Fatalf("Unexpected error returned from Stop: %+v", err) + } + + select { + case e := <-outChan: + if e.Type != EventNotWatching { + t.Fatalf("Unexpected event %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not stop watching!") + } + + tree = tc.Tree(tc.RootPath) + if len(nodes)+1 != len(tree) { + t.Fatalf("Incorrect node count from tree: expected %d got %d (%v)", len(nodes), len(tree), tree) + } + + for k, v := range tree { + if v.Err == nil { + t.Fatalf("No error on %q", k) + } + } + + data, children, err := tc.Get(tc.RootPath) + if err == nil { + t.Fatalf("Get after stop did not return an error") + } + if string(data) != tc.RootPath { + t.Fatalf("Unepxected data in %q: %q", tc.RootPath, string(data)) + } + if !reflect.DeepEqual(children, []string{"asd", "foo"}) { + t.Fatalf("Unepxected children in %q: %q", tc.RootPath, children) + } + }) +} + +func TestTreeCacheStopDuringRetry(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + const root = "/foo" + + var inRetry sync.WaitGroup + inRetry.Add(1) + + bootStrapped := false + + outChan := make(chan Event) + tc, err := NewTreeCacheWithOpts(zk, root, TreeCacheOpts{ + RetryPolicy: RetryPolicyFunc(func(attempt int, lastError error) time.Duration { + // Don't block during bootstrap + if !bootStrapped { + bootStrapped = true + return -1 + } else { + inRetry.Done() + return math.MaxInt64 + } + }), + outChan: outChan, + }) + requireNoErrorf(t, err, "Could not create TreeCache: %+v", err) + + _, err = zk.Create(root, nil, 0, nil) + requireNoErrorf(t, err, "Could not create %q: %+v", root, err) + + // Updating the node then deleting will queue an EventNodeDataChanged then an EventNodeDeleted. + _, err = zk.Set(root, []byte{1}, -1) + requireNoErrorf(t, err, "Could not update %q: %+v", root, err) + + err = zk.Delete(root, -1) + requireNoErrorf(t, err, "Could not delete %q: %+v", root, err) + + // Because the cache is blocked until we read outChan, we know that when it tries to process the queued + // EventNodeDataChanged, it will return an ErrNoNode, forcing it to execute the RetryPolicy. This policy returns + // an infinite timeout so this test is now stuck until Stop is called + select { + case e := <-outChan: + if e.Type != EventNodeCreated { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get node creation") + } + + // Wait until the retry loop has started to stop the cache + inRetry.Wait() + err = tc.Stop() + requireNoErrorf(t, err, "Stop returned an unexpected error") + + select { + case e := <-outChan: + if e.Type != EventNodeDataChanged { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(1 * time.Second): + t.Fatalf("Did not early exit retry loop") + } + }) +} + +func TestTreeCacheMaxDepth(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + nodePath := "/root" + + expectedPaths := map[string]bool{} + + for i := 0; i < 10; i++ { + _, err := zk.Create(nodePath, nil, 0, nil) + requireNoErrorf(t, err, "could not create %q", nodePath) + if i < 5 { + expectedPaths[nodePath] = true + } + nodePath += fmt.Sprintf("/%d", i) + } + + outChan := make(chan Event) + tc, err := NewTreeCacheWithOpts(zk, "/root", TreeCacheOpts{MaxDepth: 5, outChan: outChan}) + requireNoErrorf(t, err, "could not create TreeCache") + + tree := tc.Tree("/root") + for k, v := range tree { + requireNoErrorf(t, v.Err, "%q has an error", k) + if !expectedPaths[k] { + t.Fatalf("TreeCache did not respect max depth: %q", k) + } else { + delete(expectedPaths, k) + } + } + if len(expectedPaths) > 0 { + t.Fatalf("Not all expected paths present in TreeCache: %+v", expectedPaths) + } + + newPath := "/root/foo" + _, err = zk.Create(newPath, nil, 0, nil) + requireNoErrorf(t, err, "could not create %q", newPath) + + t.Log(<-outChan) + + tree = tc.Tree("/root") + if v, ok := tree[newPath]; !ok || v.Err != nil { + t.Fatalf("Tree did not contain new node shallower than max depth (got %v)", tree) + } + + newPath = "/root/0/1/2/3/foo" // this has depth of 6 so it should be ignored + _, err = zk.Create(newPath, nil, 0, nil) + requireNoErrorf(t, err, "could not create %q", newPath) + + t.Log(<-outChan) + + tree = tc.Tree("/root") + if _, ok := tree[newPath]; ok { + t.Fatalf("Tree did not ignore node deeper than max depth: %q", newPath) + } + }) +} + +func TestTreeCache_isPastMaxDepth(t *testing.T) { + tests := []struct { + maxDepth int + paths map[string]bool + }{ + { + maxDepth: 0, + paths: map[string]bool{ + "/root": false, + "/root/foo": false, + "/root/foo/bar": false, + "/root/foo/bar/baz": false, + }, + }, + { + maxDepth: 1, + paths: map[string]bool{ + "/root": false, + "/root/foo": true, + "/root/foo/bar": true, + "/root/foo/bar/baz": true, + }, + }, + { + maxDepth: 2, + paths: map[string]bool{ + "/root": false, + "/root/foo": false, + "/root/foo/bar": true, + "/root/foo/bar/baz": true, + }, + }, + } + tc := &TreeCache{RootPath: "/root"} + for _, test := range tests { + tc.maxDepth = test.maxDepth + for path, isPast := range test.paths { + if actual := tc.isPastMaxDepth(path); actual != isPast { + t.Fatalf("isPastMaxDepth(%q) for maxDepth=%d returned %v, expected %v", path, tc.maxDepth, actual, isPast) + } + } + } +} diff --git a/unlimited_channel.go b/unlimited_channel.go new file mode 100644 index 00000000..6c920420 --- /dev/null +++ b/unlimited_channel.go @@ -0,0 +1,88 @@ +package zk + +import ( + "sync" +) + +type unlimitedChannelNode struct { + event Event + next *unlimitedChannelNode +} + +type unlimitedChannel struct { + head *unlimitedChannelNode + tail *unlimitedChannelNode + cond *sync.Cond + closed bool +} + +// toUnlimitedChannel uses a backing unlimitedChannel used to effectively turn a buffered channel into a channel with an +// infinite buffer by storing all incoming elements into a singly-linked queue and popping them as they are read. +func toUnlimitedChannel(in <-chan Event) <-chan Event { + q := &unlimitedChannel{cond: sync.NewCond(new(sync.Mutex))} + + go func() { + defer q.close() + for e := range in { + q.push(e) + } + }() + + out := make(chan Event) + go func() { + for { + e, closed := q.next() + if closed { + close(out) + return + } + out <- e + } + }() + + return out +} + +func (q *unlimitedChannel) push(e Event) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + if q.closed { + // Panic like a closed channel + panic("send on closed unlimited channel") + } + + if q.head == nil { + q.head = &unlimitedChannelNode{event: e} + q.tail = q.head + } else { + q.tail.next = &unlimitedChannelNode{event: e} + q.tail = q.tail.next + } + q.cond.Signal() +} + +func (q *unlimitedChannel) close() { + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.closed = true + q.cond.Signal() +} + +func (q *unlimitedChannel) next() (Event, bool) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + // Wait until the queue either has an element or has been closed + for q.head == nil && !q.closed { + q.cond.Wait() + } + + if q.head != nil { + e := q.head.event + q.head = q.head.next + return e, false + } else { // we know from the condition check above that if the head is nil, then the queue is closed + return Event{}, true + } +} diff --git a/unlimited_channel_test.go b/unlimited_channel_test.go new file mode 100644 index 00000000..7ccb77cc --- /dev/null +++ b/unlimited_channel_test.go @@ -0,0 +1,54 @@ +package zk + +import ( + "fmt" + "reflect" + "testing" +) + +func newEvent(i int) Event { + return Event{Path: fmt.Sprintf("/%d", i)} +} + +func TestQueue(t *testing.T) { + in := make(chan Event) + out := toUnlimitedChannel(in) + + // check that events can be pushed without consumers + for i := 0; i < 100; i++ { + in <- newEvent(i) + } + + events := 0 + for actual := range out { + expected := newEvent(events) + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Did not receive expected event from queue: actual %+v expected %+v", actual, expected) + } + events++ + if events == 100 { + close(in) + } + } + if events != 100 { + t.Fatalf("Did not receive 100 events") + } +} + +func TestQueueClose(t *testing.T) { + in := make(chan Event) + out := toUnlimitedChannel(in) + + in <- Event{} + close(in) + + _, ok := <-out + if !ok { + t.Fatalf("Closed inifinite queue did not drain remamining events") + } + + _, ok = <-out + if ok { + t.Fatalf("Too many events returned by closed queue") + } +} diff --git a/zk_test.go b/zk_test.go index 131118fd..01a7d7f9 100644 --- a/zk_test.go +++ b/zk_test.go @@ -186,22 +186,22 @@ func TestCreateContainer(t *testing.T) { } func TestIncrementalReconfig(t *testing.T) { - requireMinimumZkVersion(t, "3.5") + RequireMinimumZkVersion(t, "3.5") ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() // start and add a new server. tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") defer os.RemoveAll(tmpPath) startPort := int(rand.Int31n(6000) + 10000) srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv4")) if err := os.Mkdir(srvPath, 0700); err != nil { - requireNoError(t, err, "failed to make server path") + requireNoErrorf(t, err, "failed to make server path") } testSrvConfig := ServerConfigServer{ ID: 4, @@ -218,35 +218,35 @@ func TestIncrementalReconfig(t *testing.T) { // TODO: clean all this server creating up to a better helper method cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintln(fi, "4") fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) testServer, err := NewIntegrationTestServer(t, cfgPath, nil, nil) - requireNoError(t, err) - requireNoError(t, testServer.Start()) + requireNoErrorf(t, err) + requireNoErrorf(t, testServer.Start()) defer testServer.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -262,7 +262,7 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to remove node from cluster") + requireNoErrorf(t, err, "failed to remove node from cluster") } // add node a new 4th node @@ -271,30 +271,30 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to add new server to cluster") + requireNoErrorf(t, err, "failed to add new server to cluster") } } func TestReconfig(t *testing.T) { - requireMinimumZkVersion(t, "3.5") + RequireMinimumZkVersion(t, "3.5") // This test enures we can do an non-incremental reconfig ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -308,7 +308,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") // reconfig to all the hosts again s = []string{} @@ -317,7 +317,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") } func TestOpsAfterCloseDontDeadlock(t *testing.T) { @@ -440,6 +440,121 @@ func TestIfAuthdataSurvivesReconnect(t *testing.T) { } } +func TestPersistentWatchOnReconnect(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk2.Close() + + const testNode = "/gozk-test" + + if err := zk.Delete(testNode, -1); err != nil && err != ErrNoNode { + t.Fatalf("Delete returned error: %+v", err) + } + + watchEventsCh, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) + if err != nil { + t.Fatalf("AddPersistentWatch returned error: %+v", err) + } + + _, err = zk2.Create(testNode, []byte{1}, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned an error: %+v", err) + } + + // check to see that we received the node creation event + select { + case ev := <-watchEventsCh: + if ev.Type != EventNodeCreated { + t.Fatalf("Second event on persistent watch was not a node creation event: %+v", ev) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("Persistent watcher for %q did not receive node creation event", testNode) + } + + // Simulate network error by brutally closing the network connection. + zk.conn.Close() + + _, err = zk2.Set(testNode, []byte{2}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + + // zk should still be waiting to reconnect, so none of the watches should have been triggered + select { + case <-watchEventsCh: + t.Fatalf("Persistent watcher for %q should not have triggered yet", testNode) + case <-time.After(100 * time.Millisecond): + } + + // now we let the reconnect occur and make sure it resets watches + close(zk.reconnectLatch) + + // wait for reconnect event + select { + case ev := <-watchEventsCh: + if ev.Type != EventWatching { + t.Fatalf("Persistent watcher did not receive reconnect event: %+v", ev) + } + case <-time.After(5 * time.Second): + t.Fatalf("Persistent watcher for %q did not receive connection event", testNode) + } + + eventsReceived := 0 + timeout := time.After(2 * time.Second) + secondTimeout := time.After(4 * time.Second) + for { + select { + case e := <-watchEventsCh: + if e.Type != EventNodeDataChanged { + t.Fatalf("Unexpected event received by persistent watcher: %+v", e) + } + eventsReceived++ + case <-timeout: + _, err = zk2.Set(testNode, []byte{3}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + case <-secondTimeout: + switch eventsReceived { + case 2: + t.Fatalf("Sanity check failed: the persistent watch logic is based around the assumption that the " + + "setWatchers call _does not_ bootstrap you on reconnect based on the relative Zxid (unlike " + + "standard watches).") + case 1: + return + default: + t.Fatalf("Received no events after reconnect") + } + } + } + }) +} + +func TestPersistentWatchOnClose(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(_ *TestCluster, zk *Conn) { + ch, err := zk.AddPersistentWatch("/", AddWatchModePersistent) + if err != nil { + t.Fatalf("Could not add persistent watch: %+v", err) + } + zk.Close() + select { + case e := <-ch: + if e.Type != EventNotWatching { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get disconnect event") + } + }) +} + func TestMultiFailures(t *testing.T) { // This test case ensures that we return the errors associated with each // opeThis in the event a call to Multi() fails.