Skip to content

Commit

Permalink
Merge pull request #35 from smartcontractkit/fix/potential-write-dead…
Browse files Browse the repository at this point in the history
…lock

Fix potential deadlock in client & server writes
  • Loading branch information
skubakdj authored May 16, 2023
2 parents b5d02d0 + 1d7aa8b commit fef92b5
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func (cc *ClientConn) handleMessageRequest(r *message.Request) {
cc.conn.mu.RUnlock()
cc.mu.RUnlock()

if err := tr.Write(replyMsg); err != nil {
if err := tr.Write(ctx, replyMsg); err != nil {
cc.dopts.logger.Errorf("error writing to transport: %s", err)
}
}
Expand Down Expand Up @@ -376,7 +376,7 @@ func (cc *ClientConn) Invoke(ctx context.Context, method string, args interface{
cc.conn.mu.RUnlock()
cc.mu.RUnlock()

if err := tr.Write(reqB); err != nil {
if err := tr.Write(ctx, reqB); err != nil {
return err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ClientTransport interface {
Read() <-chan []byte

// Write sends a message to the stream.
Write(msg []byte) error
Write(ctx context.Context, msg []byte) error

// Close tears down this transport. Once it returns, the transport
// should not be accessed any more.
Expand Down Expand Up @@ -73,7 +73,7 @@ type ServerTransport interface {
Read() <-chan []byte

// Write sends a message to the stream.
Write(msg []byte) error
Write(ctx context.Context, msg []byte) error

// Close tears down the transport. Once it is called, the transport
// should not be accessed any more.
Expand Down
15 changes: 11 additions & 4 deletions internal/transport/websocket_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,17 @@ func (c *WebsocketClient) Read() <-chan []byte {
}

// Write writes a message the websocket connection.
func (c *WebsocketClient) Write(msg []byte) error {
c.write <- msg

return nil
func (c *WebsocketClient) Write(ctx context.Context, msg []byte) error {
select {
case <-c.done:
return fmt.Errorf("[wsrpc] could not write message, websocket is closed")
case <-c.interrupt:
return fmt.Errorf("[wsrpc] could not write message, transport is closed")
case <-ctx.Done():
return fmt.Errorf("[wsrpc] could not write message, context is done")
case c.write <- msg:
return nil
}
}

// Close closes the websocket connection and cleans up pump goroutines.
Expand Down
18 changes: 13 additions & 5 deletions internal/transport/websocket_server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package transport

import (
"context"
"fmt"
"log"
"sync"
"time"
Expand Down Expand Up @@ -61,11 +63,17 @@ func (s *WebsocketServer) Read() <-chan []byte {
}

// Write writes a message the websocket connection.
func (s *WebsocketServer) Write(msg []byte) error {
// Send the message to the channel
s.write <- msg

return nil
func (s *WebsocketServer) Write(ctx context.Context, msg []byte) error {
select {
case <-s.done:
return fmt.Errorf("[wsrpc] could not write message, websocket is closed")
case <-s.interrupt:
return fmt.Errorf("[wsrpc] could not write message, transport is closed")
case <-ctx.Done():
return fmt.Errorf("[wsrpc] could not write message, context is done")
case s.write <- msg:
return nil
}
}

// Close closes the websocket connection and cleans up pump goroutines. Notifies
Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (s *Server) wshandler(w http.ResponseWriter, r *http.Request) {
}

// sendMsg writes the message to the connection which matches the public key.
func (s *Server) sendMsg(pub [32]byte, msg []byte) error {
func (s *Server) sendMsg(ctx context.Context, pub [32]byte, msg []byte) error {
// Find the transport matching the public key
s.mu.RLock()
tr, err := s.connMgr.getTransport(pub)
Expand All @@ -186,7 +186,7 @@ func (s *Server) sendMsg(pub [32]byte, msg []byte) error {
return err
}

return tr.Write(msg)
return tr.Write(ctx, msg)
}

// handleRead listens to the transport read channel and passes the message to the
Expand Down Expand Up @@ -248,7 +248,7 @@ func (s *Server) handleMessageRequest(pubKey credentials.StaticSizedPublicKey, r
return
}

if err := s.sendMsg(pubKey, replyMsg); err != nil {
if err := s.sendMsg(ctx, pubKey, replyMsg); err != nil {
log.Printf("error sending message: %s", err)
}
}
Expand Down Expand Up @@ -314,7 +314,7 @@ func (s *Server) Invoke(ctx context.Context, method string, args interface{}, re
}
pubKey := p.PublicKey

if err = s.sendMsg(pubKey, req); err != nil {
if err = s.sendMsg(ctx, pubKey, req); err != nil {
return err
}

Expand Down

0 comments on commit fef92b5

Please sign in to comment.