From d7b580d1f05c6468761600c1e7548be5154e0d5e Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Sun, 10 Nov 2024 07:27:20 -0600 Subject: [PATCH] fix: protect against send on closed channel in protocols (#783) --- protocol/blockfetch/client.go | 18 ++++++++++++++++++ protocol/localstatequery/client.go | 18 ++++++++++++++++++ protocol/localtxsubmission/client.go | 12 ++++++++++++ protocol/txsubmission/server.go | 12 ++++++++++++ 4 files changed, 60 insertions(+) diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index c96b8d06..e9f688d8 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -206,6 +206,12 @@ func (c *Client) handleStartBatch() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } c.startBatchResultChan <- nil return nil } @@ -218,6 +224,12 @@ func (c *Client) handleNoBlocks() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } err := fmt.Errorf("block(s) not found") c.startBatchResultChan <- err return nil @@ -244,6 +256,12 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error { if err != nil { return err } + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } // We use the callback when requesting ranges and the internal channel for a single block if c.blockUseCallback { if err := c.config.BlockFunc(c.callbackContext, wrappedBlock.Type, blk); err != nil { diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index 0bd565ae..85be0a68 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -849,6 +849,12 @@ func (c *Client) handleAcquired() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } c.acquired = true c.acquireResultChan <- nil c.currentEra = -1 @@ -863,6 +869,12 @@ func (c *Client) handleFailure(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } msgFailure := msg.(*MsgFailure) switch msgFailure.Failure { case AcquireFailurePointTooOld: @@ -883,6 +895,12 @@ func (c *Client) handleResult(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } msgResult := msg.(*MsgResult) c.queryResultChan <- msgResult.Result return nil diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index 37b1b11e..f2c02fb2 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -155,6 +155,12 @@ func (c *Client) handleAcceptTx() error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } c.submitResultChan <- nil return nil } @@ -167,6 +173,12 @@ func (c *Client) handleRejectTx(msg protocol.Message) error { "role", "client", "connection_id", c.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-c.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } msgRejectTx := msg.(*MsgRejectTx) rejectErr, err := ledger.NewTxSubmitErrorFromCbor(msgRejectTx.Reason) if err != nil { diff --git a/protocol/txsubmission/server.go b/protocol/txsubmission/server.go index 67ca6240..ae72d135 100644 --- a/protocol/txsubmission/server.go +++ b/protocol/txsubmission/server.go @@ -163,6 +163,12 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error { "role", "server", "connection_id", s.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-s.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } msgReplyTxIds := msg.(*MsgReplyTxIds) s.requestTxIdsResultChan <- msgReplyTxIds.TxIds return nil @@ -176,6 +182,12 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error { "role", "server", "connection_id", s.callbackContext.ConnectionId.String(), ) + // Check for shutdown + select { + case <-s.Protocol.DoneChan(): + return protocol.ProtocolShuttingDownError + default: + } msgReplyTxs := msg.(*MsgReplyTxs) s.requestTxsResultChan <- msgReplyTxs.Txs return nil