Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shutdown bug in #412 #422

Merged
merged 3 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions impl/graphsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork,
incomingRequestHooks.Register(selectorvalidator.SelectorValidator(maxRecursionDepth))
}
responseAllocator := allocator.NewAllocator(gsConfig.totalMaxMemoryResponder, gsConfig.maxMemoryPerPeerResponder)
createMessageQueue := func(ctx context.Context, p peer.ID) peermanager.PeerQueue {
return messagequeue.New(ctx, p, network, responseAllocator, gsConfig.messageSendRetries, gsConfig.sendMessageTimeout)
createMessageQueue := func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) peermanager.PeerQueue {
return messagequeue.New(ctx, p, network, responseAllocator, gsConfig.messageSendRetries, gsConfig.sendMessageTimeout, onShutdown)
}
peerManager := peermanager.NewMessageManager(ctx, createMessageQueue)

Expand Down
14 changes: 14 additions & 0 deletions impl/graphsync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,20 @@ func TestNetworkDisconnect(t *testing.T) {
drain(requestor)
drain(responder)

// verify we can execute a request after disconnection
_, err = td.mn.LinkPeers(td.host1.ID(), td.host2.ID())
require.NoError(t, err)
_, err = td.mn.ConnectPeers(td.host1.ID(), td.host2.ID())
require.NoError(t, err)
requestCtx, requestCancel = context.WithTimeout(ctx, 1*time.Second)
defer requestCancel()
progressChan, errChan = requestor.Request(requestCtx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension)
blockChain.VerifyWholeChain(ctx, progressChan)
testutil.VerifyEmptyErrors(ctx, t, errChan)

drain(requestor)
drain(responder)

tracing := collectTracing(t)

traceStrings := tracing.TracesToStrings()
Expand Down
5 changes: 4 additions & 1 deletion messagequeue/messagequeue.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ type MessageQueue struct {
allocator Allocator
maxRetries int
sendMessageTimeout time.Duration
onShutdown func(peer.ID)
}

// New creats a new MessageQueue.
func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Allocator, maxRetries int, sendMessageTimeout time.Duration) *MessageQueue {
func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Allocator, maxRetries int, sendMessageTimeout time.Duration, onShutdown func(peer.ID)) *MessageQueue {
return &MessageQueue{
ctx: ctx,
network: network,
Expand All @@ -93,6 +94,7 @@ func New(ctx context.Context, p peer.ID, network MessageNetwork, allocator Alloc
allocator: allocator,
maxRetries: maxRetries,
sendMessageTimeout: sendMessageTimeout,
onShutdown: onShutdown,
}
}

Expand Down Expand Up @@ -154,6 +156,7 @@ func (mq *MessageQueue) runQueue() {
defer func() {
_ = mq.allocator.ReleasePeerMemory(mq.p)
mq.eventPublisher.Shutdown()
mq.onShutdown(mq.p)
}()
mq.eventPublisher.Startup()
for {
Expand Down
26 changes: 13 additions & 13 deletions messagequeue/messagequeue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestStartupAndShutdown(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -37,7 +37,7 @@ func TestStartupAndShutdown(t *testing.T) {
messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup}
allocator := allocator2.NewAllocator(1<<30, 1<<30)

messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
id := graphsync.NewRequestID()
priority := graphsync.Priority(rand.Int31())
Expand All @@ -62,7 +62,7 @@ func TestShutdownDuringMessageSend(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -75,7 +75,7 @@ func TestShutdownDuringMessageSend(t *testing.T) {
messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup}
allocator := allocator2.NewAllocator(1<<30, 1<<30)

messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
id := graphsync.NewRequestID()
priority := graphsync.Priority(rand.Int31())
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestProcessingNotification(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -123,7 +123,7 @@ func TestProcessingNotification(t *testing.T) {
messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup}
allocator := allocator2.NewAllocator(1<<30, 1<<30)

messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
waitGroup.Add(1)
blks := testutil.GenerateBlocksOfSize(3, 128)
Expand Down Expand Up @@ -187,7 +187,7 @@ func TestDedupingMessages(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -196,7 +196,7 @@ func TestDedupingMessages(t *testing.T) {
messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup}
allocator := allocator2.NewAllocator(1<<30, 1<<30)

messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
waitGroup.Add(1)
id := graphsync.NewRequestID()
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestSendsVeryLargeBlocksResponses(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -274,7 +274,7 @@ func TestSendsVeryLargeBlocksResponses(t *testing.T) {
messageNetwork := &fakeMessageNetwork{nil, nil, messageSender, &waitGroup}
allocator := allocator2.NewAllocator(1<<30, 1<<30)

messageQueue := New(ctx, peer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
waitGroup.Add(1)

Expand Down Expand Up @@ -334,7 +334,7 @@ func TestSendsResponsesMemoryPressure(t *testing.T) {
// use allocator with very small limit
allocator := allocator2.NewAllocator(1000, 1000)

messageQueue := New(ctx, p, messageNetwork, allocator, messageSendRetries, sendMessageTimeout)
messageQueue := New(ctx, p, messageNetwork, allocator, messageSendRetries, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
waitGroup.Add(1)

Expand Down Expand Up @@ -381,7 +381,7 @@ func TestNetworkErrorClearResponses(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

peer := testutil.GeneratePeers(1)[0]
targetPeer := testutil.GeneratePeers(1)[0]
messagesSent := make(chan gsmsg.GraphSyncMessage)
resetChan := make(chan struct{}, 1)
fullClosedChan := make(chan struct{}, 1)
Expand All @@ -393,7 +393,7 @@ func TestNetworkErrorClearResponses(t *testing.T) {
allocator := allocator2.NewAllocator(1<<30, 1<<30)

// we use only a retry count of 1 to avoid multiple send attempts for each message
messageQueue := New(ctx, peer, messageNetwork, allocator, 1, sendMessageTimeout)
messageQueue := New(ctx, targetPeer, messageNetwork, allocator, 1, sendMessageTimeout, func(peer.ID) {})
messageQueue.Startup()
waitGroup.Add(1)

Expand Down
15 changes: 13 additions & 2 deletions peermanager/peermanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type PeerProcess interface {
type PeerHandler interface{}

// PeerProcessFactory provides a function that will create a PeerQueue.
type PeerProcessFactory func(ctx context.Context, p peer.ID) PeerHandler
type PeerProcessFactory func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler

type peerProcessInstance struct {
refcnt int
Expand Down Expand Up @@ -105,7 +105,7 @@ func (pm *PeerManager) GetProcess(
func (pm *PeerManager) getOrCreate(p peer.ID) *peerProcessInstance {
pqi, ok := pm.peerProcesses[p]
if !ok {
pq := pm.createPeerProcess(pm.ctx, p)
pq := pm.createPeerProcess(pm.ctx, p, pm.onQueueShutdown)
if pprocess, ok := pq.(PeerProcess); ok {
pprocess.Startup()
}
Expand All @@ -114,3 +114,14 @@ func (pm *PeerManager) getOrCreate(p peer.ID) *peerProcessInstance {
}
return pqi
}

func (pm *PeerManager) onQueueShutdown(p peer.ID) {
pm.peerProcessesLk.Lock()
defer pm.peerProcessesLk.Unlock()
_, ok := pm.peerProcesses[p]
if !ok {

return
}
gammazero marked this conversation as resolved.
Show resolved Hide resolved
delete(pm.peerProcesses, p)
}
2 changes: 1 addition & 1 deletion peermanager/peermanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (fp *fakePeerProcess) Shutdown() {}

func TestAddingAndRemovingPeers(t *testing.T) {
ctx := context.Background()
peerProcessFatory := func(ctx context.Context, p peer.ID) PeerHandler {
peerProcessFatory := func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler {
return &fakePeerProcess{}
}

Expand Down
6 changes: 3 additions & 3 deletions peermanager/peermessagemanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type PeerQueue interface {
}

// PeerQueueFactory provides a function that will create a PeerQueue.
type PeerQueueFactory func(ctx context.Context, p peer.ID) PeerQueue
type PeerQueueFactory func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerQueue

// PeerMessageManager manages message queues for peers
type PeerMessageManager struct {
Expand All @@ -25,8 +25,8 @@ type PeerMessageManager struct {
// NewMessageManager generates a new manger for sending messages
func NewMessageManager(ctx context.Context, createPeerQueue PeerQueueFactory) *PeerMessageManager {
return &PeerMessageManager{
PeerManager: New(ctx, func(ctx context.Context, p peer.ID) PeerHandler {
return createPeerQueue(ctx, p)
PeerManager: New(ctx, func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerHandler {
return createPeerQueue(ctx, p, onShutdown)
}),
}
}
Expand Down
4 changes: 3 additions & 1 deletion peermanager/peermessagemanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var _ PeerQueue = (*fakePeer)(nil)
type fakePeer struct {
p peer.ID
messagesSent chan messageSent
onShutdown func(peer.ID)
}

func (fp *fakePeer) AllocateAndBuildMessage(blkSize uint64, buildMessage func(b *messagequeue.Builder)) {
Expand All @@ -50,10 +51,11 @@ func (fp *fakePeer) Shutdown() {}
//}

func makePeerQueueFactory(messagesSent chan messageSent) PeerQueueFactory {
return func(ctx context.Context, p peer.ID) PeerQueue {
return func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) PeerQueue {
return &fakePeer{
p: p,
messagesSent: messagesSent,
onShutdown: onShutdown,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion responsemanager/responseassembler/responseassembler.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type ResponseAssembler struct {
// New generates a new ResponseAssembler for sending responses
func New(ctx context.Context, peerHandler PeerMessageHandler) *ResponseAssembler {
return &ResponseAssembler{
PeerManager: peermanager.New(ctx, func(ctx context.Context, p peer.ID) peermanager.PeerHandler {
PeerManager: peermanager.New(ctx, func(ctx context.Context, p peer.ID, onShutdown func(peer.ID)) peermanager.PeerHandler {
return newTracker()
}),
peerHandler: peerHandler,
Expand Down