Skip to content

Commit

Permalink
Merge pull request #5442 from oasisprotocol/peternose/bugfix/conn-cal…
Browse files Browse the repository at this point in the history
…l-deadlock

go/runtime/host/protocol/connection: Cancel call if connection is closed
  • Loading branch information
peternose authored Nov 13, 2023
2 parents 12a5d62 + 61710b5 commit 13f9e38
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 52 deletions.
1 change: 1 addition & 0 deletions .changelog/5442.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
go/runtime/host/protocol/connection: Cancel call if connection is closed
103 changes: 54 additions & 49 deletions go/runtime/host/protocol/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ type connection struct { // nolint: maligned
handler Handler

state state
pendingRequests map[uint64]chan *Body
pendingRequests map[uint64]chan<- *Body
nextRequestID uint64

info *RuntimeInfoResponse
Expand Down Expand Up @@ -293,51 +293,39 @@ func (c *connection) call(ctx context.Context, body *Body) (result *Body, err er
}
}()

respCh, err := c.makeRequest(ctx, body)
if err != nil {
return nil, err
}

select {
case resp, ok := <-respCh:
if !ok {
return nil, fmt.Errorf("channel closed")
}

if resp.Error != nil {
// Decode error.
err = errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message)
return nil, err
}

return resp, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (c *connection) makeRequest(ctx context.Context, body *Body) (<-chan *Body, error) {
// Create channel for sending the response and grab next request identifier.
ch := make(chan *Body, 1)
respCh := make(chan *Body, 1)

c.Lock()
id := c.nextRequestID
c.nextRequestID++
c.pendingRequests[id] = ch
c.pendingRequests[id] = respCh
c.Unlock()

defer func() {
c.Lock()
defer c.Unlock()
delete(c.pendingRequests, id)
}()

msg := Message{
ID: id,
MessageType: MessageRequest,
Body: *body,
}

// Queue the message.
if err := c.sendMessage(ctx, &msg); err != nil {
if err = c.sendMessage(ctx, &msg); err != nil {
return nil, fmt.Errorf("failed to send message: %w", err)
}

return ch, nil
// Await a response.
resp, err := c.readResponse(ctx, respCh)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}

return resp, nil
}

func (c *connection) sendMessage(ctx context.Context, msg *Message) error {
Expand All @@ -351,9 +339,23 @@ func (c *connection) sendMessage(ctx context.Context, msg *Message) error {
}
}

func (c *connection) workerOutgoing() {
defer c.quitWg.Done()
func (c *connection) readResponse(ctx context.Context, respCh <-chan *Body) (*Body, error) {
select {
case resp := <-respCh:
if resp.Error != nil {
// Decode error.
return nil, errors.FromCode(resp.Error.Module, resp.Error.Code, resp.Error.Message)
}

return resp, nil
case <-c.closeCh:
return nil, fmt.Errorf("connection closed")
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (c *connection) workerOutgoing() {
for {
select {
case msg := <-c.outCh:
Expand Down Expand Up @@ -450,7 +452,6 @@ func (c *connection) handleMessage(ctx context.Context, message *Message) {
}

respCh <- &message.Body
close(respCh)
default:
c.logger.Warn("received a malformed message from worker, ignoring",
"message", fmt.Sprintf("%+v", message),
Expand All @@ -459,24 +460,18 @@ func (c *connection) handleMessage(ctx context.Context, message *Message) {
}

func (c *connection) workerIncoming() {
// Wait for request handlers to finish.
var wg sync.WaitGroup
defer wg.Wait()

// Cancel all request handlers.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

defer func() {
// Close connection and signal that connection is closed.
_ = c.conn.Close()
close(c.closeCh)

// Cancel all request handlers.
cancel()

// Close all pending request channels.
c.Lock()
for id, ch := range c.pendingRequests {
close(ch)
delete(c.pendingRequests, id)
}
c.Unlock()

c.quitWg.Done()
}()

for {
Expand All @@ -491,7 +486,11 @@ func (c *connection) workerIncoming() {
}

// Handle message in a separate goroutine.
go c.handleMessage(ctx, &message)
wg.Add(1)
go func() {
defer wg.Done()
c.handleMessage(ctx, &message)
}()
}
}

Expand All @@ -507,8 +506,14 @@ func (c *connection) initConn(conn net.Conn) {
c.codec = cbor.NewMessageCodec(conn, moduleName)

c.quitWg.Add(2)
go c.workerIncoming()
go c.workerOutgoing()
go func() {
defer c.quitWg.Done()
c.workerIncoming()
}()
go func() {
defer c.quitWg.Done()
c.workerOutgoing()
}()

// Change protocol state to Initializing so that some of the requests are allowed.
c.setStateLocked(stateInitializing)
Expand Down Expand Up @@ -583,7 +588,7 @@ func NewConnection(logger *logging.Logger, runtimeID common.Namespace, handler H
runtimeID: runtimeID,
handler: handler,
state: stateUninitialized,
pendingRequests: make(map[uint64]chan *Body),
pendingRequests: make(map[uint64]chan<- *Body),
outCh: make(chan *Message),
closeCh: make(chan struct{}),
logger: logger,
Expand Down
7 changes: 4 additions & 3 deletions go/runtime/host/sgx/sgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ const (

// Runtime RAK initialization timeout.
//
// This can take a long time in deployments that run multiple
// nodes on a single machine, all sharing the same EPC.
runtimeRAKTimeout = 60 * time.Second
// This can take a long time in deployments that run multiple nodes on a single machine, all
// sharing the same EPC. Additionally, this includes time to do the initial consensus light
// client sync and freshness verification which can take some time.
runtimeRAKTimeout = 5 * time.Minute
// Runtime attest interval.
defaultRuntimeAttestInterval = 2 * time.Hour
)
Expand Down

0 comments on commit 13f9e38

Please sign in to comment.