Skip to content

Commit

Permalink
fix: protect against send on closed channel in protocols (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
agaffney authored Nov 10, 2024
1 parent b8c2661 commit d7b580d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 0 deletions.
18 changes: 18 additions & 0 deletions protocol/blockfetch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ func (c *Client) handleStartBatch() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
c.startBatchResultChan <- nil
return nil
}
Expand All @@ -218,6 +224,12 @@ func (c *Client) handleNoBlocks() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
err := fmt.Errorf("block(s) not found")
c.startBatchResultChan <- err
return nil
Expand All @@ -244,6 +256,12 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error {
if err != nil {
return err
}
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
// We use the callback when requesting ranges and the internal channel for a single block
if c.blockUseCallback {
if err := c.config.BlockFunc(c.callbackContext, wrappedBlock.Type, blk); err != nil {
Expand Down
18 changes: 18 additions & 0 deletions protocol/localstatequery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,12 @@ func (c *Client) handleAcquired() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
c.acquired = true
c.acquireResultChan <- nil
c.currentEra = -1
Expand All @@ -863,6 +869,12 @@ func (c *Client) handleFailure(msg protocol.Message) error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
msgFailure := msg.(*MsgFailure)
switch msgFailure.Failure {
case AcquireFailurePointTooOld:
Expand All @@ -883,6 +895,12 @@ func (c *Client) handleResult(msg protocol.Message) error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
msgResult := msg.(*MsgResult)
c.queryResultChan <- msgResult.Result
return nil
Expand Down
12 changes: 12 additions & 0 deletions protocol/localtxsubmission/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ func (c *Client) handleAcceptTx() error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
c.submitResultChan <- nil
return nil
}
Expand All @@ -167,6 +173,12 @@ func (c *Client) handleRejectTx(msg protocol.Message) error {
"role", "client",
"connection_id", c.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-c.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
msgRejectTx := msg.(*MsgRejectTx)
rejectErr, err := ledger.NewTxSubmitErrorFromCbor(msgRejectTx.Reason)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions protocol/txsubmission/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
"role", "server",
"connection_id", s.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-s.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
msgReplyTxIds := msg.(*MsgReplyTxIds)
s.requestTxIdsResultChan <- msgReplyTxIds.TxIds
return nil
Expand All @@ -176,6 +182,12 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error {
"role", "server",
"connection_id", s.callbackContext.ConnectionId.String(),
)
// Check for shutdown
select {
case <-s.Protocol.DoneChan():
return protocol.ProtocolShuttingDownError
default:
}
msgReplyTxs := msg.(*MsgReplyTxs)
s.requestTxsResultChan <- msgReplyTxs.Txs
return nil
Expand Down

0 comments on commit d7b580d

Please sign in to comment.