diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index d1bcad7a..c1ac6767 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -73,6 +73,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { go func() { <-c.Protocol.DoneChan() close(c.blockChan) + close(c.startBatchResultChan) }() return c } @@ -95,7 +96,10 @@ func (c *Client) GetBlockRange(start common.Point, end common.Point) error { c.busyMutex.Unlock() return err } - err := <-c.startBatchResultChan + err, ok := <-c.startBatchResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } if err != nil { c.busyMutex.Unlock() return err @@ -112,7 +116,10 @@ func (c *Client) GetBlock(point common.Point) (ledger.Block, error) { c.busyMutex.Unlock() return nil, err } - err := <-c.startBatchResultChan + err, ok := <-c.startBatchResultChan + if !ok { + return nil, protocol.ProtocolShuttingDownError + } if err != nil { c.busyMutex.Unlock() return nil, err diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 3592e9f1..94865a2c 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -146,9 +146,15 @@ func (c *Client) GetCurrentTip() (*Tip, error) { if err := c.SendMessage(msg); err != nil { return nil, err } - tip := <-c.currentTipChan + tip, ok := <-c.currentTipChan + if !ok { + return nil, protocol.ProtocolShuttingDownError + } // Clear out intersect result channel to prevent blocking - <-c.intersectResultChan + _, ok = <-c.intersectResultChan + if !ok { + return nil, protocol.ProtocolShuttingDownError + } c.wantCurrentTip = false return &tip, nil } @@ -171,6 +177,8 @@ func (c *Client) GetAvailableBlockRange( gotIntersectResult := false for { select { + case <-c.DoneChan(): + return start, end, protocol.ProtocolShuttingDownError case tip := <-c.currentTipChan: end = tip.Point c.wantCurrentTip = false @@ -200,6 +208,8 @@ func (c *Client) GetAvailableBlockRange( } for { select { + case <-c.DoneChan(): + return start, end, protocol.ProtocolShuttingDownError case tip := <-c.currentTipChan: end = tip.Point c.wantCurrentTip = false @@ -237,7 +247,9 @@ func (c *Client) Sync(intersectPoints []common.Point) error { if err := c.SendMessage(msg); err != nil { return err } - if err := <-c.intersectResultChan; err != nil { + if err, ok := <-c.intersectResultChan; !ok { + return protocol.ProtocolShuttingDownError + } else if err != nil { return err } // Pipeline the initial block requests to speed things up a bit diff --git a/protocol/keepalive/client.go b/protocol/keepalive/client.go index dc10856f..8910d206 100644 --- a/protocol/keepalive/client.go +++ b/protocol/keepalive/client.go @@ -60,15 +60,9 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { // Start goroutine to cleanup resources on protocol shutdown go func() { <-c.Protocol.DoneChan() + // Stop any existing timer if c.timer != nil { - // Stop timer and drain channel - if ok := c.timer.Stop(); !ok { - // Read item from channel, if available - select { - case <-c.timer.C: - default: - } - } + c.timer.Stop() } }() return c @@ -93,13 +87,7 @@ func (c *Client) sendKeepAlive() { func (c *Client) startTimer() { // Stop any existing timer if c.timer != nil { - if ok := c.timer.Stop(); !ok { - // Read item from channel, if available - select { - case <-c.timer.C: - default: - } - } + c.timer.Stop() } // Create new timer c.timer = time.AfterFunc(c.config.Period, c.sendKeepAlive) diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index 91649030..8b80612a 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -154,7 +154,10 @@ func (c *Client) acquire(point *common.Point) error { if err := c.SendMessage(msg); err != nil { return err } - err := <-c.acquireResultChan + err, ok := <-c.acquireResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } return err } @@ -178,7 +181,10 @@ func (c *Client) runQuery(query interface{}, result interface{}) error { if err := c.SendMessage(msg); err != nil { return err } - resultCbor := <-c.queryResultChan + resultCbor, ok := <-c.queryResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } if _, err := cbor.Decode(resultCbor, result); err != nil { return err } diff --git a/protocol/localtxmonitor/client.go b/protocol/localtxmonitor/client.go index 0d1dcad1..c2e953c1 100644 --- a/protocol/localtxmonitor/client.go +++ b/protocol/localtxmonitor/client.go @@ -110,7 +110,10 @@ func (c *Client) acquire() error { return err } // Wait for reply - <-c.acquireResultChan + _, ok := <-c.acquireResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } return nil } diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index 75eaf6e6..50b52cff 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -94,7 +94,10 @@ func (c *Client) SubmitTx(eraId uint16, tx []byte) error { if err := c.SendMessage(msg); err != nil { return err } - err := <-c.submitResultChan + err, ok := <-c.submitResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } return err } diff --git a/protocol/peersharing/client.go b/protocol/peersharing/client.go index b65c9e18..0a2d3039 100644 --- a/protocol/peersharing/client.go +++ b/protocol/peersharing/client.go @@ -65,7 +65,10 @@ func (c *Client) GetPeers(amount uint8) ([]interface{}, error) { if err := c.SendMessage(msg); err != nil { return nil, err } - peers := <-c.sharePeersChan + peers, ok := <-c.sharePeersChan + if !ok { + return nil, protocol.ProtocolShuttingDownError + } return peers, nil }