diff --git a/conn.go b/conn.go index 3ee5a20b..9afd2d27 100644 --- a/conn.go +++ b/conn.go @@ -929,8 +929,10 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv } func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { - r := <-c.queueRequest(opcode, req, res, recvFunc) + recv := c.queueRequest(opcode, req, res, recvFunc) select { + case r := <-recv: + return r.zxid, r.err case <-c.shouldQuit: // queueRequest() can be racy, double-check for the race here and avoid // a potential data-race. otherwise the client of this func may try to @@ -938,8 +940,6 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc // NOTE: callers of this func should check for (at least) ErrConnectionClosed // and avoid accessing fields of the response object if such error is present. return -1, ErrConnectionClosed - default: - return r.zxid, r.err } } diff --git a/conn_test.go b/conn_test.go index 630693a6..96299280 100644 --- a/conn_test.go +++ b/conn_test.go @@ -57,6 +57,36 @@ func TestRecurringReAuthHang(t *testing.T) { } } +func TestConcurrentReadAndClose(t *testing.T) { + WithListenServer(t, func(server string) { + conn, _, err := Connect([]string{server}, 15*time.Second) + if err != nil { + t.Fatalf("Failed to create Connection %s", err) + } + + okChan := make(chan struct{}) + var setErr error + go func() { + _, setErr = conn.Create("/test-path", []byte("test data"), 0, WorldACL(PermAll)) + close(okChan) + }() + + go func() { + time.Sleep(1 * time.Second) + conn.Close() + }() + + select { + case <-okChan: + if setErr != ErrConnectionClosed { + t.Fatalf("unexpected error returned from Set %v", setErr) + } + case <-time.After(3 * time.Second): + t.Fatal("apparent deadlock!") + } + }) +} + func TestDeadlockInClose(t *testing.T) { c := &Conn{ shouldQuit: make(chan struct{}), diff --git a/tcp_server_test.go b/tcp_server_test.go new file mode 100644 index 00000000..09254948 --- /dev/null +++ b/tcp_server_test.go @@ -0,0 +1,36 @@ +package zk + +import ( + "fmt" + "math/rand" + "net" + "testing" + "time" +) + +func WithListenServer(t *testing.T, test func(server string)) { + startPort := int(rand.Int31n(6000) + 10000) + server := fmt.Sprintf("localhost:%d", startPort) + l, err := net.Listen("tcp", server) + if err != nil { + t.Fatalf("Failed to start listen server: %v", err) + } + defer l.Close() + + go func() { + conn, err := l.Accept() + if err != nil { + t.Logf("Failed to accept connection: %s", err.Error()) + } + + handleRequest(conn) + }() + + test(server) +} + +// Handles incoming requests. +func handleRequest(conn net.Conn) { + time.Sleep(5 * time.Second) + conn.Close() +}