diff --git a/Makefile b/Makefile index f0b7965c..20a40f7a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # make file to hold the logic of build and test setup -ZK_VERSION ?= 3.5.6 +ZK_VERSION ?= 3.6.3 # Apache changed the name of the archive in version 3.5.x and seperated out # src and binary packages @@ -20,10 +20,12 @@ $(ZK): tar -zxf $(ZK).tar.gz rm $(ZK).tar.gz +.PHONY: zookeeper zookeeper: $(ZK) # we link to a standard directory path so then the tests dont need to find based on version # in the test code. this allows backward compatable testing. - ln -s $(ZK) zookeeper + rm -f $@ + ln -s $(ZK) $@ .PHONY: setup setup: zookeeper diff --git a/conn.go b/conn.go index 9afd2d27..0550c452 100644 --- a/conn.go +++ b/conn.go @@ -47,8 +47,14 @@ const ( watchTypeData watchType = iota watchTypeExist watchTypeChild + watchTypePersistent + watchTypePersistentRecursive ) +func (w watchType) isPersistent() bool { + return w == watchTypePersistent || w == watchTypePersistentRecursive +} + type watchPathType struct { path string wType watchType @@ -89,6 +95,7 @@ type Conn struct { recvTimeout time.Duration connectTimeout time.Duration maxBufferSize int + metricReceiver MetricReceiver creds []authCreds credsMu sync.Mutex // protects server @@ -96,7 +103,7 @@ type Conn struct { sendChan chan *request requests map[int32]*request // Xid -> pending request requestsLock sync.Mutex - watchers map[watchPathType][]chan Event + watchers map[watchPathType][]EventQueue watchersLock sync.Mutex closeChan chan struct{} // channel to tell send loop stop @@ -199,12 +206,13 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), requests: make(map[int32]*request), - watchers: make(map[watchPathType][]chan Event), + watchers: make(map[watchPathType][]EventQueue), passwd: emptyPassword, logger: DefaultLogger, logInfo: true, // default is true for backwards compatability buf: make([]byte, bufferSize), resendZkAuthFn: resendZkAuth, + metricReceiver: UnimplementedMetricReceiver{}, } // Set provided options. @@ -310,6 +318,12 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption { } } +func WithMetricReceiver(mr MetricReceiver) connOption { + return func(c *Conn) { + c.metricReceiver = mr + } +} + // Close will submit a close request with ZK and signal the connection to stop // sending and receiving packets. func (c *Conn) Close() { @@ -530,29 +544,44 @@ func (c *Conn) flushRequests(err error) { c.requestsLock.Unlock() } +var eventWatchTypes = map[EventType][]watchType{ + EventNodeCreated: {watchTypeExist, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeDataChanged: {watchTypeExist, watchTypeData, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeChildrenChanged: {watchTypeChild, watchTypePersistent}, + EventNodeDeleted: {watchTypeExist, watchTypeData, watchTypeChild, watchTypePersistent, watchTypePersistentRecursive}, +} +var persistentWatchTypes = []watchType{watchTypePersistent, watchTypePersistentRecursive} + // Send event to all interested watchers func (c *Conn) notifyWatches(ev Event) { - var wTypes []watchType - switch ev.Type { - case EventNodeCreated: - wTypes = []watchType{watchTypeExist} - case EventNodeDataChanged: - wTypes = []watchType{watchTypeExist, watchTypeData} - case EventNodeChildrenChanged: - wTypes = []watchType{watchTypeChild} - case EventNodeDeleted: - wTypes = []watchType{watchTypeExist, watchTypeData, watchTypeChild} + wTypes := eventWatchTypes[ev.Type] + if len(wTypes) == 0 { + return } + c.watchersLock.Lock() defer c.watchersLock.Unlock() + + broadcast := func(wpt watchPathType) { + for _, ch := range c.watchers[wpt] { + ch.push(ev) + if !wpt.wType.isPersistent() { + ch.close() + delete(c.watchers, wpt) + } + } + } + for _, t := range wTypes { - wpt := watchPathType{ev.Path, t} - if watchers := c.watchers[wpt]; len(watchers) > 0 { - for _, ch := range watchers { - ch <- ev - close(ch) + if t == watchTypePersistentRecursive { + for p := ev.Path; ; p, _ = SplitPath(p) { + broadcast(watchPathType{p, t}) + if p == "/" { + break + } } - delete(c.watchers, wpt) + } else { + broadcast(watchPathType{ev.Path, t}) } } } @@ -562,16 +591,23 @@ func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - if len(c.watchers) >= 0 { + if len(c.watchers) > 0 { for pathType, watchers := range c.watchers { + if err == ErrSessionExpired && pathType.wType.isPersistent() { + // Ignore ErrSessionExpired for persistent watchers as the client will either automatically reconnect, + // or this is a shutdown-worthy error in which case there will be a followup invocation of this method + // with ErrClosing + continue + } + ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} c.sendEvent(ev) // also publish globally for _, ch := range watchers { - ch <- ev - close(ch) + ch.push(ev) + ch.close() } + delete(c.watchers, pathType) } - c.watchers = make(map[watchPathType][]chan Event) } } @@ -610,12 +646,7 @@ func (c *Conn) sendSetWatches() { reqs = append(reqs, req) } sizeSoFar = 28 // fixed overhead of a set-watches packet - req = &setWatchesRequest{ - RelativeZxid: c.lastZxid, - DataWatches: make([]string, 0), - ExistWatches: make([]string, 0), - ChildWatches: make([]string, 0), - } + req = &setWatchesRequest{RelativeZxid: c.lastZxid} } sizeSoFar += addlLen switch pathType.wType { @@ -625,6 +656,10 @@ func (c *Conn) sendSetWatches() { req.ExistWatches = append(req.ExistWatches, pathType.path) case watchTypeChild: req.ChildWatches = append(req.ChildWatches, pathType.path) + case watchTypePersistent: + req.PersistentWatches = append(req.PersistentWatches, pathType.path) + case watchTypePersistentRecursive: + req.PersistentRecursiveWatches = append(req.PersistentRecursiveWatches, pathType.path) } n++ } @@ -646,7 +681,37 @@ func (c *Conn) sendSetWatches() { // aren't failure modes where a blocking write to the channel of requests // could hang indefinitely and cause this goroutine to leak... for _, req := range reqs { - _, err := c.request(opSetWatches, req, res, nil) + var op int32 = opSetWatches + if len(req.PersistentWatches) > 0 || len(req.PersistentRecursiveWatches) > 0 { + // to maintain compatibility with older servers, only send opSetWatches2 if persistent watches are used + op = opSetWatches2 + } + + _, err := c.request(op, req, res, func(r *request, header *responseHeader, err error) { + if err == nil && op == opSetWatches2 { + // If the setWatches was successful, notify the persistent watchers they've been reconnected. + // Because we process responses in one routine, we know that the following will execute before + // subsequent responses are processed. This means we won't end up in a situation where events are + // sent to watchers before the reconnect event is sent. + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + var paths []string + if wt == watchTypePersistent { + paths = req.PersistentWatches + } else { + paths = req.PersistentRecursiveWatches + } + for _, p := range paths { + e := Event{Type: EventWatching, State: StateConnected, Path: p} + c.sendEvent(e) // also publish globally + for _, ch := range c.watchers[watchPathType{path: p, wType: wt}] { + ch.push(e) + } + } + } + } + }) if err != nil { c.logger.Printf("Failed to set previous watches: %v", err) break @@ -784,6 +849,7 @@ func (c *Conn) sendLoop() error { c.conn.Close() return err } + c.metricReceiver.PingSent() case <-c.closeChan: return nil } @@ -827,21 +893,22 @@ func (c *Conn) recvLoop(conn net.Conn) error { } if res.Xid == -1 { - res := &watcherEvent{} - _, err = decodePacket(buf[16:blen], res) + we := &watcherEvent{} + _, err = decodePacket(buf[16:blen], we) if err != nil { return err } ev := Event{ - Type: res.Type, - State: res.State, - Path: res.Path, + Type: we.Type, + State: we.State, + Path: we.Path, Err: nil, } c.sendEvent(ev) c.notifyWatches(ev) } else if res.Xid == -2 { // Ping response. Ignore. + c.metricReceiver.PongReceived() } else if res.Xid < 0 { c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) } else { @@ -880,14 +947,15 @@ func (c *Conn) nextXid() int32 { return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) } -func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { +func (c *Conn) addWatcher(path string, watchType watchType, ch EventQueue) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - ch := make(chan Event, 1) wpt := watchPathType{path, watchType} c.watchers[wpt] = append(c.watchers[wpt], ch) - return ch + if watchType.isPersistent() { + ch.push(Event{Type: EventWatching, State: StateConnected, Path: path}) + } } func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { @@ -928,7 +996,12 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv return rq.recvChan } -func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { +func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (_ int64, err error) { + start := time.Now() + defer func() { + c.metricReceiver.RequestCompleted(time.Now().Sub(start), err) + }() + recv := c.queueRequest(opcode, req, res, recvFunc) select { case r := <-recv: @@ -981,7 +1054,7 @@ func (c *Conn) Children(path string) ([]string, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Children, &res.Stat, err + return res.Children, res.Stat, err } // ChildrenW returns the children of a znode and sets a watch. @@ -990,17 +1063,18 @@ func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeChild) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeChild, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Children, &res.Stat, ech, err + return res.Children, res.Stat, ech, err } // Get gets the contents of a znode. @@ -1014,7 +1088,7 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Data, &res.Stat, err + return res.Data, res.Stat, err } // GetW returns the contents of a znode and sets a watch @@ -1023,17 +1097,18 @@ func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeData) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeData, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Data, &res.Stat, ech, err + return res.Data, res.Stat, ech, err } // Set updates the contents of a znode. @@ -1047,10 +1122,10 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { if err == ErrConnectionClosed { return nil, err } - return &res.Stat, err + return res.Stat, err } -// Create creates a znode. +// Create creates a znode. If acl is empty, it uses the global WorldACL with PermAll // The returned path is the new path assigned by the server, it may not be the // same as the input, for example when creating a sequence znode the returned path // will be the input path with a sequence number appended. @@ -1059,6 +1134,10 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return "", err } + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) if err == ErrConnectionClosed { @@ -1067,6 +1146,24 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return res.Path, err } +// CreateAndReturnStat is the equivalent of Create, but it also returns the Stat of the created node. +func (c *Conn) CreateAndReturnStat(path string, data []byte, flags int32, acl []ACL) (string, *Stat, error) { + if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { + return "", nil, err + } + + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + + res := &create2Response{} + _, err := c.request(opCreate2, &CreateRequest{path, data, acl, flags}, res, nil) + if err == ErrConnectionClosed { + return "", nil, err + } + return res.Path, res.Stat, err +} + // CreateContainer creates a container znode and returns the path. func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) { if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { @@ -1170,7 +1267,7 @@ func (c *Conn) Exists(path string) (bool, *Stat, error) { exists = false err = nil } - return exists, &res.Stat, err + return exists, res.Stat, err } // ExistsW tells the existence of a znode and sets a watch. @@ -1179,13 +1276,14 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { return false, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + ech = newChanEventChannel() if err == nil { - ech = c.addWatcher(path, watchTypeData) + c.addWatcher(path, watchTypeData, ech) } else if err == ErrNoNode { - ech = c.addWatcher(path, watchTypeExist) + c.addWatcher(path, watchTypeExist, ech) } }) exists := true @@ -1196,7 +1294,7 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { if err != nil { return false, nil, nil, err } - return exists, &res.Stat, ech, err + return exists, res.Stat, ech, err } // GetACL gets the ACLs of a znode. @@ -1210,7 +1308,7 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Acl, &res.Stat, err + return res.Acl, res.Stat, err } // SetACL updates the ACLs of a znode. @@ -1224,7 +1322,7 @@ func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { if err == ErrConnectionClosed { return nil, err } - return &res.Stat, err + return res.Stat, err } // Sync flushes the channel between process and the leader of a given znode, @@ -1255,8 +1353,7 @@ type MultiResponse struct { // *CheckVersionRequest. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { req := &multiRequest{ - Ops: make([]multiRequestOp, 0, len(ops)), - DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, + Ops: make([]multiRequestOp, 0, len(ops)), } for _, op := range ops { var opCode int32 @@ -1286,6 +1383,44 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { return mr, err } +// MultiRead executes multiple ZooKeeper read operations at once. The provided ops must be one of GetDataOp or +// GetChildrenOp. Returns an error on network or connectivity errors, not on any op errors such as ErrNoNode. To check +// if any ops failed, check the corresponding MultiReadResponse.Err. +func (c *Conn) MultiRead(ops ...ReadOp) ([]MultiReadResponse, error) { + req := &multiRequest{ + Ops: make([]multiRequestOp, len(ops)), + } + for i, op := range ops { + req.Ops[i] = multiRequestOp{ + Header: multiHeader{op.opCode(), false, -1}, + Op: pathWatchRequest{Path: op.GetPath()}, + } + } + res := &multiReadResponse{} + _, err := c.request(opMultiRead, req, res, nil) + return res.OpResults, err +} + +// GetDataAndChildren executes a multi-read to get the given node's data and its children in one call. +func (c *Conn) GetDataAndChildren(path string) ([]byte, *Stat, []string, error) { + if err := validatePath(path, false); err != nil { + return nil, nil, nil, err + } + + opResults, err := c.MultiRead(GetDataOp(path), GetChildrenOp(path)) + if err != nil { + return nil, nil, nil, err + } + + for _, r := range opResults { + if r.Err != nil { + return nil, nil, nil, r.Err + } + } + + return opResults[0].Data, opResults[0].Stat, opResults[1].Children, nil +} + // IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers // by lists of members. For more info refer to the ZK documentation. // @@ -1321,7 +1456,7 @@ func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) { func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) { response := &reconfigReponse{} _, err := c.request(opReconfig, request, response, nil) - return &response.Stat, err + return response.Stat, err } // Server returns the current or last-connected server name. @@ -1331,6 +1466,97 @@ func (c *Conn) Server() string { return c.server } +func (c *Conn) AddPersistentWatch(path string, mode AddWatchMode) (ch EventQueue, err error) { + if err = validatePath(path, false); err != nil { + return nil, err + } + + res := &addWatchResponse{} + _, err = c.request(opAddWatch, &addWatchRequest{Path: path, Mode: mode}, res, func(r *request, header *responseHeader, err error) { + if err == nil { + var wt watchType + if mode == AddWatchModePersistent { + wt = watchTypePersistent + } else { + wt = watchTypePersistentRecursive + } + + ch = newUnlimitedEventQueue() + c.addWatcher(path, wt, ch) + } + }) + if err == ErrConnectionClosed { + return nil, err + } + return ch, err +} + +func (c *Conn) RemovePersistentWatch(path string, ch EventQueue) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + deleted := false + + res := &checkWatchesResponse{} + _, err = c.request(opCheckWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for i, w := range c.watchers[wpt] { + if w == ch { + deleted = true + c.watchers[wpt] = append(c.watchers[wpt][:i], c.watchers[wpt][i+1:]...) + w.push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + w.close() + return + } + } + } + }) + + if err != nil { + return err + } + + if !deleted { + return ErrNoWatcher + } + + return nil +} + +func (c *Conn) RemoveAllPersistentWatches(path string) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + res := &checkWatchesResponse{} + _, err = c.request(opRemoveWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for _, ch := range c.watchers[wpt] { + ch.push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + ch.close() + } + delete(c.watchers, wpt) + } + }) + return err +} + func resendZkAuth(ctx context.Context, c *Conn) error { shouldCancel := func() bool { select { @@ -1389,3 +1615,23 @@ func resendZkAuth(ctx context.Context, c *Conn) error { return nil } + +func JoinPath(parent, child string) string { + if !strings.HasSuffix(parent, "/") { + parent += "/" + } + if strings.HasPrefix(child, "/") { + child = child[1:] + } + return parent + child +} + +func SplitPath(path string) (dir, name string) { + i := strings.LastIndex(path, "/") + if i == 0 { + dir, name = "/", path[1:] + } else { + dir, name = path[:i], path[i+1:] + } + return dir, name +} diff --git a/conn_test.go b/conn_test.go index 96299280..a6cc998e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io/ioutil" + "strings" "sync" "testing" "time" @@ -113,85 +114,144 @@ func TestDeadlockInClose(t *testing.T) { } func TestNotifyWatches(t *testing.T) { + queueImpls := []struct { + name string + new func() EventQueue + }{ + { + name: "chan", + new: func() EventQueue { return newChanEventChannel() }, + }, + { + name: "unlimited", + new: func() EventQueue { return newUnlimitedEventQueue() }, + }, + } + cases := []struct { eType EventType path string watches map[watchPathType]bool }{ { - EventNodeCreated, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeChild}: false, - {"/", watchTypeData}: false, - }, - }, - { - EventNodeCreated, "/a", - map[watchPathType]bool{ + eType: EventNodeCreated, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, {"/b", watchTypeExist}: false, + + {"/a", watchTypeChild}: false, + + {"/a", watchTypeData}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeDataChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: false, + eType: EventNodeDataChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeChildrenChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: false, - {"/", watchTypeData}: false, - {"/", watchTypeChild}: true, + eType: EventNodeChildrenChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: false, + {"/a", watchTypeData}: false, + {"/a", watchTypeChild}: true, + {"/a", watchTypePersistent}: true, + {"/a", watchTypePersistentRecursive}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: false, + {"/", watchTypePersistentRecursive}: false, }, }, { - EventNodeDeleted, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: true, + eType: EventNodeDeleted, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: true, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, } - conn := &Conn{watchers: make(map[watchPathType][]chan Event)} - - for idx, c := range cases { - t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { - c := c - - notifications := make([]struct { - path string - notify bool - ch <-chan Event - }, len(c.watches)) - - var idx int - for wpt, expectEvent := range c.watches { - ch := conn.addWatcher(wpt.path, wpt.wType) - notifications[idx].path = wpt.path - notifications[idx].notify = expectEvent - notifications[idx].ch = ch - idx++ - } - ev := Event{Type: c.eType, Path: c.path} - conn.notifyWatches(ev) - - for _, res := range notifications { - select { - case e := <-res.ch: - if !res.notify || e.Path != res.path { - t.Fatal("unexpeted notification received") + for _, impl := range queueImpls { + t.Run(impl.name, func(t *testing.T) { + for idx, c := range cases { + c := c + t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { + notifications := make([]struct { + watchPathType + notify bool + ch EventQueue + }, len(c.watches)) + + conn := &Conn{watchers: make(map[watchPathType][]EventQueue)} + + var idx int + for wpt, expectEvent := range c.watches { + notifications[idx].watchPathType = wpt + notifications[idx].notify = expectEvent + ch := impl.new() + conn.addWatcher(wpt.path, wpt.wType, ch) + notifications[idx].ch = ch + if wpt.wType.isPersistent() { + e, _ := ch.Next(context.Background()) + if e.Type != EventWatching { + t.Fatalf("First event on persistent watcher should always be EventWatching") + } + } + idx++ } - default: - if res.notify { - t.Fatal("expected notification not received") + + conn.notifyWatches(Event{Type: c.eType, Path: c.path}) + + for _, res := range notifications { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + t.Cleanup(cancel) + + e, err := res.ch.Next(ctx) + if err == nil { + isPathCorrect := + (res.wType == watchTypePersistentRecursive && strings.HasPrefix(e.Path, res.path)) || + e.Path == res.path + if !res.notify || !isPathCorrect { + t.Logf("unexpeted notification received by %+v: %+v", res, e) + t.Fail() + } + } else { + if res.notify { + t.Logf("expected notification not received for %+v", res) + t.Fail() + } + } } - } + }) } }) } diff --git a/constants.go b/constants.go index 84455d2b..0add0b7e 100644 --- a/constants.go +++ b/constants.go @@ -26,12 +26,18 @@ const ( opGetChildren2 = 12 opCheck = 13 opMulti = 14 + opCreate2 = 15 opReconfig = 16 + opCheckWatches = 17 + opRemoveWatches = 18 opCreateContainer = 19 opCreateTTL = 21 + opMultiRead = 22 opClose = -11 opSetAuth = 100 opSetWatches = 101 + opSetWatches2 = 105 + opAddWatch = 106 opError = -1 // Not in protocol, used internally opWatcherEvent = -2 @@ -47,6 +53,7 @@ const ( // EventSession represents a session event. EventSession EventType = -1 EventNotWatching EventType = -2 + EventWatching EventType = -3 ) var ( @@ -57,6 +64,7 @@ var ( EventNodeChildrenChanged: "EventNodeChildrenChanged", EventSession: "EventSession", EventNotWatching: "EventNotWatching", + EventWatching: "EventWatching", } ) @@ -129,6 +137,8 @@ var ( ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") ErrBadArguments = errors.New("invalid arguments") + ErrNoWatcher = errors.New("zk: no such watcher") + ErrUnimplemented = errors.New("zk: Not implemented") // ErrInvalidCallback = errors.New("zk: invalid callback specified") errCodeToError = map[ErrCode]error{ @@ -147,8 +157,10 @@ var ( errClosing: ErrClosing, errNothing: ErrNothing, errSessionMoved: ErrSessionMoved, + errNoWatcher: ErrNoWatcher, errZReconfigDisabled: ErrReconfigDisabled, errBadArguments: ErrBadArguments, + errUnimplemented: ErrUnimplemented, } ) @@ -186,6 +198,7 @@ const ( errClosing ErrCode = -116 errNothing ErrCode = -117 errSessionMoved ErrCode = -118 + errNoWatcher ErrCode = -121 // Attempts to perform a reconfiguration operation when reconfiguration feature is disabled errZReconfigDisabled ErrCode = -123 ) @@ -224,6 +237,7 @@ var ( opClose: "close", opSetAuth: "setAuth", opSetWatches: "setWatches", + opAddWatch: "addWatch", opWatcherEvent: "watcherEvent", } @@ -263,3 +277,33 @@ var ( ModeStandalone: "standalone", } ) + +// AddWatchMode asd +type AddWatchMode int32 + +func (m AddWatchMode) String() string { + if name, ok := addWatchModeNames[m]; ok { + return name + } + return "unknown" +} + +const ( + AddWatchModePersistent AddWatchMode = iota + AddWatchModePersistentRecursive AddWatchMode = iota +) + +var ( + addWatchModeNames = map[AddWatchMode]string{ + AddWatchModePersistent: "persistent", + AddWatchModePersistentRecursive: "persistentRecursive", + } +) + +type WatcherType int32 + +const ( + WatcherTypeChildren = WatcherType(1) + WatcherTypeData = WatcherType(2) + WatcherTypeAny = WatcherType(3) +) diff --git a/go.mod b/go.mod index a2662730..f47e0825 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/go-zookeeper/zk +module github.com/PapaCharlie/zk -go 1.13 +go 1.18 diff --git a/metrics.go b/metrics.go new file mode 100644 index 00000000..9294d3ed --- /dev/null +++ b/metrics.go @@ -0,0 +1,20 @@ +package zk + +import ( + "time" +) + +type MetricReceiver interface { + PingSent() + PongReceived() + RequestCompleted(duration time.Duration, err error) +} + +var _ MetricReceiver = UnimplementedMetricReceiver{} + +type UnimplementedMetricReceiver struct { +} + +func (u UnimplementedMetricReceiver) PingSent() {} +func (u UnimplementedMetricReceiver) PongReceived() {} +func (u UnimplementedMetricReceiver) RequestCompleted(time.Duration, error) {} diff --git a/server_help_test.go b/server_help_test.go index 6a49ad2b..4ea1e7d0 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -7,6 +7,8 @@ import ( "math/rand" "os" "path/filepath" + "runtime/debug" + "strconv" "strings" "testing" "time" @@ -34,6 +36,38 @@ type TestCluster struct { Servers []TestServer } +func WithTestCluster(t *testing.T, testTimeout time.Duration, f func(ts *TestCluster, zk *Conn)) { + ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + ts.Stop() + }) + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + t.Cleanup(func() { + zk.Close() + }) + doneChan := make(chan struct{}) + go func() { + defer func() { + close(doneChan) + if r := recover(); r != nil { + t.Error(r, string(debug.Stack())) + } + }() + f(ts, zk) + }() + select { + case <-doneChan: + case <-time.After(testTimeout): + t.Fatalf("Test did not complete within timeout") + } +} + // TODO: pull this into its own package to allow for better isolation of integration tests vs. unit // testing. This should be used on CI systems and local only when needed whereas unit tests should remain // fast and not rely on external dependencies. @@ -53,7 +87,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl } tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") success := false startPort := int(rand.Int31n(6000) + 10000) @@ -67,7 +101,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN+1)) - requireNoError(t, os.Mkdir(srvPath, 0700), "failed to make server path") + requireNoErrorf(t, os.Mkdir(srvPath, 0700), "failed to make server path") port := startPort + serverN*3 cfg := ServerConfig{ @@ -88,20 +122,20 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) srv, err := NewIntegrationTestServer(t, cfgPath, stdout, stderr) - requireNoError(t, err) + requireNoErrorf(t, err) if err := srv.Start(); err != nil { return nil, err @@ -251,9 +285,36 @@ func (tc *TestCluster) StopAllServers() error { return nil } -func requireNoError(t *testing.T, err error, msgAndArgs ...interface{}) { +func requireNoErrorf(t *testing.T, err error, msgAndArgs ...interface{}) { if err != nil { + t.Helper() t.Logf("received unexpected error: %v", err) - t.Fatal(msgAndArgs...) + t.Fatalf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } +} + +func RequireMinimumZkVersion(t *testing.T, minimum string) { + if val, ok := os.LookupEnv("ZK_VERSION"); ok { + split := func(v string) (parts []int) { + for _, s := range strings.Split(minimum, ".") { + i, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("invalid version segment: %q", s) + } + parts = append(parts, i) + } + return parts + } + + minimumV, actualV := split(minimum), split(val) + for i, p := range minimumV { + if actualV[i] < p { + if !strings.HasPrefix(val, minimum) { + t.Skipf("running with zookeeper that does not support this api (requires at least %s)", minimum) + } + } + } + } else { + t.Skip("did not detect zk_version from env. skipping test") } } diff --git a/structs.go b/structs.go index 8eb41e39..66c10463 100644 --- a/structs.go +++ b/structs.go @@ -3,6 +3,7 @@ package zk import ( "encoding/binary" "errors" + "fmt" "log" "reflect" "runtime" @@ -18,8 +19,8 @@ var ( type defaultLogger struct{} -func (defaultLogger) Printf(format string, a ...interface{}) { - log.Printf(format, a...) +func (defaultLogger) Printf(format string, v ...interface{}) { + log.Output(3, fmt.Sprintf(format, v...)) } type ACL struct { @@ -135,7 +136,7 @@ type pathResponse struct { } type statResponse struct { - Stat Stat + Stat *Stat } // @@ -177,6 +178,10 @@ type CreateTTLRequest struct { } type createResponse pathResponse +type create2Response struct { + Path string + Stat *Stat +} type DeleteRequest PathVersionRequest type deleteResponse struct{} @@ -190,7 +195,7 @@ type getAclRequest pathRequest type getAclResponse struct { Acl []ACL - Stat Stat + Stat *Stat } type getChildrenRequest pathRequest @@ -199,18 +204,59 @@ type getChildrenResponse struct { Children []string } +type ReadOp interface { + GetPath() string + IsGetData() bool + IsGetChildren() bool + opCode() int32 +} + type getChildren2Request pathWatchRequest +type GetChildrenOp string + +func (g GetChildrenOp) IsGetData() bool { + return false +} + +func (g GetChildrenOp) IsGetChildren() bool { + return true +} + +func (g GetChildrenOp) GetPath() string { + return string(g) +} + +func (g GetChildrenOp) opCode() int32 { + return opGetChildren +} type getChildren2Response struct { Children []string - Stat Stat + Stat *Stat } type getDataRequest pathWatchRequest +type GetDataOp string + +func (g GetDataOp) IsGetData() bool { + return true +} + +func (g GetDataOp) IsGetChildren() bool { + return false +} + +func (g GetDataOp) GetPath() string { + return string(g) +} + +func (g GetDataOp) opCode() int32 { + return opGetData +} type getDataResponse struct { Data []byte - Stat Stat + Stat *Stat } type getMaxChildrenRequest pathRequest @@ -256,10 +302,12 @@ type setSaslResponse struct { } type setWatchesRequest struct { - RelativeZxid int64 - DataWatches []string - ExistWatches []string - ChildWatches []string + RelativeZxid int64 + DataWatches []string + ExistWatches []string + ChildWatches []string + PersistentWatches []string + PersistentRecursiveWatches []string } type setWatchesResponse struct{} @@ -275,8 +323,7 @@ type multiRequestOp struct { Op interface{} } type multiRequest struct { - Ops []multiRequestOp - DoneHeader multiHeader + Ops []multiRequestOp } type multiResponseOp struct { Header multiHeader @@ -285,8 +332,15 @@ type multiResponseOp struct { Err ErrCode } type multiResponse struct { - Ops []multiResponseOp - DoneHeader multiHeader + Ops []multiResponseOp +} +type MultiReadResponse struct { + getDataResponse + getChildrenResponse + Err error +} +type multiReadResponse struct { + OpResults []MultiReadResponse } // zk version 3.5 reconfig API @@ -301,6 +355,20 @@ type reconfigRequest struct { type reconfigReponse getDataResponse +type addWatchRequest struct { + Path string + Mode AddWatchMode +} + +type addWatchResponse struct{} + +type checkWatchesRequest struct { + Path string + Type WatcherType +} + +type checkWatchesResponse struct{} + func (r *multiRequest) Encode(buf []byte) (int, error) { total := 0 for _, op := range r.Ops { @@ -311,8 +379,7 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { } total += n } - r.DoneHeader.Done = true - n, err := encodePacketValue(buf[total:], reflect.ValueOf(r.DoneHeader)) + n, err := encodePacketValue(buf[total:], reflect.ValueOf(multiHeader{Type: -1, Done: true, Err: -1})) if err != nil { return total, err } @@ -323,7 +390,6 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { func (r *multiRequest) Decode(buf []byte) (int, error) { r.Ops = make([]multiRequestOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -333,7 +399,6 @@ func (r *multiRequest) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -355,7 +420,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { var multiErr error r.Ops = make([]multiResponseOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -365,7 +429,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -399,6 +462,48 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { return total, multiErr } +func (r *multiReadResponse) Decode(buf []byte) (total int, multiErr error) { + for { + header := &multiHeader{} + n, err := decodePacketValue(buf[total:], reflect.ValueOf(header)) + if err != nil { + return total, err + } + total += n + if header.Done { + break + } + + var res MultiReadResponse + var errCode ErrCode + var w reflect.Value + switch header.Type { + case opGetData: + w = reflect.ValueOf(&res.getDataResponse) + case opGetChildren: + w = reflect.ValueOf(&res.getChildrenResponse) + case opError: + w = reflect.ValueOf(&errCode) + default: + return total, ErrAPIError + } + + n, err = decodePacketValue(buf[total:], w) + if err != nil { + return total, err + } + total += n + + if errCode != errOk { + res.Err = errCode.toError() + } + + r.OpResults = append(r.OpResults, res) + } + + return total, nil +} + type watcherEvent struct { Type EventType State State @@ -598,7 +703,7 @@ func requestStructForOp(op int32) interface{} { switch op { case opClose: return &closeRequest{} - case opCreate: + case opCreate, opCreate2: return &CreateRequest{} case opCreateContainer: return &CreateContainerRequest{} @@ -622,7 +727,7 @@ func requestStructForOp(op int32) interface{} { return &setAclRequest{} case opSetData: return &SetDataRequest{} - case opSetWatches: + case opSetWatches, opSetWatches2: return &setWatchesRequest{} case opSync: return &syncRequest{} @@ -634,6 +739,8 @@ func requestStructForOp(op int32) interface{} { return &multiRequest{} case opReconfig: return &reconfigRequest{} + case opAddWatch: + return &addWatchRequest{} } return nil } diff --git a/structs_test.go b/structs_test.go index 3a38ab45..9c4e9024 100644 --- a/structs_test.go +++ b/structs_test.go @@ -10,7 +10,7 @@ func TestEncodeDecodePacket(t *testing.T) { encodeDecodeTest(t, &requestHeader{-2, 5}) encodeDecodeTest(t, &connectResponse{1, 2, 3, nil}) encodeDecodeTest(t, &connectResponse{1, 2, 3, []byte{4, 5, 6}}) - encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, Stat{}}) + encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, &Stat{}}) encodeDecodeTest(t, &getChildrenResponse{[]string{"foo", "bar"}}) encodeDecodeTest(t, &pathWatchRequest{"path", true}) encodeDecodeTest(t, &pathWatchRequest{"path", false}) diff --git a/unlimited_channel.go b/unlimited_channel.go new file mode 100644 index 00000000..d3307d89 --- /dev/null +++ b/unlimited_channel.go @@ -0,0 +1,107 @@ +package zk + +import ( + "context" + "errors" + "sync" +) + +var ErrEventQueueClosed = errors.New("zk: event queue closed") + +type EventQueue interface { + Next(ctx context.Context) (Event, error) + push(e Event) + close() +} + +type chanEventQueue chan Event + +func (c chanEventQueue) Next(ctx context.Context) (Event, error) { + select { + case <-ctx.Done(): + return Event{}, ctx.Err() + case e, ok := <-c: + if !ok { + return Event{}, ErrEventQueueClosed + } else { + return e, nil + } + } +} + +func (c chanEventQueue) push(e Event) { + c <- e +} + +func (c chanEventQueue) close() { + close(c) +} + +func newChanEventChannel() chanEventQueue { + return make(chan Event, 1) +} + +type unlimitedEventQueue struct { + lock sync.Mutex + newEvent chan struct{} + events []Event +} + +func newUnlimitedEventQueue() *unlimitedEventQueue { + return &unlimitedEventQueue{ + newEvent: make(chan struct{}), + } +} + +func (q *unlimitedEventQueue) push(e Event) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.newEvent == nil { + // Panic like a closed channel + panic("send on closed unlimited channel") + } + + q.events = append(q.events, e) + close(q.newEvent) + q.newEvent = make(chan struct{}) +} + +func (q *unlimitedEventQueue) close() { + q.lock.Lock() + defer q.lock.Unlock() + + if q.newEvent == nil { + // Panic like a closed channel + panic("close of closed EventQueue") + } + + close(q.newEvent) + q.newEvent = nil +} + +func (q *unlimitedEventQueue) Next(ctx context.Context) (Event, error) { + for { + q.lock.Lock() + if len(q.events) > 0 { + e := q.events[0] + q.events = q.events[1:] + q.lock.Unlock() + return e, nil + } + + ch := q.newEvent + if ch == nil { + q.lock.Unlock() + return Event{}, ErrEventQueueClosed + } + q.lock.Unlock() + + select { + case <-ctx.Done(): + return Event{}, ctx.Err() + case <-ch: + continue + } + } +} diff --git a/unlimited_channel_test.go b/unlimited_channel_test.go new file mode 100644 index 00000000..f0b292d3 --- /dev/null +++ b/unlimited_channel_test.go @@ -0,0 +1,94 @@ +//go:build go1.18 + +package zk + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" +) + +func newEvent(i int) Event { + return Event{Path: fmt.Sprintf("/%d", i)} +} + +func TestUnlimitedChannel(t *testing.T) { + names := []string{"notClosedAfterPushes", "closeAfterPushes"} + for i, closeAfterPushes := range []bool{false, true} { + t.Run(names[i], func(t *testing.T) { + ch := newUnlimitedEventQueue() + const eventCount = 10 + + // check that events can be pushed without consumers + for i := 0; i < eventCount; i++ { + ch.push(newEvent(i)) + } + if closeAfterPushes { + ch.close() + } + + for events := 0; events < eventCount; events++ { + actual, err := ch.Next(context.Background()) + if err != nil { + t.Fatalf("Unexpected error returned from Next (events %d): %+v", events, err) + } + expected := newEvent(events) + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Did not receive expected event from queue: actual %+v expected %+v", actual, expected) + } + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + _, err := ch.Next(ctx) + if closeAfterPushes { + if err != ErrEventQueueClosed { + t.Fatalf("Did not receive expected error (%v) from Next: %v", ErrEventQueueClosed, err) + } + } else { + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Next did not exit with cancelled context: %+v", err) + } + } + }) + } + t.Run("interleaving", func(t *testing.T) { + ch := newUnlimitedEventQueue() + + for i := 0; i < 10; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + expected := newEvent(i) + + ctx = &customContext{ + Context: ctx, + f: func() { + ch.push(expected) + }, + } + + actual, err := ch.Next(ctx) + if err != nil { + t.Fatalf("Received unexpected error from Next: %+v", err) + } + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Unexpected event received from Next (expected %+v, actual %+v", expected, actual) + } + } + }) +} + +type customContext struct { + context.Context + f func() +} + +func (c *customContext) Done() <-chan struct{} { + c.f() + return c.Context.Done() +} diff --git a/watch.go b/watch.go new file mode 100644 index 00000000..b44c6ce4 --- /dev/null +++ b/watch.go @@ -0,0 +1 @@ +package zk diff --git a/watch_test.go b/watch_test.go new file mode 100644 index 00000000..b44c6ce4 --- /dev/null +++ b/watch_test.go @@ -0,0 +1 @@ +package zk diff --git a/zk_test.go b/zk_test.go index 9129c766..94cab58e 100644 --- a/zk_test.go +++ b/zk_test.go @@ -1,8 +1,8 @@ package zk import ( + "bytes" "context" - "encoding/hex" "fmt" "io" "io/ioutil" @@ -187,27 +187,22 @@ func TestCreateContainer(t *testing.T) { } func TestIncrementalReconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") + ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() // start and add a new server. tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") defer os.RemoveAll(tmpPath) startPort := int(rand.Int31n(6000) + 10000) srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv4")) if err := os.Mkdir(srvPath, 0700); err != nil { - requireNoError(t, err, "failed to make server path") + requireNoErrorf(t, err, "failed to make server path") } testSrvConfig := ServerConfigServer{ ID: 4, @@ -224,35 +219,35 @@ func TestIncrementalReconfig(t *testing.T) { // TODO: clean all this server creating up to a better helper method cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintln(fi, "4") fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) testServer, err := NewIntegrationTestServer(t, cfgPath, nil, nil) - requireNoError(t, err) - requireNoError(t, testServer.Start()) + requireNoErrorf(t, err) + requireNoErrorf(t, testServer.Start()) defer testServer.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -268,7 +263,7 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to remove node from cluster") + requireNoErrorf(t, err, "failed to remove node from cluster") } // add node a new 4th node @@ -277,36 +272,30 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to add new server to cluster") + requireNoErrorf(t, err, "failed to add new server to cluster") } } func TestReconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") // This test enures we can do an non-incremental reconfig ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -320,7 +309,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") // reconfig to all the hosts again s = []string{} @@ -329,7 +318,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") } func TestOpsAfterCloseDontDeadlock(t *testing.T) { @@ -400,6 +389,134 @@ func TestMulti(t *testing.T) { } } +func TestMultiRead(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + nodeChildren := map[string][]string{} + nodeData := map[string][]byte{} + var ops []ReadOp + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } else { + dir, name := SplitPath(path) + nodeChildren[dir] = append(nodeChildren[dir], name) + nodeData[path] = data + ops = append(ops, GetDataOp(path), GetChildrenOp(path)) + } + } + + root := "/gozk-test" + create(root, nil) + + for i := byte(0); i < 10; i++ { + child := JoinPath(root, fmt.Sprint(i)) + create(child, []byte{i}) + } + + const foo = "foo" + create(JoinPath(JoinPath(root, "0"), foo), []byte(foo)) + + opResults, err := zk.MultiRead(ops...) + if err != nil { + t.Fatalf("MultiRead returned error: %+v", err) + } else if len(opResults) != len(ops) { + t.Fatalf("Expected %d responses got %d", len(ops), len(opResults)) + } + + nodeStats := map[string]*Stat{} + for k := range nodeData { + _, nodeStats[k], err = zk.Exists(k) + requireNoErrorf(t, err, "exists returned an error") + } + + for i, res := range opResults { + opPath := ops[i].GetPath() + switch op := ops[i].(type) { + case GetDataOp: + if res.Err != nil { + t.Fatalf("GetDataOp(%q) returned an error: %+v", op, res.Err) + } + if !bytes.Equal(res.Data, nodeData[opPath]) { + t.Fatalf("GetDataOp(%q).Data did not return %+v, got %+v", op, nodeData[opPath], res.Data) + } + if !reflect.DeepEqual(res.Stat, nodeStats[opPath]) { + t.Fatalf("GetDataOp(%q).Stat did not return %+v, got %+v", op, nodeStats[opPath], res.Stat) + } + case GetChildrenOp: + if res.Err != nil { + t.Fatalf("GetChildrenOp(%q) returned an error: %+v", opPath, res.Err) + } + // Cannot use DeepEqual here because it fails for []string{} == nil, even though in practice they are + // the same. + actual, expected := res.Children, nodeChildren[opPath] + if len(actual) != len(expected) { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + sort.Strings(actual) + sort.Strings(expected) + for i, c := range expected { + if actual[i] != c { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + } + } + } + + opResults, err = zk.MultiRead(GetDataOp("/invalid"), GetDataOp(root)) + requireNoErrorf(t, err, "MultiRead returned error") + + if opResults[0].Err != ErrNoNode { + t.Fatalf("MultiRead on invalid node did not return error") + } + if opResults[1].Err != nil { + t.Fatalf("MultiRead on valid node did not return error") + } + if !reflect.DeepEqual(opResults[1].Data, nodeData[root]) { + t.Fatalf("MultiRead on valid node did not return correct data") + } + }) +} + +func TestGetDataAndChildren(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + + const path = "/test" + _, _, _, err := zk.GetDataAndChildren(path) + if err != ErrNoNode { + t.Fatalf("GetDataAndChildren(%q) did not return an error", path) + } + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } + } + expectedData := []byte{1, 2, 3, 4} + create(path, expectedData) + var expectedChildren []string + for i := 0; i < 10; i++ { + child := fmt.Sprint(i) + create(JoinPath(path, child), nil) + expectedChildren = append(expectedChildren, child) + } + + data, _, children, err := zk.GetDataAndChildren(path) + requireNoErrorf(t, err, "GetDataAndChildren return an error") + + if !bytes.Equal(data, expectedData) { + t.Fatalf("GetDataAndChildren(%q) did not return expected data (expected %v): %v", path, expectedData, data) + } + sort.Strings(children) + if !reflect.DeepEqual(children, expectedChildren) { + t.Fatalf("GetDataAndChildren(%q) did not return expected children (expected %v): %v", + path, expectedChildren, children) + } + }) +} + func TestIfAuthdataSurvivesReconnect(t *testing.T) { // This test case ensures authentication data is being resubmited after // reconnect. @@ -452,6 +569,135 @@ func TestIfAuthdataSurvivesReconnect(t *testing.T) { } } +func TestPersistentWatchOnReconnect(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk2.Close() + + const testNode = "/gozk-test" + + if err := zk.Delete(testNode, -1); err != nil && err != ErrNoNode { + t.Fatalf("Delete returned error: %+v", err) + } + + watchEventsQueue, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) + if err != nil { + t.Fatalf("AddPersistentWatch returned error: %+v", err) + } + + watchEventsCh := make(chan Event) + go func() { + e, err := watchEventsQueue.Next(context.Background()) + if err != nil { + close(watchEventsCh) + return + } else { + watchEventsCh <- e + } + }() + + _, err = zk2.Create(testNode, []byte{1}, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned an error: %+v", err) + } + + // check to see that we received the node creation event + select { + case ev := <-watchEventsCh: + if ev.Type != EventNodeCreated { + t.Fatalf("Second event on persistent watch was not a node creation event: %+v", ev) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("Persistent watcher for %q did not receive node creation event", testNode) + } + + // Simulate network error by brutally closing the network connection. + zk.conn.Close() + + _, err = zk2.Set(testNode, []byte{2}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + + // zk should still be waiting to reconnect, so none of the watches should have been triggered + select { + case <-watchEventsCh: + t.Fatalf("Persistent watcher for %q should not have triggered yet", testNode) + case <-time.After(100 * time.Millisecond): + } + + // now we let the reconnect occur and make sure it resets watches + close(zk.reconnectLatch) + + // wait for reconnect event + select { + case ev := <-watchEventsCh: + if ev.Type != EventWatching { + t.Fatalf("Persistent watcher did not receive reconnect event: %+v", ev) + } + case <-time.After(5 * time.Second): + t.Fatalf("Persistent watcher for %q did not receive connection event", testNode) + } + + eventsReceived := 0 + timeout := time.After(2 * time.Second) + secondTimeout := time.After(4 * time.Second) + for { + select { + case e := <-watchEventsCh: + if e.Type != EventNodeDataChanged { + t.Fatalf("Unexpected event received by persistent watcher: %+v", e) + } + eventsReceived++ + case <-timeout: + _, err = zk2.Set(testNode, []byte{3}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + case <-secondTimeout: + switch eventsReceived { + case 2: + t.Fatalf("Sanity check failed: the persistent watch logic is based around the assumption that the " + + "setWatchers call _does not_ bootstrap you on reconnect based on the relative Zxid (unlike " + + "standard watches).") + case 1: + return + default: + t.Fatalf("Received no events after reconnect") + } + } + } + }) +} + +func TestPersistentWatchOnClose(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(_ *TestCluster, zk *Conn) { + ch, err := zk.AddPersistentWatch("/", AddWatchModePersistent) + if err != nil { + t.Fatalf("Could not add persistent watch: %+v", err) + } + zk.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + e, err := ch.Next(ctx) + if err != nil { + t.Fatalf("Did not get disconnect event (%+v)", err) + } + if e.Type != EventNotWatching { + t.Fatalf("Unexpected event: %+v", e) + } + }) +} + func TestMultiFailures(t *testing.T) { // This test case ensures that we return the errors associated with each // opeThis in the event a call to Multi() fails. @@ -604,6 +850,7 @@ func TestAuth(t *testing.T) { } } +// Tests that we correctly handle a response larger than the default buffer size func TestChildren(t *testing.T) { ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { @@ -622,38 +869,44 @@ func TestChildren(t *testing.T) { } } - deleteNode("/gozk-test-big") + testNode := "/gozk-test-big" + deleteNode(testNode) - if path, err := zk.Create("/gozk-test-big", []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + if _, err := zk.Create(testNode, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != "/gozk-test-big" { - t.Fatalf("Create returned different path '%s' != '/gozk-test-big'", path) } - rb := make([]byte, 1000) - hb := make([]byte, 2000) - prefix := []byte("/gozk-test-big/") - for i := 0; i < 10000; i++ { - _, err := rand.Read(rb) - if err != nil { - t.Fatal("Cannot create random znode name") - } - hex.Encode(hb, rb) + const ( + nodesToCreate = 100 + // By creating many nodes with long names, the response from the Children call should be significantly longer + // than the buffer size, forcing recvLoop to allocate a bigger buffer + nameLength = 2 * bufferSize / nodesToCreate + ) + + format := fmt.Sprintf("%%0%dd", nameLength) + if name := fmt.Sprintf(format, 0); len(name) != nameLength { + // Sanity check that the generated format string creates strings of the right length + t.Fatalf("Length of generated name was not %d, got %d", nameLength, len(name)) + } - expect := string(append(prefix, hb...)) - if path, err := zk.Create(expect, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + var createdNodes []string + for i := 0; i < nodesToCreate; i++ { + name := fmt.Sprintf(format, i) + createdNodes = append(createdNodes, name) + path := testNode + "/" + name + if _, err := zk.Create(path, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != expect { - t.Fatalf("Create returned different path '%s' != '%s'", path, expect) } - defer deleteNode(string(expect)) + defer deleteNode(path) } - children, _, err := zk.Children("/gozk-test-big") + children, _, err := zk.Children(testNode) if err != nil { t.Fatalf("Children returned error: %+v", err) - } else if len(children) != 10000 { - t.Fatal("Children returned wrong number of nodes") + } + sort.Strings(children) + if !reflect.DeepEqual(children, createdNodes) { + t.Fatal("Children did not return expected nodes") } } @@ -765,10 +1018,16 @@ func TestSetWatchers(t *testing.T) { } }() - // we create lots of paths to watch, to make sure a "set watches" request - // on re-create will be too big and be required to span multiple packets - for i := 0; i < 1000; i++ { - testPath, err := zk.Create(fmt.Sprintf("/gozk-test-%d", i), []byte{}, 0, WorldACL(PermAll)) + // we create lots of long paths to watch, to make sure a "set watches" request on will be too big and be broken + // into multiple packets. The size is chosen such that each packet can hold exactly 2 watches, meaning we should + // see half as many packets as there are watches. + const ( + watches = 50 + watchedNodeNameFormat = "/gozk-test-%0450d" + ) + + for i := 0; i < watches; i++ { + testPath, err := zk.Create(fmt.Sprintf(watchedNodeNameFormat, i), []byte{}, 0, WorldACL(PermAll)) if err != nil { t.Fatalf("Create returned: %+v", err) } @@ -852,9 +1111,9 @@ func TestSetWatchers(t *testing.T) { buf := make([]byte, bufferSize) totalWatches := 0 actualReqs := setWatchReqs.Load().([]*setWatchesRequest) - if len(actualReqs) < 12 { - // sanity check: we should have generated *at least* 12 requests to reset watches - t.Fatalf("too few setWatchesRequest messages: %d", len(actualReqs)) + if len(actualReqs) != watches/2 { + // sanity check: we should have generated exactly 25 requests to reset watches + t.Fatalf("Did not send exactly %d setWatches requests, got %d instead", watches/2, len(actualReqs)) } for _, r := range actualReqs { totalWatches += len(r.ChildWatches) + len(r.DataWatches) + len(r.ExistWatches)