diff --git a/internal/decision/engine.go b/internal/decision/engine.go index 4a49c243..620bb868 100644 --- a/internal/decision/engine.go +++ b/internal/decision/engine.go @@ -745,32 +745,19 @@ func (e *Engine) MessageSent(p peer.ID, m bsmsg.BitSwapMessage) { func (e *Engine) PeerConnected(p peer.ID) { e.lock.Lock() defer e.lock.Unlock() - l, ok := e.ledgerMap[p] + + _, ok := e.ledgerMap[p] if !ok { - l = newLedger(p) - e.ledgerMap[p] = l + e.ledgerMap[p] = newLedger(p) } - - l.lk.Lock() - defer l.lk.Unlock() - l.ref++ } // PeerDisconnected is called when a peer disconnects. func (e *Engine) PeerDisconnected(p peer.ID) { e.lock.Lock() defer e.lock.Unlock() - l, ok := e.ledgerMap[p] - if !ok { - return - } - l.lk.Lock() - defer l.lk.Unlock() - l.ref-- - if l.ref <= 0 { - delete(e.ledgerMap, p) - } + delete(e.ledgerMap, p) } // If the want is a want-have, and it's below a certain size, send the full diff --git a/internal/decision/ledger.go b/internal/decision/ledger.go index 8f103bd4..87fedc45 100644 --- a/internal/decision/ledger.go +++ b/internal/decision/ledger.go @@ -43,10 +43,6 @@ type ledger struct { // wantList is a (bounded, small) set of keys that Partner desires. wantList *wl.Wantlist - // ref is the reference count for this ledger, its used to ensure we - // don't drop the reference to this ledger in multi-connection scenarios - ref int - lk sync.RWMutex } diff --git a/internal/messagequeue/messagequeue.go b/internal/messagequeue/messagequeue.go index d42db10d..9fcab6d3 100644 --- a/internal/messagequeue/messagequeue.go +++ b/internal/messagequeue/messagequeue.go @@ -25,7 +25,8 @@ const ( defaultRebroadcastInterval = 30 * time.Second // maxRetries is the number of times to attempt to send a message before // giving up - maxRetries = 10 + maxRetries = 3 + sendTimeout = 30 * time.Second // maxMessageSize is the maximum message size in bytes maxMessageSize = 1024 * 1024 * 2 // sendErrorBackoff is the time to wait before retrying to connect after @@ -46,7 +47,7 @@ const ( // sender. type MessageNetwork interface { ConnectTo(context.Context, peer.ID) error - NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error) + NewMessageSender(context.Context, peer.ID, *bsnet.MessageSenderOpts) (bsnet.MessageSender, error) Latency(peer.ID) time.Duration Ping(context.Context, peer.ID) ping.Result Self() peer.ID @@ -55,6 +56,7 @@ type MessageNetwork interface { // MessageQueue implements queue of want messages to send to peers. type MessageQueue struct { ctx context.Context + shutdown func() p peer.ID network MessageNetwork dhTimeoutMgr DontHaveTimeoutManager @@ -62,7 +64,6 @@ type MessageQueue struct { sendErrorBackoff time.Duration outgoingWork chan time.Time - done chan struct{} // Take lock whenever any of these variables are modified wllock sync.Mutex @@ -169,8 +170,10 @@ func New(ctx context.Context, p peer.ID, network MessageNetwork, onDontHaveTimeo func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork, maxMsgSize int, sendErrorBackoff time.Duration, dhTimeoutMgr DontHaveTimeoutManager) *MessageQueue { + ctx, cancel := context.WithCancel(ctx) mq := &MessageQueue{ ctx: ctx, + shutdown: cancel, p: p, network: network, dhTimeoutMgr: dhTimeoutMgr, @@ -179,7 +182,6 @@ func newMessageQueue(ctx context.Context, p peer.ID, network MessageNetwork, peerWants: newRecallWantList(), cancels: cid.NewSet(), outgoingWork: make(chan time.Time, 1), - done: make(chan struct{}), rebroadcastInterval: defaultRebroadcastInterval, sendErrorBackoff: sendErrorBackoff, priority: maxPriority, @@ -300,12 +302,17 @@ func (mq *MessageQueue) Startup() { // Shutdown stops the processing of messages for a message queue. func (mq *MessageQueue) Shutdown() { - close(mq.done) + mq.shutdown() } func (mq *MessageQueue) onShutdown() { // Shut down the DONT_HAVE timeout manager mq.dhTimeoutMgr.Shutdown() + + // Reset the streamMessageSender + if mq.sender != nil { + _ = mq.sender.Reset() + } } func (mq *MessageQueue) runQueue() { @@ -351,15 +358,7 @@ func (mq *MessageQueue) runQueue() { // in sendMessageDebounce. Send immediately. workScheduled = time.Time{} mq.sendIfReady() - case <-mq.done: - if mq.sender != nil { - mq.sender.Close() - } - return case <-mq.ctx.Done(): - if mq.sender != nil { - _ = mq.sender.Reset() - } return } } @@ -409,12 +408,12 @@ func (mq *MessageQueue) sendIfReady() { } func (mq *MessageQueue) sendMessage() { - err := mq.initializeSender() + sender, err := mq.initializeSender() if err != nil { - log.Infof("cant open message sender to peer %s: %s", mq.p, err) - // TODO: cant connect, what now? - // TODO: should we stop using this connection and clear the want list - // to avoid using up memory? + // If we fail to initialize the sender, the networking layer will + // emit a Disconnect event and the MessageQueue will get cleaned up + log.Infof("Could not open message sender to peer %s: %s", mq.p, err) + mq.Shutdown() return } @@ -423,7 +422,7 @@ func (mq *MessageQueue) sendMessage() { mq.dhTimeoutMgr.Start() // Convert want lists to a Bitswap Message - message, onSent := mq.extractOutgoingMessage(mq.sender.SupportsHave()) + message := mq.extractOutgoingMessage(mq.sender.SupportsHave()) // After processing the message, clear out its fields to save memory defer mq.msg.Reset(false) @@ -435,23 +434,22 @@ func (mq *MessageQueue) sendMessage() { wantlist := message.Wantlist() mq.logOutgoingMessage(wantlist) - // Try to send this message repeatedly - for i := 0; i < maxRetries; i++ { - if mq.attemptSendAndRecovery(message) { - // We were able to send successfully. - onSent() + if err := sender.SendMsg(mq.ctx, message); err != nil { + // If the message couldn't be sent, the networking layer will + // emit a Disconnect event and the MessageQueue will get cleaned up + log.Infof("Could not send message to peer %s: %s", mq.p, err) + mq.Shutdown() + return + } - mq.simulateDontHaveWithTimeout(wantlist) + // Set a timer to wait for responses + mq.simulateDontHaveWithTimeout(wantlist) - // If the message was too big and only a subset of wants could be - // sent, schedule sending the rest of the wants in the next - // iteration of the event loop. - if mq.hasPendingWork() { - mq.signalWorkReady() - } - - return - } + // If the message was too big and only a subset of wants could be + // sent, schedule sending the rest of the wants in the next + // iteration of the event loop. + if mq.hasPendingWork() { + mq.signalWorkReady() } } @@ -540,7 +538,7 @@ func (mq *MessageQueue) pendingWorkCount() int { } // Convert the lists of wants into a Bitswap message -func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwapMessage, func()) { +func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) bsmsg.BitSwapMessage { mq.wllock.Lock() defer mq.wllock.Unlock() @@ -567,7 +565,6 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwap } // Add each regular want-have / want-block to the message - peerSent := peerEntries[:0] for _, e := range peerEntries { if msgSize >= mq.maxMessageSize { break @@ -579,12 +576,13 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwap mq.peerWants.RemoveType(e.Cid, pb.Message_Wantlist_Have) } else { msgSize += mq.msg.AddEntry(e.Cid, e.Priority, e.WantType, true) - peerSent = append(peerSent, e) + + // Move the key from pending to sent + mq.peerWants.MarkSent(e) } } // Add each broadcast want-have to the message - bcstSent := bcstEntries[:0] for _, e := range bcstEntries { if msgSize >= mq.maxMessageSize { break @@ -600,89 +598,27 @@ func (mq *MessageQueue) extractOutgoingMessage(supportsHave bool) (bsmsg.BitSwap } msgSize += mq.msg.AddEntry(e.Cid, e.Priority, wantType, false) - bcstSent = append(bcstSent, e) - } - // Called when the message has been successfully sent. - onMessageSent := func() { - mq.wllock.Lock() - defer mq.wllock.Unlock() - - // Move the keys from pending to sent - for _, e := range bcstSent { - mq.bcstWants.MarkSent(e) - } - for _, e := range peerSent { - mq.peerWants.MarkSent(e) - } + // Move the key from pending to sent + mq.bcstWants.MarkSent(e) } - return mq.msg, onMessageSent -} - -func (mq *MessageQueue) initializeSender() error { - if mq.sender != nil { - return nil - } - nsender, err := openSender(mq.ctx, mq.network, mq.p) - if err != nil { - return err - } - mq.sender = nsender - return nil + return mq.msg } -func (mq *MessageQueue) attemptSendAndRecovery(message bsmsg.BitSwapMessage) bool { - err := mq.sender.SendMsg(mq.ctx, message) - if err == nil { - return true - } - - log.Infof("bitswap send error: %s", err) - _ = mq.sender.Reset() - mq.sender = nil - - select { - case <-mq.done: - return true - case <-mq.ctx.Done(): - return true - case <-time.After(mq.sendErrorBackoff): - // wait 100ms in case disconnect notifications are still propagating - log.Warn("SendMsg errored but neither 'done' nor context.Done() were set") - } - - err = mq.initializeSender() - if err != nil { - log.Infof("couldnt open sender again after SendMsg(%s) failed: %s", mq.p, err) - return true - } - - // TODO: Is this the same instance for the remote peer? - // If its not, we should resend our entire wantlist to them - /* - if mq.sender.InstanceID() != mq.lastSeenInstanceID { - wlm = mq.getFullWantlistMessage() +func (mq *MessageQueue) initializeSender() (bsnet.MessageSender, error) { + if mq.sender == nil { + opts := &bsnet.MessageSenderOpts{ + MaxRetries: maxRetries, + SendTimeout: sendTimeout, + SendErrorBackoff: sendErrorBackoff, + } + nsender, err := mq.network.NewMessageSender(mq.ctx, mq.p, opts) + if err != nil { + return nil, err } - */ - return false -} - -func openSender(ctx context.Context, network MessageNetwork, p peer.ID) (bsnet.MessageSender, error) { - // allow ten minutes for connections this includes looking them up in the - // dht dialing them, and handshaking - conctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - err := network.ConnectTo(conctx, p) - if err != nil { - return nil, err - } - nsender, err := network.NewMessageSender(ctx, p) - if err != nil { - return nil, err + mq.sender = nsender } - - return nsender, nil + return mq.sender, nil } diff --git a/internal/messagequeue/messagequeue_test.go b/internal/messagequeue/messagequeue_test.go index 49c1033d..344da41a 100644 --- a/internal/messagequeue/messagequeue_test.go +++ b/internal/messagequeue/messagequeue_test.go @@ -2,7 +2,6 @@ package messagequeue import ( "context" - "errors" "fmt" "math" "math/rand" @@ -31,7 +30,7 @@ func (fmn *fakeMessageNetwork) ConnectTo(context.Context, peer.ID) error { return fmn.connectError } -func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID) (bsnet.MessageSender, error) { +func (fmn *fakeMessageNetwork) NewMessageSender(context.Context, peer.ID, *bsnet.MessageSenderOpts) (bsnet.MessageSender, error) { if fmn.messageSenderError == nil { return fmn.messageSender, nil } @@ -83,23 +82,17 @@ func (fp *fakeDontHaveTimeoutMgr) pendingCount() int { type fakeMessageSender struct { lk sync.Mutex - sendError error - fullClosed chan<- struct{} reset chan<- struct{} messagesSent chan<- []bsmsg.Entry - sendErrors chan<- error supportsHave bool } -func newFakeMessageSender(sendError error, fullClosed chan<- struct{}, reset chan<- struct{}, - messagesSent chan<- []bsmsg.Entry, sendErrors chan<- error, supportsHave bool) *fakeMessageSender { +func newFakeMessageSender(reset chan<- struct{}, + messagesSent chan<- []bsmsg.Entry, supportsHave bool) *fakeMessageSender { return &fakeMessageSender{ - sendError: sendError, - fullClosed: fullClosed, reset: reset, messagesSent: messagesSent, - sendErrors: sendErrors, supportsHave: supportsHave, } } @@ -108,20 +101,10 @@ func (fms *fakeMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMess fms.lk.Lock() defer fms.lk.Unlock() - if fms.sendError != nil { - fms.sendErrors <- fms.sendError - return fms.sendError - } fms.messagesSent <- msg.Wantlist() return nil } -func (fms *fakeMessageSender) clearSendError() { - fms.lk.Lock() - defer fms.lk.Unlock() - - fms.sendError = nil -} -func (fms *fakeMessageSender) Close() error { fms.fullClosed <- struct{}{}; return nil } +func (fms *fakeMessageSender) Close() error { return nil } func (fms *fakeMessageSender) Reset() error { fms.reset <- struct{}{}; return nil } func (fms *fakeMessageSender) SupportsHave() bool { return fms.supportsHave } @@ -155,10 +138,8 @@ func totalEntriesLength(messages [][]bsmsg.Entry) int { func TestStartupAndShutdown(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -186,21 +167,17 @@ func TestStartupAndShutdown(t *testing.T) { timeoutctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) defer cancel() select { - case <-fullClosedChan: case <-resetChan: - t.Fatal("message sender should have been closed but was reset") case <-timeoutctx.Done(): - t.Fatal("message sender should have been closed but wasn't") + t.Fatal("message sender should have been reset but wasn't") } } func TestSendingMessagesDeduped(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -220,10 +197,8 @@ func TestSendingMessagesDeduped(t *testing.T) { func TestSendingMessagesPartialDupe(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -243,10 +218,8 @@ func TestSendingMessagesPartialDupe(t *testing.T) { func TestSendingMessagesPriority(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -312,10 +285,8 @@ func TestSendingMessagesPriority(t *testing.T) { func TestCancelOverridesPendingWants(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -364,10 +335,8 @@ func TestCancelOverridesPendingWants(t *testing.T) { func TestWantOverridesPendingCancels(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -412,10 +381,8 @@ func TestWantOverridesPendingCancels(t *testing.T) { func TestWantlistRebroadcast(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] messageQueue := New(ctx, peerID, fakenet, mockTimeoutCb) @@ -509,10 +476,8 @@ func TestWantlistRebroadcast(t *testing.T) { func TestSendingLargeMessages(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} dhtm := &fakeDontHaveTimeoutMgr{} peerID := testutil.GeneratePeers(1)[0] @@ -540,10 +505,8 @@ func TestSendingLargeMessages(t *testing.T) { func TestSendToPeerThatDoesntSupportHave(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, false) + fakeSender := newFakeMessageSender(resetChan, messagesSent, false) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] @@ -596,10 +559,8 @@ func TestSendToPeerThatDoesntSupportHave(t *testing.T) { func TestSendToPeerThatDoesntSupportHaveMonitorsTimeouts(t *testing.T) { ctx := context.Background() messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, false) + fakeSender := newFakeMessageSender(resetChan, messagesSent, false) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} peerID := testutil.GeneratePeers(1)[0] @@ -626,105 +587,6 @@ func TestSendToPeerThatDoesntSupportHaveMonitorsTimeouts(t *testing.T) { } } -func TestResendAfterError(t *testing.T) { - ctx := context.Background() - messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) - resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) - fakenet := &fakeMessageNetwork{nil, nil, fakeSender} - dhtm := &fakeDontHaveTimeoutMgr{} - peerID := testutil.GeneratePeers(1)[0] - sendErrBackoff := 5 * time.Millisecond - messageQueue := newMessageQueue(ctx, peerID, fakenet, maxMessageSize, sendErrBackoff, dhtm) - wantBlocks := testutil.GenerateCids(10) - wantHaves := testutil.GenerateCids(10) - - messageQueue.Startup() - - var errs []error - go func() { - // After the first error is received, clear sendError so that - // subsequent sends will not error - errs = append(errs, <-sendErrors) - fakeSender.clearSendError() - }() - - // Make the first send error out - fakeSender.sendError = errors.New("send err") - messageQueue.AddWants(wantBlocks, wantHaves) - messages := collectMessages(ctx, t, messagesSent, 10*time.Millisecond) - - if len(errs) != 1 { - t.Fatal("Expected first send to error") - } - - if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks) { - t.Fatal("Expected subsequent send to succeed") - } -} - -func TestResendAfterMaxRetries(t *testing.T) { - ctx := context.Background() - messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) - resetChan := make(chan struct{}, maxRetries*2) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) - fakenet := &fakeMessageNetwork{nil, nil, fakeSender} - dhtm := &fakeDontHaveTimeoutMgr{} - peerID := testutil.GeneratePeers(1)[0] - sendErrBackoff := 2 * time.Millisecond - messageQueue := newMessageQueue(ctx, peerID, fakenet, maxMessageSize, sendErrBackoff, dhtm) - wantBlocks := testutil.GenerateCids(10) - wantHaves := testutil.GenerateCids(10) - wantBlocks2 := testutil.GenerateCids(10) - wantHaves2 := testutil.GenerateCids(10) - - messageQueue.Startup() - - var lk sync.Mutex - var errs []error - go func() { - lk.Lock() - defer lk.Unlock() - for len(errs) < maxRetries { - err := <-sendErrors - errs = append(errs, err) - } - }() - - // Make the first group of send attempts error out - fakeSender.sendError = errors.New("send err") - messageQueue.AddWants(wantBlocks, wantHaves) - messages := collectMessages(ctx, t, messagesSent, 50*time.Millisecond) - - lk.Lock() - errCount := len(errs) - lk.Unlock() - if errCount != maxRetries { - t.Fatal("Expected maxRetries errors, got", len(errs)) - } - - // No successful send after max retries, so expect no messages sent - if totalEntriesLength(messages) != 0 { - t.Fatal("Expected no messages") - } - - // Clear sendError so that subsequent sends will not error - fakeSender.clearSendError() - - // Add a new batch of wants - messageQueue.AddWants(wantBlocks2, wantHaves2) - messages = collectMessages(ctx, t, messagesSent, 10*time.Millisecond) - - // All wants from previous and new send should be sent - if totalEntriesLength(messages) != len(wantHaves)+len(wantBlocks)+len(wantHaves2)+len(wantBlocks2) { - t.Fatal("Expected subsequent send to send first and second batches of wants") - } -} - func filterWantTypes(wantlist []bsmsg.Entry) ([]cid.Cid, []cid.Cid, []cid.Cid) { var wbs []cid.Cid var whs []cid.Cid @@ -747,10 +609,8 @@ func BenchmarkMessageQueue(b *testing.B) { createQueue := func() *MessageQueue { messagesSent := make(chan []bsmsg.Entry) - sendErrors := make(chan error) resetChan := make(chan struct{}, 1) - fullClosedChan := make(chan struct{}, 1) - fakeSender := newFakeMessageSender(nil, fullClosedChan, resetChan, messagesSent, sendErrors, true) + fakeSender := newFakeMessageSender(resetChan, messagesSent, true) fakenet := &fakeMessageNetwork{nil, nil, fakeSender} dhtm := &fakeDontHaveTimeoutMgr{} peerID := testutil.GeneratePeers(1)[0] diff --git a/internal/peermanager/peermanager.go b/internal/peermanager/peermanager.go index c2159b19..0cf8b2e3 100644 --- a/internal/peermanager/peermanager.go +++ b/internal/peermanager/peermanager.go @@ -30,17 +30,12 @@ type Session interface { // PeerQueueFactory provides a function that will create a PeerQueue. type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue -type peerQueueInstance struct { - refcnt int - pq PeerQueue -} - // PeerManager manages a pool of peers and sends messages to peers in the pool. type PeerManager struct { // sync access to peerQueues and peerWantManager pqLk sync.RWMutex // peerQueues -- interact through internal utility functions get/set/remove/iterate - peerQueues map[peer.ID]*peerQueueInstance + peerQueues map[peer.ID]PeerQueue pwm *peerWantManager createPeerQueue PeerQueueFactory @@ -57,7 +52,7 @@ type PeerManager struct { func New(ctx context.Context, createPeerQueue PeerQueueFactory, self peer.ID) *PeerManager { wantGauge := metrics.NewCtx(ctx, "wantlist_total", "Number of items in wantlist.").Gauge() return &PeerManager{ - peerQueues: make(map[peer.ID]*peerQueueInstance), + peerQueues: make(map[peer.ID]PeerQueue), pwm: newPeerWantManager(wantGauge), createPeerQueue: createPeerQueue, ctx: ctx, @@ -92,19 +87,15 @@ func (pm *PeerManager) Connected(p peer.ID, initialWantHaves []cid.Cid) { defer pm.pqLk.Unlock() pq := pm.getOrCreate(p) - pq.refcnt++ - - // If this is the first connection to the peer - if pq.refcnt == 1 { - // Inform the peer want manager that there's a new peer - pm.pwm.addPeer(p) - // Record that the want-haves are being sent to the peer - _, wantHaves := pm.pwm.prepareSendWants(p, nil, initialWantHaves) - // Broadcast any live want-haves to the newly connected peers - pq.pq.AddBroadcastWantHaves(wantHaves) - // Inform the sessions that the peer has connected - pm.signalAvailability(p, true) - } + + // Inform the peer want manager that there's a new peer + pm.pwm.addPeer(p) + // Record that the want-haves are being sent to the peer + _, wantHaves := pm.pwm.prepareSendWants(p, nil, initialWantHaves) + // Broadcast any live want-haves to the newly connected peers + pq.AddBroadcastWantHaves(wantHaves) + // Inform the sessions that the peer has connected + pm.signalAvailability(p, true) } // Disconnected is called to remove a peer from the pool. @@ -118,17 +109,12 @@ func (pm *PeerManager) Disconnected(p peer.ID) { return } - pq.refcnt-- - if pq.refcnt > 0 { - return - } - // Inform the sessions that the peer has disconnected pm.signalAvailability(p, false) // Clean up the peer delete(pm.peerQueues, p) - pq.pq.Shutdown() + pq.Shutdown() pm.pwm.removePeer(p) } @@ -141,8 +127,8 @@ func (pm *PeerManager) BroadcastWantHaves(ctx context.Context, wantHaves []cid.C defer pm.pqLk.Unlock() for p, ks := range pm.pwm.prepareBroadcastWantHaves(wantHaves) { - if pqi, ok := pm.peerQueues[p]; ok { - pqi.pq.AddBroadcastWantHaves(ks) + if pq, ok := pm.peerQueues[p]; ok { + pq.AddBroadcastWantHaves(ks) } } } @@ -153,9 +139,9 @@ func (pm *PeerManager) SendWants(ctx context.Context, p peer.ID, wantBlocks []ci pm.pqLk.Lock() defer pm.pqLk.Unlock() - if pqi, ok := pm.peerQueues[p]; ok { + if pq, ok := pm.peerQueues[p]; ok { wblks, whvs := pm.pwm.prepareSendWants(p, wantBlocks, wantHaves) - pqi.pq.AddWants(wblks, whvs) + pq.AddWants(wblks, whvs) } } @@ -167,8 +153,8 @@ func (pm *PeerManager) SendCancels(ctx context.Context, cancelKs []cid.Cid) { // Send a CANCEL to each peer that has been sent a want-block or want-have for p, ks := range pm.pwm.prepareSendCancels(cancelKs) { - if pqi, ok := pm.peerQueues[p]; ok { - pqi.pq.AddCancels(ks) + if pq, ok := pm.peerQueues[p]; ok { + pq.AddCancels(ks) } } } @@ -197,15 +183,14 @@ func (pm *PeerManager) CurrentWantHaves() []cid.Cid { return pm.pwm.getWantHaves() } -func (pm *PeerManager) getOrCreate(p peer.ID) *peerQueueInstance { - pqi, ok := pm.peerQueues[p] +func (pm *PeerManager) getOrCreate(p peer.ID) PeerQueue { + pq, ok := pm.peerQueues[p] if !ok { - pq := pm.createPeerQueue(pm.ctx, p) + pq = pm.createPeerQueue(pm.ctx, p) pq.Startup() - pqi = &peerQueueInstance{0, pq} - pm.peerQueues[p] = pqi + pm.peerQueues[p] = pq } - return pqi + return pq } // RegisterSession tells the PeerManager that the given session is interested diff --git a/internal/peermanager/peermanager_test.go b/internal/peermanager/peermanager_test.go index 0305b9f9..f979b2c8 100644 --- a/internal/peermanager/peermanager_test.go +++ b/internal/peermanager/peermanager_test.go @@ -99,7 +99,7 @@ func TestAddingAndRemovingPeers(t *testing.T) { t.Fatal("Peers connected that shouldn't be connected") } - // removing a peer with only one reference + // disconnect a peer peerManager.Disconnected(peer1) connectedPeers = peerManager.ConnectedPeers() @@ -107,13 +107,12 @@ func TestAddingAndRemovingPeers(t *testing.T) { t.Fatal("Peer should have been disconnected but was not") } - // connecting a peer twice, then disconnecting once, should stay in queue - peerManager.Connected(peer2, nil) - peerManager.Disconnected(peer2) + // reconnect peer + peerManager.Connected(peer1, nil) connectedPeers = peerManager.ConnectedPeers() - if !testutil.ContainsPeer(connectedPeers, peer2) { - t.Fatal("Peer was disconnected but should not have been") + if !testutil.ContainsPeer(connectedPeers, peer1) { + t.Fatal("Peer should have been connected but was not") } } diff --git a/network/connecteventmanager.go b/network/connecteventmanager.go new file mode 100644 index 00000000..b28e8e5b --- /dev/null +++ b/network/connecteventmanager.go @@ -0,0 +1,105 @@ +package network + +import ( + "sync" + + "github.com/libp2p/go-libp2p-core/peer" +) + +type ConnectionListener interface { + PeerConnected(peer.ID) + PeerDisconnected(peer.ID) +} + +type connectEventManager struct { + connListener ConnectionListener + lk sync.RWMutex + conns map[peer.ID]*connState +} + +type connState struct { + refs int + responsive bool +} + +func newConnectEventManager(connListener ConnectionListener) *connectEventManager { + return &connectEventManager{ + connListener: connListener, + conns: make(map[peer.ID]*connState), + } +} + +func (c *connectEventManager) Connected(p peer.ID) { + c.lk.Lock() + defer c.lk.Unlock() + + state, ok := c.conns[p] + if !ok { + state = &connState{responsive: true} + c.conns[p] = state + } + state.refs++ + + if state.refs == 1 && state.responsive { + c.connListener.PeerConnected(p) + } +} + +func (c *connectEventManager) Disconnected(p peer.ID) { + c.lk.Lock() + defer c.lk.Unlock() + + state, ok := c.conns[p] + if !ok { + // Should never happen + return + } + state.refs-- + + if state.refs == 0 { + if state.responsive { + c.connListener.PeerDisconnected(p) + } + delete(c.conns, p) + } +} + +func (c *connectEventManager) MarkUnresponsive(p peer.ID) { + c.lk.Lock() + defer c.lk.Unlock() + + state, ok := c.conns[p] + if !ok || !state.responsive { + return + } + state.responsive = false + + c.connListener.PeerDisconnected(p) +} + +func (c *connectEventManager) OnMessage(p peer.ID) { + // This is a frequent operation so to avoid different message arrivals + // getting blocked by a write lock, first take a read lock to check if + // we need to modify state + c.lk.RLock() + state, ok := c.conns[p] + c.lk.RUnlock() + + if !ok || state.responsive { + return + } + + // We need to make a modification so now take a write lock + c.lk.Lock() + defer c.lk.Unlock() + + // Note: state may have changed in the time between when read lock + // was released and write lock taken, so check again + state, ok = c.conns[p] + if !ok || state.responsive { + return + } + + state.responsive = true + c.connListener.PeerConnected(p) +} diff --git a/network/connecteventmanager_test.go b/network/connecteventmanager_test.go new file mode 100644 index 00000000..fb81abee --- /dev/null +++ b/network/connecteventmanager_test.go @@ -0,0 +1,144 @@ +package network + +import ( + "testing" + + "github.com/ipfs/go-bitswap/internal/testutil" + "github.com/libp2p/go-libp2p-core/peer" +) + +type mockConnListener struct { + conns map[peer.ID]int +} + +func newMockConnListener() *mockConnListener { + return &mockConnListener{ + conns: make(map[peer.ID]int), + } +} + +func (cl *mockConnListener) PeerConnected(p peer.ID) { + cl.conns[p]++ +} + +func (cl *mockConnListener) PeerDisconnected(p peer.ID) { + cl.conns[p]-- +} + +func TestConnectEventManagerConnectionCount(t *testing.T) { + connListener := newMockConnListener() + peers := testutil.GeneratePeers(2) + cem := newConnectEventManager(connListener) + + // Peer A: 1 Connection + cem.Connected(peers[0]) + if connListener.conns[peers[0]] != 1 { + t.Fatal("Expected Connected event") + } + + // Peer A: 2 Connections + cem.Connected(peers[0]) + if connListener.conns[peers[0]] != 1 { + t.Fatal("Unexpected no Connected event for the same peer") + } + + // Peer A: 2 Connections + // Peer B: 1 Connection + cem.Connected(peers[1]) + if connListener.conns[peers[1]] != 1 { + t.Fatal("Expected Connected event") + } + + // Peer A: 2 Connections + // Peer B: 0 Connections + cem.Disconnected(peers[1]) + if connListener.conns[peers[1]] != 0 { + t.Fatal("Expected Disconnected event") + } + + // Peer A: 1 Connection + // Peer B: 0 Connections + cem.Disconnected(peers[0]) + if connListener.conns[peers[0]] != 1 { + t.Fatal("Expected no Disconnected event for peer with one remaining conn") + } + + // Peer A: 0 Connections + // Peer B: 0 Connections + cem.Disconnected(peers[0]) + if connListener.conns[peers[0]] != 0 { + t.Fatal("Expected Disconnected event") + } +} + +func TestConnectEventManagerMarkUnresponsive(t *testing.T) { + connListener := newMockConnListener() + p := testutil.GeneratePeers(1)[0] + cem := newConnectEventManager(connListener) + + // Peer A: 1 Connection + cem.Connected(p) + if connListener.conns[p] != 1 { + t.Fatal("Expected Connected event") + } + + // Peer A: 1 Connection + cem.MarkUnresponsive(p) + if connListener.conns[p] != 0 { + t.Fatal("Expected Disconnected event") + } + + // Peer A: 2 Connections + cem.Connected(p) + if connListener.conns[p] != 0 { + t.Fatal("Expected no Connected event for unresponsive peer") + } + + // Peer A: 2 Connections + cem.OnMessage(p) + if connListener.conns[p] != 1 { + t.Fatal("Expected Connected event for newly responsive peer") + } + + // Peer A: 2 Connections + cem.OnMessage(p) + if connListener.conns[p] != 1 { + t.Fatal("Expected no further Connected event for subsequent messages") + } + + // Peer A: 1 Connection + cem.Disconnected(p) + if connListener.conns[p] != 1 { + t.Fatal("Expected no Disconnected event for peer with one remaining conn") + } + + // Peer A: 0 Connections + cem.Disconnected(p) + if connListener.conns[p] != 0 { + t.Fatal("Expected Disconnected event") + } +} + +func TestConnectEventManagerDisconnectAfterMarkUnresponsive(t *testing.T) { + connListener := newMockConnListener() + p := testutil.GeneratePeers(1)[0] + cem := newConnectEventManager(connListener) + + // Peer A: 1 Connection + cem.Connected(p) + if connListener.conns[p] != 1 { + t.Fatal("Expected Connected event") + } + + // Peer A: 1 Connection + cem.MarkUnresponsive(p) + if connListener.conns[p] != 0 { + t.Fatal("Expected Disconnected event") + } + + // Peer A: 0 Connections + cem.Disconnected(p) + if connListener.conns[p] != 0 { + t.Fatal("Expected not to receive a second Disconnected event") + } +} diff --git a/network/interface.go b/network/interface.go index 6b2878e3..a350d525 100644 --- a/network/interface.go +++ b/network/interface.go @@ -42,7 +42,7 @@ type BitSwapNetwork interface { ConnectTo(context.Context, peer.ID) error DisconnectFrom(context.Context, peer.ID) error - NewMessageSender(context.Context, peer.ID) (MessageSender, error) + NewMessageSender(context.Context, peer.ID, *MessageSenderOpts) (MessageSender, error) ConnectionManager() connmgr.ConnManager @@ -63,6 +63,12 @@ type MessageSender interface { SupportsHave() bool } +type MessageSenderOpts struct { + MaxRetries int + SendTimeout time.Duration + SendErrorBackoff time.Duration +} + // Receiver is an interface that can receive messages from the BitSwapNetwork. type Receiver interface { ReceiveMessage( diff --git a/network/ipfs_impl.go b/network/ipfs_impl.go index b5661408..e57d37ce 100644 --- a/network/ipfs_impl.go +++ b/network/ipfs_impl.go @@ -2,6 +2,7 @@ package network import ( "context" + "errors" "fmt" "io" "sync/atomic" @@ -22,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" msgio "github.com/libp2p/go-msgio" ma "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multistream" ) var log = logging.Logger("bitswap_network") @@ -43,6 +45,7 @@ func NewFromIpfsHost(host host.Host, r routing.ContentRouting, opts ...NetOpt) B supportedProtocols: s.SupportedProtocols, } + return &bitswapNetwork } @@ -71,8 +74,9 @@ type impl struct { // alignment. stats Stats - host host.Host - routing routing.ContentRouting + host host.Host + routing routing.ContentRouting + connectEvtMgr *connectEventManager protocolBitswapNoVers protocol.ID protocolBitswapOneZero protocol.ID @@ -86,24 +90,124 @@ type impl struct { } type streamMessageSender struct { - s network.Stream - bsnet *impl + to peer.ID + stream network.Stream + connected bool + bsnet *impl + opts *MessageSenderOpts } -func (s *streamMessageSender) Close() error { - return helpers.FullClose(s.s) +// Open a stream to the remote peer +func (s *streamMessageSender) Connect(ctx context.Context) (network.Stream, error) { + if s.connected { + return s.stream, nil + } + + if err := s.bsnet.ConnectTo(ctx, s.to); err != nil { + return nil, err + } + + stream, err := s.bsnet.newStreamToPeer(ctx, s.to) + if err != nil { + return nil, err + } + + s.stream = stream + s.connected = true + return s.stream, nil } +// Reset the stream func (s *streamMessageSender) Reset() error { - return s.s.Reset() + if s.stream != nil { + err := s.stream.Reset() + s.connected = false + return err + } + return nil } -func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error { - return s.bsnet.msgToStream(ctx, s.s, msg) +// Close the stream +func (s *streamMessageSender) Close() error { + return helpers.FullClose(s.stream) } +// Indicates whether the peer supports HAVE / DONT_HAVE messages func (s *streamMessageSender) SupportsHave() bool { - return s.bsnet.SupportsHave(s.s.Protocol()) + return s.bsnet.SupportsHave(s.stream.Protocol()) +} + +// Send a message to the peer, attempting multiple times +func (s *streamMessageSender) SendMsg(ctx context.Context, msg bsmsg.BitSwapMessage) error { + return s.multiAttempt(ctx, func(fnctx context.Context) error { + return s.send(fnctx, msg) + }) +} + +// Perform a function with multiple attempts, and a timeout +func (s *streamMessageSender) multiAttempt(ctx context.Context, fn func(context.Context) error) error { + // Try to call the function repeatedly + var err error + for i := 0; i < s.opts.MaxRetries; i++ { + deadline := time.Now().Add(s.opts.SendTimeout) + sndctx, cancel := context.WithDeadline(ctx, deadline) + + if err = fn(sndctx); err == nil { + cancel() + // Attempt was successful + return nil + } + cancel() + + // Attempt failed + + // If the sender has been closed or the context cancelled, just bail out + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Protocol is not supported, so no need to try multiple times + if errors.Is(err, multistream.ErrNotSupported) { + s.bsnet.connectEvtMgr.MarkUnresponsive(s.to) + return err + } + + // Failed to send so reset stream and try again + _ = s.Reset() + + // Failed too many times so mark the peer as unresponsive and return an error + if i == s.opts.MaxRetries-1 { + s.bsnet.connectEvtMgr.MarkUnresponsive(s.to) + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(s.opts.SendErrorBackoff): + // wait a short time in case disconnect notifications are still propagating + log.Infof("send message to %s failed but context was not Done: %s", s.to, err) + } + } + return err +} + +// Send a message to the peer +func (s *streamMessageSender) send(ctx context.Context, msg bsmsg.BitSwapMessage) error { + stream, err := s.Connect(ctx) + if err != nil { + log.Infof("failed to open stream to %s: %s", s.to, err) + return err + } + + if err = s.bsnet.msgToStream(ctx, stream, msg); err != nil { + log.Infof("failed to send message to %s: %s", s.to, err) + return err + } + + return nil } func (bsnet *impl) Self() peer.ID { @@ -164,17 +268,39 @@ func (bsnet *impl) msgToStream(ctx context.Context, s network.Stream, msg bsmsg. return nil } -func (bsnet *impl) NewMessageSender(ctx context.Context, p peer.ID) (MessageSender, error) { - s, err := bsnet.newStreamToPeer(ctx, p) +func (bsnet *impl) NewMessageSender(ctx context.Context, p peer.ID, opts *MessageSenderOpts) (MessageSender, error) { + opts = setDefaultOpts(opts) + + sender := &streamMessageSender{ + to: p, + bsnet: bsnet, + opts: opts, + } + + err := sender.multiAttempt(ctx, func(fnctx context.Context) error { + _, err := sender.Connect(fnctx) + return err + }) + if err != nil { return nil, err } - return &streamMessageSender{s: s, bsnet: bsnet}, nil + return sender, nil } -func (bsnet *impl) newStreamToPeer(ctx context.Context, p peer.ID) (network.Stream, error) { - return bsnet.host.NewStream(ctx, p, bsnet.supportedProtocols...) +func setDefaultOpts(opts *MessageSenderOpts) *MessageSenderOpts { + copy := *opts + if opts.MaxRetries == 0 { + copy.MaxRetries = 3 + } + if opts.SendTimeout == 0 { + copy.SendTimeout = sendMessageTimeout + } + if opts.SendErrorBackoff == 0 { + copy.SendErrorBackoff = 100 * time.Millisecond + } + return © } func (bsnet *impl) SendMessage( @@ -197,11 +323,15 @@ func (bsnet *impl) SendMessage( //nolint go helpers.AwaitEOF(s) return s.Close() +} +func (bsnet *impl) newStreamToPeer(ctx context.Context, p peer.ID) (network.Stream, error) { + return bsnet.host.NewStream(ctx, p, bsnet.supportedProtocols...) } func (bsnet *impl) SetDelegate(r Receiver) { bsnet.receiver = r + bsnet.connectEvtMgr = newConnectEventManager(r) for _, proto := range bsnet.supportedProtocols { bsnet.host.SetStreamHandler(proto, bsnet.handleNewStream) } @@ -268,6 +398,7 @@ func (bsnet *impl) handleNewStream(s network.Stream) { p := s.Conn().RemotePeer() ctx := context.Background() log.Debugf("bitswap net handleNewStream from %s", s.Conn().RemotePeer()) + bsnet.connectEvtMgr.OnMessage(s.Conn().RemotePeer()) bsnet.receiver.ReceiveMessage(ctx, p, received) atomic.AddUint64(&bsnet.stats.MessagesRecvd, 1) } @@ -291,10 +422,10 @@ func (nn *netNotifiee) impl() *impl { } func (nn *netNotifiee) Connected(n network.Network, v network.Conn) { - nn.impl().receiver.PeerConnected(v.RemotePeer()) + nn.impl().connectEvtMgr.Connected(v.RemotePeer()) } func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { - nn.impl().receiver.PeerDisconnected(v.RemotePeer()) + nn.impl().connectEvtMgr.Disconnected(v.RemotePeer()) } func (nn *netNotifiee) OpenedStream(n network.Network, s network.Stream) {} func (nn *netNotifiee) ClosedStream(n network.Network, v network.Stream) {} diff --git a/network/ipfs_impl_test.go b/network/ipfs_impl_test.go index 5e0f512b..454bb410 100644 --- a/network/ipfs_impl_test.go +++ b/network/ipfs_impl_test.go @@ -2,16 +2,22 @@ package network_test import ( "context" + "fmt" + "sync" "testing" "time" - tn "github.com/ipfs/go-bitswap/testnet" bsmsg "github.com/ipfs/go-bitswap/message" pb "github.com/ipfs/go-bitswap/message/pb" bsnet "github.com/ipfs/go-bitswap/network" + tn "github.com/ipfs/go-bitswap/testnet" + ds "github.com/ipfs/go-datastore" blocksutil "github.com/ipfs/go-ipfs-blocksutil" mockrouting "github.com/ipfs/go-ipfs-routing/mock" + "github.com/multiformats/go-multistream" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" tnet "github.com/libp2p/go-libp2p-testing/net" @@ -22,7 +28,7 @@ import ( type receiver struct { peers map[peer.ID]struct{} messageReceived chan struct{} - connectionEvent chan struct{} + connectionEvent chan bool lastMessage bsmsg.BitSwapMessage lastSender peer.ID } @@ -31,7 +37,7 @@ func newReceiver() *receiver { return &receiver{ peers: make(map[peer.ID]struct{}), messageReceived: make(chan struct{}), - connectionEvent: make(chan struct{}, 1), + connectionEvent: make(chan bool, 1), } } @@ -52,12 +58,96 @@ func (r *receiver) ReceiveError(err error) { func (r *receiver) PeerConnected(p peer.ID) { r.peers[p] = struct{}{} - r.connectionEvent <- struct{}{} + r.connectionEvent <- true } func (r *receiver) PeerDisconnected(p peer.ID) { delete(r.peers, p) - r.connectionEvent <- struct{}{} + r.connectionEvent <- false +} + +var mockNetErr = fmt.Errorf("network err") + +type ErrStream struct { + network.Stream + lk sync.Mutex + err error + timingOut bool +} + +type ErrHost struct { + host.Host + lk sync.Mutex + err error + timingOut bool + streams []*ErrStream +} + +func (es *ErrStream) Write(b []byte) (int, error) { + es.lk.Lock() + defer es.lk.Unlock() + + if es.err != nil { + return 0, es.err + } + if es.timingOut { + return 0, context.DeadlineExceeded + } + return es.Stream.Write(b) +} + +func (eh *ErrHost) Connect(ctx context.Context, pi peer.AddrInfo) error { + eh.lk.Lock() + defer eh.lk.Unlock() + + if eh.err != nil { + return eh.err + } + if eh.timingOut { + return context.DeadlineExceeded + } + return eh.Host.Connect(ctx, pi) +} + +func (eh *ErrHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + eh.lk.Lock() + defer eh.lk.Unlock() + + if eh.err != nil { + return nil, mockNetErr + } + if eh.timingOut { + return nil, context.DeadlineExceeded + } + stream, err := eh.Host.NewStream(ctx, p, pids...) + estrm := &ErrStream{Stream: stream, err: eh.err, timingOut: eh.timingOut} + + eh.streams = append(eh.streams, estrm) + return estrm, err +} + +func (eh *ErrHost) setError(err error) { + eh.lk.Lock() + defer eh.lk.Unlock() + + eh.err = err + for _, s := range eh.streams { + s.lk.Lock() + s.err = err + s.lk.Unlock() + } +} + +func (eh *ErrHost) setTimeoutState(timingOut bool) { + eh.lk.Lock() + defer eh.lk.Unlock() + + eh.timingOut = timingOut + for _, s := range eh.streams { + s.lk.Lock() + s.timingOut = timingOut + s.lk.Unlock() + } } func TestMessageSendAndReceive(t *testing.T) { @@ -164,13 +254,255 @@ func TestMessageSendAndReceive(t *testing.T) { } } +func TestMessageResendAfterError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // create network + mn := mocknet.New(ctx) + mr := mockrouting.NewServer() + streamNet, err := tn.StreamNet(ctx, mn, mr) + if err != nil { + t.Fatal("Unable to setup network") + } + p1 := tnet.RandIdentityOrFatal(t) + p2 := tnet.RandIdentityOrFatal(t) + + h1, err := mn.AddPeer(p1.PrivateKey(), p1.Address()) + if err != nil { + t.Fatal(err) + } + + // Create a special host that we can force to start returning errors + eh := &ErrHost{Host: h1} + routing := mr.ClientWithDatastore(context.TODO(), p1, ds.NewMapDatastore()) + bsnet1 := bsnet.NewFromIpfsHost(eh, routing) + + bsnet2 := streamNet.Adapter(p2) + r1 := newReceiver() + r2 := newReceiver() + bsnet1.SetDelegate(r1) + bsnet2.SetDelegate(r2) + + err = mn.LinkAll() + if err != nil { + t.Fatal(err) + } + err = bsnet1.ConnectTo(ctx, p2.ID()) + if err != nil { + t.Fatal(err) + } + isConnected := <-r1.connectionEvent + if !isConnected { + t.Fatal("Expected connect event") + } + + err = bsnet2.ConnectTo(ctx, p1.ID()) + if err != nil { + t.Fatal(err) + } + + blockGenerator := blocksutil.NewBlockGenerator() + block1 := blockGenerator.Next() + msg := bsmsg.New(false) + msg.AddEntry(block1.Cid(), 1, pb.Message_Wantlist_Block, true) + + testSendErrorBackoff := 100 * time.Millisecond + ms, err := bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{ + MaxRetries: 3, + SendTimeout: 100 * time.Millisecond, + SendErrorBackoff: testSendErrorBackoff, + }) + if err != nil { + t.Fatal(err) + } + + // Return an error from the networking layer the next time we try to send + // a message + eh.setError(mockNetErr) + + go func() { + time.Sleep(testSendErrorBackoff / 2) + // Stop throwing errors so that the following attempt to send succeeds + eh.setError(nil) + }() + + // Send message with retries, first one should fail, then subsequent + // message should succeed + err = ms.SendMsg(ctx, msg) + if err != nil { + t.Fatal(err) + } + + select { + case <-ctx.Done(): + t.Fatal("did not receive message sent") + case <-r2.messageReceived: + } +} + +func TestMessageSendTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // create network + mn := mocknet.New(ctx) + mr := mockrouting.NewServer() + streamNet, err := tn.StreamNet(ctx, mn, mr) + if err != nil { + t.Fatal("Unable to setup network") + } + p1 := tnet.RandIdentityOrFatal(t) + p2 := tnet.RandIdentityOrFatal(t) + + h1, err := mn.AddPeer(p1.PrivateKey(), p1.Address()) + if err != nil { + t.Fatal(err) + } + + // Create a special host that we can force to start timing out + eh := &ErrHost{Host: h1} + routing := mr.ClientWithDatastore(context.TODO(), p1, ds.NewMapDatastore()) + bsnet1 := bsnet.NewFromIpfsHost(eh, routing) + + bsnet2 := streamNet.Adapter(p2) + r1 := newReceiver() + r2 := newReceiver() + bsnet1.SetDelegate(r1) + bsnet2.SetDelegate(r2) + + err = mn.LinkAll() + if err != nil { + t.Fatal(err) + } + err = bsnet1.ConnectTo(ctx, p2.ID()) + if err != nil { + t.Fatal(err) + } + isConnected := <-r1.connectionEvent + if !isConnected { + t.Fatal("Expected connect event") + } + + err = bsnet2.ConnectTo(ctx, p1.ID()) + if err != nil { + t.Fatal(err) + } + + blockGenerator := blocksutil.NewBlockGenerator() + block1 := blockGenerator.Next() + msg := bsmsg.New(false) + msg.AddEntry(block1.Cid(), 1, pb.Message_Wantlist_Block, true) + + ms, err := bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{ + MaxRetries: 3, + SendTimeout: 100 * time.Millisecond, + SendErrorBackoff: 100 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + // Return a DeadlineExceeded error from the networking layer the next time we try to + // send a message + eh.setTimeoutState(true) + + // Send message with retries, all attempts should fail + err = ms.SendMsg(ctx, msg) + if err == nil { + t.Fatal("Expected error from SednMsg") + } + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("Did not receive disconnect event") + case isConnected = <-r1.connectionEvent: + if isConnected { + t.Fatal("Expected disconnect event (got connect event)") + } + } +} + +func TestMessageSendNotSupportedResponse(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + // create network + mn := mocknet.New(ctx) + mr := mockrouting.NewServer() + streamNet, err := tn.StreamNet(ctx, mn, mr) + if err != nil { + t.Fatal("Unable to setup network") + } + p1 := tnet.RandIdentityOrFatal(t) + p2 := tnet.RandIdentityOrFatal(t) + + h1, err := mn.AddPeer(p1.PrivateKey(), p1.Address()) + if err != nil { + t.Fatal(err) + } + + // Create a special host that responds with ErrNotSupported + eh := &ErrHost{Host: h1} + routing := mr.ClientWithDatastore(context.TODO(), p1, ds.NewMapDatastore()) + bsnet1 := bsnet.NewFromIpfsHost(eh, routing) + + bsnet2 := streamNet.Adapter(p2) + r1 := newReceiver() + r2 := newReceiver() + bsnet1.SetDelegate(r1) + bsnet2.SetDelegate(r2) + + err = mn.LinkAll() + if err != nil { + t.Fatal(err) + } + err = bsnet1.ConnectTo(ctx, p2.ID()) + if err != nil { + t.Fatal(err) + } + isConnected := <-r1.connectionEvent + if !isConnected { + t.Fatal("Expected connect event") + } + + err = bsnet2.ConnectTo(ctx, p1.ID()) + if err != nil { + t.Fatal(err) + } + + blockGenerator := blocksutil.NewBlockGenerator() + block1 := blockGenerator.Next() + msg := bsmsg.New(false) + msg.AddEntry(block1.Cid(), 1, pb.Message_Wantlist_Block, true) + + eh.setError(multistream.ErrNotSupported) + _, err = bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{ + MaxRetries: 3, + SendTimeout: 100 * time.Millisecond, + SendErrorBackoff: 100 * time.Millisecond, + }) + if err == nil { + t.Fatal("Expected ErrNotSupported") + } + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("Did not receive disconnect event") + case isConnected = <-r1.connectionEvent: + if isConnected { + t.Fatal("Expected disconnect event (got connect event)") + } + } +} + func TestSupportsHave(t *testing.T) { ctx := context.Background() mn := mocknet.New(ctx) mr := mockrouting.NewServer() streamNet, err := tn.StreamNet(ctx, mn, mr) if err != nil { - t.Fatal("Unable to setup network") + t.Fatalf("Unable to setup network: %s", err) } type testCase struct { @@ -199,7 +531,7 @@ func TestSupportsHave(t *testing.T) { t.Fatal(err) } - senderCurrent, err := bsnet1.NewMessageSender(ctx, p2.ID()) + senderCurrent, err := bsnet1.NewMessageSender(ctx, p2.ID(), &bsnet.MessageSenderOpts{}) if err != nil { t.Fatal(err) } diff --git a/testnet/virtual.go b/testnet/virtual.go index 1e472110..c44b430d 100644 --- a/testnet/virtual.go +++ b/testnet/virtual.go @@ -284,7 +284,7 @@ func (mp *messagePasser) SupportsHave() bool { return false } -func (nc *networkClient) NewMessageSender(ctx context.Context, p peer.ID) (bsnet.MessageSender, error) { +func (nc *networkClient) NewMessageSender(ctx context.Context, p peer.ID, opts *bsnet.MessageSenderOpts) (bsnet.MessageSender, error) { return &messagePasser{ net: nc, target: p,