From 9cb97524204dfbe3ce80785a66471567ef46160b Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Thu, 1 Sep 2022 17:49:57 -0400 Subject: [PATCH 1/4] Fix the runtime of TestChildWatch and TestSetWatchers TestChildWatch was introduced to test what happens when the response is too large for the default allocated buffer size. It used to create 10k nodes which could take a very significant amount of time. The same effect can be achieved with far fewer nodes, significantly speeding up the test's runtime. The test now runs in 5s on my machine instead of sometimes multiple minutes... TestSetWatchers tested something similar, except that it checks that the outgoing setWatchers packet is broken up into multiple packets when it's too large. Using a similar trick we can generate names of specific lengths to test that the behavior is correct. It was also flaky because if your local ZK deployment is a little slow, deleting all the nodes can take longer than the session timeout, spuriously failing the test. This has also been fixed, and the test now runs in a little over 5 seconds as well, instead of failing. Finally, standardize the ZK server version checking to be a bit more flexible and friendlier towards future versions of ZooKeeper (note: the original implementation doesn't even work because the env variable name is incorrect... It `ZK_VERSION`, not `zk_version`) --- Makefile | 4 +- server_help_test.go | 46 ++++++++++++---- zk_test.go | 127 ++++++++++++++++++++++---------------------- 3 files changed, 104 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index f0b7965c..1856aeb9 100644 --- a/Makefile +++ b/Makefile @@ -20,10 +20,12 @@ $(ZK): tar -zxf $(ZK).tar.gz rm $(ZK).tar.gz +.PHONY: zookeeper zookeeper: $(ZK) # we link to a standard directory path so then the tests dont need to find based on version # in the test code. this allows backward compatable testing. - ln -s $(ZK) zookeeper + rm -f $@ + ln -s $(ZK) $@ .PHONY: setup setup: zookeeper diff --git a/server_help_test.go b/server_help_test.go index 6a49ad2b..deb780bd 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "path/filepath" + "strconv" "strings" "testing" "time" @@ -53,7 +54,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl } tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") success := false startPort := int(rand.Int31n(6000) + 10000) @@ -67,7 +68,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN+1)) - requireNoError(t, os.Mkdir(srvPath, 0700), "failed to make server path") + requireNoErrorf(t, os.Mkdir(srvPath, 0700), "failed to make server path") port := startPort + serverN*3 cfg := ServerConfig{ @@ -88,20 +89,20 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) srv, err := NewIntegrationTestServer(t, cfgPath, stdout, stderr) - requireNoError(t, err) + requireNoErrorf(t, err) if err := srv.Start(); err != nil { return nil, err @@ -251,9 +252,36 @@ func (tc *TestCluster) StopAllServers() error { return nil } -func requireNoError(t *testing.T, err error, msgAndArgs ...interface{}) { +func requireNoErrorf(t *testing.T, err error, msgAndArgs ...interface{}) { if err != nil { + t.Helper() t.Logf("received unexpected error: %v", err) - t.Fatal(msgAndArgs...) + t.Fatalf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } +} + +func RequireMinimumZkVersion(t *testing.T, minimum string) { + if val, ok := os.LookupEnv("ZK_VERSION"); ok { + split := func(v string) (parts []int) { + for _, s := range strings.Split(minimum, ".") { + i, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("invalid version segment: %q", s) + } + parts = append(parts, i) + } + return parts + } + + minimumV, actualV := split(minimum), split(val) + for i, p := range minimumV { + if actualV[i] < p { + if !strings.HasPrefix(val, minimum) { + t.Skipf("running with zookeeper that does not support this api (requires at least %s)", minimum) + } + } + } + } else { + t.Skip("did not detect zk_version from env. skipping test") } } diff --git a/zk_test.go b/zk_test.go index 9129c766..603a91c5 100644 --- a/zk_test.go +++ b/zk_test.go @@ -2,7 +2,6 @@ package zk import ( "context" - "encoding/hex" "fmt" "io" "io/ioutil" @@ -187,27 +186,22 @@ func TestCreateContainer(t *testing.T) { } func TestIncrementalReconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") + ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() // start and add a new server. tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") defer os.RemoveAll(tmpPath) startPort := int(rand.Int31n(6000) + 10000) srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv4")) if err := os.Mkdir(srvPath, 0700); err != nil { - requireNoError(t, err, "failed to make server path") + requireNoErrorf(t, err, "failed to make server path") } testSrvConfig := ServerConfigServer{ ID: 4, @@ -224,35 +218,35 @@ func TestIncrementalReconfig(t *testing.T) { // TODO: clean all this server creating up to a better helper method cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintln(fi, "4") fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) testServer, err := NewIntegrationTestServer(t, cfgPath, nil, nil) - requireNoError(t, err) - requireNoError(t, testServer.Start()) + requireNoErrorf(t, err) + requireNoErrorf(t, testServer.Start()) defer testServer.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -268,7 +262,7 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to remove node from cluster") + requireNoErrorf(t, err, "failed to remove node from cluster") } // add node a new 4th node @@ -277,36 +271,30 @@ func TestIncrementalReconfig(t *testing.T) { if err != nil && err == ErrConnectionClosed { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to add new server to cluster") + requireNoErrorf(t, err, "failed to add new server to cluster") } } func TestReconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") // This test enures we can do an non-incremental reconfig ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -320,7 +308,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") // reconfig to all the hosts again s = []string{} @@ -329,7 +317,7 @@ func TestReconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") } func TestOpsAfterCloseDontDeadlock(t *testing.T) { @@ -604,6 +592,7 @@ func TestAuth(t *testing.T) { } } +// Tests that we correctly handle a response larger than the default buffer size func TestChildren(t *testing.T) { ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { @@ -622,38 +611,44 @@ func TestChildren(t *testing.T) { } } - deleteNode("/gozk-test-big") + testNode := "/gozk-test-big" + deleteNode(testNode) - if path, err := zk.Create("/gozk-test-big", []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + if _, err := zk.Create(testNode, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != "/gozk-test-big" { - t.Fatalf("Create returned different path '%s' != '/gozk-test-big'", path) } - rb := make([]byte, 1000) - hb := make([]byte, 2000) - prefix := []byte("/gozk-test-big/") - for i := 0; i < 10000; i++ { - _, err := rand.Read(rb) - if err != nil { - t.Fatal("Cannot create random znode name") - } - hex.Encode(hb, rb) + const ( + nodesToCreate = 100 + // By creating many nodes with long names, the response from the Children call should be significantly longer + // than the buffer size, forcing recvLoop to allocate a bigger buffer + nameLength = 2 * bufferSize / nodesToCreate + ) + + format := fmt.Sprintf("%%0%dd", nameLength) + if name := fmt.Sprintf(format, 0); len(name) != nameLength { + // Sanity check that the generated format string creates strings of the right length + t.Fatalf("Length of generated name was not %d, got %d", nameLength, len(name)) + } - expect := string(append(prefix, hb...)) - if path, err := zk.Create(expect, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + var createdNodes []string + for i := 0; i < nodesToCreate; i++ { + name := fmt.Sprintf(format, i) + createdNodes = append(createdNodes, name) + path := testNode + "/" + name + if _, err := zk.Create(path, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != expect { - t.Fatalf("Create returned different path '%s' != '%s'", path, expect) } - defer deleteNode(string(expect)) + defer deleteNode(path) } - children, _, err := zk.Children("/gozk-test-big") + children, _, err := zk.Children(testNode) if err != nil { t.Fatalf("Children returned error: %+v", err) - } else if len(children) != 10000 { - t.Fatal("Children returned wrong number of nodes") + } + sort.Strings(children) + if !reflect.DeepEqual(children, createdNodes) { + t.Fatal("Children did not return expected nodes") } } @@ -765,10 +760,16 @@ func TestSetWatchers(t *testing.T) { } }() - // we create lots of paths to watch, to make sure a "set watches" request - // on re-create will be too big and be required to span multiple packets - for i := 0; i < 1000; i++ { - testPath, err := zk.Create(fmt.Sprintf("/gozk-test-%d", i), []byte{}, 0, WorldACL(PermAll)) + // we create lots of long paths to watch, to make sure a "set watches" request on will be too big and be broken + // into multiple packets. The size is chosen such that each packet can hold exactly 2 watches, meaning we should + // see half as many packets as there are watches. + const ( + watches = 50 + watchedNodeNameFormat = "/gozk-test-%0450d" + ) + + for i := 0; i < watches; i++ { + testPath, err := zk.Create(fmt.Sprintf(watchedNodeNameFormat, i), []byte{}, 0, WorldACL(PermAll)) if err != nil { t.Fatalf("Create returned: %+v", err) } @@ -852,9 +853,9 @@ func TestSetWatchers(t *testing.T) { buf := make([]byte, bufferSize) totalWatches := 0 actualReqs := setWatchReqs.Load().([]*setWatchesRequest) - if len(actualReqs) < 12 { - // sanity check: we should have generated *at least* 12 requests to reset watches - t.Fatalf("too few setWatchesRequest messages: %d", len(actualReqs)) + if len(actualReqs) != watches/2 { + // sanity check: we should have generated exactly 25 requests to reset watches + t.Fatalf("Did not send exactly %d setWatches requests, got %d instead", watches/2, len(actualReqs)) } for _, r := range actualReqs { totalWatches += len(r.ChildWatches) + len(r.DataWatches) + len(r.ExistWatches) From 1d3218bbfefd23840b6d51f532082e6009912395 Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Thu, 1 Sep 2022 17:49:57 -0400 Subject: [PATCH 2/4] Support persistent watches Implements the new persistent watch types introduced in 3.6, along with some utilities that are critical when implementing local caches. --- Makefile | 2 +- conn.go | 345 +++++++++++++++++++++++++++++++------- conn_test.go | 171 ++++++++++++------- constants.go | 44 +++++ go.mod | 4 +- server_help_test.go | 33 ++++ structs.go | 151 ++++++++++++++--- structs_test.go | 2 +- unlimited_channel.go | 104 ++++++++++++ unlimited_channel_test.go | 63 +++++++ watch.go | 1 + watch_test.go | 1 + zk_test.go | 244 +++++++++++++++++++++++++++ 13 files changed, 1025 insertions(+), 140 deletions(-) create mode 100644 unlimited_channel.go create mode 100644 unlimited_channel_test.go create mode 100644 watch.go create mode 100644 watch_test.go diff --git a/Makefile b/Makefile index 1856aeb9..20a40f7a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # make file to hold the logic of build and test setup -ZK_VERSION ?= 3.5.6 +ZK_VERSION ?= 3.6.3 # Apache changed the name of the archive in version 3.5.x and seperated out # src and binary packages diff --git a/conn.go b/conn.go index 9afd2d27..b58090f9 100644 --- a/conn.go +++ b/conn.go @@ -47,8 +47,14 @@ const ( watchTypeData watchType = iota watchTypeExist watchTypeChild + watchTypePersistent + watchTypePersistentRecursive ) +func (w watchType) isPersistent() bool { + return w == watchTypePersistent || w == watchTypePersistentRecursive +} + type watchPathType struct { path string wType watchType @@ -96,7 +102,7 @@ type Conn struct { sendChan chan *request requests map[int32]*request // Xid -> pending request requestsLock sync.Mutex - watchers map[watchPathType][]chan Event + watchers map[watchPathType][]EventQueue watchersLock sync.Mutex closeChan chan struct{} // channel to tell send loop stop @@ -199,7 +205,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), requests: make(map[int32]*request), - watchers: make(map[watchPathType][]chan Event), + watchers: make(map[watchPathType][]EventQueue), passwd: emptyPassword, logger: DefaultLogger, logInfo: true, // default is true for backwards compatability @@ -530,31 +536,46 @@ func (c *Conn) flushRequests(err error) { c.requestsLock.Unlock() } +var eventWatchTypes = map[EventType][]watchType{ + EventNodeCreated: {watchTypeExist, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeDataChanged: {watchTypeExist, watchTypeData, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeChildrenChanged: {watchTypeChild, watchTypePersistent}, + EventNodeDeleted: {watchTypeExist, watchTypeData, watchTypeChild, watchTypePersistent, watchTypePersistentRecursive}, +} +var persistentWatchTypes = []watchType{watchTypePersistent, watchTypePersistentRecursive} + // Send event to all interested watchers func (c *Conn) notifyWatches(ev Event) { - var wTypes []watchType - switch ev.Type { - case EventNodeCreated: - wTypes = []watchType{watchTypeExist} - case EventNodeDataChanged: - wTypes = []watchType{watchTypeExist, watchTypeData} - case EventNodeChildrenChanged: - wTypes = []watchType{watchTypeChild} - case EventNodeDeleted: - wTypes = []watchType{watchTypeExist, watchTypeData, watchTypeChild} + wTypes := eventWatchTypes[ev.Type] + if len(wTypes) == 0 { + return } + c.watchersLock.Lock() defer c.watchersLock.Unlock() - for _, t := range wTypes { - wpt := watchPathType{ev.Path, t} - if watchers := c.watchers[wpt]; len(watchers) > 0 { - for _, ch := range watchers { - ch <- ev - close(ch) + + broadcast := func(wpt watchPathType) { + for _, ch := range c.watchers[wpt] { + ch.push(ev) + if !wpt.wType.isPersistent() { + ch.close() } delete(c.watchers, wpt) } } + + for _, t := range wTypes { + if t == watchTypePersistentRecursive { + for p := ev.Path; ; p, _ = SplitPath(p) { + broadcast(watchPathType{p, t}) + if p == "/" { + break + } + } + } else { + broadcast(watchPathType{ev.Path, t}) + } + } } // Send error to all watchers and clear watchers map @@ -562,16 +583,23 @@ func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - if len(c.watchers) >= 0 { + if len(c.watchers) > 0 { for pathType, watchers := range c.watchers { + if err == ErrSessionExpired && pathType.wType.isPersistent() { + // Ignore ErrSessionExpired for persistent watchers as the client will either automatically reconnect, + // or this is a shutdown-worthy error in which case there will be a followup invocation of this method + // with ErrClosing + continue + } + ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} c.sendEvent(ev) // also publish globally for _, ch := range watchers { - ch <- ev - close(ch) + ch.push(ev) + ch.close() } + delete(c.watchers, pathType) } - c.watchers = make(map[watchPathType][]chan Event) } } @@ -610,12 +638,7 @@ func (c *Conn) sendSetWatches() { reqs = append(reqs, req) } sizeSoFar = 28 // fixed overhead of a set-watches packet - req = &setWatchesRequest{ - RelativeZxid: c.lastZxid, - DataWatches: make([]string, 0), - ExistWatches: make([]string, 0), - ChildWatches: make([]string, 0), - } + req = &setWatchesRequest{RelativeZxid: c.lastZxid} } sizeSoFar += addlLen switch pathType.wType { @@ -625,6 +648,10 @@ func (c *Conn) sendSetWatches() { req.ExistWatches = append(req.ExistWatches, pathType.path) case watchTypeChild: req.ChildWatches = append(req.ChildWatches, pathType.path) + case watchTypePersistent: + req.PersistentWatches = append(req.PersistentWatches, pathType.path) + case watchTypePersistentRecursive: + req.PersistentRecursiveWatches = append(req.PersistentRecursiveWatches, pathType.path) } n++ } @@ -646,7 +673,37 @@ func (c *Conn) sendSetWatches() { // aren't failure modes where a blocking write to the channel of requests // could hang indefinitely and cause this goroutine to leak... for _, req := range reqs { - _, err := c.request(opSetWatches, req, res, nil) + var op int32 = opSetWatches + if len(req.PersistentWatches) > 0 || len(req.PersistentRecursiveWatches) > 0 { + // to maintain compatibility with older servers, only send opSetWatches2 if persistent watches are used + op = opSetWatches2 + } + + _, err := c.request(op, req, res, func(r *request, header *responseHeader, err error) { + if err == nil && op == opSetWatches2 { + // If the setWatches was successful, notify the persistent watchers they've been reconnected. + // Because we process responses in one routine, we know that the following will execute before + // subsequent responses are processed. This means we won't end up in a situation where events are + // sent to watchers before the reconnect event is sent. + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + var paths []string + if wt == watchTypePersistent { + paths = req.PersistentWatches + } else { + paths = req.PersistentRecursiveWatches + } + for _, p := range paths { + e := Event{Type: EventWatching, State: StateConnected, Path: p} + c.sendEvent(e) // also publish globally + for _, ch := range c.watchers[watchPathType{path: p, wType: wt}] { + ch.push(e) + } + } + } + } + }) if err != nil { c.logger.Printf("Failed to set previous watches: %v", err) break @@ -827,15 +884,15 @@ func (c *Conn) recvLoop(conn net.Conn) error { } if res.Xid == -1 { - res := &watcherEvent{} - _, err = decodePacket(buf[16:blen], res) + we := &watcherEvent{} + _, err = decodePacket(buf[16:blen], we) if err != nil { return err } ev := Event{ - Type: res.Type, - State: res.State, - Path: res.Path, + Type: we.Type, + State: we.State, + Path: we.Path, Err: nil, } c.sendEvent(ev) @@ -880,14 +937,15 @@ func (c *Conn) nextXid() int32 { return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) } -func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { +func (c *Conn) addWatcher(path string, watchType watchType, ch EventQueue) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - ch := make(chan Event, 1) wpt := watchPathType{path, watchType} c.watchers[wpt] = append(c.watchers[wpt], ch) - return ch + if watchType.isPersistent() { + ch.push(Event{Type: EventWatching, State: StateConnected, Path: path}) + } } func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { @@ -981,7 +1039,7 @@ func (c *Conn) Children(path string) ([]string, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Children, &res.Stat, err + return res.Children, res.Stat, err } // ChildrenW returns the children of a znode and sets a watch. @@ -990,17 +1048,18 @@ func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeChild) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeChild, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Children, &res.Stat, ech, err + return res.Children, res.Stat, ech, err } // Get gets the contents of a znode. @@ -1014,7 +1073,7 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Data, &res.Stat, err + return res.Data, res.Stat, err } // GetW returns the contents of a znode and sets a watch @@ -1023,17 +1082,18 @@ func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeData) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeData, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Data, &res.Stat, ech, err + return res.Data, res.Stat, ech, err } // Set updates the contents of a znode. @@ -1047,10 +1107,10 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { if err == ErrConnectionClosed { return nil, err } - return &res.Stat, err + return res.Stat, err } -// Create creates a znode. +// Create creates a znode. If acl is empty, it uses the global WorldACL with PermAll // The returned path is the new path assigned by the server, it may not be the // same as the input, for example when creating a sequence znode the returned path // will be the input path with a sequence number appended. @@ -1059,6 +1119,10 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return "", err } + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) if err == ErrConnectionClosed { @@ -1067,6 +1131,24 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return res.Path, err } +// CreateAndReturnStat is the equivalent of Create, but it also returns the Stat of the created node. +func (c *Conn) CreateAndReturnStat(path string, data []byte, flags int32, acl []ACL) (string, *Stat, error) { + if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { + return "", nil, err + } + + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + + res := &create2Response{} + _, err := c.request(opCreate2, &CreateRequest{path, data, acl, flags}, res, nil) + if err == ErrConnectionClosed { + return "", nil, err + } + return res.Path, res.Stat, err +} + // CreateContainer creates a container znode and returns the path. func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) { if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { @@ -1170,7 +1252,7 @@ func (c *Conn) Exists(path string) (bool, *Stat, error) { exists = false err = nil } - return exists, &res.Stat, err + return exists, res.Stat, err } // ExistsW tells the existence of a znode and sets a watch. @@ -1179,13 +1261,14 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { return false, nil, nil, err } - var ech <-chan Event + var ech chanEventQueue res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + ech = newChanEventChannel() if err == nil { - ech = c.addWatcher(path, watchTypeData) + c.addWatcher(path, watchTypeData, ech) } else if err == ErrNoNode { - ech = c.addWatcher(path, watchTypeExist) + c.addWatcher(path, watchTypeExist, ech) } }) exists := true @@ -1196,7 +1279,7 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { if err != nil { return false, nil, nil, err } - return exists, &res.Stat, ech, err + return exists, res.Stat, ech, err } // GetACL gets the ACLs of a znode. @@ -1210,7 +1293,7 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { if err == ErrConnectionClosed { return nil, nil, err } - return res.Acl, &res.Stat, err + return res.Acl, res.Stat, err } // SetACL updates the ACLs of a znode. @@ -1224,7 +1307,7 @@ func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { if err == ErrConnectionClosed { return nil, err } - return &res.Stat, err + return res.Stat, err } // Sync flushes the channel between process and the leader of a given znode, @@ -1255,8 +1338,7 @@ type MultiResponse struct { // *CheckVersionRequest. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { req := &multiRequest{ - Ops: make([]multiRequestOp, 0, len(ops)), - DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, + Ops: make([]multiRequestOp, 0, len(ops)), } for _, op := range ops { var opCode int32 @@ -1286,6 +1368,44 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { return mr, err } +// MultiRead executes multiple ZooKeeper read operations at once. The provided ops must be one of GetDataOp or +// GetChildrenOp. Returns an error on network or connectivity errors, not on any op errors such as ErrNoNode. To check +// if any ops failed, check the corresponding MultiReadResponse.Err. +func (c *Conn) MultiRead(ops ...ReadOp) ([]MultiReadResponse, error) { + req := &multiRequest{ + Ops: make([]multiRequestOp, len(ops)), + } + for i, op := range ops { + req.Ops[i] = multiRequestOp{ + Header: multiHeader{op.opCode(), false, -1}, + Op: pathWatchRequest{Path: op.GetPath()}, + } + } + res := &multiReadResponse{} + _, err := c.request(opMultiRead, req, res, nil) + return res.OpResults, err +} + +// GetDataAndChildren executes a multi-read to get the given node's data and its children in one call. +func (c *Conn) GetDataAndChildren(path string) ([]byte, *Stat, []string, error) { + if err := validatePath(path, false); err != nil { + return nil, nil, nil, err + } + + opResults, err := c.MultiRead(GetDataOp(path), GetChildrenOp(path)) + if err != nil { + return nil, nil, nil, err + } + + for _, r := range opResults { + if r.Err != nil { + return nil, nil, nil, r.Err + } + } + + return opResults[0].Data, opResults[0].Stat, opResults[1].Children, nil +} + // IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers // by lists of members. For more info refer to the ZK documentation. // @@ -1321,7 +1441,7 @@ func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) { func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) { response := &reconfigReponse{} _, err := c.request(opReconfig, request, response, nil) - return &response.Stat, err + return response.Stat, err } // Server returns the current or last-connected server name. @@ -1331,6 +1451,97 @@ func (c *Conn) Server() string { return c.server } +func (c *Conn) AddPersistentWatch(path string, mode AddWatchMode) (ch EventQueue, err error) { + if err = validatePath(path, false); err != nil { + return nil, err + } + + res := &addWatchResponse{} + _, err = c.request(opAddWatch, &addWatchRequest{Path: path, Mode: mode}, res, func(r *request, header *responseHeader, err error) { + if err == nil { + var wt watchType + if mode == AddWatchModePersistent { + wt = watchTypePersistent + } else { + wt = watchTypePersistentRecursive + } + + ch = newUnlimitedEventQueue() + c.addWatcher(path, wt, ch) + } + }) + if err == ErrConnectionClosed { + return nil, err + } + return ch, err +} + +func (c *Conn) RemovePersistentWatch(path string, ch EventQueue) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + deleted := false + + res := &checkWatchesResponse{} + _, err = c.request(opCheckWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for i, w := range c.watchers[wpt] { + if w == ch { + deleted = true + c.watchers[wpt] = append(c.watchers[wpt][:i], c.watchers[wpt][i+1:]...) + w.push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + w.close() + return + } + } + } + }) + + if err != nil { + return err + } + + if !deleted { + return ErrNoWatcher + } + + return nil +} + +func (c *Conn) RemoveAllPersistentWatches(path string) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + res := &checkWatchesResponse{} + _, err = c.request(opRemoveWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for _, ch := range c.watchers[wpt] { + ch.push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + ch.close() + } + delete(c.watchers, wpt) + } + }) + return err +} + func resendZkAuth(ctx context.Context, c *Conn) error { shouldCancel := func() bool { select { @@ -1389,3 +1600,23 @@ func resendZkAuth(ctx context.Context, c *Conn) error { return nil } + +func JoinPath(parent, child string) string { + if !strings.HasSuffix(parent, "/") { + parent += "/" + } + if strings.HasPrefix(child, "/") { + child = child[1:] + } + return parent + child +} + +func SplitPath(path string) (dir, name string) { + i := strings.LastIndex(path, "/") + if i == 0 { + dir, name = "/", path[1:] + } else { + dir, name = path[:i], path[i+1:] + } + return dir, name +} diff --git a/conn_test.go b/conn_test.go index 96299280..1fb6787a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io/ioutil" + "strings" "sync" "testing" "time" @@ -113,85 +114,141 @@ func TestDeadlockInClose(t *testing.T) { } func TestNotifyWatches(t *testing.T) { + queueImpls := []struct { + name string + new func() EventQueue + }{ + { + name: "chan", + new: func() EventQueue { return newChanEventChannel() }, + }, + { + name: "unlimited", + new: func() EventQueue { return newUnlimitedEventQueue() }, + }, + } + cases := []struct { eType EventType path string watches map[watchPathType]bool }{ { - EventNodeCreated, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeChild}: false, - {"/", watchTypeData}: false, - }, - }, - { - EventNodeCreated, "/a", - map[watchPathType]bool{ + eType: EventNodeCreated, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, {"/b", watchTypeExist}: false, + + {"/a", watchTypeChild}: false, + + {"/a", watchTypeData}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeDataChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: false, + eType: EventNodeDataChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeChildrenChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: false, - {"/", watchTypeData}: false, - {"/", watchTypeChild}: true, + eType: EventNodeChildrenChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: false, + {"/a", watchTypeData}: false, + {"/a", watchTypeChild}: true, + {"/a", watchTypePersistent}: true, + {"/a", watchTypePersistentRecursive}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: false, + {"/", watchTypePersistentRecursive}: false, }, }, { - EventNodeDeleted, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: true, + eType: EventNodeDeleted, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: true, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, } - conn := &Conn{watchers: make(map[watchPathType][]chan Event)} - - for idx, c := range cases { - t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { - c := c - - notifications := make([]struct { - path string - notify bool - ch <-chan Event - }, len(c.watches)) - - var idx int - for wpt, expectEvent := range c.watches { - ch := conn.addWatcher(wpt.path, wpt.wType) - notifications[idx].path = wpt.path - notifications[idx].notify = expectEvent - notifications[idx].ch = ch - idx++ - } - ev := Event{Type: c.eType, Path: c.path} - conn.notifyWatches(ev) - - for _, res := range notifications { - select { - case e := <-res.ch: - if !res.notify || e.Path != res.path { - t.Fatal("unexpeted notification received") + for _, impl := range queueImpls { + t.Run(impl.name, func(t *testing.T) { + for idx, c := range cases { + c := c + t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { + notifications := make([]struct { + watchPathType + notify bool + ch EventQueue + }, len(c.watches)) + + conn := &Conn{watchers: make(map[watchPathType][]EventQueue)} + + var idx int + for wpt, expectEvent := range c.watches { + notifications[idx].watchPathType = wpt + notifications[idx].notify = expectEvent + ch := impl.new() + conn.addWatcher(wpt.path, wpt.wType, ch) + notifications[idx].ch = ch + if wpt.wType.isPersistent() { + e := <-ch.Next() + if e.Type != EventWatching { + t.Fatalf("First event on persistent watcher should always be EventWatching") + } + } + idx++ } - default: - if res.notify { - t.Fatal("expected notification not received") + + conn.notifyWatches(Event{Type: c.eType, Path: c.path}) + + for _, res := range notifications { + select { + case e := <-res.ch.Next(): + isPathCorrect := + (res.wType == watchTypePersistentRecursive && strings.HasPrefix(e.Path, res.path)) || + e.Path == res.path + if !res.notify || !isPathCorrect { + t.Logf("unexpeted notification received by %+v: %+v", res, e) + t.Fail() + } + default: + if res.notify { + t.Logf("expected notification not received for %+v", res) + t.Fail() + } + } } - } + }) } }) } diff --git a/constants.go b/constants.go index 84455d2b..0add0b7e 100644 --- a/constants.go +++ b/constants.go @@ -26,12 +26,18 @@ const ( opGetChildren2 = 12 opCheck = 13 opMulti = 14 + opCreate2 = 15 opReconfig = 16 + opCheckWatches = 17 + opRemoveWatches = 18 opCreateContainer = 19 opCreateTTL = 21 + opMultiRead = 22 opClose = -11 opSetAuth = 100 opSetWatches = 101 + opSetWatches2 = 105 + opAddWatch = 106 opError = -1 // Not in protocol, used internally opWatcherEvent = -2 @@ -47,6 +53,7 @@ const ( // EventSession represents a session event. EventSession EventType = -1 EventNotWatching EventType = -2 + EventWatching EventType = -3 ) var ( @@ -57,6 +64,7 @@ var ( EventNodeChildrenChanged: "EventNodeChildrenChanged", EventSession: "EventSession", EventNotWatching: "EventNotWatching", + EventWatching: "EventWatching", } ) @@ -129,6 +137,8 @@ var ( ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") ErrBadArguments = errors.New("invalid arguments") + ErrNoWatcher = errors.New("zk: no such watcher") + ErrUnimplemented = errors.New("zk: Not implemented") // ErrInvalidCallback = errors.New("zk: invalid callback specified") errCodeToError = map[ErrCode]error{ @@ -147,8 +157,10 @@ var ( errClosing: ErrClosing, errNothing: ErrNothing, errSessionMoved: ErrSessionMoved, + errNoWatcher: ErrNoWatcher, errZReconfigDisabled: ErrReconfigDisabled, errBadArguments: ErrBadArguments, + errUnimplemented: ErrUnimplemented, } ) @@ -186,6 +198,7 @@ const ( errClosing ErrCode = -116 errNothing ErrCode = -117 errSessionMoved ErrCode = -118 + errNoWatcher ErrCode = -121 // Attempts to perform a reconfiguration operation when reconfiguration feature is disabled errZReconfigDisabled ErrCode = -123 ) @@ -224,6 +237,7 @@ var ( opClose: "close", opSetAuth: "setAuth", opSetWatches: "setWatches", + opAddWatch: "addWatch", opWatcherEvent: "watcherEvent", } @@ -263,3 +277,33 @@ var ( ModeStandalone: "standalone", } ) + +// AddWatchMode asd +type AddWatchMode int32 + +func (m AddWatchMode) String() string { + if name, ok := addWatchModeNames[m]; ok { + return name + } + return "unknown" +} + +const ( + AddWatchModePersistent AddWatchMode = iota + AddWatchModePersistentRecursive AddWatchMode = iota +) + +var ( + addWatchModeNames = map[AddWatchMode]string{ + AddWatchModePersistent: "persistent", + AddWatchModePersistentRecursive: "persistentRecursive", + } +) + +type WatcherType int32 + +const ( + WatcherTypeChildren = WatcherType(1) + WatcherTypeData = WatcherType(2) + WatcherTypeAny = WatcherType(3) +) diff --git a/go.mod b/go.mod index a2662730..f47e0825 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/go-zookeeper/zk +module github.com/PapaCharlie/zk -go 1.13 +go 1.18 diff --git a/server_help_test.go b/server_help_test.go index deb780bd..4ea1e7d0 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "path/filepath" + "runtime/debug" "strconv" "strings" "testing" @@ -35,6 +36,38 @@ type TestCluster struct { Servers []TestServer } +func WithTestCluster(t *testing.T, testTimeout time.Duration, f func(ts *TestCluster, zk *Conn)) { + ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + ts.Stop() + }) + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + t.Cleanup(func() { + zk.Close() + }) + doneChan := make(chan struct{}) + go func() { + defer func() { + close(doneChan) + if r := recover(); r != nil { + t.Error(r, string(debug.Stack())) + } + }() + f(ts, zk) + }() + select { + case <-doneChan: + case <-time.After(testTimeout): + t.Fatalf("Test did not complete within timeout") + } +} + // TODO: pull this into its own package to allow for better isolation of integration tests vs. unit // testing. This should be used on CI systems and local only when needed whereas unit tests should remain // fast and not rely on external dependencies. diff --git a/structs.go b/structs.go index 8eb41e39..66c10463 100644 --- a/structs.go +++ b/structs.go @@ -3,6 +3,7 @@ package zk import ( "encoding/binary" "errors" + "fmt" "log" "reflect" "runtime" @@ -18,8 +19,8 @@ var ( type defaultLogger struct{} -func (defaultLogger) Printf(format string, a ...interface{}) { - log.Printf(format, a...) +func (defaultLogger) Printf(format string, v ...interface{}) { + log.Output(3, fmt.Sprintf(format, v...)) } type ACL struct { @@ -135,7 +136,7 @@ type pathResponse struct { } type statResponse struct { - Stat Stat + Stat *Stat } // @@ -177,6 +178,10 @@ type CreateTTLRequest struct { } type createResponse pathResponse +type create2Response struct { + Path string + Stat *Stat +} type DeleteRequest PathVersionRequest type deleteResponse struct{} @@ -190,7 +195,7 @@ type getAclRequest pathRequest type getAclResponse struct { Acl []ACL - Stat Stat + Stat *Stat } type getChildrenRequest pathRequest @@ -199,18 +204,59 @@ type getChildrenResponse struct { Children []string } +type ReadOp interface { + GetPath() string + IsGetData() bool + IsGetChildren() bool + opCode() int32 +} + type getChildren2Request pathWatchRequest +type GetChildrenOp string + +func (g GetChildrenOp) IsGetData() bool { + return false +} + +func (g GetChildrenOp) IsGetChildren() bool { + return true +} + +func (g GetChildrenOp) GetPath() string { + return string(g) +} + +func (g GetChildrenOp) opCode() int32 { + return opGetChildren +} type getChildren2Response struct { Children []string - Stat Stat + Stat *Stat } type getDataRequest pathWatchRequest +type GetDataOp string + +func (g GetDataOp) IsGetData() bool { + return true +} + +func (g GetDataOp) IsGetChildren() bool { + return false +} + +func (g GetDataOp) GetPath() string { + return string(g) +} + +func (g GetDataOp) opCode() int32 { + return opGetData +} type getDataResponse struct { Data []byte - Stat Stat + Stat *Stat } type getMaxChildrenRequest pathRequest @@ -256,10 +302,12 @@ type setSaslResponse struct { } type setWatchesRequest struct { - RelativeZxid int64 - DataWatches []string - ExistWatches []string - ChildWatches []string + RelativeZxid int64 + DataWatches []string + ExistWatches []string + ChildWatches []string + PersistentWatches []string + PersistentRecursiveWatches []string } type setWatchesResponse struct{} @@ -275,8 +323,7 @@ type multiRequestOp struct { Op interface{} } type multiRequest struct { - Ops []multiRequestOp - DoneHeader multiHeader + Ops []multiRequestOp } type multiResponseOp struct { Header multiHeader @@ -285,8 +332,15 @@ type multiResponseOp struct { Err ErrCode } type multiResponse struct { - Ops []multiResponseOp - DoneHeader multiHeader + Ops []multiResponseOp +} +type MultiReadResponse struct { + getDataResponse + getChildrenResponse + Err error +} +type multiReadResponse struct { + OpResults []MultiReadResponse } // zk version 3.5 reconfig API @@ -301,6 +355,20 @@ type reconfigRequest struct { type reconfigReponse getDataResponse +type addWatchRequest struct { + Path string + Mode AddWatchMode +} + +type addWatchResponse struct{} + +type checkWatchesRequest struct { + Path string + Type WatcherType +} + +type checkWatchesResponse struct{} + func (r *multiRequest) Encode(buf []byte) (int, error) { total := 0 for _, op := range r.Ops { @@ -311,8 +379,7 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { } total += n } - r.DoneHeader.Done = true - n, err := encodePacketValue(buf[total:], reflect.ValueOf(r.DoneHeader)) + n, err := encodePacketValue(buf[total:], reflect.ValueOf(multiHeader{Type: -1, Done: true, Err: -1})) if err != nil { return total, err } @@ -323,7 +390,6 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { func (r *multiRequest) Decode(buf []byte) (int, error) { r.Ops = make([]multiRequestOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -333,7 +399,6 @@ func (r *multiRequest) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -355,7 +420,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { var multiErr error r.Ops = make([]multiResponseOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -365,7 +429,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -399,6 +462,48 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { return total, multiErr } +func (r *multiReadResponse) Decode(buf []byte) (total int, multiErr error) { + for { + header := &multiHeader{} + n, err := decodePacketValue(buf[total:], reflect.ValueOf(header)) + if err != nil { + return total, err + } + total += n + if header.Done { + break + } + + var res MultiReadResponse + var errCode ErrCode + var w reflect.Value + switch header.Type { + case opGetData: + w = reflect.ValueOf(&res.getDataResponse) + case opGetChildren: + w = reflect.ValueOf(&res.getChildrenResponse) + case opError: + w = reflect.ValueOf(&errCode) + default: + return total, ErrAPIError + } + + n, err = decodePacketValue(buf[total:], w) + if err != nil { + return total, err + } + total += n + + if errCode != errOk { + res.Err = errCode.toError() + } + + r.OpResults = append(r.OpResults, res) + } + + return total, nil +} + type watcherEvent struct { Type EventType State State @@ -598,7 +703,7 @@ func requestStructForOp(op int32) interface{} { switch op { case opClose: return &closeRequest{} - case opCreate: + case opCreate, opCreate2: return &CreateRequest{} case opCreateContainer: return &CreateContainerRequest{} @@ -622,7 +727,7 @@ func requestStructForOp(op int32) interface{} { return &setAclRequest{} case opSetData: return &SetDataRequest{} - case opSetWatches: + case opSetWatches, opSetWatches2: return &setWatchesRequest{} case opSync: return &syncRequest{} @@ -634,6 +739,8 @@ func requestStructForOp(op int32) interface{} { return &multiRequest{} case opReconfig: return &reconfigRequest{} + case opAddWatch: + return &addWatchRequest{} } return nil } diff --git a/structs_test.go b/structs_test.go index 3a38ab45..9c4e9024 100644 --- a/structs_test.go +++ b/structs_test.go @@ -10,7 +10,7 @@ func TestEncodeDecodePacket(t *testing.T) { encodeDecodeTest(t, &requestHeader{-2, 5}) encodeDecodeTest(t, &connectResponse{1, 2, 3, nil}) encodeDecodeTest(t, &connectResponse{1, 2, 3, []byte{4, 5, 6}}) - encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, Stat{}}) + encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, &Stat{}}) encodeDecodeTest(t, &getChildrenResponse{[]string{"foo", "bar"}}) encodeDecodeTest(t, &pathWatchRequest{"path", true}) encodeDecodeTest(t, &pathWatchRequest{"path", false}) diff --git a/unlimited_channel.go b/unlimited_channel.go new file mode 100644 index 00000000..4bb91427 --- /dev/null +++ b/unlimited_channel.go @@ -0,0 +1,104 @@ +//go:build go1.18 +// +build go1.18 + +package zk + +import ( + "sync" +) + +type EventQueue interface { + Next() <-chan Event + push(e Event) + close() +} + +type chanEventQueue chan Event + +func (c chanEventQueue) Next() <-chan Event { + return c +} + +func (c chanEventQueue) push(e Event) { + c <- e +} + +func (c chanEventQueue) close() { + close(c) +} + +func newChanEventChannel() chanEventQueue { + return make(chan Event, 1) +} + +func newUnlimitedChannelNode() *unlimitedEventQueueNode { + return &unlimitedEventQueueNode{event: make(chan Event, 1)} +} + +type unlimitedEventQueueNode struct { + event chan Event + next *unlimitedEventQueueNode +} + +type unlimitedEventQueue struct { + lock sync.Mutex + head *unlimitedEventQueueNode + tail *unlimitedEventQueueNode +} + +// newUnlimitedEventQueue uses a backing unlimitedEventQueue used to effectively turn a buffered channel into a channel +// with an infinite buffer by storing all incoming elements into a singly-linked queue and popping them as they are +// read. +func newUnlimitedEventQueue() *unlimitedEventQueue { + head := newUnlimitedChannelNode() + q := &unlimitedEventQueue{ + head: head, + tail: head, + } + return q +} + +func (q *unlimitedEventQueue) push(e Event) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.tail == nil { + // Panic like a closed channel + panic("send on closed unlimited channel") + } + + next := newUnlimitedChannelNode() + tail := q.tail + tail.next = next + q.tail = next + tail.event <- e +} + +func (q *unlimitedEventQueue) close() { + q.lock.Lock() + defer q.lock.Unlock() + close(q.tail.event) + q.tail = nil +} + +var closedChannel = func() <-chan Event { + ch := make(chan Event) + close(ch) + return ch +}() + +func (q *unlimitedEventQueue) Next() <-chan Event { + q.lock.Lock() + defer q.lock.Unlock() + + if q.head == nil && q.tail == nil { + return closedChannel + } + + node := q.head + if node.next == nil { + node.next = newUnlimitedChannelNode() + } + q.head = q.head.next + return node.event +} diff --git a/unlimited_channel_test.go b/unlimited_channel_test.go new file mode 100644 index 00000000..aad7f9cc --- /dev/null +++ b/unlimited_channel_test.go @@ -0,0 +1,63 @@ +//go:build go1.18 + +package zk + +import ( + "fmt" + "reflect" + "testing" + "time" +) + +func newEvent(i int) Event { + return Event{Path: fmt.Sprintf("/%d", i)} +} + +func TestUnlimitedChannel(t *testing.T) { + names := []string{"notClosedAfterPushes", "closeAfterPushes"} + for i, closeAfterPushes := range []bool{false, true} { + t.Run(names[i], func(t *testing.T) { + ch := newUnlimitedEventQueue() + const eventCount = 2 + + // check that events can be pushed without consumers + for i := 0; i < eventCount; i++ { + ch.push(newEvent(i)) + } + if closeAfterPushes { + ch.close() + } + + events := 0 + for { + actual, ok := <-ch.Next() + expected := newEvent(events) + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Did not receive expected event from queue (ok: %+v): actual %+v expected %+v", + ok, actual, expected) + } + events++ + if events == eventCount { + if closeAfterPushes { + select { + case _, ok := <-ch.Next(): + if ok { + t.Fatal("Next did not return closed channel") + } + case <-time.After(time.Second): + t.Fatal("Next never closed") + } + } else { + select { + case e, ok := <-ch.Next(): + t.Fatalf("Next received unexpected value (%+v) or was closed (%+v)", e, ok) + case <-time.After(time.Millisecond * 10): + return + } + } + break + } + } + }) + } +} diff --git a/watch.go b/watch.go new file mode 100644 index 00000000..b44c6ce4 --- /dev/null +++ b/watch.go @@ -0,0 +1 @@ +package zk diff --git a/watch_test.go b/watch_test.go new file mode 100644 index 00000000..b44c6ce4 --- /dev/null +++ b/watch_test.go @@ -0,0 +1 @@ +package zk diff --git a/zk_test.go b/zk_test.go index 603a91c5..fdeacbc1 100644 --- a/zk_test.go +++ b/zk_test.go @@ -1,6 +1,7 @@ package zk import ( + "bytes" "context" "fmt" "io" @@ -388,6 +389,134 @@ func TestMulti(t *testing.T) { } } +func TestMultiRead(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + nodeChildren := map[string][]string{} + nodeData := map[string][]byte{} + var ops []ReadOp + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } else { + dir, name := SplitPath(path) + nodeChildren[dir] = append(nodeChildren[dir], name) + nodeData[path] = data + ops = append(ops, GetDataOp(path), GetChildrenOp(path)) + } + } + + root := "/gozk-test" + create(root, nil) + + for i := byte(0); i < 10; i++ { + child := JoinPath(root, fmt.Sprint(i)) + create(child, []byte{i}) + } + + const foo = "foo" + create(JoinPath(JoinPath(root, "0"), foo), []byte(foo)) + + opResults, err := zk.MultiRead(ops...) + if err != nil { + t.Fatalf("MultiRead returned error: %+v", err) + } else if len(opResults) != len(ops) { + t.Fatalf("Expected %d responses got %d", len(ops), len(opResults)) + } + + nodeStats := map[string]*Stat{} + for k := range nodeData { + _, nodeStats[k], err = zk.Exists(k) + requireNoErrorf(t, err, "exists returned an error") + } + + for i, res := range opResults { + opPath := ops[i].GetPath() + switch op := ops[i].(type) { + case GetDataOp: + if res.Err != nil { + t.Fatalf("GetDataOp(%q) returned an error: %+v", op, res.Err) + } + if !bytes.Equal(res.Data, nodeData[opPath]) { + t.Fatalf("GetDataOp(%q).Data did not return %+v, got %+v", op, nodeData[opPath], res.Data) + } + if !reflect.DeepEqual(res.Stat, nodeStats[opPath]) { + t.Fatalf("GetDataOp(%q).Stat did not return %+v, got %+v", op, nodeStats[opPath], res.Stat) + } + case GetChildrenOp: + if res.Err != nil { + t.Fatalf("GetChildrenOp(%q) returned an error: %+v", opPath, res.Err) + } + // Cannot use DeepEqual here because it fails for []string{} == nil, even though in practice they are + // the same. + actual, expected := res.Children, nodeChildren[opPath] + if len(actual) != len(expected) { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + sort.Strings(actual) + sort.Strings(expected) + for i, c := range expected { + if actual[i] != c { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + } + } + } + + opResults, err = zk.MultiRead(GetDataOp("/invalid"), GetDataOp(root)) + requireNoErrorf(t, err, "MultiRead returned error") + + if opResults[0].Err != ErrNoNode { + t.Fatalf("MultiRead on invalid node did not return error") + } + if opResults[1].Err != nil { + t.Fatalf("MultiRead on valid node did not return error") + } + if !reflect.DeepEqual(opResults[1].Data, nodeData[root]) { + t.Fatalf("MultiRead on valid node did not return correct data") + } + }) +} + +func TestGetDataAndChildren(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + + const path = "/test" + _, _, _, err := zk.GetDataAndChildren(path) + if err != ErrNoNode { + t.Fatalf("GetDataAndChildren(%q) did not return an error", path) + } + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } + } + expectedData := []byte{1, 2, 3, 4} + create(path, expectedData) + var expectedChildren []string + for i := 0; i < 10; i++ { + child := fmt.Sprint(i) + create(JoinPath(path, child), nil) + expectedChildren = append(expectedChildren, child) + } + + data, _, children, err := zk.GetDataAndChildren(path) + requireNoErrorf(t, err, "GetDataAndChildren return an error") + + if !bytes.Equal(data, expectedData) { + t.Fatalf("GetDataAndChildren(%q) did not return expected data (expected %v): %v", path, expectedData, data) + } + sort.Strings(children) + if !reflect.DeepEqual(children, expectedChildren) { + t.Fatalf("GetDataAndChildren(%q) did not return expected children (expected %v): %v", + path, expectedChildren, children) + } + }) +} + func TestIfAuthdataSurvivesReconnect(t *testing.T) { // This test case ensures authentication data is being resubmited after // reconnect. @@ -440,6 +569,121 @@ func TestIfAuthdataSurvivesReconnect(t *testing.T) { } } +func TestPersistentWatchOnReconnect(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk2.Close() + + const testNode = "/gozk-test" + + if err := zk.Delete(testNode, -1); err != nil && err != ErrNoNode { + t.Fatalf("Delete returned error: %+v", err) + } + + watchEventsCh, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) + if err != nil { + t.Fatalf("AddPersistentWatch returned error: %+v", err) + } + + _, err = zk2.Create(testNode, []byte{1}, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned an error: %+v", err) + } + + // check to see that we received the node creation event + select { + case ev := <-watchEventsCh.Next(): + if ev.Type != EventNodeCreated { + t.Fatalf("Second event on persistent watch was not a node creation event: %+v", ev) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("Persistent watcher for %q did not receive node creation event", testNode) + } + + // Simulate network error by brutally closing the network connection. + zk.conn.Close() + + _, err = zk2.Set(testNode, []byte{2}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + + // zk should still be waiting to reconnect, so none of the watches should have been triggered + select { + case <-watchEventsCh.Next(): + t.Fatalf("Persistent watcher for %q should not have triggered yet", testNode) + case <-time.After(100 * time.Millisecond): + } + + // now we let the reconnect occur and make sure it resets watches + close(zk.reconnectLatch) + + // wait for reconnect event + select { + case ev := <-watchEventsCh.Next(): + if ev.Type != EventWatching { + t.Fatalf("Persistent watcher did not receive reconnect event: %+v", ev) + } + case <-time.After(5 * time.Second): + t.Fatalf("Persistent watcher for %q did not receive connection event", testNode) + } + + eventsReceived := 0 + timeout := time.After(2 * time.Second) + secondTimeout := time.After(4 * time.Second) + for { + select { + case e := <-watchEventsCh.Next(): + if e.Type != EventNodeDataChanged { + t.Fatalf("Unexpected event received by persistent watcher: %+v", e) + } + eventsReceived++ + case <-timeout: + _, err = zk2.Set(testNode, []byte{3}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + case <-secondTimeout: + switch eventsReceived { + case 2: + t.Fatalf("Sanity check failed: the persistent watch logic is based around the assumption that the " + + "setWatchers call _does not_ bootstrap you on reconnect based on the relative Zxid (unlike " + + "standard watches).") + case 1: + return + default: + t.Fatalf("Received no events after reconnect") + } + } + } + }) +} + +func TestPersistentWatchOnClose(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + WithTestCluster(t, 10*time.Second, func(_ *TestCluster, zk *Conn) { + ch, err := zk.AddPersistentWatch("/", AddWatchModePersistent) + if err != nil { + t.Fatalf("Could not add persistent watch: %+v", err) + } + zk.Close() + select { + case e := <-ch.Next(): + if e.Type != EventNotWatching { + t.Fatalf("Unexpected event: %+v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Did not get disconnect event") + } + }) +} + func TestMultiFailures(t *testing.T) { // This test case ensures that we return the errors associated with each // opeThis in the event a call to Multi() fails. From ab6c1f2b07a9c515eaf113c634e5cb116d3163bf Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Tue, 16 May 2023 15:24:04 -0400 Subject: [PATCH 3/4] Add metrics receiver and fix persistent watch EventQueue --- conn.go | 17 +++++- conn_test.go | 11 ++-- metrics.go | 20 +++++++ unlimited_channel.go | 107 ++++++++++++++++++++------------------ unlimited_channel_test.go | 83 ++++++++++++++++++++--------- zk_test.go | 38 +++++++++----- 6 files changed, 181 insertions(+), 95 deletions(-) create mode 100644 metrics.go diff --git a/conn.go b/conn.go index b58090f9..cb8067ed 100644 --- a/conn.go +++ b/conn.go @@ -95,6 +95,7 @@ type Conn struct { recvTimeout time.Duration connectTimeout time.Duration maxBufferSize int + metricReceiver MetricReceiver creds []authCreds credsMu sync.Mutex // protects server @@ -211,6 +212,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti logInfo: true, // default is true for backwards compatability buf: make([]byte, bufferSize), resendZkAuthFn: resendZkAuth, + metricReceiver: UnimplementedMetricReceiver{}, } // Set provided options. @@ -316,6 +318,12 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption { } } +func WithMetricReceiver(mr MetricReceiver) connOption { + return func(c *Conn) { + c.metricReceiver = mr + } +} + // Close will submit a close request with ZK and signal the connection to stop // sending and receiving packets. func (c *Conn) Close() { @@ -559,8 +567,8 @@ func (c *Conn) notifyWatches(ev Event) { ch.push(ev) if !wpt.wType.isPersistent() { ch.close() + delete(c.watchers, wpt) } - delete(c.watchers, wpt) } } @@ -841,6 +849,7 @@ func (c *Conn) sendLoop() error { c.conn.Close() return err } + c.metricReceiver.PingSent() case <-c.closeChan: return nil } @@ -899,6 +908,7 @@ func (c *Conn) recvLoop(conn net.Conn) error { c.notifyWatches(ev) } else if res.Xid == -2 { // Ping response. Ignore. + c.metricReceiver.PongReceived() } else if res.Xid < 0 { c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) } else { @@ -987,6 +997,11 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv } func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { + start := time.Now() + defer func() { + c.metricReceiver.RequestCompleted(time.Now().Sub(start)) + }() + recv := c.queueRequest(opcode, req, res, recvFunc) select { case r := <-recv: diff --git a/conn_test.go b/conn_test.go index 1fb6787a..a6cc998e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -221,7 +221,7 @@ func TestNotifyWatches(t *testing.T) { conn.addWatcher(wpt.path, wpt.wType, ch) notifications[idx].ch = ch if wpt.wType.isPersistent() { - e := <-ch.Next() + e, _ := ch.Next(context.Background()) if e.Type != EventWatching { t.Fatalf("First event on persistent watcher should always be EventWatching") } @@ -232,8 +232,11 @@ func TestNotifyWatches(t *testing.T) { conn.notifyWatches(Event{Type: c.eType, Path: c.path}) for _, res := range notifications { - select { - case e := <-res.ch.Next(): + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + t.Cleanup(cancel) + + e, err := res.ch.Next(ctx) + if err == nil { isPathCorrect := (res.wType == watchTypePersistentRecursive && strings.HasPrefix(e.Path, res.path)) || e.Path == res.path @@ -241,7 +244,7 @@ func TestNotifyWatches(t *testing.T) { t.Logf("unexpeted notification received by %+v: %+v", res, e) t.Fail() } - default: + } else { if res.notify { t.Logf("expected notification not received for %+v", res) t.Fail() diff --git a/metrics.go b/metrics.go new file mode 100644 index 00000000..ca7886d6 --- /dev/null +++ b/metrics.go @@ -0,0 +1,20 @@ +package zk + +import ( + "time" +) + +type MetricReceiver interface { + PingSent() + PongReceived() + RequestCompleted(duration time.Duration) +} + +var _ MetricReceiver = UnimplementedMetricReceiver{} + +type UnimplementedMetricReceiver struct { +} + +func (u UnimplementedMetricReceiver) PingSent() {} +func (u UnimplementedMetricReceiver) PongReceived() {} +func (u UnimplementedMetricReceiver) RequestCompleted(time.Duration) {} diff --git a/unlimited_channel.go b/unlimited_channel.go index 4bb91427..d3307d89 100644 --- a/unlimited_channel.go +++ b/unlimited_channel.go @@ -1,22 +1,32 @@ -//go:build go1.18 -// +build go1.18 - package zk import ( + "context" + "errors" "sync" ) +var ErrEventQueueClosed = errors.New("zk: event queue closed") + type EventQueue interface { - Next() <-chan Event + Next(ctx context.Context) (Event, error) push(e Event) close() } type chanEventQueue chan Event -func (c chanEventQueue) Next() <-chan Event { - return c +func (c chanEventQueue) Next(ctx context.Context) (Event, error) { + select { + case <-ctx.Done(): + return Event{}, ctx.Err() + case e, ok := <-c: + if !ok { + return Event{}, ErrEventQueueClosed + } else { + return e, nil + } + } } func (c chanEventQueue) push(e Event) { @@ -31,74 +41,67 @@ func newChanEventChannel() chanEventQueue { return make(chan Event, 1) } -func newUnlimitedChannelNode() *unlimitedEventQueueNode { - return &unlimitedEventQueueNode{event: make(chan Event, 1)} -} - -type unlimitedEventQueueNode struct { - event chan Event - next *unlimitedEventQueueNode -} - type unlimitedEventQueue struct { - lock sync.Mutex - head *unlimitedEventQueueNode - tail *unlimitedEventQueueNode + lock sync.Mutex + newEvent chan struct{} + events []Event } -// newUnlimitedEventQueue uses a backing unlimitedEventQueue used to effectively turn a buffered channel into a channel -// with an infinite buffer by storing all incoming elements into a singly-linked queue and popping them as they are -// read. func newUnlimitedEventQueue() *unlimitedEventQueue { - head := newUnlimitedChannelNode() - q := &unlimitedEventQueue{ - head: head, - tail: head, + return &unlimitedEventQueue{ + newEvent: make(chan struct{}), } - return q } func (q *unlimitedEventQueue) push(e Event) { q.lock.Lock() defer q.lock.Unlock() - if q.tail == nil { + if q.newEvent == nil { // Panic like a closed channel panic("send on closed unlimited channel") } - next := newUnlimitedChannelNode() - tail := q.tail - tail.next = next - q.tail = next - tail.event <- e + q.events = append(q.events, e) + close(q.newEvent) + q.newEvent = make(chan struct{}) } func (q *unlimitedEventQueue) close() { q.lock.Lock() defer q.lock.Unlock() - close(q.tail.event) - q.tail = nil -} -var closedChannel = func() <-chan Event { - ch := make(chan Event) - close(ch) - return ch -}() - -func (q *unlimitedEventQueue) Next() <-chan Event { - q.lock.Lock() - defer q.lock.Unlock() - - if q.head == nil && q.tail == nil { - return closedChannel + if q.newEvent == nil { + // Panic like a closed channel + panic("close of closed EventQueue") } - node := q.head - if node.next == nil { - node.next = newUnlimitedChannelNode() + close(q.newEvent) + q.newEvent = nil +} + +func (q *unlimitedEventQueue) Next(ctx context.Context) (Event, error) { + for { + q.lock.Lock() + if len(q.events) > 0 { + e := q.events[0] + q.events = q.events[1:] + q.lock.Unlock() + return e, nil + } + + ch := q.newEvent + if ch == nil { + q.lock.Unlock() + return Event{}, ErrEventQueueClosed + } + q.lock.Unlock() + + select { + case <-ctx.Done(): + return Event{}, ctx.Err() + case <-ch: + continue + } } - q.head = q.head.next - return node.event } diff --git a/unlimited_channel_test.go b/unlimited_channel_test.go index aad7f9cc..f0b292d3 100644 --- a/unlimited_channel_test.go +++ b/unlimited_channel_test.go @@ -3,6 +3,8 @@ package zk import ( + "context" + "errors" "fmt" "reflect" "testing" @@ -18,7 +20,7 @@ func TestUnlimitedChannel(t *testing.T) { for i, closeAfterPushes := range []bool{false, true} { t.Run(names[i], func(t *testing.T) { ch := newUnlimitedEventQueue() - const eventCount = 2 + const eventCount = 10 // check that events can be pushed without consumers for i := 0; i < eventCount; i++ { @@ -28,36 +30,65 @@ func TestUnlimitedChannel(t *testing.T) { ch.close() } - events := 0 - for { - actual, ok := <-ch.Next() + for events := 0; events < eventCount; events++ { + actual, err := ch.Next(context.Background()) + if err != nil { + t.Fatalf("Unexpected error returned from Next (events %d): %+v", events, err) + } expected := newEvent(events) if !reflect.DeepEqual(actual, expected) { - t.Fatalf("Did not receive expected event from queue (ok: %+v): actual %+v expected %+v", - ok, actual, expected) + t.Fatalf("Did not receive expected event from queue: actual %+v expected %+v", actual, expected) } - events++ - if events == eventCount { - if closeAfterPushes { - select { - case _, ok := <-ch.Next(): - if ok { - t.Fatal("Next did not return closed channel") - } - case <-time.After(time.Second): - t.Fatal("Next never closed") - } - } else { - select { - case e, ok := <-ch.Next(): - t.Fatalf("Next received unexpected value (%+v) or was closed (%+v)", e, ok) - case <-time.After(time.Millisecond * 10): - return - } - } - break + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + _, err := ch.Next(ctx) + if closeAfterPushes { + if err != ErrEventQueueClosed { + t.Fatalf("Did not receive expected error (%v) from Next: %v", ErrEventQueueClosed, err) + } + } else { + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Next did not exit with cancelled context: %+v", err) } } }) } + t.Run("interleaving", func(t *testing.T) { + ch := newUnlimitedEventQueue() + + for i := 0; i < 10; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + expected := newEvent(i) + + ctx = &customContext{ + Context: ctx, + f: func() { + ch.push(expected) + }, + } + + actual, err := ch.Next(ctx) + if err != nil { + t.Fatalf("Received unexpected error from Next: %+v", err) + } + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Unexpected event received from Next (expected %+v, actual %+v", expected, actual) + } + } + }) +} + +type customContext struct { + context.Context + f func() +} + +func (c *customContext) Done() <-chan struct{} { + c.f() + return c.Context.Done() } diff --git a/zk_test.go b/zk_test.go index fdeacbc1..94cab58e 100644 --- a/zk_test.go +++ b/zk_test.go @@ -586,11 +586,22 @@ func TestPersistentWatchOnReconnect(t *testing.T) { t.Fatalf("Delete returned error: %+v", err) } - watchEventsCh, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) + watchEventsQueue, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) if err != nil { t.Fatalf("AddPersistentWatch returned error: %+v", err) } + watchEventsCh := make(chan Event) + go func() { + e, err := watchEventsQueue.Next(context.Background()) + if err != nil { + close(watchEventsCh) + return + } else { + watchEventsCh <- e + } + }() + _, err = zk2.Create(testNode, []byte{1}, 0, WorldACL(PermAll)) if err != nil { t.Fatalf("Create returned an error: %+v", err) @@ -598,7 +609,7 @@ func TestPersistentWatchOnReconnect(t *testing.T) { // check to see that we received the node creation event select { - case ev := <-watchEventsCh.Next(): + case ev := <-watchEventsCh: if ev.Type != EventNodeCreated { t.Fatalf("Second event on persistent watch was not a node creation event: %+v", ev) } @@ -616,7 +627,7 @@ func TestPersistentWatchOnReconnect(t *testing.T) { // zk should still be waiting to reconnect, so none of the watches should have been triggered select { - case <-watchEventsCh.Next(): + case <-watchEventsCh: t.Fatalf("Persistent watcher for %q should not have triggered yet", testNode) case <-time.After(100 * time.Millisecond): } @@ -626,7 +637,7 @@ func TestPersistentWatchOnReconnect(t *testing.T) { // wait for reconnect event select { - case ev := <-watchEventsCh.Next(): + case ev := <-watchEventsCh: if ev.Type != EventWatching { t.Fatalf("Persistent watcher did not receive reconnect event: %+v", ev) } @@ -639,7 +650,7 @@ func TestPersistentWatchOnReconnect(t *testing.T) { secondTimeout := time.After(4 * time.Second) for { select { - case e := <-watchEventsCh.Next(): + case e := <-watchEventsCh: if e.Type != EventNodeDataChanged { t.Fatalf("Unexpected event received by persistent watcher: %+v", e) } @@ -673,13 +684,16 @@ func TestPersistentWatchOnClose(t *testing.T) { t.Fatalf("Could not add persistent watch: %+v", err) } zk.Close() - select { - case e := <-ch.Next(): - if e.Type != EventNotWatching { - t.Fatalf("Unexpected event: %+v", e) - } - case <-time.After(2 * time.Second): - t.Fatalf("Did not get disconnect event") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + e, err := ch.Next(ctx) + if err != nil { + t.Fatalf("Did not get disconnect event (%+v)", err) + } + if e.Type != EventNotWatching { + t.Fatalf("Unexpected event: %+v", e) } }) } From f77548697bfedea2a4515dbf6b64bb3af0eb6ca3 Mon Sep 17 00:00:00 2001 From: PapaCharlie Date: Wed, 17 May 2023 16:19:20 -0400 Subject: [PATCH 4/4] Add error to RequestCompleted --- conn.go | 4 ++-- metrics.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index cb8067ed..0550c452 100644 --- a/conn.go +++ b/conn.go @@ -996,10 +996,10 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv return rq.recvChan } -func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { +func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (_ int64, err error) { start := time.Now() defer func() { - c.metricReceiver.RequestCompleted(time.Now().Sub(start)) + c.metricReceiver.RequestCompleted(time.Now().Sub(start), err) }() recv := c.queueRequest(opcode, req, res, recvFunc) diff --git a/metrics.go b/metrics.go index ca7886d6..9294d3ed 100644 --- a/metrics.go +++ b/metrics.go @@ -7,7 +7,7 @@ import ( type MetricReceiver interface { PingSent() PongReceived() - RequestCompleted(duration time.Duration) + RequestCompleted(duration time.Duration, err error) } var _ MetricReceiver = UnimplementedMetricReceiver{} @@ -15,6 +15,6 @@ var _ MetricReceiver = UnimplementedMetricReceiver{} type UnimplementedMetricReceiver struct { } -func (u UnimplementedMetricReceiver) PingSent() {} -func (u UnimplementedMetricReceiver) PongReceived() {} -func (u UnimplementedMetricReceiver) RequestCompleted(time.Duration) {} +func (u UnimplementedMetricReceiver) PingSent() {} +func (u UnimplementedMetricReceiver) PongReceived() {} +func (u UnimplementedMetricReceiver) RequestCompleted(time.Duration, error) {}