diff --git a/client/wsclient.go b/client/wsclient.go index d07f7bda..7ace7697 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -3,6 +3,7 @@ package client import ( "context" "errors" + "fmt" "net/http" "net/url" "sync" @@ -17,6 +18,9 @@ import ( "github.com/open-telemetry/opamp-go/protobufs" ) +var errStopping = errors.New("client is stopping or stopped, no more messages can be sent") +var errEarlyStop = errors.New("context canceled before shutdown could complete") + // wsClient is an OpAMP Client implementation for WebSocket transport. // See specification: https://github.com/open-telemetry/opamp-spec/blob/main/specification.md#websocket-transport type wsClient struct { @@ -35,8 +39,29 @@ type wsClient struct { // The sender is responsible for sending portion of the OpAMP protocol. sender *internal.WSSender + + // Indicates whether the client is open for more messages to be sent. + // Should be protected by connectionOpenMutex. + connectionOpen bool + // Indicates the connection is being written to. + // A read lock on this mutex indicates that a message is being queued for writing. + // A write lock on this mutex indicates that the connection is being shut down. + connectionOpenMutex sync.RWMutex + + // Sends a signal to the background processors controller thread to stop + // all background processors. + stopBGProcessing chan struct{} + // Responds to a signal from stopBGProcessing indicating that all processors + // have been stopped. + bgProcessingStopped chan struct{} + + // Network connection timeout used for the WebSocket closing handshake. + // This field is currently only modified during testing. + connShutdownTimeout time.Duration } +var _ OpAMPClient = &wsClient{} + // NewWebSocket creates a new OpAMP Client that uses WebSocket transport. func NewWebSocket(logger types.Logger) *wsClient { if logger == nil { @@ -45,8 +70,12 @@ func NewWebSocket(logger types.Logger) *wsClient { sender := internal.NewSender(logger) w := &wsClient{ - common: internal.NewClientCommon(logger, sender), - sender: sender, + common: internal.NewClientCommon(logger, sender), + sender: sender, + connectionOpen: true, + stopBGProcessing: make(chan struct{}, 1), + bgProcessingStopped: make(chan struct{}, 1), + connShutdownTimeout: 10 * time.Second, } return w } @@ -80,13 +109,63 @@ func (c *wsClient) Start(ctx context.Context, settings types.StartSettings) erro } func (c *wsClient) Stop(ctx context.Context) error { + // Prevent any additional writers from writing to the connection + // and stop reconnecting if the connection closes. + c.connectionOpenMutex.Lock() + c.connectionOpen = false + c.connectionOpenMutex.Unlock() + // Close connection if any. c.connMutex.RLock() conn := c.conn c.connMutex.RUnlock() if conn != nil { - _ = conn.Close() + // Shut down the sender and any other background processors. + c.stopBGProcessing <- struct{}{} + select { + case <-c.bgProcessingStopped: + case <-ctx.Done(): + c.closeConnection() + return errEarlyStop + } + + // At this point all other writers to the connection should be stopped. + // We can write to the connection without any risk of contention. + + defaultCloseHandler := conn.CloseHandler() + closed := make(chan struct{}) + + // The server should respond with a close message of its own, which will + // trigger this callback. At this point the close sequence has been + // completed and the TCP connection can be gracefully closed. + conn.SetCloseHandler(func(code int, text string) error { + err := defaultCloseHandler(code, text) + closed <- struct{}{} + return err + }) + + // Start the closing handshake by writing a close message to the server. + // If the server responds with its own close message, the connection reader will + // shut down and there will be no more reads from or writes to the connection. + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(c.connShutdownTimeout)) + + if err != nil { + c.closeConnection() + return fmt.Errorf("could not write close message to WebSocket, connection closed without performing closing handshake: %w", err) + } + + select { + case <-closed: + // runOneCycle will close the connection if the closing handshake completed, + // so there's no need to close it here. + case <-time.After(c.connShutdownTimeout): + c.closeConnection() + case <-ctx.Done(): + c.closeConnection() + return errEarlyStop + } } return c.common.Stop(ctx) @@ -97,22 +176,47 @@ func (c *wsClient) AgentDescription() *protobufs.AgentDescription { } func (c *wsClient) SetAgentDescription(descr *protobufs.AgentDescription) error { + c.connectionOpenMutex.RLock() + defer c.connectionOpenMutex.RUnlock() + if !c.connectionOpen { + return errStopping + } return c.common.SetAgentDescription(descr) } func (c *wsClient) SetHealth(health *protobufs.AgentHealth) error { + c.connectionOpenMutex.RLock() + defer c.connectionOpenMutex.RUnlock() + if !c.connectionOpen { + return errStopping + } return c.common.SetHealth(health) } func (c *wsClient) UpdateEffectiveConfig(ctx context.Context) error { + c.connectionOpenMutex.RLock() + defer c.connectionOpenMutex.RUnlock() + if !c.connectionOpen { + return errStopping + } return c.common.UpdateEffectiveConfig(ctx) } func (c *wsClient) SetRemoteConfigStatus(status *protobufs.RemoteConfigStatus) error { + c.connectionOpenMutex.RLock() + defer c.connectionOpenMutex.RUnlock() + if !c.connectionOpen { + return errStopping + } return c.common.SetRemoteConfigStatus(status) } func (c *wsClient) SetPackageStatuses(statuses *protobufs.PackageStatuses) error { + c.connectionOpenMutex.RLock() + defer c.connectionOpenMutex.RUnlock() + if !c.connectionOpen { + return errStopping + } return c.common.SetPackageStatuses(statuses) } @@ -192,12 +296,29 @@ func (c *wsClient) ensureConnected(ctx context.Context) error { } } +func (c *wsClient) closeConnection() { + c.connMutex.Lock() + defer c.connMutex.Unlock() + + if c.conn == nil { + return + } + + // Close the connection. + _ = c.conn.Close() + + // Unset the field to indicate that the connection is closed. + c.conn = nil +} + // runOneCycle performs the following actions: -// 1. connect (try until succeeds). -// 2. send first status report. -// 3. receive and process messages until error happens. +// 1. connect (try until succeeds). +// 2. set up a background processor to send messages. +// 3. send first status report. +// 4. receive and process messages until an error occurs or the connection closes. +// // If it encounters an error it closes the connection and returns. -// Will stop and return if Stop() is called (ctx is cancelled, isStopping is set). +// Will stop and return if Stop() is called. func (c *wsClient) runOneCycle(ctx context.Context) { if err := c.ensureConnected(ctx); err != nil { // Can't connect, so can't move forward. This currently happens when we @@ -205,8 +326,9 @@ func (c *wsClient) runOneCycle(ctx context.Context) { return } + defer c.closeConnection() + if c.common.IsStopping() { - _ = c.conn.Close() return } @@ -220,12 +342,24 @@ func (c *wsClient) runOneCycle(ctx context.Context) { // Create a cancellable context for background processors. procCtx, procCancel := context.WithCancel(ctx) + // Stop background processors if we receive a signal to do so. + // Note that the receiver does not respond to signals and + // will only stop when the connection closes or errors. + go func() { + select { + case <-c.stopBGProcessing: + procCancel() + c.sender.WaitToStop() + close(c.bgProcessingStopped) + case <-procCtx.Done(): + } + }() + // Connected successfully. Start the sender. This will also send the first // status report. if err := c.sender.Start(procCtx, c.conn); err != nil { c.common.Logger.Errorf("Failed to send first status report: %v", err) // We could not send the report, the only thing we can do is start over. - _ = c.conn.Close() procCancel() return } @@ -242,15 +376,13 @@ func (c *wsClient) runOneCycle(ctx context.Context) { ) r.ReceiverLoop(ctx) + // If we exited receiverLoop it means there is a connection error or the closing handshake + // has completed. We cannot read messages anymore, so clean up the connection. + // If there is a connection error we will need to start over. + // Stop the background processors. procCancel() - // If we exited receiverLoop it means there is a connection error, we cannot - // read messages anymore. We need to start over. - - // Close the connection to unblock the WSSender as well. - _ = c.conn.Close() - // Wait for WSSender to stop. c.sender.WaitToStop() } @@ -258,9 +390,12 @@ func (c *wsClient) runOneCycle(ctx context.Context) { func (c *wsClient) runUntilStopped(ctx context.Context) { // Iterates until we detect that the client is stopping. for { - if c.common.IsStopping() { + c.connectionOpenMutex.RLock() + if c.common.IsStopping() || !c.connectionOpen { + c.connectionOpenMutex.RUnlock() return } + c.connectionOpenMutex.RUnlock() c.runOneCycle(ctx) } diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 715b140f..b7bd48da 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "strings" + "sync" "sync/atomic" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "github.com/open-telemetry/opamp-go/client/internal" @@ -177,3 +180,294 @@ func TestVerifyWSCompress(t *testing.T) { }) } } + +func TestHandlesStopBeforeStart(t *testing.T) { + client := NewWebSocket(nil) + require.Error(t, client.Stop(context.Background())) +} + +func TestPerformsClosingHandshake(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesSlowCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + defHandler := wsConn.CloseHandler() + + wsConn.SetCloseHandler(func(code int, _ string) error { + require.Equal(t, websocket.CloseNormalClosure, code, "Client sent non-normal closing code") + + time.Sleep(200 * time.Millisecond) + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + client.Stop(context.Background()) + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesNoCloseMessageFromServer(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + client.connShutdownTimeout = 100 * time.Millisecond + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + wsConn.SetCloseHandler(func(code int, _ string) error { + // Don't send close message + return nil + }) + + go func() { + client.Stop(context.Background()) + closed <- struct{}{} + }() + + select { + case <-closed: + case <-time.After(1 * time.Second): + require.Fail(t, "Connection never closed") + } +} + +func TestHandlesConnectionError(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + // Write an invalid message to the connection. The client + // will take this as an error and reconnect to the server. + writer, err := wsConn.NextWriter(websocket.BinaryMessage) + require.NoError(t, err) + n, err := writer.Write([]byte{99, 1, 2, 3, 4, 5}) + require.NoError(t, err) + require.Equal(t, 6, n) + err = writer.Close() + require.NoError(t, err) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never re-established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + err = client.Stop(context.Background()) + require.NoError(t, err) +} + +func TestDisallowsSendingAfterStopped(t *testing.T) { + srv := internal.StartMockServer(t) + var wsConn *websocket.Conn + connected := make(chan struct{}) + closed := make(chan struct{}) + + srv.OnWSConnect = func(conn *websocket.Conn) { + wsConn = conn + connected <- struct{}{} + } + + client := NewWebSocket(nil) + startClient(t, types.StartSettings{ + OpAMPServerURL: srv.GetHTTPTestServer().URL, + }, client) + + select { + case <-connected: + case <-time.After(2 * time.Second): + require.Fail(t, "Connection never established") + } + + require.Eventually(t, func() bool { + client.connMutex.RLock() + conn := client.conn + client.connMutex.RUnlock() + return conn != nil + }, 2*time.Second, 250*time.Millisecond) + + wg := sync.WaitGroup{} + send := make(chan struct{}) + + defHandler := wsConn.CloseHandler() + wsConn.SetCloseHandler(func(code int, _ string) error { + close(send) + // Pause the stopping process to ensure that sends are disallowed while the client + // is stopping, not necessarily just after it has stopped. + wg.Wait() + err := defHandler(code, "") + closed <- struct{}{} + return err + }) + + wg.Add(5) + go func() { + err := client.Stop(context.Background()) + require.NoError(t, err) + }() + go func() { + <-send + err := client.SetAgentDescription(&protobufs.AgentDescription{}) + require.Error(t, err) + wg.Done() + }() + go func() { + <-send + err := client.SetHealth(&protobufs.AgentHealth{}) + require.Error(t, err) + wg.Done() + }() + go func() { + <-send + err := client.UpdateEffectiveConfig(context.Background()) + require.Error(t, err) + wg.Done() + }() + go func() { + <-send + err := client.SetRemoteConfigStatus(&protobufs.RemoteConfigStatus{}) + require.Error(t, err) + wg.Done() + }() + go func() { + <-send + err := client.SetPackageStatuses(&protobufs.PackageStatuses{}) + require.Error(t, err) + wg.Done() + }() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Error("Connection failed to close") + } +} diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 04827d38..b6dec72c 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -751,20 +751,12 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) { defer conn.Close() - timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - select { - case <-timeout.Done(): - t.Error("Client failed to connect before timeout") - default: - if _, ok := srvConnVal.Load().(types.Connection); ok == true { - break - } - } - - cancel() + require.Eventually(t, func() bool { + return srvConnVal.Load() != nil + }, 2*time.Second, 250*time.Millisecond) - srvConn := srvConnVal.Load().(types.Connection) + srvConn, ok := srvConnVal.Load().(types.Connection) + require.True(t, ok, "The server connection is not a types.Connection") for i := 0; i < 20; i++ { go func() { defer func() {