From 849f26e628efab3781e38a081f56442cbe1d18fe Mon Sep 17 00:00:00 2001 From: haoqixu Date: Tue, 21 Nov 2023 23:24:30 +0800 Subject: [PATCH] add testcase from #205 --- client/wsclient.go | 11 ++- client/wsclient_test.go | 207 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 3 deletions(-) diff --git a/client/wsclient.go b/client/wsclient.go index 121354ea..eb6e5887 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -39,6 +39,10 @@ type wsClient struct { // The sender is responsible for sending portion of the OpAMP protocol. sender *internal.WSSender + + // Network connection timeout used for the WebSocket closing handshake. + // This field is currently only modified during testing. + connShutdownTimeout time.Duration } // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. @@ -49,8 +53,9 @@ func NewWebSocket(logger types.Logger) *wsClient { sender := internal.NewSender(logger) w := &wsClient{ - common: internal.NewClientCommon(logger, sender), - sender: sender, + common: internal.NewClientCommon(logger, sender), + sender: sender, + connShutdownTimeout: defaultShutdownTimeout, } return w } @@ -259,7 +264,7 @@ func (c *wsClient) runOneCycle(ctx context.Context) { select { case <-r.IsStopped(): c.common.Logger.Debugf("shutdown handshake complete.") - case <-time.After(defaultShutdownTimeout): + case <-time.After(c.connShutdownTimeout): c.common.Logger.Debugf("timeout waiting for close message.") // not receive close message from the server, close the connection to force the receive loop to stop _ = c.conn.Close() diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 7696e9e6..15540871 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -6,9 +6,11 @@ import ( "strings" "sync/atomic" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -177,3 +179,208 @@ func TestVerifyWSCompress(t *testing.T) { }) } } + +func TestHandlesStopBeforeStart(t *testing.T) { + client := NewWebSocket(nil) + require.Error(t, client.Stop(context.Background())) +} + +func TestPerformsClosingHandshake(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesSlowCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + time.Sleep(200 * time.Millisecond) + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesNoCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + wsConn.SetCloseHandler(func(code int, _ string) error { + // Don't send close message + return nil + }) + + go func() { + client.Stop(context.Background()) + closed <- struct{}{} + }() + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesConnectionError(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + // Write an invalid message to the connection. The client + // will take this as an error and reconnect to the server. + writer, err := wsConn.NextWriter(websocket.BinaryMessage) + require.NoError(t, err) + n, err := writer.Write([]byte{99, 1, 2, 3, 4, 5}) + require.NoError(t, err) + require.Equal(t, 6, n) + err = writer.Close() + require.NoError(t, err) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never re-established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + err = client.Stop(context.Background()) + require.NoError(t, err) +}