From caa872f68b2cc3847ba0eb2703b6ca8788506d9d Mon Sep 17 00:00:00 2001 From: Hannah Howard Date: Wed, 8 Jul 2020 00:39:26 -0700 Subject: [PATCH] All changes to date including pause requests & start paused, along with new adds for cleanups and checking of execution (#75) * WIP * feat(graphsync): pause/unpause requests Allow graphsync requests to be paused and unpaused via request cancelling and DoNotSendCIDs extension * fix(requestmanager): refactor executor remove extraneous allocation of closure functions * feat(graphsync): support external request pauses allow pausing requests imperatively via PauseRequest function * fix(lint): fix lint errors * feat(responsemanager): start requests paused add the ability for a hook to pause a response right when it's first received * feat(responsemanager): improve cancellation UX provide a mechanism for responders to learn a requestor cancelled a request and for requestors to learn a request was cancelled * feat(requestmanager): process request cancelled status process the responder returning a request cancelled error code and also support sentinel errors * feat(executor): refactor to remove loader remove loader, also only fire restart request as needed * fix(asyncloader): load requests synchronously when possible * fix(responsemanager): fix external pause Fix external pauses missing a block * fix(responsemanager): do not delay complete listener Run complete listener in same thread as response processing, making it less susceptable to interruption via cancel * fix(responsemanager): fix context check fix checking for context cancellation errors based off of the way ipld-prime does not wrap errors * fix(responsemanager): more precise cancel make cancels only get recorded if actual blocks are not sent -- otherwise the request is considered complete -- and the complete hook always runs * fix(requestmanager): handle non processed pauses Handler the case where a pause is requested but never actually takes place * refactor(responsemanager): handle cancels, correctly this time Properly handle cancels for both paused and unpaused states * fix(errors): remove regex cause it appears to be very slow * fix(traverser): fix race condition for shutdown make sure that the traverser is finished in the request executor * fix(deps): mod tidy * fix(executor): add back network error --- graphsync.go | 53 ++ impl/graphsync.go | 102 ++-- impl/graphsync_test.go | 46 ++ ipldutil/traverser.go | 38 +- message/message.go | 44 +- message/message_test.go | 93 ++++ requestmanager/asyncloader/asyncloader.go | 12 +- requestmanager/executor/executor.go | 243 ++++++++-- requestmanager/executor/executor_test.go | 441 +++++++++++++++++ requestmanager/hooks/hooks_test.go | 11 + requestmanager/hooks/responsehooks.go | 9 + requestmanager/loader/loader.go | 84 ---- requestmanager/loader/loader_test.go | 135 ------ requestmanager/requestmanager.go | 213 ++++++-- requestmanager/requestmanager_test.go | 453 ++++++++++-------- requestmanager/testloader/asyncloader.go | 157 ++++++ responsemanager/hooks/blockhooks.go | 8 +- responsemanager/hooks/hooks_test.go | 15 +- .../{completedlisteners.go => listeners.go} | 32 ++ responsemanager/hooks/requesthook.go | 7 + .../peerresponsemanager/peerresponsesender.go | 20 +- responsemanager/queryexecutor.go | 113 +++-- responsemanager/responsemanager.go | 134 ++++-- responsemanager/responsemanager_test.go | 100 +++- .../runtraversal/runtraversal_test.go | 4 + testutil/testchain.go | 16 + 26 files changed, 1922 insertions(+), 661 deletions(-) create mode 100644 requestmanager/executor/executor_test.go delete mode 100644 requestmanager/loader/loader.go delete mode 100644 requestmanager/loader/loader_test.go create mode 100644 requestmanager/testloader/asyncloader.go rename responsemanager/hooks/{completedlisteners.go => listeners.go} (53%) diff --git a/graphsync.go b/graphsync.go index 724a4191..83aa8202 100644 --- a/graphsync.go +++ b/graphsync.go @@ -86,8 +86,45 @@ const ( RequestFailedLegal = ResponseStatusCode(33) // RequestFailedContentNotFound means the respondent does not have the content. RequestFailedContentNotFound = ResponseStatusCode(34) + // RequestCancelled means the responder was processing the request but decided to top, for whatever reason + RequestCancelled = ResponseStatusCode(35) ) +// RequestFailedBusyErr is an error message received on the error channel when the peer is busy +type RequestFailedBusyErr struct{} + +func (e RequestFailedBusyErr) Error() string { + return "Request Failed - Peer Is Busy" +} + +// RequestFailedContentNotFoundErr is an error message received on the error channel when the content is not found +type RequestFailedContentNotFoundErr struct{} + +func (e RequestFailedContentNotFoundErr) Error() string { + return "Request Failed - Content Not Found" +} + +// RequestFailedLegalErr is an error message received on the error channel when the request fails for legal reasons +type RequestFailedLegalErr struct{} + +func (e RequestFailedLegalErr) Error() string { + return "Request Failed - For Legal Reasons" +} + +// RequestFailedUnknownErr is an error message received on the error channel when the request fails for unknown reasons +type RequestFailedUnknownErr struct{} + +func (e RequestFailedUnknownErr) Error() string { + return "Request Failed - Unknown Reason" +} + +// RequestCancelledErr is an error message received on the error channel that indicates the responder cancelled a request +type RequestCancelledErr struct{} + +func (e RequestCancelledErr) Error() string { + return "Request Failed - Responder Cancelled" +} + var ( // ErrExtensionAlreadyRegistered means a user extension can be registered only once ErrExtensionAlreadyRegistered = errors.New("extension already registered") @@ -158,6 +195,7 @@ type IncomingRequestHookActions interface { UseLinkTargetNodeStyleChooser(traversal.LinkTargetNodeStyleChooser) TerminateWithError(error) ValidateRequest() + PauseResponse() } // OutgoingBlockHookActions are actions that an outgoing block hook can take to @@ -187,6 +225,7 @@ type IncomingResponseHookActions interface { type IncomingBlockHookActions interface { TerminateWithError(error) UpdateRequestWithExtensions(...ExtensionData) + PauseRequest() } // RequestUpdatedHookActions are actions that can be taken in a request updated hook to @@ -236,6 +275,9 @@ type OnRequestUpdatedHook func(p peer.ID, request RequestData, updateRequest Req // OnResponseCompletedListener provides a way to listen for when responder has finished serving a response type OnResponseCompletedListener func(p peer.ID, request RequestData, status ResponseStatusCode) +// OnRequestorCancelledListener provides a way to listen for responses the requestor canncels +type OnRequestorCancelledListener func(p peer.ID, request RequestData) + // UnregisterHookFunc is a function call to unregister a hook that was previously registered type UnregisterHookFunc func() @@ -268,6 +310,17 @@ type GraphExchange interface { // RegisterCompletedResponseListener adds a listener on the responder for completed responses RegisterCompletedResponseListener(listener OnResponseCompletedListener) UnregisterHookFunc + // RegisterRequestorCancelledListener adds a listener on the responder for + // responses cancelled by the requestor + RegisterRequestorCancelledListener(listener OnRequestorCancelledListener) UnregisterHookFunc + + // UnpauseRequest unpauses a request that was paused in a block hook based request ID + // Can also send extensions with unpause + UnpauseRequest(RequestID, ...ExtensionData) error + + // PauseRequest pauses an in progress request (may take 1 or more blocks to process) + PauseRequest(RequestID) error + // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID // Can also send extensions with unpause UnpauseResponse(peer.ID, RequestID, ...ExtensionData) error diff --git a/impl/graphsync.go b/impl/graphsync.go index 8eb61327..bff7070f 100644 --- a/impl/graphsync.go +++ b/impl/graphsync.go @@ -29,26 +29,27 @@ const maxRecursionDepth = 100 // GraphSync is an instance of a GraphSync exchange that implements // the graphsync protocol. type GraphSync struct { - network gsnet.GraphSyncNetwork - loader ipld.Loader - storer ipld.Storer - requestManager *requestmanager.RequestManager - responseManager *responsemanager.ResponseManager - asyncLoader *asyncloader.AsyncLoader - peerResponseManager *peerresponsemanager.PeerResponseManager - peerTaskQueue *peertaskqueue.PeerTaskQueue - peerManager *peermanager.PeerMessageManager - incomingRequestHooks *responderhooks.IncomingRequestHooks - outgoingBlockHooks *responderhooks.OutgoingBlockHooks - requestUpdatedHooks *responderhooks.RequestUpdatedHooks - completedResponseListeners *responderhooks.CompletedResponseListeners - incomingResponseHooks *requestorhooks.IncomingResponseHooks - outgoingRequestHooks *requestorhooks.OutgoingRequestHooks - incomingBlockHooks *requestorhooks.IncomingBlockHooks - persistenceOptions *persistenceoptions.PersistenceOptions - ctx context.Context - cancel context.CancelFunc - unregisterDefaultValidator graphsync.UnregisterHookFunc + network gsnet.GraphSyncNetwork + loader ipld.Loader + storer ipld.Storer + requestManager *requestmanager.RequestManager + responseManager *responsemanager.ResponseManager + asyncLoader *asyncloader.AsyncLoader + peerResponseManager *peerresponsemanager.PeerResponseManager + peerTaskQueue *peertaskqueue.PeerTaskQueue + peerManager *peermanager.PeerMessageManager + incomingRequestHooks *responderhooks.IncomingRequestHooks + outgoingBlockHooks *responderhooks.OutgoingBlockHooks + requestUpdatedHooks *responderhooks.RequestUpdatedHooks + completedResponseListeners *responderhooks.CompletedResponseListeners + requestorCancelledListeners *responderhooks.RequestorCancelledListeners + incomingResponseHooks *requestorhooks.IncomingResponseHooks + outgoingRequestHooks *requestorhooks.OutgoingRequestHooks + incomingBlockHooks *requestorhooks.IncomingBlockHooks + persistenceOptions *persistenceoptions.PersistenceOptions + ctx context.Context + cancel context.CancelFunc + unregisterDefaultValidator graphsync.UnregisterHookFunc } // Option defines the functional option type that can be used to configure @@ -88,29 +89,31 @@ func New(parent context.Context, network gsnet.GraphSyncNetwork, outgoingBlockHooks := responderhooks.NewBlockHooks() requestUpdatedHooks := responderhooks.NewUpdateHooks() completedResponseListeners := responderhooks.NewCompletedResponseListeners() - responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners) + requestorCancelledListeners := responderhooks.NewRequestorCancelledListeners() + responseManager := responsemanager.New(ctx, loader, peerResponseManager, peerTaskQueue, incomingRequestHooks, outgoingBlockHooks, requestUpdatedHooks, completedResponseListeners, requestorCancelledListeners) unregisterDefaultValidator := incomingRequestHooks.Register(selectorvalidator.SelectorValidator(maxRecursionDepth)) graphSync := &GraphSync{ - network: network, - loader: loader, - storer: storer, - asyncLoader: asyncLoader, - requestManager: requestManager, - peerManager: peerManager, - persistenceOptions: persistenceOptions, - incomingRequestHooks: incomingRequestHooks, - outgoingBlockHooks: outgoingBlockHooks, - requestUpdatedHooks: requestUpdatedHooks, - completedResponseListeners: completedResponseListeners, - incomingResponseHooks: incomingResponseHooks, - outgoingRequestHooks: outgoingRequestHooks, - incomingBlockHooks: incomingBlockHooks, - peerTaskQueue: peerTaskQueue, - peerResponseManager: peerResponseManager, - responseManager: responseManager, - ctx: ctx, - cancel: cancel, - unregisterDefaultValidator: unregisterDefaultValidator, + network: network, + loader: loader, + storer: storer, + asyncLoader: asyncLoader, + requestManager: requestManager, + peerManager: peerManager, + persistenceOptions: persistenceOptions, + incomingRequestHooks: incomingRequestHooks, + outgoingBlockHooks: outgoingBlockHooks, + requestUpdatedHooks: requestUpdatedHooks, + completedResponseListeners: completedResponseListeners, + requestorCancelledListeners: requestorCancelledListeners, + incomingResponseHooks: incomingResponseHooks, + outgoingRequestHooks: outgoingRequestHooks, + incomingBlockHooks: incomingBlockHooks, + peerTaskQueue: peerTaskQueue, + peerResponseManager: peerResponseManager, + responseManager: responseManager, + ctx: ctx, + cancel: cancel, + unregisterDefaultValidator: unregisterDefaultValidator, } for _, option := range options { @@ -177,6 +180,23 @@ func (gs *GraphSync) RegisterIncomingBlockHook(hook graphsync.OnIncomingBlockHoo return gs.incomingBlockHooks.Register(hook) } +// RegisterRequestorCancelledListener adds a listener on the responder for +// responses cancelled by the requestor +func (gs *GraphSync) RegisterRequestorCancelledListener(listener graphsync.OnRequestorCancelledListener) graphsync.UnregisterHookFunc { + return gs.requestorCancelledListeners.Register(listener) +} + +// UnpauseRequest unpauses a request that was paused in a block hook based request ID +// Can also send extensions with unpause +func (gs *GraphSync) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + return gs.requestManager.UnpauseRequest(requestID, extensions...) +} + +// PauseRequest pauses an in progress request (may take 1 or more blocks to process) +func (gs *GraphSync) PauseRequest(requestID graphsync.RequestID) error { + return gs.requestManager.PauseRequest(requestID) +} + // UnpauseResponse unpauses a response that was paused in a block hook based on peer ID and request ID func (gs *GraphSync) UnpauseResponse(p peer.ID, requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { return gs.responseManager.UnpauseResponse(p, requestID, extensions...) diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index c6902226..80d8a47e 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -384,6 +384,52 @@ func TestPauseResume(t *testing.T) { require.Len(t, td.blockStore1, blockChainLength, "did not store all blocks") } +func TestPauseResumeRequest(t *testing.T) { + // create network + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + td := newGsTestData(ctx, t) + + // initialize graphsync on first node to make requests + requestor := td.GraphSyncHost1() + + // setup receiving peer to just record message coming in + blockChainLength := 100 + blockSize := 100 + blockChain := testutil.SetupBlockChain(ctx, t, td.loader2, td.storer2, uint64(blockSize), blockChainLength) + + // initialize graphsync on second node to response to requests + _ = td.GraphSyncHost2() + + stopPoint := 50 + blocksReceived := 0 + requestIDChan := make(chan graphsync.RequestID, 1) + requestor.RegisterIncomingBlockHook(func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { + select { + case requestIDChan <- responseData.RequestID(): + default: + } + blocksReceived++ + if blocksReceived == stopPoint { + hookActions.PauseRequest() + } + }) + + progressChan, errChan := requestor.Request(ctx, td.host2.ID(), blockChain.TipLink, blockChain.Selector(), td.extension) + + blockChain.VerifyResponseRange(ctx, progressChan, 0, stopPoint-1) + timer := time.NewTimer(100 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should pause request", progressChan) + + requestID := <-requestIDChan + err := requestor.UnpauseRequest(requestID, td.extensionUpdate) + require.NoError(t, err) + + blockChain.VerifyRemainder(ctx, progressChan, stopPoint-1) + testutil.VerifyEmptyErrors(ctx, t, errChan) + require.Len(t, td.blockStore1, blockChainLength, "did not store all blocks") +} func TestPauseResumeViaUpdate(t *testing.T) { // create network diff --git a/ipldutil/traverser.go b/ipldutil/traverser.go index e5e2fe9d..d5e22c50 100644 --- a/ipldutil/traverser.go +++ b/ipldutil/traverser.go @@ -12,6 +12,14 @@ import ( var defaultVisitor traversal.AdvVisitFn = func(traversal.Progress, ipld.Node, traversal.VisitReason) error { return nil } +// ContextCancelError is a sentinel that indicates the passed in context +// was cancelled +type ContextCancelError struct{} + +func (cp ContextCancelError) Error() string { + return "Context cancelled" +} + // TraversalBuilder defines parameters for an iterative traversal type TraversalBuilder struct { Root ipld.Link @@ -31,6 +39,8 @@ type Traverser interface { Advance(reader io.Reader) error // Error errors the traversal by returning the given error as the result of the next IPLD load Error(err error) + // Shutdown cancels the traversal + Shutdown(ctx context.Context) } type state struct { @@ -47,9 +57,12 @@ type nextResponse struct { // Start initiates the traversal (run in a go routine because the regular // selector traversal expects a call back) -func (tb TraversalBuilder) Start(ctx context.Context) Traverser { +func (tb TraversalBuilder) Start(parentCtx context.Context) Traverser { + ctx, cancel := context.WithCancel(parentCtx) t := &traverser{ + parentCtx: parentCtx, ctx: ctx, + cancel: cancel, root: tb.Root, selector: tb.Selector, visitor: defaultVisitor, @@ -57,6 +70,7 @@ func (tb TraversalBuilder) Start(ctx context.Context) Traverser { awaitRequest: make(chan struct{}, 1), stateChan: make(chan state, 1), responses: make(chan nextResponse), + stopped: make(chan struct{}), } if tb.Visitor != nil { t.visitor = tb.Visitor @@ -71,7 +85,9 @@ func (tb TraversalBuilder) Start(ctx context.Context) Traverser { // traverser is a class to perform a selector traversal that stops every time a new block is loaded // and waits for manual input (in the form of advance or error) type traverser struct { + parentCtx context.Context ctx context.Context + cancel func() root ipld.Link selector ipld.Node visitor traversal.AdvVisitFn @@ -83,6 +99,7 @@ type traverser struct { awaitRequest chan struct{} stateChan chan state responses chan nextResponse + stopped chan struct{} } func (t *traverser) checkState() { @@ -91,7 +108,7 @@ func (t *traverser) checkState() { select { case <-t.ctx.Done(): t.isDone = true - t.completionErr = errors.New("Context cancelled") + t.completionErr = ContextCancelError{} case newState := <-t.stateChan: t.isDone = newState.isDone t.completionErr = newState.completionErr @@ -116,15 +133,16 @@ func (t *traverser) start() { case t.awaitRequest <- struct{}{}: } go func() { + defer close(t.stopped) loader := func(lnk ipld.Link, lnkCtx ipld.LinkContext) (io.Reader, error) { select { case <-t.ctx.Done(): - return nil, errors.New("Context cancelled") + return nil, ContextCancelError{} case t.stateChan <- state{false, nil, lnk, lnkCtx}: } select { case <-t.ctx.Done(): - return nil, errors.New("Context cancelled") + return nil, ContextCancelError{} case response := <-t.responses: return response.input, response.err } @@ -158,6 +176,14 @@ func (t *traverser) start() { }() } +func (t *traverser) Shutdown(ctx context.Context) { + t.cancel() + select { + case <-ctx.Done(): + case <-t.stopped: + } +} + // IsComplete returns true if a traversal is complete func (t *traverser) IsComplete() (bool, error) { t.checkState() @@ -179,12 +205,12 @@ func (t *traverser) Advance(reader io.Reader) error { } select { case <-t.ctx.Done(): - return errors.New("context cancelled") + return ContextCancelError{} case t.awaitRequest <- struct{}{}: } select { case <-t.ctx.Done(): - return errors.New("context cancelled") + return ContextCancelError{} case t.responses <- nextResponse{reader, nil}: } return nil diff --git a/message/message.go b/message/message.go index d62675c8..82f3fa1b 100644 --- a/message/message.go +++ b/message/message.go @@ -28,7 +28,8 @@ func IsTerminalFailureCode(status graphsync.ResponseStatusCode) bool { return status == graphsync.RequestFailedBusy || status == graphsync.RequestFailedContentNotFound || status == graphsync.RequestFailedLegal || - status == graphsync.RequestFailedUnknown + status == graphsync.RequestFailedUnknown || + status == graphsync.RequestCancelled } // IsTerminalResponseCode returns true if the response code signals @@ -388,3 +389,44 @@ func (gsr GraphSyncResponse) Extension(name graphsync.ExtensionName) ([]byte, bo return val, true } + +// ReplaceExtensions merges the extensions given extensions into the request to create a new request, +// but always uses new data +func (gsr GraphSyncRequest) ReplaceExtensions(extensions []graphsync.ExtensionData) GraphSyncRequest { + req, _ := gsr.MergeExtensions(extensions, func(name graphsync.ExtensionName, oldData []byte, newData []byte) ([]byte, error) { + return newData, nil + }) + return req +} + +// MergeExtensions merges the given list of extensions to produce a new request with the combination of the old request +// plus the new extensions. When an old extension and a new extension are both present, mergeFunc is called to produce +// the result +func (gsr GraphSyncRequest) MergeExtensions(extensions []graphsync.ExtensionData, mergeFunc func(name graphsync.ExtensionName, oldData []byte, newData []byte) ([]byte, error)) (GraphSyncRequest, error) { + if gsr.extensions == nil { + return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.isCancel, gsr.isUpdate, toExtensionsMap(extensions)), nil + } + newExtensionMap := toExtensionsMap(extensions) + combinedExtensions := make(map[string][]byte) + for name, newData := range newExtensionMap { + oldData, ok := gsr.extensions[name] + if !ok { + combinedExtensions[name] = newData + continue + } + resultData, err := mergeFunc(graphsync.ExtensionName(name), oldData, newData) + if err != nil { + return GraphSyncRequest{}, err + } + combinedExtensions[name] = resultData + } + + for name, oldData := range gsr.extensions { + _, ok := combinedExtensions[name] + if ok { + continue + } + combinedExtensions[name] = oldData + } + return newRequest(gsr.id, gsr.root, gsr.selector, gsr.priority, gsr.isCancel, gsr.isUpdate, combinedExtensions), nil +} diff --git a/message/message_test.go b/message/message_test.go index ece03ea2..b14bfec9 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "errors" "math/rand" "testing" @@ -281,3 +282,95 @@ func TestToNetFromNetEquivalency(t *testing.T) { require.True(t, ok) } } + +func TestMergeExtensions(t *testing.T) { + extensionName1 := graphsync.ExtensionName("graphsync/1") + extensionName2 := graphsync.ExtensionName("graphsync/2") + extensionName3 := graphsync.ExtensionName("graphsync/3") + initialExtensions := []graphsync.ExtensionData{ + { + Name: extensionName1, + Data: []byte("applesauce"), + }, + { + Name: extensionName2, + Data: []byte("hello"), + }, + } + replacementExtensions := []graphsync.ExtensionData{ + { + Name: extensionName2, + Data: []byte("world"), + }, + { + Name: extensionName3, + Data: []byte("cheese"), + }, + } + defaultMergeFunc := func(name graphsync.ExtensionName, oldData []byte, newData []byte) ([]byte, error) { + return []byte(string(oldData) + " " + string(newData)), nil + } + root := testutil.GenerateCids(1)[0] + ssb := builder.NewSelectorSpecBuilder(basicnode.Style.Any) + selector := ssb.Matcher().Node() + id := graphsync.RequestID(rand.Int31()) + priority := graphsync.Priority(rand.Int31()) + defaultRequest := NewRequest(id, root, selector, priority, initialExtensions...) + t.Run("when merging into empty", func(t *testing.T) { + emptyRequest := NewRequest(id, root, selector, priority) + resultRequest, err := emptyRequest.MergeExtensions(replacementExtensions, defaultMergeFunc) + require.NoError(t, err) + require.Equal(t, emptyRequest.ID(), resultRequest.ID()) + require.Equal(t, emptyRequest.Priority(), resultRequest.Priority()) + require.Equal(t, emptyRequest.Root().String(), resultRequest.Root().String()) + require.Equal(t, emptyRequest.Selector(), resultRequest.Selector()) + _, has := resultRequest.Extension(extensionName1) + require.False(t, has) + extData2, has := resultRequest.Extension(extensionName2) + require.True(t, has) + require.Equal(t, []byte("world"), extData2) + extData3, has := resultRequest.Extension(extensionName3) + require.True(t, has) + require.Equal(t, []byte("cheese"), extData3) + }) + t.Run("when merging two requests", func(t *testing.T) { + resultRequest, err := defaultRequest.MergeExtensions(replacementExtensions, defaultMergeFunc) + require.NoError(t, err) + require.Equal(t, defaultRequest.ID(), resultRequest.ID()) + require.Equal(t, defaultRequest.Priority(), resultRequest.Priority()) + require.Equal(t, defaultRequest.Root().String(), resultRequest.Root().String()) + require.Equal(t, defaultRequest.Selector(), resultRequest.Selector()) + extData1, has := resultRequest.Extension(extensionName1) + require.True(t, has) + require.Equal(t, []byte("applesauce"), extData1) + extData2, has := resultRequest.Extension(extensionName2) + require.True(t, has) + require.Equal(t, []byte("hello world"), extData2) + extData3, has := resultRequest.Extension(extensionName3) + require.True(t, has) + require.Equal(t, []byte("cheese"), extData3) + }) + t.Run("when merging errors", func(t *testing.T) { + errorMergeFunc := func(name graphsync.ExtensionName, oldData []byte, newData []byte) ([]byte, error) { + return nil, errors.New("something went wrong") + } + _, err := defaultRequest.MergeExtensions(replacementExtensions, errorMergeFunc) + require.Error(t, err) + }) + t.Run("when merging with replace", func(t *testing.T) { + resultRequest := defaultRequest.ReplaceExtensions(replacementExtensions) + require.Equal(t, defaultRequest.ID(), resultRequest.ID()) + require.Equal(t, defaultRequest.Priority(), resultRequest.Priority()) + require.Equal(t, defaultRequest.Root().String(), resultRequest.Root().String()) + require.Equal(t, defaultRequest.Selector(), resultRequest.Selector()) + extData1, has := resultRequest.Extension(extensionName1) + require.True(t, has) + require.Equal(t, []byte("applesauce"), extData1) + extData2, has := resultRequest.Extension(extensionName2) + require.True(t, has) + require.Equal(t, []byte("world"), extData2) + extData3, has := resultRequest.Extension(extensionName3) + require.True(t, has) + require.Equal(t, []byte("cheese"), extData3) + }) +} diff --git a/requestmanager/asyncloader/asyncloader.go b/requestmanager/asyncloader/asyncloader.go index 2ccc83c5..9e4a0069 100644 --- a/requestmanager/asyncloader/asyncloader.go +++ b/requestmanager/asyncloader/asyncloader.go @@ -123,12 +123,17 @@ func (al *AsyncLoader) ProcessResponse(responses map[graphsync.RequestID]metadat // for errors -- only one message will be sent over either. func (al *AsyncLoader) AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult { resultChan := make(chan types.AsyncLoadResult, 1) + response := make(chan struct{}, 1) lr := loadattemptqueue.NewLoadRequest(requestID, link, resultChan) select { case <-al.ctx.Done(): resultChan <- types.AsyncLoadResult{Data: nil, Err: errors.New("Context closed")} close(resultChan) - case al.incomingMessages <- &loadRequestMessage{requestID, lr}: + case al.incomingMessages <- &loadRequestMessage{response, requestID, lr}: + } + select { + case <-al.ctx.Done(): + case <-response: } return resultChan } @@ -154,6 +159,7 @@ func (al *AsyncLoader) CleanupRequest(requestID graphsync.RequestID) { } type loadRequestMessage struct { + response chan struct{} requestID graphsync.RequestID loadRequest loadattemptqueue.LoadRequest } @@ -239,6 +245,10 @@ func (lrm *loadRequestMessage) handle(al *AsyncLoader) { _, retry := al.activeRequests[lrm.requestID] loadAttemptQueue := al.getLoadAttemptQueue(al.requestQueues[lrm.requestID]) loadAttemptQueue.AttemptLoad(lrm.loadRequest, retry) + select { + case <-al.ctx.Done(): + case lrm.response <- struct{}{}: + } } func (rpom *registerPersistenceOptionMessage) register(al *AsyncLoader) error { diff --git a/requestmanager/executor/executor.go b/requestmanager/executor/executor.go index 0d78320a..30c96478 100644 --- a/requestmanager/executor/executor.go +++ b/requestmanager/executor/executor.go @@ -1,40 +1,66 @@ package executor import ( + "bytes" "context" + "strings" + "sync/atomic" + "github.com/ipfs/go-cid" "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/cidset" "github.com/ipfs/go-graphsync/ipldutil" gsmsg "github.com/ipfs/go-graphsync/message" - "github.com/ipfs/go-graphsync/requestmanager/loader" + "github.com/ipfs/go-graphsync/requestmanager/hooks" + "github.com/ipfs/go-graphsync/requestmanager/types" ipld "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/traversal" + peer "github.com/libp2p/go-libp2p-core/peer" ) -// RequestExecution runs a single graphsync request with data loaded from the -// asynchronous loader +// AsyncLoadFn is a function which given a request id and an ipld.Link, returns +// a channel which will eventually return data for the link or an err +type AsyncLoadFn func(graphsync.RequestID, ipld.Link) <-chan types.AsyncLoadResult + +// ExecutionEnv are request parameters that last between requests +type ExecutionEnv struct { + Ctx context.Context + SendRequest func(peer.ID, gsmsg.GraphSyncRequest) + RunBlockHooks func(p peer.ID, response graphsync.ResponseData, blk graphsync.BlockData) error + TerminateRequest func(graphsync.RequestID) + WaitForMessages func(ctx context.Context, resumeMessages chan graphsync.ExtensionData) ([]graphsync.ExtensionData, error) + Loader AsyncLoadFn +} + +// RequestExecution are parameters for a single request execution type RequestExecution struct { + Ctx context.Context + P peer.ID + NetworkError chan error Request gsmsg.GraphSyncRequest - SendRequest func(gsmsg.GraphSyncRequest) - Loader loader.AsyncLoadFn - RunBlockHooks func(blk graphsync.BlockData) error - TerminateRequest func() + LastResponse *atomic.Value + DoNotSendCids *cid.Set NodeStyleChooser traversal.LinkTargetNodeStyleChooser + ResumeMessages chan []graphsync.ExtensionData + PauseMessages chan struct{} } // Start begins execution of a request in a go routine -func (re RequestExecution) Start(ctx context.Context) (chan graphsync.ResponseProgress, chan error) { +func (ee ExecutionEnv) Start(re RequestExecution) (chan graphsync.ResponseProgress, chan error) { executor := &requestExecutor{ inProgressChan: make(chan graphsync.ResponseProgress), inProgressErr: make(chan error), - ctx: ctx, + ctx: re.Ctx, + p: re.P, + networkError: re.NetworkError, request: re.Request, - sendRequest: re.SendRequest, - loader: re.Loader, - runBlockHooks: re.RunBlockHooks, - terminateRequest: re.TerminateRequest, + lastResponse: re.LastResponse, + doNotSendCids: re.DoNotSendCids, nodeStyleChooser: re.NodeStyleChooser, + resumeMessages: re.ResumeMessages, + pauseMessages: re.PauseMessages, + env: ee, } executor.sendRequest(executor.request) go executor.run() @@ -42,15 +68,20 @@ func (re RequestExecution) Start(ctx context.Context) (chan graphsync.ResponsePr } type requestExecutor struct { - inProgressChan chan graphsync.ResponseProgress - inProgressErr chan error - ctx context.Context - request gsmsg.GraphSyncRequest - sendRequest func(gsmsg.GraphSyncRequest) - loader loader.AsyncLoadFn - runBlockHooks func(blk graphsync.BlockData) error - terminateRequest func() - nodeStyleChooser traversal.LinkTargetNodeStyleChooser + inProgressChan chan graphsync.ResponseProgress + inProgressErr chan error + ctx context.Context + p peer.ID + networkError chan error + request gsmsg.GraphSyncRequest + lastResponse *atomic.Value + nodeStyleChooser traversal.LinkTargetNodeStyleChooser + resumeMessages chan []graphsync.ExtensionData + pauseMessages chan struct{} + doNotSendCids *cid.Set + env ExecutionEnv + restartNeeded bool + pendingExtensions []graphsync.ExtensionData } func (re *requestExecutor) visitor(tp traversal.Progress, node ipld.Node, tr traversal.VisitReason) error { @@ -65,20 +96,178 @@ func (re *requestExecutor) visitor(tp traversal.Progress, node ipld.Node, tr tra return nil } +func (re *requestExecutor) traverse() error { + traverser := ipldutil.TraversalBuilder{ + Root: cidlink.Link{Cid: re.request.Root()}, + Selector: re.request.Selector(), + Visitor: re.visitor, + Chooser: re.nodeStyleChooser, + }.Start(re.ctx) + defer traverser.Shutdown(context.Background()) + for { + isComplete, err := traverser.IsComplete() + if isComplete { + return err + } + lnk, _ := traverser.CurrentRequest() + resultChan := re.env.Loader(re.request.ID(), lnk) + var result types.AsyncLoadResult + select { + case result = <-resultChan: + default: + err := re.sendRestartAsNeeded() + if err != nil { + return err + } + select { + case <-re.ctx.Done(): + return ipldutil.ContextCancelError{} + case result = <-resultChan: + } + } + err = re.processResult(traverser, lnk, result) + if _, ok := err.(hooks.ErrPaused); ok { + err = re.waitForResume() + if err != nil { + return err + } + err = traverser.Advance(bytes.NewReader(result.Data)) + if err != nil { + return err + } + } else if err != nil { + return err + } + } +} + func (re *requestExecutor) run() { - selector, _ := ipldutil.ParseSelector(re.request.Selector()) - loaderFn := loader.WrapAsyncLoader(re.ctx, re.loader, re.request.ID(), re.inProgressErr, re.runBlockHooks) - err := ipldutil.Traverse(re.ctx, loaderFn, re.nodeStyleChooser, cidlink.Link{Cid: re.request.Root()}, selector, re.visitor) + err := re.traverse() if err != nil { - _, isContextErr := err.(loader.ContextCancelError) - if !isContextErr { + if !isContextErr(err) { select { case <-re.ctx.Done(): case re.inProgressErr <- err: } } } + select { + case networkError := <-re.networkError: + select { + case re.inProgressErr <- networkError: + case <-re.env.Ctx.Done(): + } + default: + } re.terminateRequest() close(re.inProgressChan) close(re.inProgressErr) } + +func (re *requestExecutor) sendRequest(request gsmsg.GraphSyncRequest) { + re.env.SendRequest(re.p, request) +} + +func (re *requestExecutor) terminateRequest() { + re.env.TerminateRequest(re.request.ID()) +} + +func (re *requestExecutor) runBlockHooks(blk graphsync.BlockData) error { + response := re.lastResponse.Load().(gsmsg.GraphSyncResponse) + return re.env.RunBlockHooks(re.p, response, blk) +} + +func (re *requestExecutor) waitForResume() error { + select { + case <-re.ctx.Done(): + return ipldutil.ContextCancelError{} + case re.pendingExtensions = <-re.resumeMessages: + re.restartNeeded = true + return nil + } +} + +func (re *requestExecutor) onNewBlockWithPause(block graphsync.BlockData) error { + err := re.onNewBlock(block) + select { + case <-re.pauseMessages: + re.sendRequest(gsmsg.CancelRequest(re.request.ID())) + if err == nil { + err = hooks.ErrPaused{} + } + default: + } + return err +} + +func (re *requestExecutor) onNewBlock(block graphsync.BlockData) error { + re.doNotSendCids.Add(block.Link().(cidlink.Link).Cid) + return re.runBlockHooks(block) +} + +func (re *requestExecutor) processResult(traverser ipldutil.Traverser, link ipld.Link, result types.AsyncLoadResult) error { + if result.Err != nil { + select { + case <-re.ctx.Done(): + return ipldutil.ContextCancelError{} + case re.inProgressErr <- result.Err: + traverser.Error(traversal.SkipMe{}) + return nil + } + } + err := re.onNewBlockWithPause(&blockData{link, result.Local, uint64(len(result.Data))}) + if err != nil { + return err + } + err = traverser.Advance(bytes.NewReader(result.Data)) + if err != nil { + return err + } + return nil +} + +func (re *requestExecutor) sendRestartAsNeeded() error { + if !re.restartNeeded { + return nil + } + extensions := re.pendingExtensions + re.pendingExtensions = nil + re.restartNeeded = false + cidsData, err := cidset.EncodeCidSet(re.doNotSendCids) + if err != nil { + return err + } + extensions = append(extensions, graphsync.ExtensionData{Name: graphsync.ExtensionDoNotSendCIDs, Data: cidsData}) + re.request = re.request.ReplaceExtensions(extensions) + re.sendRequest(re.request) + return nil +} + +func isContextErr(err error) bool { + // TODO: Match with errors.Is when https://github.com/ipld/go-ipld-prime/issues/58 is resolved + return strings.Contains(err.Error(), ipldutil.ContextCancelError{}.Error()) +} + +type blockData struct { + link ipld.Link + local bool + size uint64 +} + +// Link is the link/cid for the block +func (bd *blockData) Link() ipld.Link { + return bd.link +} + +// BlockSize specifies the size of the block +func (bd *blockData) BlockSize() uint64 { + return bd.size +} + +// BlockSize specifies the amount of data actually transmitted over the network +func (bd *blockData) BlockSizeOnWire() uint64 { + if bd.local { + return 0 + } + return bd.size +} diff --git a/requestmanager/executor/executor_test.go b/requestmanager/executor/executor_test.go new file mode 100644 index 00000000..0db6d0e7 --- /dev/null +++ b/requestmanager/executor/executor_test.go @@ -0,0 +1,441 @@ +package executor_test + +import ( + "context" + "errors" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/ipfs/go-cid" + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/cidset" + "github.com/ipfs/go-graphsync/ipldutil" + gsmsg "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/requestmanager/executor" + "github.com/ipfs/go-graphsync/requestmanager/hooks" + "github.com/ipfs/go-graphsync/requestmanager/testloader" + "github.com/ipfs/go-graphsync/requestmanager/types" + "github.com/ipfs/go-graphsync/testutil" + "github.com/ipld/go-ipld-prime" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + basicnode "github.com/ipld/go-ipld-prime/node/basic" + peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +type configureLoaderFn func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, fal *testloader.FakeAsyncLoader, startStop [2]int) + +func TestRequestExecutionBlockChain(t *testing.T) { + testCases := map[string]struct { + configureLoader configureLoaderFn + configureRequestExecution func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) + verifyResults func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) + }{ + "simple success case": { + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 0, ree.currentWaitForResumeResult) + require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "error at block hook": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = errors.New("something went wrong") + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyResponseRangeSync(responses, 0, 5) + require.Len(t, receivedErrors, 1) + require.Regexp(t, "something went wrong", receivedErrors[0].Error()) + require.Equal(t, 0, ree.currentWaitForResumeResult) + require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) + require.Len(t, ree.blookHooksCalled, 6) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "context cancelled": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = ipldutil.ContextCancelError{} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyResponseRangeSync(responses, 0, 5) + require.Empty(t, receivedErrors) + require.Equal(t, 0, ree.currentWaitForResumeResult) + require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) + require.Len(t, ree.blookHooksCalled, 6) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "simple pause": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + ree.waitForResumeResults = append(ree.waitForResumeResults, nil) + ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 1, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 6, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "multiple pause": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(7)}] = hooks.ErrPaused{} + ree.waitForResumeResults = append(ree.waitForResumeResults, nil, nil) + ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 2, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 6, cidSet.Len()) + doNotSendCidsExt, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 8, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "multiple pause with extensions": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(7)}] = hooks.ErrPaused{} + ree.waitForResumeResults = append(ree.waitForResumeResults, []graphsync.ExtensionData{ + { + Name: graphsync.ExtensionName("applesauce"), + Data: []byte("cheese 1"), + }, + }, []graphsync.ExtensionData{ + { + Name: graphsync.ExtensionName("applesauce"), + Data: []byte("cheese 2"), + }, + }) + ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 2, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + testExtData, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionName("applesauce")) + require.True(t, has) + require.Equal(t, "cheese 1", string(testExtData)) + doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 6, cidSet.Len()) + testExtData, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionName("applesauce")) + require.True(t, has) + require.Equal(t, "cheese 2", string(testExtData)) + doNotSendCidsExt, has = ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 8, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "preexisting do not send cids": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.doNotSendCids.Add(tbc.GenisisLink.(cidlink.Link).Cid) + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + ree.waitForResumeResults = append(ree.waitForResumeResults, nil) + ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 1, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + doNotSendCidsExt, has := ree.requestsSent[1].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 7, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "pause but request is cancelled": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.blockHookResults[blockHookKey{p, requestID, tbc.LinkTipIndex(5)}] = hooks.ErrPaused{} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyResponseRangeSync(responses, 0, 5) + require.Empty(t, receivedErrors) + require.Equal(t, 0, ree.currentWaitForResumeResult) + require.Equal(t, []requestSent{{ree.p, ree.request}}, ree.requestsSent) + require.Len(t, ree.blookHooksCalled, 6) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "pause externally": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.externalPauses = append(ree.externalPauses, pauseKey{requestID, tbc.LinkTipIndex(5)}) + ree.waitForResumeResults = append(ree.waitForResumeResults, nil) + ree.loaderRanges = [][2]int{{0, 6}, {6, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 1, ree.currentPauseResult) + require.Equal(t, 1, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + require.True(t, ree.requestsSent[1].request.IsCancel()) + doNotSendCidsExt, has := ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 6, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + "pause externally multiple": { + configureRequestExecution: func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, ree *requestExecutionEnv) { + ree.externalPauses = append(ree.externalPauses, pauseKey{requestID, tbc.LinkTipIndex(5)}, pauseKey{requestID, tbc.LinkTipIndex(7)}) + ree.waitForResumeResults = append(ree.waitForResumeResults, nil, nil) + ree.loaderRanges = [][2]int{{0, 6}, {6, 8}, {8, 10}} + }, + verifyResults: func(t *testing.T, tbc *testutil.TestBlockChain, ree *requestExecutionEnv, responses []graphsync.ResponseProgress, receivedErrors []error) { + tbc.VerifyWholeChainSync(responses) + require.Empty(t, receivedErrors) + require.Equal(t, 2, ree.currentPauseResult) + require.Equal(t, 2, ree.currentWaitForResumeResult) + require.Equal(t, ree.request, ree.requestsSent[0].request) + require.True(t, ree.requestsSent[1].request.IsCancel()) + doNotSendCidsExt, has := ree.requestsSent[2].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err := cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 6, cidSet.Len()) + require.True(t, ree.requestsSent[3].request.IsCancel()) + doNotSendCidsExt, has = ree.requestsSent[4].request.Extension(graphsync.ExtensionDoNotSendCIDs) + require.True(t, has) + cidSet, err = cidset.DecodeCidSet(doNotSendCidsExt) + require.NoError(t, err) + require.Equal(t, 8, cidSet.Len()) + require.Len(t, ree.blookHooksCalled, 10) + require.Equal(t, ree.request.ID(), ree.terminateRequested) + require.True(t, ree.nodeStyleChooserCalled) + }, + }, + } + for testCase, data := range testCases { + t.Run(testCase, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + loader, storer := testutil.NewTestStore(make(map[ipld.Link][]byte)) + tbc := testutil.SetupBlockChain(ctx, t, loader, storer, 100, 10) + fal := testloader.NewFakeAsyncLoader() + requestID := graphsync.RequestID(rand.Int31()) + p := testutil.GeneratePeers(1)[0] + configureLoader := data.configureLoader + if configureLoader == nil { + configureLoader = func(p peer.ID, requestID graphsync.RequestID, tbc *testutil.TestBlockChain, fal *testloader.FakeAsyncLoader, startStop [2]int) { + fal.SuccessResponseOn(requestID, tbc.Blocks(startStop[0], startStop[1])) + } + } + requestCtx, requestCancel := context.WithCancel(ctx) + ree := &requestExecutionEnv{ + ctx: requestCtx, + cancelFn: requestCancel, + p: p, + resumeMessages: make(chan []graphsync.ExtensionData, 1), + pauseMessages: make(chan struct{}, 1), + blockHookResults: make(map[blockHookKey]error), + doNotSendCids: cid.NewSet(), + request: gsmsg.NewRequest(requestID, tbc.TipLink.(cidlink.Link).Cid, tbc.Selector(), graphsync.Priority(rand.Int31())), + fal: fal, + tbc: tbc, + configureLoader: configureLoader, + } + fal.OnAsyncLoad(ree.checkPause) + if data.configureRequestExecution != nil { + data.configureRequestExecution(p, requestID, tbc, ree) + } + if len(ree.loaderRanges) == 0 { + ree.loaderRanges = [][2]int{{0, 10}} + } + inProgress, inProgressErr := ree.requestExecution() + var responsesReceived []graphsync.ResponseProgress + var errorsReceived []error + var inProgressDone, inProgressErrDone bool + for !inProgressDone || !inProgressErrDone { + select { + case response, ok := <-inProgress: + if !ok { + inProgress = nil + inProgressDone = true + } else { + responsesReceived = append(responsesReceived, response) + } + case err, ok := <-inProgressErr: + if !ok { + inProgressErr = nil + inProgressErrDone = true + } else { + errorsReceived = append(errorsReceived, err) + } + case <-ctx.Done(): + t.Fatal("did not complete request") + } + } + data.verifyResults(t, tbc, ree, responsesReceived, errorsReceived) + }) + } +} + +type requestSent struct { + p peer.ID + request gsmsg.GraphSyncRequest +} + +type blockHookKey struct { + p peer.ID + requestID graphsync.RequestID + link ipld.Link +} + +type pauseKey struct { + requestID graphsync.RequestID + link ipld.Link +} + +type requestExecutionEnv struct { + // params + ctx context.Context + cancelFn func() + request gsmsg.GraphSyncRequest + p peer.ID + blockHookResults map[blockHookKey]error + doNotSendCids *cid.Set + waitForResumeResults [][]graphsync.ExtensionData + resumeMessages chan []graphsync.ExtensionData + pauseMessages chan struct{} + externalPauses []pauseKey + loaderRanges [][2]int + + // results + currentPauseResult int + currentWaitForResumeResult int + requestsSent []requestSent + blookHooksCalled []blockHookKey + terminateRequested graphsync.RequestID + nodeStyleChooserCalled bool + + // deps + configureLoader configureLoaderFn + tbc *testutil.TestBlockChain + fal *testloader.FakeAsyncLoader +} + +func (ree *requestExecutionEnv) terminateRequest(requestID graphsync.RequestID) { + ree.terminateRequested = requestID +} + +func (ree *requestExecutionEnv) waitForResume() ([]graphsync.ExtensionData, error) { + if len(ree.waitForResumeResults) <= ree.currentWaitForResumeResult { + return nil, ipldutil.ContextCancelError{} + } + extensions := ree.waitForResumeResults[ree.currentWaitForResumeResult] + ree.currentWaitForResumeResult++ + return extensions, nil +} + +func (ree *requestExecutionEnv) sendRequest(p peer.ID, request gsmsg.GraphSyncRequest) { + ree.requestsSent = append(ree.requestsSent, requestSent{p, request}) + if ree.currentWaitForResumeResult < len(ree.loaderRanges) && !request.IsCancel() { + ree.configureLoader(ree.p, ree.request.ID(), ree.tbc, ree.fal, ree.loaderRanges[ree.currentWaitForResumeResult]) + } +} + +func (ree *requestExecutionEnv) nodeStyleChooser(ipld.Link, ipld.LinkContext) (ipld.NodeStyle, error) { + ree.nodeStyleChooserCalled = true + return basicnode.Style.Any, nil +} + +func (ree *requestExecutionEnv) checkPause(requestID graphsync.RequestID, link ipld.Link, result <-chan types.AsyncLoadResult) { + if ree.currentPauseResult >= len(ree.externalPauses) { + return + } + currentPause := ree.externalPauses[ree.currentPauseResult] + if currentPause.link == link && currentPause.requestID == requestID { + ree.currentPauseResult++ + ree.pauseMessages <- struct{}{} + extensions, err := ree.waitForResume() + if err != nil { + ree.cancelFn() + } else { + ree.resumeMessages <- extensions + } + } +} + +func (ree *requestExecutionEnv) runBlockHooks(p peer.ID, response graphsync.ResponseData, blk graphsync.BlockData) error { + bhk := blockHookKey{p, response.RequestID(), blk.Link()} + ree.blookHooksCalled = append(ree.blookHooksCalled, bhk) + err := ree.blockHookResults[bhk] + if _, ok := err.(hooks.ErrPaused); ok { + extensions, err := ree.waitForResume() + if err != nil { + ree.cancelFn() + } else { + ree.resumeMessages <- extensions + } + } + return err +} + +func (ree *requestExecutionEnv) requestExecution() (chan graphsync.ResponseProgress, chan error) { + var lastResponse atomic.Value + lastResponse.Store(gsmsg.NewResponse(ree.request.ID(), graphsync.RequestAcknowledged)) + return executor.ExecutionEnv{ + SendRequest: ree.sendRequest, + RunBlockHooks: ree.runBlockHooks, + TerminateRequest: ree.terminateRequest, + Loader: ree.fal.AsyncLoad, + }.Start(executor.RequestExecution{ + Ctx: ree.ctx, + P: ree.p, + LastResponse: &lastResponse, + Request: ree.request, + DoNotSendCids: ree.doNotSendCids, + NodeStyleChooser: ree.nodeStyleChooser, + ResumeMessages: ree.resumeMessages, + PauseMessages: ree.pauseMessages, + }) +} diff --git a/requestmanager/hooks/hooks_test.go b/requestmanager/hooks/hooks_test.go index fa1b2ea7..addd0c00 100644 --- a/requestmanager/hooks/hooks_test.go +++ b/requestmanager/hooks/hooks_test.go @@ -140,6 +140,17 @@ func TestBlockHookProcessing(t *testing.T) { require.EqualError(t, result.Err, "something went wrong") }, }, + "pause request": { + configure: func(t *testing.T, hooks *hooks.IncomingBlockHooks) { + hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { + hookActions.PauseRequest() + }) + }, + assert: func(t *testing.T, result hooks.UpdateResult) { + require.Empty(t, result.Extensions) + require.EqualError(t, result.Err, hooks.ErrPaused{}.Error()) + }, + }, "hooks update with extensions": { configure: func(t *testing.T, hooks *hooks.IncomingBlockHooks) { hooks.Register(func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { diff --git a/requestmanager/hooks/responsehooks.go b/requestmanager/hooks/responsehooks.go index 2af0afe9..643578a5 100644 --- a/requestmanager/hooks/responsehooks.go +++ b/requestmanager/hooks/responsehooks.go @@ -7,6 +7,11 @@ import ( "github.com/ipfs/go-graphsync" ) +// ErrPaused indicates a request should stop processing, but only cause it's paused +type ErrPaused struct{} + +func (e ErrPaused) Error() string { return "request has been paused" } + // IncomingResponseHooks is a set of incoming response hooks that can be processed type IncomingResponseHooks struct { pubSub *pubsub.PubSub @@ -67,3 +72,7 @@ func (rha *updateHookActions) TerminateWithError(err error) { func (rha *updateHookActions) UpdateRequestWithExtensions(extensions ...graphsync.ExtensionData) { rha.extensions = append(rha.extensions, extensions...) } + +func (rha *updateHookActions) PauseRequest() { + rha.err = ErrPaused{} +} diff --git a/requestmanager/loader/loader.go b/requestmanager/loader/loader.go deleted file mode 100644 index 089d7e7e..00000000 --- a/requestmanager/loader/loader.go +++ /dev/null @@ -1,84 +0,0 @@ -package loader - -import ( - "bytes" - "context" - "io" - - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/requestmanager/types" - ipld "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/traversal" -) - -// ContextCancelError is a sentinel that indicates the passed in request context -// was cancelled -type ContextCancelError struct{} - -func (ContextCancelError) Error() string { - return "request context cancelled" -} - -// AsyncLoadFn is a function which given a request id and an ipld.Link, returns -// a channel which will eventually return data for the link or an err -type AsyncLoadFn func(graphsync.RequestID, ipld.Link) <-chan types.AsyncLoadResult - -// OnNewBlockFn is a function that is called whenever a new block is successfully loaded -// before the loader completes -type OnNewBlockFn func(graphsync.BlockData) error - -// WrapAsyncLoader creates a regular ipld link laoder from an asynchronous load -// function, with the given cancellation context, for the given requests, and will -// transmit load errors on the given channel -func WrapAsyncLoader( - ctx context.Context, - asyncLoadFn AsyncLoadFn, - requestID graphsync.RequestID, - errorChan chan error, - onNewBlockFn OnNewBlockFn) ipld.Loader { - return func(link ipld.Link, linkContext ipld.LinkContext) (io.Reader, error) { - resultChan := asyncLoadFn(requestID, link) - select { - case <-ctx.Done(): - return nil, ContextCancelError{} - case result := <-resultChan: - if result.Err != nil { - select { - case <-ctx.Done(): - return nil, ContextCancelError{} - case errorChan <- result.Err: - return nil, traversal.SkipMe{} - } - } - err := onNewBlockFn(&blockData{link, result.Local, uint64(len(result.Data))}) - if err != nil { - return nil, err - } - return bytes.NewReader(result.Data), nil - } - } -} - -type blockData struct { - link ipld.Link - local bool - size uint64 -} - -// Link is the link/cid for the block -func (bd *blockData) Link() ipld.Link { - return bd.link -} - -// BlockSize specifies the size of the block -func (bd *blockData) BlockSize() uint64 { - return bd.size -} - -// BlockSize specifies the amount of data actually transmitted over the network -func (bd *blockData) BlockSizeOnWire() uint64 { - if bd.local { - return 0 - } - return bd.size -} diff --git a/requestmanager/loader/loader_test.go b/requestmanager/loader/loader_test.go deleted file mode 100644 index 23d9d8d1..00000000 --- a/requestmanager/loader/loader_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package loader - -import ( - "context" - "errors" - "io" - "io/ioutil" - "math/rand" - "testing" - "time" - - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/requestmanager/types" - "github.com/stretchr/testify/require" - - "github.com/ipfs/go-graphsync/testutil" - "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/traversal" -) - -type callParams struct { - requestID graphsync.RequestID - link ipld.Link -} - -func makeAsyncLoadFn(responseChan chan types.AsyncLoadResult, calls chan callParams) AsyncLoadFn { - return func(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult { - calls <- callParams{requestID, link} - return responseChan - } -} - -func TestWrappedAsyncLoaderReturnsValues(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - responseChan := make(chan types.AsyncLoadResult, 1) - calls := make(chan callParams, 1) - asyncLoadFn := makeAsyncLoadFn(responseChan, calls) - errChan := make(chan error) - requestID := graphsync.RequestID(rand.Int31()) - onNewBlockFn := func(graphsync.BlockData) error { return nil } - loader := WrapAsyncLoader(ctx, asyncLoadFn, requestID, errChan, onNewBlockFn) - - link := testutil.NewTestLink() - - data := testutil.RandomBytes(100) - responseChan <- types.AsyncLoadResult{Data: data, Err: nil} - stream, err := loader(link, ipld.LinkContext{}) - require.NoError(t, err, "should load") - returnedData, err := ioutil.ReadAll(stream) - require.NoError(t, err, "stream did not read") - require.Equal(t, data, returnedData, "should return correct data") -} - -func TestWrappedAsyncLoaderSideChannelsErrors(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - responseChan := make(chan types.AsyncLoadResult, 1) - calls := make(chan callParams, 1) - asyncLoadFn := makeAsyncLoadFn(responseChan, calls) - errChan := make(chan error, 1) - requestID := graphsync.RequestID(rand.Int31()) - onNewBlockFn := func(graphsync.BlockData) error { return nil } - loader := WrapAsyncLoader(ctx, asyncLoadFn, requestID, errChan, onNewBlockFn) - - link := testutil.NewTestLink() - err := errors.New("something went wrong") - responseChan <- types.AsyncLoadResult{Data: nil, Err: err} - stream, loadErr := loader(link, ipld.LinkContext{}) - require.Nil(t, stream, "should return nil reader") - _, isSkipErr := loadErr.(traversal.SkipMe) - require.True(t, isSkipErr) - var returnedErr error - testutil.AssertReceive(ctx, t, errChan, &returnedErr, "should return an error on side channel") - require.EqualError(t, returnedErr, err.Error()) -} - -func TestWrappedAsyncLoaderContextCancels(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - subCtx, subCancel := context.WithCancel(ctx) - responseChan := make(chan types.AsyncLoadResult, 1) - calls := make(chan callParams, 1) - asyncLoadFn := makeAsyncLoadFn(responseChan, calls) - errChan := make(chan error, 1) - requestID := graphsync.RequestID(rand.Int31()) - onNewBlockFn := func(graphsync.BlockData) error { return nil } - loader := WrapAsyncLoader(subCtx, asyncLoadFn, requestID, errChan, onNewBlockFn) - link := testutil.NewTestLink() - resultsChan := make(chan struct { - io.Reader - error - }) - go func() { - stream, err := loader(link, ipld.LinkContext{}) - resultsChan <- struct { - io.Reader - error - }{stream, err} - }() - subCancel() - - var result struct { - io.Reader - error - } - testutil.AssertReceive(ctx, t, resultsChan, &result, "should return from sub context cancelling") - require.Nil(t, result.Reader) - require.Error(t, result.error, "should error from sub context cancelling") -} - -func TestWrappedAsyncLoaderBlockHookErrors(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - responseChan := make(chan types.AsyncLoadResult, 1) - calls := make(chan callParams, 1) - asyncLoadFn := makeAsyncLoadFn(responseChan, calls) - errChan := make(chan error, 1) - requestID := graphsync.RequestID(rand.Int31()) - blockHookErr := errors.New("Something went wrong") - onNewBlockFn := func(graphsync.BlockData) error { return blockHookErr } - loader := WrapAsyncLoader(ctx, asyncLoadFn, requestID, errChan, onNewBlockFn) - - link := testutil.NewTestLink() - - data := testutil.RandomBytes(100) - responseChan <- types.AsyncLoadResult{Data: data, Err: nil} - stream, err := loader(link, ipld.LinkContext{}) - require.Nil(t, stream, "should return nil reader") - require.EqualError(t, err, blockHookErr.Error()) -} diff --git a/requestmanager/requestmanager.go b/requestmanager/requestmanager.go index f59cb41b..4453e914 100644 --- a/requestmanager/requestmanager.go +++ b/requestmanager/requestmanager.go @@ -2,9 +2,12 @@ package requestmanager import ( "context" + "errors" "fmt" "sync/atomic" + "github.com/ipfs/go-cid" + "github.com/ipfs/go-graphsync/cidset" "github.com/ipfs/go-graphsync/requestmanager/executor" "github.com/ipfs/go-graphsync/requestmanager/hooks" @@ -28,11 +31,14 @@ const ( ) type inProgressRequestStatus struct { - ctx context.Context - cancelFn func() - p peer.ID - networkError chan error - lastResponse atomic.Value + ctx context.Context + cancelFn func() + p peer.ID + networkError chan error + resumeMessages chan []graphsync.ExtensionData + pauseMessages chan struct{} + paused bool + lastResponse atomic.Value } // PeerHandler is an interface that can send requests to peers @@ -181,6 +187,7 @@ func (rm *RequestManager) singleErrorResponse(err error) (chan graphsync.Respons type cancelRequestMessage struct { requestID graphsync.RequestID + isPause bool } func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, @@ -189,7 +196,7 @@ func (rm *RequestManager) cancelRequest(requestID graphsync.RequestID, cancelMessageChannel := rm.messages for cancelMessageChannel != nil || incomingResponses != nil || incomingErrors != nil { select { - case cancelMessageChannel <- &cancelRequestMessage{requestID}: + case cancelMessageChannel <- &cancelRequestMessage{requestID, false}: cancelMessageChannel = nil // clear out any remaining responses, in case and "incoming reponse" // messages get processed before our cancel message @@ -223,6 +230,44 @@ func (rm *RequestManager) ProcessResponses(p peer.ID, responses []gsmsg.GraphSyn } } +type unpauseRequestMessage struct { + id graphsync.RequestID + extensions []graphsync.ExtensionData + response chan error +} + +// UnpauseRequest unpauses a request that was paused in a block hook based request ID +// Can also send extensions with unpause +func (rm *RequestManager) UnpauseRequest(requestID graphsync.RequestID, extensions ...graphsync.ExtensionData) error { + response := make(chan error, 1) + return rm.sendSyncMessage(&unpauseRequestMessage{requestID, extensions, response}, response) +} + +type pauseRequestMessage struct { + id graphsync.RequestID + response chan error +} + +// PauseRequest pauses an in progress request (may take 1 or more blocks to process) +func (rm *RequestManager) PauseRequest(requestID graphsync.RequestID) error { + response := make(chan error, 1) + return rm.sendSyncMessage(&pauseRequestMessage{requestID, response}, response) +} + +func (rm *RequestManager) sendSyncMessage(message requestManagerMessage, response chan error) error { + select { + case <-rm.ctx.Done(): + return errors.New("Context Cancelled") + case rm.messages <- message: + } + select { + case <-rm.ctx.Done(): + return errors.New("Context Cancelled") + case err := <-response: + return err + } +} + // Startup starts processing for the WantManager. func (rm *RequestManager) Startup() { go rm.run() @@ -258,40 +303,58 @@ type terminateRequestMessage struct { requestID graphsync.RequestID } +func (nrm *newRequestMessage) setupRequest(requestID graphsync.RequestID, rm *RequestManager) (chan graphsync.ResponseProgress, chan error) { + request, hooksResult, err := rm.validateRequest(requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions) + if err != nil { + return rm.singleErrorResponse(err) + } + doNotSendCidsData, has := request.Extension(graphsync.ExtensionDoNotSendCIDs) + var doNotSendCids *cid.Set + if has { + doNotSendCids, err = cidset.DecodeCidSet(doNotSendCidsData) + if err != nil { + return rm.singleErrorResponse(err) + } + } else { + doNotSendCids = cid.NewSet() + } + ctx, cancel := context.WithCancel(rm.ctx) + p := nrm.p + resumeMessages := make(chan []graphsync.ExtensionData, 1) + pauseMessages := make(chan struct{}, 1) + networkError := make(chan error, 1) + requestStatus := &inProgressRequestStatus{ + ctx: ctx, cancelFn: cancel, p: p, resumeMessages: resumeMessages, pauseMessages: pauseMessages, networkError: networkError, + } + lastResponse := &requestStatus.lastResponse + lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) + rm.inProgressRequestStatuses[request.ID()] = requestStatus + incoming, incomingError := executor.ExecutionEnv{ + Ctx: rm.ctx, + SendRequest: rm.peerHandler.SendRequest, + TerminateRequest: rm.terminateRequest, + RunBlockHooks: rm.processBlockHooks, + Loader: rm.asyncLoader.AsyncLoad, + }.Start( + executor.RequestExecution{ + Ctx: ctx, + P: p, + Request: request, + NetworkError: networkError, + LastResponse: lastResponse, + DoNotSendCids: doNotSendCids, + NodeStyleChooser: hooksResult.CustomChooser, + ResumeMessages: resumeMessages, + PauseMessages: pauseMessages, + }) + return incoming, incomingError +} + func (nrm *newRequestMessage) handle(rm *RequestManager) { var ipr inProgressRequest ipr.requestID = rm.nextRequestID rm.nextRequestID++ - request, hooksResult, err := rm.validateRequest(ipr.requestID, nrm.p, nrm.root, nrm.selector, nrm.extensions) - if err != nil { - ipr.incoming, ipr.incomingError = rm.singleErrorResponse(err) - } else { - ctx, cancel := context.WithCancel(rm.ctx) - p := nrm.p - requestStatus := &inProgressRequestStatus{ - ctx: ctx, cancelFn: cancel, p: p, - } - lastResponse := &requestStatus.lastResponse - lastResponse.Store(gsmsg.NewResponse(request.ID(), graphsync.RequestAcknowledged)) - rm.inProgressRequestStatuses[request.ID()] = requestStatus - ipr.incoming, ipr.incomingError = executor.RequestExecution{ - Request: request, - SendRequest: func(gsRequest gsmsg.GraphSyncRequest) { rm.peerHandler.SendRequest(p, gsRequest) }, - Loader: rm.asyncLoader.AsyncLoad, - RunBlockHooks: func(bd graphsync.BlockData) error { - response := lastResponse.Load().(gsmsg.GraphSyncResponse) - return rm.processBlockHooks(p, response, bd) - }, - TerminateRequest: func() { - select { - case <-ctx.Done(): - case rm.messages <- &terminateRequestMessage{request.ID()}: - } - }, - NodeStyleChooser: hooksResult.CustomChooser, - }.Start(ctx) - requestStatus.networkError = ipr.incomingError - } + ipr.incoming, ipr.incomingError = nrm.setupRequest(ipr.requestID, rm) select { case nrm.inProgressRequestChan <- ipr: @@ -311,8 +374,11 @@ func (crm *cancelRequestMessage) handle(rm *RequestManager) { } rm.peerHandler.SendRequest(inProgressRequestStatus.p, gsmsg.CancelRequest(crm.requestID)) - delete(rm.inProgressRequestStatuses, crm.requestID) - inProgressRequestStatus.cancelFn() + if crm.isPause { + inProgressRequestStatus.paused = true + } else { + inProgressRequestStatus.cancelFn() + } } func (prm *processResponseMessage) handle(rm *RequestManager) { @@ -386,7 +452,6 @@ func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncRespons requestStatus.cancelFn() } rm.asyncLoader.CompleteResponsesFor(response.RequestID()) - delete(rm.inProgressRequestStatuses, response.RequestID()) } } } @@ -394,13 +459,15 @@ func (rm *RequestManager) processTerminations(responses []gsmsg.GraphSyncRespons func (rm *RequestManager) generateResponseErrorFromStatus(status graphsync.ResponseStatusCode) error { switch status { case graphsync.RequestFailedBusy: - return fmt.Errorf("Request Failed - Peer Is Busy") + return graphsync.RequestFailedBusyErr{} case graphsync.RequestFailedContentNotFound: - return fmt.Errorf("Request Failed - Content Not Found") + return graphsync.RequestFailedContentNotFoundErr{} case graphsync.RequestFailedLegal: - return fmt.Errorf("Request Failed - For Legal Reasons") + return graphsync.RequestFailedLegalErr{} case graphsync.RequestFailedUnknown: - return fmt.Errorf("Request Failed - Unknown Reason") + return graphsync.RequestFailedUnknownErr{} + case graphsync.RequestCancelled: + return graphsync.RequestCancelledErr{} default: return fmt.Errorf("Unknown") } @@ -413,14 +480,22 @@ func (rm *RequestManager) processBlockHooks(p peer.ID, response graphsync.Respon rm.peerHandler.SendRequest(p, updateRequest) } if result.Err != nil { + _, isPause := result.Err.(hooks.ErrPaused) select { case <-rm.ctx.Done(): - case rm.messages <- &cancelRequestMessage{response.RequestID()}: + case rm.messages <- &cancelRequestMessage{response.RequestID(), isPause}: } } return result.Err } +func (rm *RequestManager) terminateRequest(requestID graphsync.RequestID) { + select { + case <-rm.ctx.Done(): + case rm.messages <- &terminateRequestMessage{requestID}: + } +} + func (rm *RequestManager) validateRequest(requestID graphsync.RequestID, p peer.ID, root ipld.Link, selectorSpec ipld.Node, extensions []graphsync.ExtensionData) (gsmsg.GraphSyncRequest, hooks.RequestResult, error) { _, err := ipldutil.EncodeNode(selectorSpec) if err != nil { @@ -442,3 +517,53 @@ func (rm *RequestManager) validateRequest(requestID graphsync.RequestID, p peer. } return request, hooksResult, nil } + +func (urm *unpauseRequestMessage) unpause(rm *RequestManager) error { + inProgressRequestStatus, ok := rm.inProgressRequestStatuses[urm.id] + if !ok { + return errors.New("request not found") + } + if !inProgressRequestStatus.paused { + return errors.New("request is not paused") + } + inProgressRequestStatus.paused = false + select { + case <-inProgressRequestStatus.pauseMessages: + rm.peerHandler.SendRequest(inProgressRequestStatus.p, gsmsg.UpdateRequest(urm.id, urm.extensions...)) + return nil + case <-rm.ctx.Done(): + return errors.New("context cancelled") + case inProgressRequestStatus.resumeMessages <- urm.extensions: + return nil + } +} +func (urm *unpauseRequestMessage) handle(rm *RequestManager) { + err := urm.unpause(rm) + select { + case <-rm.ctx.Done(): + case urm.response <- err: + } +} +func (prm *pauseRequestMessage) pause(rm *RequestManager) error { + inProgressRequestStatus, ok := rm.inProgressRequestStatuses[prm.id] + if !ok { + return errors.New("request not found") + } + if inProgressRequestStatus.paused { + return errors.New("request is already paused") + } + inProgressRequestStatus.paused = true + select { + case <-rm.ctx.Done(): + return errors.New("context cancelled") + case inProgressRequestStatus.pauseMessages <- struct{}{}: + return nil + } +} +func (prm *pauseRequestMessage) handle(rm *RequestManager) { + err := prm.pause(rm) + select { + case <-rm.ctx.Done(): + case prm.response <- err: + } +} diff --git a/requestmanager/requestmanager_test.go b/requestmanager/requestmanager_test.go index 7cbe14e0..471e67ab 100644 --- a/requestmanager/requestmanager_test.go +++ b/requestmanager/requestmanager_test.go @@ -4,10 +4,12 @@ import ( "context" "errors" "fmt" - "sync" "testing" "time" + "github.com/ipfs/go-graphsync/cidset" + "github.com/ipfs/go-graphsync/requestmanager/testloader" + "github.com/ipfs/go-graphsync" "github.com/ipfs/go-graphsync/requestmanager/hooks" "github.com/ipfs/go-graphsync/requestmanager/types" @@ -42,111 +44,6 @@ func (fph *fakePeerHandler) SendRequest(p peer.ID, } } -type requestKey struct { - requestID graphsync.RequestID - link ipld.Link -} - -type storeKey struct { - requestID graphsync.RequestID - storeName string -} - -type fakeAsyncLoader struct { - responseChannelsLk sync.RWMutex - responseChannels map[requestKey]chan types.AsyncLoadResult - responses chan map[graphsync.RequestID]metadata.Metadata - blks chan []blocks.Block - storesRequestedLk sync.RWMutex - storesRequested map[storeKey]struct{} -} - -func newFakeAsyncLoader() *fakeAsyncLoader { - return &fakeAsyncLoader{ - responseChannels: make(map[requestKey]chan types.AsyncLoadResult), - responses: make(chan map[graphsync.RequestID]metadata.Metadata, 1), - blks: make(chan []blocks.Block, 1), - storesRequested: make(map[storeKey]struct{}), - } -} - -func (fal *fakeAsyncLoader) StartRequest(requestID graphsync.RequestID, name string) error { - fal.storesRequestedLk.Lock() - fal.storesRequested[storeKey{requestID, name}] = struct{}{} - fal.storesRequestedLk.Unlock() - return nil -} - -func (fal *fakeAsyncLoader) ProcessResponse(responses map[graphsync.RequestID]metadata.Metadata, - blks []blocks.Block) { - fal.responses <- responses - fal.blks <- blks -} -func (fal *fakeAsyncLoader) verifyLastProcessedBlocks(ctx context.Context, t *testing.T, expectedBlocks []blocks.Block) { - var processedBlocks []blocks.Block - testutil.AssertReceive(ctx, t, fal.blks, &processedBlocks, "did not process blocks") - require.Equal(t, expectedBlocks, processedBlocks, "did not process correct blocks") -} - -func (fal *fakeAsyncLoader) verifyLastProcessedResponses(ctx context.Context, t *testing.T, - expectedResponses map[graphsync.RequestID]metadata.Metadata) { - var responses map[graphsync.RequestID]metadata.Metadata - testutil.AssertReceive(ctx, t, fal.responses, &responses, "did not process responses") - require.Equal(t, expectedResponses, responses, "did not process correct responses") -} - -func (fal *fakeAsyncLoader) verifyNoRemainingData(t *testing.T, requestID graphsync.RequestID) { - fal.responseChannelsLk.Lock() - for key := range fal.responseChannels { - require.NotEqual(t, key.requestID, requestID, "did not clean up request properly") - } - fal.responseChannelsLk.Unlock() -} - -func (fal *fakeAsyncLoader) verifyStoreUsed(t *testing.T, requestID graphsync.RequestID, storeName string) { - fal.storesRequestedLk.RLock() - _, ok := fal.storesRequested[storeKey{requestID, storeName}] - require.True(t, ok, "request should load from correct store") - fal.storesRequestedLk.RUnlock() -} - -func (fal *fakeAsyncLoader) asyncLoad(requestID graphsync.RequestID, link ipld.Link) chan types.AsyncLoadResult { - fal.responseChannelsLk.Lock() - responseChannel, ok := fal.responseChannels[requestKey{requestID, link}] - if !ok { - responseChannel = make(chan types.AsyncLoadResult, 1) - fal.responseChannels[requestKey{requestID, link}] = responseChannel - } - fal.responseChannelsLk.Unlock() - return responseChannel -} - -func (fal *fakeAsyncLoader) AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult { - return fal.asyncLoad(requestID, link) -} -func (fal *fakeAsyncLoader) CompleteResponsesFor(requestID graphsync.RequestID) {} -func (fal *fakeAsyncLoader) CleanupRequest(requestID graphsync.RequestID) { - fal.responseChannelsLk.Lock() - for key := range fal.responseChannels { - if key.requestID == requestID { - delete(fal.responseChannels, key) - } - } - fal.responseChannelsLk.Unlock() -} - -func (fal *fakeAsyncLoader) responseOn(requestID graphsync.RequestID, link ipld.Link, result types.AsyncLoadResult) { - responseChannel := fal.asyncLoad(requestID, link) - responseChannel <- result - close(responseChannel) -} - -func (fal *fakeAsyncLoader) successResponseOn(requestID graphsync.RequestID, blks []blocks.Block) { - for _, block := range blks { - fal.responseOn(requestID, cidlink.Link{Cid: block.Cid()}, types.AsyncLoadResult{Data: block.RawData(), Local: false, Err: nil}) - } -} - func readNNetworkRequests(ctx context.Context, t *testing.T, requestRecordChan <-chan requestRecord, @@ -225,13 +122,13 @@ func TestNormalSimultaneousFetch(t *testing.T) { } td.requestManager.ProcessResponses(peers[0], firstResponses, firstBlocks) - td.fal.verifyLastProcessedBlocks(ctx, t, firstBlocks) - td.fal.verifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ + td.fal.VerifyLastProcessedBlocks(ctx, t, firstBlocks) + td.fal.VerifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ requestRecords[0].gsr.ID(): firstMetadata1, requestRecords[1].gsr.ID(): firstMetadata2, }) - td.fal.successResponseOn(requestRecords[0].gsr.ID(), td.blockChain.AllBlocks()) - td.fal.successResponseOn(requestRecords[1].gsr.ID(), blockChain2.Blocks(0, 3)) + td.fal.SuccessResponseOn(requestRecords[0].gsr.ID(), td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(requestRecords[1].gsr.ID(), blockChain2.Blocks(0, 3)) td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan1) blockChain2.VerifyResponseRange(requestCtx, returnedResponseChan2, 0, 3) @@ -248,12 +145,12 @@ func TestNormalSimultaneousFetch(t *testing.T) { } td.requestManager.ProcessResponses(peers[0], moreResponses, moreBlocks) - td.fal.verifyLastProcessedBlocks(ctx, t, moreBlocks) - td.fal.verifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ + td.fal.VerifyLastProcessedBlocks(ctx, t, moreBlocks) + td.fal.VerifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ requestRecords[1].gsr.ID(): moreMetadata, }) - td.fal.successResponseOn(requestRecords[1].gsr.ID(), moreBlocks) + td.fal.SuccessResponseOn(requestRecords[1].gsr.ID(), moreBlocks) blockChain2.VerifyRemainder(requestCtx, returnedResponseChan2, 3) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan1) @@ -284,8 +181,8 @@ func TestCancelRequestInProgress(t *testing.T) { td.requestManager.ProcessResponses(peers[0], firstResponses, firstBlocks) - td.fal.successResponseOn(requestRecords[0].gsr.ID(), firstBlocks) - td.fal.successResponseOn(requestRecords[1].gsr.ID(), firstBlocks) + td.fal.SuccessResponseOn(requestRecords[0].gsr.ID(), firstBlocks) + td.fal.SuccessResponseOn(requestRecords[1].gsr.ID(), firstBlocks) td.blockChain.VerifyResponseRange(requestCtx1, returnedResponseChan1, 0, 3) cancel1() rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] @@ -300,8 +197,8 @@ func TestCancelRequestInProgress(t *testing.T) { gsmsg.NewResponse(requestRecords[1].gsr.ID(), graphsync.RequestCompletedFull, moreMetadata), } td.requestManager.ProcessResponses(peers[0], moreResponses, moreBlocks) - td.fal.successResponseOn(requestRecords[0].gsr.ID(), moreBlocks) - td.fal.successResponseOn(requestRecords[1].gsr.ID(), moreBlocks) + td.fal.SuccessResponseOn(requestRecords[0].gsr.ID(), moreBlocks) + td.fal.SuccessResponseOn(requestRecords[1].gsr.ID(), moreBlocks) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan1) td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan2) @@ -327,7 +224,7 @@ func TestCancelManagerExitsGracefully(t *testing.T) { gsmsg.NewResponse(rr.gsr.ID(), graphsync.PartialResponse, firstMetadata), } td.requestManager.ProcessResponses(peers[0], firstResponses, firstBlocks) - td.fal.successResponseOn(rr.gsr.ID(), firstBlocks) + td.fal.SuccessResponseOn(rr.gsr.ID(), firstBlocks) td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, 3) managerCancel() @@ -337,7 +234,7 @@ func TestCancelManagerExitsGracefully(t *testing.T) { gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestCompletedFull, moreMetadata), } td.requestManager.ProcessResponses(peers[0], moreResponses, moreBlocks) - td.fal.successResponseOn(rr.gsr.ID(), moreBlocks) + td.fal.SuccessResponseOn(rr.gsr.ID(), moreBlocks) testutil.VerifyEmptyResponse(requestCtx, t, returnedResponseChan) testutil.VerifyEmptyErrors(requestCtx, t, returnedErrorChan) } @@ -374,7 +271,7 @@ func TestLocallyFulfilledFirstRequestFailsLater(t *testing.T) { rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] // async loaded response responds immediately - td.fal.successResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan) @@ -401,7 +298,7 @@ func TestLocallyFulfilledFirstRequestSucceedsLater(t *testing.T) { rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] // async loaded response responds immediately - td.fal.successResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan) @@ -411,7 +308,7 @@ func TestLocallyFulfilledFirstRequestSucceedsLater(t *testing.T) { } td.requestManager.ProcessResponses(peers[0], firstResponses, td.blockChain.AllBlocks()) - td.fal.verifyNoRemainingData(t, rr.gsr.ID()) + td.fal.VerifyNoRemainingData(t, rr.gsr.ID()) testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) } @@ -433,7 +330,7 @@ func TestRequestReturnsMissingBlocks(t *testing.T) { } td.requestManager.ProcessResponses(peers[0], firstResponses, nil) for _, block := range td.blockChain.AllBlocks() { - td.fal.responseOn(rr.gsr.ID(), cidlink.Link{Cid: block.Cid()}, types.AsyncLoadResult{Data: nil, Err: fmt.Errorf("Terrible Thing")}) + td.fal.ResponseOn(rr.gsr.ID(), cidlink.Link{Cid: block.Cid()}, types.AsyncLoadResult{Data: nil, Err: fmt.Errorf("Terrible Thing")}) } testutil.VerifyEmptyResponse(ctx, t, returnedResponseChan) errs := testutil.CollectErrors(ctx, t, returnedErrorChan) @@ -448,24 +345,11 @@ func TestEncodingExtensions(t *testing.T) { defer cancel() peers := testutil.GeneratePeers(1) - extensionData1 := testutil.RandomBytes(100) - extensionName1 := graphsync.ExtensionName("AppleSauce/McGee") - extension1 := graphsync.ExtensionData{ - Name: extensionName1, - Data: extensionData1, - } - extensionData2 := testutil.RandomBytes(100) - extensionName2 := graphsync.ExtensionName("HappyLand/Happenstance") - extension2 := graphsync.ExtensionData{ - Name: extensionName2, - Data: extensionData2, - } - expectedError := make(chan error, 2) receivedExtensionData := make(chan []byte, 2) expectedUpdateChan := make(chan []graphsync.ExtensionData, 2) hook := func(p peer.ID, responseData graphsync.ResponseData, hookActions graphsync.IncomingResponseHookActions) { - data, has := responseData.Extension(extensionName1) + data, has := responseData.Extension(td.extensionName1) require.True(t, has, "did not receive extension data in response") receivedExtensionData <- data err := <-expectedError @@ -478,18 +362,18 @@ func TestEncodingExtensions(t *testing.T) { } } td.responseHooks.Register(hook) - returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), extension1, extension2) + returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), td.extension1, td.extension2) rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] gsr := rr.gsr - returnedData1, found := gsr.Extension(extensionName1) + returnedData1, found := gsr.Extension(td.extensionName1) require.True(t, found) - require.Equal(t, extensionData1, returnedData1, "did not encode first extension correctly") + require.Equal(t, td.extensionData1, returnedData1, "did not encode first extension correctly") - returnedData2, found := gsr.Extension(extensionName2) + returnedData2, found := gsr.Extension(td.extensionName2) require.True(t, found) - require.Equal(t, extensionData2, returnedData2, "did not encode second extension correctly") + require.Equal(t, td.extensionData2, returnedData2, "did not encode second extension correctly") t.Run("responding to extensions", func(t *testing.T) { expectedData := testutil.RandomBytes(100) @@ -501,7 +385,7 @@ func TestEncodingExtensions(t *testing.T) { Data: nil, }, graphsync.ExtensionData{ - Name: extensionName1, + Name: td.extensionName1, Data: expectedData, }, ), @@ -509,7 +393,7 @@ func TestEncodingExtensions(t *testing.T) { expectedError <- nil expectedUpdateChan <- []graphsync.ExtensionData{ { - Name: extensionName1, + Name: td.extensionName1, Data: expectedUpdate, }, } @@ -519,7 +403,7 @@ func TestEncodingExtensions(t *testing.T) { require.Equal(t, expectedData, received, "did not receive correct extension data from resposne") rr = readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - receivedUpdateData, has := rr.gsr.Extension(extensionName1) + receivedUpdateData, has := rr.gsr.Extension(td.extensionName1) require.True(t, has) require.Equal(t, expectedUpdate, receivedUpdateData, "should have updated with correct extension") @@ -534,7 +418,7 @@ func TestEncodingExtensions(t *testing.T) { Data: nil, }, graphsync.ExtensionData{ - Name: extensionName1, + Name: td.extensionName1, Data: nextExpectedData, }, ), @@ -542,11 +426,11 @@ func TestEncodingExtensions(t *testing.T) { expectedError <- errors.New("a terrible thing happened") expectedUpdateChan <- []graphsync.ExtensionData{ { - Name: extensionName1, + Name: td.extensionName1, Data: nextExpectedUpdate1, }, { - Name: extensionName2, + Name: td.extensionName2, Data: nextExpectedUpdate2, }, } @@ -555,10 +439,10 @@ func TestEncodingExtensions(t *testing.T) { require.Equal(t, nextExpectedData, received, "did not receive correct extension data from resposne") rr = readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - receivedUpdateData, has = rr.gsr.Extension(extensionName1) + receivedUpdateData, has = rr.gsr.Extension(td.extensionName1) require.True(t, has) require.Equal(t, nextExpectedUpdate1, receivedUpdateData, "should have updated with correct extension") - receivedUpdateData, has = rr.gsr.Extension(extensionName2) + receivedUpdateData, has = rr.gsr.Extension(td.extensionName2) require.True(t, has) require.Equal(t, nextExpectedUpdate2, receivedUpdateData, "should have updated with correct extension") @@ -575,19 +459,6 @@ func TestBlockHooks(t *testing.T) { defer cancel() peers := testutil.GeneratePeers(1) - extensionData1 := testutil.RandomBytes(100) - extensionName1 := graphsync.ExtensionName("AppleSauce/McGee") - extension1 := graphsync.ExtensionData{ - Name: extensionName1, - Data: extensionData1, - } - extensionData2 := testutil.RandomBytes(100) - extensionName2 := graphsync.ExtensionName("HappyLand/Happenstance") - extension2 := graphsync.ExtensionData{ - Name: extensionName2, - Data: extensionData2, - } - receivedBlocks := make(chan graphsync.BlockData, 4) receivedResponses := make(chan graphsync.ResponseData, 4) expectedError := make(chan error, 4) @@ -605,18 +476,18 @@ func TestBlockHooks(t *testing.T) { } } td.blockHooks.Register(hook) - returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), extension1, extension2) + returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), td.extension1, td.extension2) rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] gsr := rr.gsr - returnedData1, found := gsr.Extension(extensionName1) + returnedData1, found := gsr.Extension(td.extensionName1) require.True(t, found) - require.Equal(t, extensionData1, returnedData1, "did not encode first extension correctly") + require.Equal(t, td.extensionData1, returnedData1, "did not encode first extension correctly") - returnedData2, found := gsr.Extension(extensionName2) + returnedData2, found := gsr.Extension(td.extensionName2) require.True(t, found) - require.Equal(t, extensionData2, returnedData2, "did not encode second extension correctly") + require.Equal(t, td.extensionData2, returnedData2, "did not encode second extension correctly") t.Run("responding to extensions", func(t *testing.T) { expectedData := testutil.RandomBytes(100) @@ -633,7 +504,7 @@ func TestBlockHooks(t *testing.T) { Data: firstMetadataEncoded, }, graphsync.ExtensionData{ - Name: extensionName1, + Name: td.extensionName1, Data: expectedData, }, ), @@ -644,7 +515,7 @@ func TestBlockHooks(t *testing.T) { if i == len(firstBlocks)-1 { update = []graphsync.ExtensionData{ { - Name: extensionName1, + Name: td.extensionName1, Data: expectedUpdate, }, } @@ -653,14 +524,14 @@ func TestBlockHooks(t *testing.T) { } td.requestManager.ProcessResponses(peers[0], firstResponses, firstBlocks) - td.fal.verifyLastProcessedBlocks(ctx, t, firstBlocks) - td.fal.verifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ + td.fal.VerifyLastProcessedBlocks(ctx, t, firstBlocks) + td.fal.VerifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ rr.gsr.ID(): firstMetadata, }) - td.fal.successResponseOn(rr.gsr.ID(), firstBlocks) + td.fal.SuccessResponseOn(rr.gsr.ID(), firstBlocks) ur := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - receivedUpdateData, has := ur.gsr.Extension(extensionName1) + receivedUpdateData, has := ur.gsr.Extension(td.extensionName1) require.True(t, has) require.Equal(t, expectedUpdate, receivedUpdateData, "should have updated with correct extension") @@ -672,7 +543,7 @@ func TestBlockHooks(t *testing.T) { metadata, has := receivedResponse.Extension(graphsync.ExtensionMetadata) require.True(t, has) require.Equal(t, firstMetadataEncoded, metadata, "should receive correct metadata") - receivedExtensionData, _ := receivedResponse.Extension(extensionName1) + receivedExtensionData, _ := receivedResponse.Extension(td.extensionName1) require.Equal(t, expectedData, receivedExtensionData, "should receive correct response extension data") var receivedBlock graphsync.BlockData testutil.AssertReceive(ctx, t, receivedBlocks, &receivedBlock, "did not receive block data") @@ -694,7 +565,7 @@ func TestBlockHooks(t *testing.T) { Data: nextMetadataEncoded, }, graphsync.ExtensionData{ - Name: extensionName1, + Name: td.extensionName1, Data: nextExpectedData, }, ), @@ -705,11 +576,11 @@ func TestBlockHooks(t *testing.T) { if i == len(nextBlocks)-1 { update = []graphsync.ExtensionData{ { - Name: extensionName1, + Name: td.extensionName1, Data: nextExpectedUpdate1, }, { - Name: extensionName2, + Name: td.extensionName2, Data: nextExpectedUpdate2, }, } @@ -717,17 +588,17 @@ func TestBlockHooks(t *testing.T) { expectedUpdateChan <- update } td.requestManager.ProcessResponses(peers[0], secondResponses, nextBlocks) - td.fal.verifyLastProcessedBlocks(ctx, t, nextBlocks) - td.fal.verifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ + td.fal.VerifyLastProcessedBlocks(ctx, t, nextBlocks) + td.fal.VerifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ rr.gsr.ID(): nextMetadata, }) - td.fal.successResponseOn(rr.gsr.ID(), nextBlocks) + td.fal.SuccessResponseOn(rr.gsr.ID(), nextBlocks) ur = readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] - receivedUpdateData, has = ur.gsr.Extension(extensionName1) + receivedUpdateData, has = ur.gsr.Extension(td.extensionName1) require.True(t, has) require.Equal(t, nextExpectedUpdate1, receivedUpdateData, "should have updated with correct extension") - receivedUpdateData, has = ur.gsr.Extension(extensionName2) + receivedUpdateData, has = ur.gsr.Extension(td.extensionName2) require.True(t, has) require.Equal(t, nextExpectedUpdate2, receivedUpdateData, "should have updated with correct extension") @@ -739,7 +610,7 @@ func TestBlockHooks(t *testing.T) { metadata, has := receivedResponse.Extension(graphsync.ExtensionMetadata) require.True(t, has) require.Equal(t, nextMetadataEncoded, metadata, "should receive correct metadata") - receivedExtensionData, _ := receivedResponse.Extension(extensionName1) + receivedExtensionData, _ := receivedResponse.Extension(td.extensionName1) require.Equal(t, nextExpectedData, receivedExtensionData, "should receive correct response extension data") var receivedBlock graphsync.BlockData testutil.AssertReceive(ctx, t, receivedBlocks, &receivedBlock, "did not receive block data") @@ -760,14 +631,8 @@ func TestOutgoingRequestHooks(t *testing.T) { defer cancel() peers := testutil.GeneratePeers(1) - extensionName1 := graphsync.ExtensionName("blockchain") - extension1 := graphsync.ExtensionData{ - Name: extensionName1, - Data: nil, - } - hook := func(p peer.ID, r graphsync.RequestData, ha graphsync.OutgoingRequestHookActions) { - _, has := r.Extension(extensionName1) + _, has := r.Extension(td.extensionName1) if has { ha.UseLinkTargetNodeStyleChooser(td.blockChain.Chooser) ha.UsePersistenceOption("chainstore") @@ -775,7 +640,7 @@ func TestOutgoingRequestHooks(t *testing.T) { } td.requestHooks.Register(hook) - returnedResponseChan1, returnedErrorChan1 := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), extension1) + returnedResponseChan1, returnedErrorChan1 := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector(), td.extension1) returnedResponseChan2, returnedErrorChan2 := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) requestRecords := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 2) @@ -792,26 +657,196 @@ func TestOutgoingRequestHooks(t *testing.T) { gsmsg.NewResponse(requestRecords[1].gsr.ID(), graphsync.RequestCompletedFull, mdExt), } td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.AllBlocks()) - td.fal.verifyLastProcessedBlocks(ctx, t, td.blockChain.AllBlocks()) - td.fal.verifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ + td.fal.VerifyLastProcessedBlocks(ctx, t, td.blockChain.AllBlocks()) + td.fal.VerifyLastProcessedResponses(ctx, t, map[graphsync.RequestID]metadata.Metadata{ requestRecords[0].gsr.ID(): md, requestRecords[1].gsr.ID(): md, }) - td.fal.successResponseOn(requestRecords[0].gsr.ID(), td.blockChain.AllBlocks()) - td.fal.successResponseOn(requestRecords[1].gsr.ID(), td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(requestRecords[0].gsr.ID(), td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(requestRecords[1].gsr.ID(), td.blockChain.AllBlocks()) td.blockChain.VerifyWholeChainWithTypes(requestCtx, returnedResponseChan1) td.blockChain.VerifyWholeChain(requestCtx, returnedResponseChan2) testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan1) testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan2) - td.fal.verifyStoreUsed(t, requestRecords[0].gsr.ID(), "chainstore") - td.fal.verifyStoreUsed(t, requestRecords[1].gsr.ID(), "") + td.fal.VerifyStoreUsed(t, requestRecords[0].gsr.ID(), "chainstore") + td.fal.VerifyStoreUsed(t, requestRecords[1].gsr.ID(), "") +} + +func TestPauseResume(t *testing.T) { + ctx := context.Background() + td := newTestData(ctx, t) + + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + peers := testutil.GeneratePeers(1) + + blocksReceived := 0 + holdForResumeAttempt := make(chan struct{}) + holdForPause := make(chan struct{}) + pauseAt := 3 + + // setup hook to pause at 3rd block (and wait on second block for resume while unpaused test) + hook := func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { + blocksReceived++ + if blocksReceived == pauseAt-1 { + <-holdForResumeAttempt + } + if blocksReceived == pauseAt { + hookActions.PauseRequest() + close(holdForPause) + } + } + td.blockHooks.Register(hook) + + // Start request + returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) + + rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + + // Start processing responses + md := metadataForBlocks(td.blockChain.AllBlocks(), true) + mdEncoded, err := metadata.EncodeMetadata(md) + require.NoError(t, err) + responses := []gsmsg.GraphSyncResponse{ + gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestCompletedFull, graphsync.ExtensionData{ + Name: graphsync.ExtensionMetadata, + Data: mdEncoded, + }), + } + td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + + // attempt to unpause while request is not paused (note: hook on second block will keep it from + // reaching pause point) + err = td.requestManager.UnpauseRequest(rr.gsr.ID()) + require.EqualError(t, err, "request is not paused") + close(holdForResumeAttempt) + // verify responses sent read ONLY for blocks BEFORE the pause + td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt-1) + // wait for the pause to occur + <-holdForPause + + // read the outgoing cancel request + pauseCancel := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + require.True(t, pauseCancel.gsr.IsCancel()) + + // verify no further responses come through + time.Sleep(100 * time.Millisecond) + testutil.AssertChannelEmpty(t, returnedResponseChan, "no response should be sent request is paused") + td.fal.CleanupRequest(rr.gsr.ID()) + + // unpause + err = td.requestManager.UnpauseRequest(rr.gsr.ID(), td.extension1, td.extension2) + require.NoError(t, err) + + // verify the correct new request with Do-no-send-cids & other extensions + resumedRequest := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + doNotSendCidsData, has := resumedRequest.gsr.Extension(graphsync.ExtensionDoNotSendCIDs) + doNotSendCids, err := cidset.DecodeCidSet(doNotSendCidsData) + require.NoError(t, err) + require.Equal(t, pauseAt, doNotSendCids.Len()) + require.True(t, has) + ext1Data, has := resumedRequest.gsr.Extension(td.extensionName1) + require.True(t, has) + require.Equal(t, td.extensionData1, ext1Data) + ext2Data, has := resumedRequest.gsr.Extension(td.extensionName2) + require.True(t, has) + require.Equal(t, td.extensionData2, ext2Data) + + // process responses + td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.RemainderBlocks(pauseAt)) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + + // verify the correct results are returned, picking up after where there request was paused + td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt-1) + testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) +} +func TestPauseResumeExternal(t *testing.T) { + ctx := context.Background() + td := newTestData(ctx, t) + + requestCtx, cancel := context.WithCancel(ctx) + defer cancel() + peers := testutil.GeneratePeers(1) + + blocksReceived := 0 + holdForPause := make(chan struct{}) + pauseAt := 3 + + // setup hook to pause at 3rd block (and wait on second block for resume while unpaused test) + hook := func(p peer.ID, responseData graphsync.ResponseData, blockData graphsync.BlockData, hookActions graphsync.IncomingBlockHookActions) { + blocksReceived++ + if blocksReceived == pauseAt { + err := td.requestManager.PauseRequest(responseData.RequestID()) + require.NoError(t, err) + close(holdForPause) + } + } + td.blockHooks.Register(hook) + + // Start request + returnedResponseChan, returnedErrorChan := td.requestManager.SendRequest(requestCtx, peers[0], td.blockChain.TipLink, td.blockChain.Selector()) + + rr := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + + // Start processing responses + md := metadataForBlocks(td.blockChain.AllBlocks(), true) + mdEncoded, err := metadata.EncodeMetadata(md) + require.NoError(t, err) + responses := []gsmsg.GraphSyncResponse{ + gsmsg.NewResponse(rr.gsr.ID(), graphsync.RequestCompletedFull, graphsync.ExtensionData{ + Name: graphsync.ExtensionMetadata, + Data: mdEncoded, + }), + } + td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.AllBlocks()) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + // verify responses sent read ONLY for blocks BEFORE the pause + td.blockChain.VerifyResponseRange(ctx, returnedResponseChan, 0, pauseAt-1) + // wait for the pause to occur + <-holdForPause + + // read the outgoing cancel request + pauseCancel := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + require.True(t, pauseCancel.gsr.IsCancel()) + + // verify no further responses come through + time.Sleep(100 * time.Millisecond) + testutil.AssertChannelEmpty(t, returnedResponseChan, "no response should be sent request is paused") + td.fal.CleanupRequest(rr.gsr.ID()) + + // unpause + err = td.requestManager.UnpauseRequest(rr.gsr.ID(), td.extension1, td.extension2) + require.NoError(t, err) + + // verify the correct new request with Do-no-send-cids & other extensions + resumedRequest := readNNetworkRequests(requestCtx, t, td.requestRecordChan, 1)[0] + doNotSendCidsData, has := resumedRequest.gsr.Extension(graphsync.ExtensionDoNotSendCIDs) + doNotSendCids, err := cidset.DecodeCidSet(doNotSendCidsData) + require.NoError(t, err) + require.Equal(t, pauseAt, doNotSendCids.Len()) + require.True(t, has) + ext1Data, has := resumedRequest.gsr.Extension(td.extensionName1) + require.True(t, has) + require.Equal(t, td.extensionData1, ext1Data) + ext2Data, has := resumedRequest.gsr.Extension(td.extensionName2) + require.True(t, has) + require.Equal(t, td.extensionData2, ext2Data) + + // process responses + td.requestManager.ProcessResponses(peers[0], responses, td.blockChain.RemainderBlocks(pauseAt)) + td.fal.SuccessResponseOn(rr.gsr.ID(), td.blockChain.AllBlocks()) + + // verify the correct results are returned, picking up after where there request was paused + td.blockChain.VerifyRemainder(ctx, returnedResponseChan, pauseAt-1) + testutil.VerifyEmptyErrors(ctx, t, returnedErrorChan) } type testData struct { requestRecordChan chan requestRecord fph *fakePeerHandler - fal *fakeAsyncLoader + fal *testloader.FakeAsyncLoader requestHooks *hooks.OutgoingRequestHooks responseHooks *hooks.IncomingResponseHooks blockHooks *hooks.IncomingBlockHooks @@ -820,13 +855,19 @@ type testData struct { loader ipld.Loader storer ipld.Storer blockChain *testutil.TestBlockChain + extensionName1 graphsync.ExtensionName + extensionData1 []byte + extension1 graphsync.ExtensionData + extensionName2 graphsync.ExtensionName + extensionData2 []byte + extension2 graphsync.ExtensionData } func newTestData(ctx context.Context, t *testing.T) *testData { td := &testData{} td.requestRecordChan = make(chan requestRecord, 3) td.fph = &fakePeerHandler{td.requestRecordChan} - td.fal = newFakeAsyncLoader() + td.fal = testloader.NewFakeAsyncLoader() td.requestHooks = hooks.NewRequestHooks() td.responseHooks = hooks.NewResponseHooks() td.blockHooks = hooks.NewBlockHooks() @@ -836,5 +877,17 @@ func newTestData(ctx context.Context, t *testing.T) *testData { td.blockStore = make(map[ipld.Link][]byte) td.loader, td.storer = testutil.NewTestStore(td.blockStore) td.blockChain = testutil.SetupBlockChain(ctx, t, td.loader, td.storer, 100, 5) + td.extensionData1 = testutil.RandomBytes(100) + td.extensionName1 = graphsync.ExtensionName("AppleSauce/McGee") + td.extension1 = graphsync.ExtensionData{ + Name: td.extensionName1, + Data: td.extensionData1, + } + td.extensionData2 = testutil.RandomBytes(100) + td.extensionName2 = graphsync.ExtensionName("HappyLand/Happenstance") + td.extension2 = graphsync.ExtensionData{ + Name: td.extensionName2, + Data: td.extensionData2, + } return td } diff --git a/requestmanager/testloader/asyncloader.go b/requestmanager/testloader/asyncloader.go new file mode 100644 index 00000000..2a068e83 --- /dev/null +++ b/requestmanager/testloader/asyncloader.go @@ -0,0 +1,157 @@ +package testloader + +import ( + "context" + "sync" + "testing" + + blocks "github.com/ipfs/go-block-format" + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/metadata" + "github.com/ipfs/go-graphsync/requestmanager/types" + "github.com/ipfs/go-graphsync/testutil" + "github.com/ipld/go-ipld-prime" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/stretchr/testify/require" +) + +type requestKey struct { + requestID graphsync.RequestID + link ipld.Link +} + +type storeKey struct { + requestID graphsync.RequestID + storeName string +} + +// FakeAsyncLoader simultates the requestmanager.AsyncLoader interface +// with mocked responses and can also be used to simulate a +// executor.AsycLoadFn -- all responses are stubbed and no actual processing is +// done +type FakeAsyncLoader struct { + responseChannelsLk sync.RWMutex + responseChannels map[requestKey]chan types.AsyncLoadResult + responses chan map[graphsync.RequestID]metadata.Metadata + blks chan []blocks.Block + storesRequestedLk sync.RWMutex + storesRequested map[storeKey]struct{} + cb func(graphsync.RequestID, ipld.Link, <-chan types.AsyncLoadResult) +} + +// NewFakeAsyncLoader returns a new FakeAsyncLoader instance +func NewFakeAsyncLoader() *FakeAsyncLoader { + return &FakeAsyncLoader{ + responseChannels: make(map[requestKey]chan types.AsyncLoadResult), + responses: make(chan map[graphsync.RequestID]metadata.Metadata, 1), + blks: make(chan []blocks.Block, 1), + storesRequested: make(map[storeKey]struct{}), + } +} + +// StartRequest just requests what store was requested for a given requestID +func (fal *FakeAsyncLoader) StartRequest(requestID graphsync.RequestID, name string) error { + fal.storesRequestedLk.Lock() + fal.storesRequested[storeKey{requestID, name}] = struct{}{} + fal.storesRequestedLk.Unlock() + return nil +} + +// ProcessResponse just records values passed to verify expectations later +func (fal *FakeAsyncLoader) ProcessResponse(responses map[graphsync.RequestID]metadata.Metadata, + blks []blocks.Block) { + fal.responses <- responses + fal.blks <- blks +} + +// VerifyLastProcessedBlocks verifies the blocks passed to the last call to ProcessResponse +// match the expected ones +func (fal *FakeAsyncLoader) VerifyLastProcessedBlocks(ctx context.Context, t *testing.T, expectedBlocks []blocks.Block) { + var processedBlocks []blocks.Block + testutil.AssertReceive(ctx, t, fal.blks, &processedBlocks, "did not process blocks") + require.Equal(t, expectedBlocks, processedBlocks, "did not process correct blocks") +} + +// VerifyLastProcessedResponses verifies the responses passed to the last call to ProcessResponse +// match the expected ones +func (fal *FakeAsyncLoader) VerifyLastProcessedResponses(ctx context.Context, t *testing.T, + expectedResponses map[graphsync.RequestID]metadata.Metadata) { + var responses map[graphsync.RequestID]metadata.Metadata + testutil.AssertReceive(ctx, t, fal.responses, &responses, "did not process responses") + require.Equal(t, expectedResponses, responses, "did not process correct responses") +} + +// VerifyNoRemainingData verifies no outstanding response channels are open for the given +// RequestID (CleanupRequest was called last) +func (fal *FakeAsyncLoader) VerifyNoRemainingData(t *testing.T, requestID graphsync.RequestID) { + fal.responseChannelsLk.RLock() + for key := range fal.responseChannels { + require.NotEqual(t, key.requestID, requestID, "did not clean up request properly") + } + fal.responseChannelsLk.RUnlock() +} + +// VerifyStoreUsed verifies the given store was used for the given request +func (fal *FakeAsyncLoader) VerifyStoreUsed(t *testing.T, requestID graphsync.RequestID, storeName string) { + fal.storesRequestedLk.RLock() + _, ok := fal.storesRequested[storeKey{requestID, storeName}] + require.True(t, ok, "request should load from correct store") + fal.storesRequestedLk.RUnlock() +} + +func (fal *FakeAsyncLoader) asyncLoad(requestID graphsync.RequestID, link ipld.Link) chan types.AsyncLoadResult { + fal.responseChannelsLk.Lock() + responseChannel, ok := fal.responseChannels[requestKey{requestID, link}] + if !ok { + responseChannel = make(chan types.AsyncLoadResult, 1) + fal.responseChannels[requestKey{requestID, link}] = responseChannel + } + fal.responseChannelsLk.Unlock() + return responseChannel +} + +// OnAsyncLoad allows you to listen for load requests to the loader and perform other actions or tests +func (fal *FakeAsyncLoader) OnAsyncLoad(cb func(graphsync.RequestID, ipld.Link, <-chan types.AsyncLoadResult)) { + fal.cb = cb +} + +// AsyncLoad simulates an asynchronous load with responses stubbed by ResponseOn & SuccessResponseOn +func (fal *FakeAsyncLoader) AsyncLoad(requestID graphsync.RequestID, link ipld.Link) <-chan types.AsyncLoadResult { + res := fal.asyncLoad(requestID, link) + if fal.cb != nil { + fal.cb(requestID, link, res) + } + return res +} + +// CompleteResponsesFor in the case of the test loader does nothing +func (fal *FakeAsyncLoader) CompleteResponsesFor(requestID graphsync.RequestID) {} + +// CleanupRequest simulates the effect of cleaning up the request by removing any response channels +// for the request +func (fal *FakeAsyncLoader) CleanupRequest(requestID graphsync.RequestID) { + fal.responseChannelsLk.Lock() + for key := range fal.responseChannels { + if key.requestID == requestID { + delete(fal.responseChannels, key) + } + } + fal.responseChannelsLk.Unlock() +} + +// ResponseOn sets the value returned when the given link is loaded for the given request. Because it's an +// "asynchronous" load, this can be called AFTER the attempt to load this link -- and the client will only get +// the response at that point +func (fal *FakeAsyncLoader) ResponseOn(requestID graphsync.RequestID, link ipld.Link, result types.AsyncLoadResult) { + responseChannel := fal.asyncLoad(requestID, link) + responseChannel <- result + close(responseChannel) +} + +// SuccessResponseOn is convenience function for setting several asynchronous responses at once as all successes +// and returning the given blocks +func (fal *FakeAsyncLoader) SuccessResponseOn(requestID graphsync.RequestID, blks []blocks.Block) { + for _, block := range blks { + fal.ResponseOn(requestID, cidlink.Link{Cid: block.Cid()}, types.AsyncLoadResult{Data: block.RawData(), Local: false, Err: nil}) + } +} diff --git a/responsemanager/hooks/blockhooks.go b/responsemanager/hooks/blockhooks.go index 45d55dc5..62a0352b 100644 --- a/responsemanager/hooks/blockhooks.go +++ b/responsemanager/hooks/blockhooks.go @@ -1,15 +1,15 @@ package hooks import ( - "errors" - "github.com/hannahhoward/go-pubsub" "github.com/ipfs/go-graphsync" peer "github.com/libp2p/go-libp2p-core/peer" ) // ErrPaused indicates a request should stop processing, but only cause it's paused -var ErrPaused = errors.New("request has been paused") +type ErrPaused struct{} + +func (e ErrPaused) Error() string { return "request has been paused" } // OutgoingBlockHooks is a set of outgoing block hooks that can be processed type OutgoingBlockHooks struct { @@ -71,5 +71,5 @@ func (bha *blockHookActions) TerminateWithError(err error) { } func (bha *blockHookActions) PauseResponse() { - bha.err = ErrPaused + bha.err = ErrPaused{} } diff --git a/responsemanager/hooks/hooks_test.go b/responsemanager/hooks/hooks_test.go index e4f0cbbf..273a4f45 100644 --- a/responsemanager/hooks/hooks_test.go +++ b/responsemanager/hooks/hooks_test.go @@ -170,6 +170,19 @@ func TestRequestHookProcessing(t *testing.T) { }, }, "hooks alter the node builder chooser": { + configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { + requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.PauseResponse() + hookActions.ValidateRequest() + }) + }, + assert: func(t *testing.T, result hooks.RequestResult) { + require.True(t, result.IsValidated) + require.True(t, result.IsPaused) + require.NoError(t, result.Err) + }, + }, + "hooks start request paused": { configure: func(t *testing.T, requestHooks *hooks.IncomingRequestHooks) { requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { if _, found := requestData.Extension(extensionName); found { @@ -262,7 +275,7 @@ func TestBlockHookProcessing(t *testing.T) { }, assert: func(t *testing.T, result hooks.BlockResult) { require.Empty(t, result.Extensions) - require.EqualError(t, result.Err, hooks.ErrPaused.Error()) + require.EqualError(t, result.Err, hooks.ErrPaused{}.Error()) }, }, } diff --git a/responsemanager/hooks/completedlisteners.go b/responsemanager/hooks/listeners.go similarity index 53% rename from responsemanager/hooks/completedlisteners.go rename to responsemanager/hooks/listeners.go index 5326704b..72cf58b3 100644 --- a/responsemanager/hooks/completedlisteners.go +++ b/responsemanager/hooks/listeners.go @@ -38,3 +38,35 @@ func (crl *CompletedResponseListeners) Register(listener graphsync.OnResponseCom func (crl *CompletedResponseListeners) NotifyCompletedListeners(p peer.ID, request graphsync.RequestData, status graphsync.ResponseStatusCode) { _ = crl.pubSub.Publish(internalCompletedResponseEvent{p, request, status}) } + +// RequestorCancelledListeners is a set of listeners for when requestors cancel +type RequestorCancelledListeners struct { + pubSub *pubsub.PubSub +} + +type internalRequestorCancelledEvent struct { + p peer.ID + request graphsync.RequestData +} + +func requestorCancelledDispatcher(event pubsub.Event, subscriberFn pubsub.SubscriberFn) error { + ie := event.(internalRequestorCancelledEvent) + listener := subscriberFn.(graphsync.OnRequestorCancelledListener) + listener(ie.p, ie.request) + return nil +} + +// NewRequestorCancelledListeners returns a new list of listeners for when requestors cancel +func NewRequestorCancelledListeners() *RequestorCancelledListeners { + return &RequestorCancelledListeners{pubSub: pubsub.New(requestorCancelledDispatcher)} +} + +// Register registers an listener for completed responses +func (rcl *RequestorCancelledListeners) Register(listener graphsync.OnRequestorCancelledListener) graphsync.UnregisterHookFunc { + return graphsync.UnregisterHookFunc(rcl.pubSub.Subscribe(listener)) +} + +// NotifyCancelledListeners notifies all listeners that a requestor cancelled a response +func (rcl *RequestorCancelledListeners) NotifyCancelledListeners(p peer.ID, request graphsync.RequestData) { + _ = rcl.pubSub.Publish(internalRequestorCancelledEvent{p, request}) +} diff --git a/responsemanager/hooks/requesthook.go b/responsemanager/hooks/requesthook.go index ba52eeb3..74273936 100644 --- a/responsemanager/hooks/requesthook.go +++ b/responsemanager/hooks/requesthook.go @@ -50,6 +50,7 @@ func (irh *IncomingRequestHooks) Register(hook graphsync.OnIncomingRequestHook) // RequestResult is the outcome of running requesthooks type RequestResult struct { IsValidated bool + IsPaused bool CustomLoader ipld.Loader CustomChooser traversal.LinkTargetNodeStyleChooser Err error @@ -68,6 +69,7 @@ func (irh *IncomingRequestHooks) ProcessRequestHooks(p peer.ID, request graphsyn type requestHookActions struct { persistenceOptions PersistenceOptions isValidated bool + isPaused bool err error loader ipld.Loader chooser traversal.LinkTargetNodeStyleChooser @@ -77,6 +79,7 @@ type requestHookActions struct { func (ha *requestHookActions) result() RequestResult { return RequestResult{ IsValidated: ha.isValidated, + IsPaused: ha.isPaused, CustomLoader: ha.loader, CustomChooser: ha.chooser, Err: ha.err, @@ -108,3 +111,7 @@ func (ha *requestHookActions) UsePersistenceOption(name string) { func (ha *requestHookActions) UseLinkTargetNodeStyleChooser(chooser traversal.LinkTargetNodeStyleChooser) { ha.chooser = chooser } + +func (ha *requestHookActions) PauseResponse() { + ha.isPaused = true +} diff --git a/responsemanager/peerresponsemanager/peerresponsesender.go b/responsemanager/peerresponsemanager/peerresponsesender.go index 39a438ee..8831434f 100644 --- a/responsemanager/peerresponsemanager/peerresponsesender.go +++ b/responsemanager/peerresponsemanager/peerresponsesender.go @@ -59,6 +59,7 @@ type PeerResponseSender interface { data []byte, ) graphsync.BlockData SendExtensionData(graphsync.RequestID, graphsync.ExtensionData) + FinishWithCancel(requestID graphsync.RequestID) FinishRequest(requestID graphsync.RequestID) graphsync.ResponseStatusCode FinishWithError(requestID graphsync.RequestID, status graphsync.ResponseStatusCode) // Transaction calls multiple operations at once so they end up in a single response @@ -74,6 +75,7 @@ type PeerResponseTransactionSender interface { data []byte, ) graphsync.BlockData SendExtensionData(graphsync.ExtensionData) + FinishWithCancel() FinishRequest() graphsync.ResponseStatusCode FinishWithError(status graphsync.ResponseStatusCode) PauseRequest() @@ -177,6 +179,10 @@ func (prts *peerResponseTransactionSender) PauseRequest() { prts.operations = append(prts.operations, statusOperation{prts.requestID, graphsync.RequestPaused}) } +func (prts *peerResponseTransactionSender) FinishWithCancel() { + _ = prts.prs.finishTracking(prts.requestID) +} + func (prs *peerResponseSender) Transaction(requestID graphsync.RequestID, transaction Transaction) error { prts := &peerResponseTransactionSender{ requestID: requestID, @@ -266,10 +272,14 @@ func (fo statusOperation) size() uint64 { return 0 } -func (prs *peerResponseSender) setupFinishOperation(requestID graphsync.RequestID) statusOperation { +func (prs *peerResponseSender) finishTracking(requestID graphsync.RequestID) bool { prs.linkTrackerLk.Lock() - isComplete := prs.linkTracker.FinishRequest(requestID) - prs.linkTrackerLk.Unlock() + defer prs.linkTrackerLk.Unlock() + return prs.linkTracker.FinishRequest(requestID) +} + +func (prs *peerResponseSender) setupFinishOperation(requestID graphsync.RequestID) statusOperation { + isComplete := prs.finishTracking(requestID) var status graphsync.ResponseStatusCode if isComplete { status = graphsync.RequestCompletedFull @@ -303,6 +313,10 @@ func (prs *peerResponseSender) PauseRequest(requestID graphsync.RequestID) { prs.execute([]responseOperation{statusOperation{requestID, graphsync.RequestPaused}}) } +func (prs *peerResponseSender) FinishWithCancel(requestID graphsync.RequestID) { + _ = prs.finishTracking(requestID) +} + func (prs *peerResponseSender) buildResponse(blkSize uint64, buildResponseFn func(*responsebuilder.ResponseBuilder)) bool { prs.responseBuildersLk.Lock() defer prs.responseBuildersLk.Unlock() diff --git a/responsemanager/queryexecutor.go b/responsemanager/queryexecutor.go index d14e3691..2b859d11 100644 --- a/responsemanager/queryexecutor.go +++ b/responsemanager/queryexecutor.go @@ -3,6 +3,7 @@ package responsemanager import ( "context" "errors" + "strings" "time" "github.com/ipfs/go-cid" @@ -18,18 +19,22 @@ import ( "github.com/libp2p/go-libp2p-core/peer" ) +var errCancelledByCommand = errors.New("response cancelled by responder") + // TODO: Move this into a seperate module and fully seperate from the ResponseManager type queryExecutor struct { - requestHooks RequestHooks - blockHooks BlockHooks - updateHooks UpdateHooks - peerManager PeerManager - loader ipld.Loader - queryQueue QueryQueue - messages chan responseManagerMessage - ctx context.Context - workSignal chan struct{} - ticker *time.Ticker + requestHooks RequestHooks + blockHooks BlockHooks + updateHooks UpdateHooks + completedListeners CompletedListeners + cancelledListeners CancelledListeners + peerManager PeerManager + loader ipld.Loader + queryQueue QueryQueue + messages chan responseManagerMessage + ctx context.Context + workSignal chan struct{} + ticker *time.Ticker } func (qe *queryExecutor) processQueriesWorker() { @@ -66,6 +71,13 @@ func (qe *queryExecutor) processQueriesWorker() { continue } status, err := qe.executeTask(key, taskData) + _, isPaused := err.(hooks.ErrPaused) + isCancelled := err != nil && isContextErr(err) + if isCancelled { + qe.cancelledListeners.NotifyCancelledListeners(key.p, taskData.request) + } else if !isPaused { + qe.completedListeners.NotifyCompletedListeners(key.p, taskData.request, status) + } select { case qe.messages <- &finishTaskRequest{key, status, err}: case <-qe.ctx.Done(): @@ -82,7 +94,8 @@ func (qe *queryExecutor) executeTask(key responseKey, taskData responseTaskData) loader := taskData.loader traverser := taskData.traverser if loader == nil || traverser == nil { - loader, traverser, err = qe.prepareQuery(taskData.ctx, key.p, taskData.request) + var isPaused bool + loader, traverser, isPaused, err = qe.prepareQuery(taskData.ctx, key.p, taskData.request) if err != nil { return graphsync.RequestFailedUnknown, err } @@ -91,34 +104,41 @@ func (qe *queryExecutor) executeTask(key responseKey, taskData responseTaskData) return graphsync.RequestFailedUnknown, errors.New("context cancelled") case qe.messages <- &setResponseDataRequest{key, loader, traverser}: } + if isPaused { + return graphsync.RequestPaused, hooks.ErrPaused{} + } } - return qe.executeQuery(key.p, taskData.request, loader, traverser, taskData.pauseSignal, taskData.updateSignal) + return qe.executeQuery(key.p, taskData.request, loader, traverser, taskData.signals) } func (qe *queryExecutor) prepareQuery(ctx context.Context, p peer.ID, - request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, error) { + request gsmsg.GraphSyncRequest) (ipld.Loader, ipldutil.Traverser, bool, error) { result := qe.requestHooks.ProcessRequestHooks(p, request) peerResponseSender := qe.peerManager.SenderForPeer(p) - var validationErr error + var transactionError error + var isPaused bool err := peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { for _, extension := range result.Extensions { transaction.SendExtensionData(extension) } if result.Err != nil || !result.IsValidated { transaction.FinishWithError(graphsync.RequestFailedUnknown) - validationErr = errors.New("request not valid") + transactionError = errors.New("request not valid") + } else if result.IsPaused { + transaction.PauseRequest() + isPaused = true } return nil }) if err != nil { - return nil, nil, err + return nil, nil, false, err } - if validationErr != nil { - return nil, nil, validationErr + if transactionError != nil { + return nil, nil, false, transactionError } if err := qe.processDoNoSendCids(request, peerResponseSender); err != nil { - return nil, nil, err + return nil, nil, false, err } rootLink := cidlink.Link{Cid: request.Root()} traverser := ipldutil.TraversalBuilder{ @@ -130,7 +150,7 @@ func (qe *queryExecutor) prepareQuery(ctx context.Context, if loader == nil { loader = qe.loader } - return loader, traverser, nil + return loader, traverser, isPaused, nil } func (qe *queryExecutor) processDoNoSendCids(request gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseSender) error { @@ -160,18 +180,14 @@ func (qe *queryExecutor) executeQuery( request gsmsg.GraphSyncRequest, loader ipld.Loader, traverser ipldutil.Traverser, - pauseSignal chan struct{}, - updateSignal chan struct{}) (graphsync.ResponseStatusCode, error) { + signals signals) (graphsync.ResponseStatusCode, error) { updateChan := make(chan []gsmsg.GraphSyncRequest) peerResponseSender := qe.peerManager.SenderForPeer(p) err := runtraversal.RunTraversal(loader, traverser, func(link ipld.Link, data []byte) error { var err error _ = peerResponseSender.Transaction(request.ID(), func(transaction peerresponsemanager.PeerResponseTransactionSender) error { - err = qe.checkForUpdates(p, request, pauseSignal, updateSignal, updateChan, transaction) - if err != nil { - if err == hooks.ErrPaused { - transaction.PauseRequest() - } + err = qe.checkForUpdates(p, request, signals, updateChan, transaction) + if _, ok := err.(hooks.ErrPaused); !ok && err != nil { return nil } blockData := transaction.SendResponse(link, data) @@ -180,21 +196,32 @@ func (qe *queryExecutor) executeQuery( for _, extension := range result.Extensions { transaction.SendExtensionData(extension) } - if result.Err == hooks.ErrPaused { + if _, ok := result.Err.(hooks.ErrPaused); ok { transaction.PauseRequest() } - err = result.Err + if result.Err != nil { + err = result.Err + } } return nil }) return err }) if err != nil { - if err != hooks.ErrPaused { - peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) - return graphsync.RequestFailedUnknown, err + _, isPaused := err.(hooks.ErrPaused) + if isPaused { + return graphsync.RequestPaused, err } - return graphsync.RequestPaused, err + if isContextErr(err) { + peerResponseSender.FinishWithCancel(request.ID()) + return graphsync.RequestCancelled, err + } + if err == errCancelledByCommand { + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestCancelled) + return graphsync.RequestCancelled, err + } + peerResponseSender.FinishWithError(request.ID(), graphsync.RequestFailedUnknown) + return graphsync.RequestFailedUnknown, err } return peerResponseSender.FinishRequest(request.ID()), nil } @@ -202,15 +229,20 @@ func (qe *queryExecutor) executeQuery( func (qe *queryExecutor) checkForUpdates( p peer.ID, request gsmsg.GraphSyncRequest, - pauseSignal chan struct{}, - updateSignal chan struct{}, + signals signals, updateChan chan []gsmsg.GraphSyncRequest, peerResponseSender peerresponsemanager.PeerResponseTransactionSender) error { for { select { - case <-pauseSignal: - return hooks.ErrPaused - case <-updateSignal: + case selfCancelled := <-signals.stopSignal: + if selfCancelled { + return errCancelledByCommand + } + return ipldutil.ContextCancelError{} + case <-signals.pauseSignal: + peerResponseSender.PauseRequest() + return hooks.ErrPaused{} + case <-signals.updateSignal: select { case qe.messages <- &responseUpdateRequest{responseKey{p, request.ID()}, updateChan}: case <-qe.ctx.Done(): @@ -233,3 +265,8 @@ func (qe *queryExecutor) checkForUpdates( } } } + +func isContextErr(err error) bool { + // TODO: Match with errors.Is when https://github.com/ipld/go-ipld-prime/issues/58 is resolved + return strings.Contains(err.Error(), ipldutil.ContextCancelError{}.Error()) +} diff --git a/responsemanager/responsemanager.go b/responsemanager/responsemanager.go index 3df278f8..46fd1568 100644 --- a/responsemanager/responsemanager.go +++ b/responsemanager/responsemanager.go @@ -26,15 +26,14 @@ const ( ) type inProgressResponseStatus struct { - ctx context.Context - cancelFn func() - request gsmsg.GraphSyncRequest - loader ipld.Loader - traverser ipldutil.Traverser - pauseSignal chan struct{} - updateSignal chan struct{} - updates []gsmsg.GraphSyncRequest - isPaused bool + ctx context.Context + cancelFn func() + request gsmsg.GraphSyncRequest + loader ipld.Loader + traverser ipldutil.Traverser + signals signals + updates []gsmsg.GraphSyncRequest + isPaused bool } type responseKey struct { @@ -42,14 +41,19 @@ type responseKey struct { requestID graphsync.RequestID } -type responseTaskData struct { - empty bool - ctx context.Context - request gsmsg.GraphSyncRequest - loader ipld.Loader - traverser ipldutil.Traverser +type signals struct { pauseSignal chan struct{} updateSignal chan struct{} + stopSignal chan bool +} + +type responseTaskData struct { + empty bool + ctx context.Context + request gsmsg.GraphSyncRequest + loader ipld.Loader + traverser ipldutil.Traverser + signals signals } // QueryQueue is an interface that can receive new selector query tasks @@ -82,6 +86,11 @@ type CompletedListeners interface { NotifyCompletedListeners(p peer.ID, request graphsync.RequestData, status graphsync.ResponseStatusCode) } +// CancelledListeners is an interface for notifying listeners that requestor cancelled +type CancelledListeners interface { + NotifyCancelledListeners(p peer.ID, request graphsync.RequestData) +} + // PeerManager is an interface that returns sender interfaces for peer responses. type PeerManager interface { SenderForPeer(p peer.ID) peerresponsemanager.PeerResponseSender @@ -99,6 +108,7 @@ type ResponseManager struct { peerManager PeerManager queryQueue QueryQueue updateHooks UpdateHooks + cancelledListeners CancelledListeners completedListeners CompletedListeners messages chan responseManagerMessage workSignal chan struct{} @@ -115,21 +125,25 @@ func New(ctx context.Context, requestHooks RequestHooks, blockHooks BlockHooks, updateHooks UpdateHooks, - completedListeners CompletedListeners) *ResponseManager { + completedListeners CompletedListeners, + cancelledListeners CancelledListeners, +) *ResponseManager { ctx, cancelFn := context.WithCancel(ctx) messages := make(chan responseManagerMessage, 16) workSignal := make(chan struct{}, 1) qe := &queryExecutor{ - requestHooks: requestHooks, - blockHooks: blockHooks, - updateHooks: updateHooks, - peerManager: peerManager, - loader: loader, - queryQueue: queryQueue, - messages: messages, - ctx: ctx, - workSignal: workSignal, - ticker: time.NewTicker(thawSpeed), + requestHooks: requestHooks, + blockHooks: blockHooks, + updateHooks: updateHooks, + completedListeners: completedListeners, + cancelledListeners: cancelledListeners, + peerManager: peerManager, + loader: loader, + queryQueue: queryQueue, + messages: messages, + ctx: ctx, + workSignal: workSignal, + ticker: time.NewTicker(thawSpeed), } return &ResponseManager{ ctx: ctx, @@ -138,6 +152,7 @@ func New(ctx context.Context, queryQueue: queryQueue, updateHooks: updateHooks, completedListeners: completedListeners, + cancelledListeners: cancelledListeners, messages: messages, workSignal: workSignal, qe: qe, @@ -283,7 +298,7 @@ func (rm *ResponseManager) processUpdate(key responseKey, update gsmsg.GraphSync if !response.isPaused { response.updates = append(response.updates, update) select { - case response.updateSignal <- struct{}{}: + case response.signals.updateSignal <- struct{}{}: default: } return @@ -343,16 +358,39 @@ func (rm *ResponseManager) unpauseRequest(p peer.ID, requestID graphsync.Request return nil } +func (rm *ResponseManager) cancelRequest(p peer.ID, requestID graphsync.RequestID, selfCancel bool) error { + key := responseKey{p, requestID} + rm.queryQueue.Remove(key, key.p) + response, ok := rm.inProgressResponses[key] + if !ok { + return errors.New("could not find request") + } + + if response.isPaused { + peerResponseSender := rm.peerManager.SenderForPeer(key.p) + if selfCancel { + rm.completedListeners.NotifyCompletedListeners(p, response.request, graphsync.RequestCancelled) + peerResponseSender.FinishWithError(requestID, graphsync.RequestCancelled) + } else { + rm.cancelledListeners.NotifyCancelledListeners(p, response.request) + peerResponseSender.FinishWithCancel(requestID) + } + delete(rm.inProgressResponses, key) + response.cancelFn() + return nil + } + select { + case response.signals.stopSignal <- selfCancel: + default: + } + return nil +} + func (prm *processRequestMessage) handle(rm *ResponseManager) { for _, request := range prm.requests { key := responseKey{p: prm.p, requestID: request.ID()} if request.IsCancel() { - rm.queryQueue.Remove(key, key.p) - response, ok := rm.inProgressResponses[key] - if ok { - response.cancelFn() - delete(rm.inProgressResponses, key) - } + _ = rm.cancelRequest(prm.p, request.ID(), false) continue } if request.IsUpdate() { @@ -362,11 +400,14 @@ func (prm *processRequestMessage) handle(rm *ResponseManager) { ctx, cancelFn := context.WithCancel(rm.ctx) rm.inProgressResponses[key] = &inProgressResponseStatus{ - ctx: ctx, - cancelFn: cancelFn, - request: request, - pauseSignal: make(chan struct{}, 1), - updateSignal: make(chan struct{}, 1), + ctx: ctx, + cancelFn: cancelFn, + request: request, + signals: signals{ + pauseSignal: make(chan struct{}, 1), + updateSignal: make(chan struct{}, 1), + stopSignal: make(chan bool, 1), + }, } // TODO: Use a better work estimation metric. rm.queryQueue.PushTasks(prm.p, peertask.Task{Topic: key, Priority: int(request.Priority()), Work: 1}) @@ -381,7 +422,7 @@ func (rdr *responseDataRequest) handle(rm *ResponseManager) { response, ok := rm.inProgressResponses[rdr.key] var taskData responseTaskData if ok { - taskData = responseTaskData{false, response.ctx, response.request, response.loader, response.traverser, response.pauseSignal, response.updateSignal} + taskData = responseTaskData{false, response.ctx, response.request, response.loader, response.traverser, response.signals} } else { taskData = responseTaskData{empty: true} } @@ -396,11 +437,10 @@ func (ftr *finishTaskRequest) handle(rm *ResponseManager) { if !ok { return } - if ftr.err == hooks.ErrPaused { + if _, ok := ftr.err.(hooks.ErrPaused); ok { response.isPaused = true return } - rm.completedListeners.NotifyCompletedListeners(ftr.key.p, response.request, ftr.status) if ftr.err != nil { log.Infof("response failed: %w", ftr.err) } @@ -457,7 +497,7 @@ func (prm *pauseRequestMessage) pauseRequest(rm *ResponseManager) error { return errors.New("request is already paused") } select { - case inProgressResponse.pauseSignal <- struct{}{}: + case inProgressResponse.signals.pauseSignal <- struct{}{}: default: } return nil @@ -472,15 +512,9 @@ func (prm *pauseRequestMessage) handle(rm *ResponseManager) { } func (crm *cancelRequestMessage) handle(rm *ResponseManager) { - key := responseKey{crm.p, crm.requestID} - rm.queryQueue.Remove(key, key.p) - inProgressResponse, ok := rm.inProgressResponses[key] - if ok { - inProgressResponse.cancelFn() - delete(rm.inProgressResponses, key) - } + err := rm.cancelRequest(crm.p, crm.requestID, true) select { case <-rm.ctx.Done(): - case crm.response <- nil: + case crm.response <- err: } } diff --git a/responsemanager/responsemanager_test.go b/responsemanager/responsemanager_test.go index 4523f611..393d3dc6 100644 --- a/responsemanager/responsemanager_test.go +++ b/responsemanager/responsemanager_test.go @@ -102,11 +102,16 @@ type pausedRequest struct { requestID graphsync.RequestID } +type cancelledRequest struct { + requestID graphsync.RequestID +} + type fakePeerResponseSender struct { sentResponses chan sentResponse sentExtensions chan sentExtension lastCompletedRequest chan completedRequest pausedRequests chan pausedRequest + cancelledRequests chan cancelledRequest ignoredLinks chan []ipld.Link } @@ -163,6 +168,10 @@ func (fprs *fakePeerResponseSender) PauseRequest(requestID graphsync.RequestID) fprs.pausedRequests <- pausedRequest{requestID} } +func (fprs *fakePeerResponseSender) FinishWithCancel(requestID graphsync.RequestID) { + fprs.cancelledRequests <- cancelledRequest{requestID} +} + func (fprs *fakePeerResponseSender) Transaction(requestID graphsync.RequestID, transaction peerresponsemanager.Transaction) error { fprts := &fakePeerResponseTransactionSender{requestID, fprs} return transaction(fprts) @@ -193,12 +202,15 @@ func (fprts *fakePeerResponseTransactionSender) PauseRequest() { fprts.prs.PauseRequest(fprts.requestID) } +func (fprts *fakePeerResponseTransactionSender) FinishWithCancel() { + fprts.prs.FinishWithCancel(fprts.requestID) +} func TestIncomingQuery(t *testing.T) { td := newTestData(t) defer td.cancel() blks := td.blockChain.AllBlocks() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() @@ -219,8 +231,12 @@ func TestCancellationQueryInProgress(t *testing.T) { td := newTestData(t) defer td.cancel() blks := td.blockChain.AllBlocks() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) + cancelledListenerCalled := make(chan struct{}, 1) + td.cancelledListeners.Register(func(p peer.ID, request graphsync.RequestData) { + cancelledListenerCalled <- struct{}{} + }) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -241,6 +257,8 @@ func TestCancellationQueryInProgress(t *testing.T) { responseManager.synchronize() + testutil.AssertDoesReceive(td.ctx, t, cancelledListenerCalled, "should call cancelled listener") + // at this point we should receive at most one more block, then traversal // should complete additionalBlocks := 0 @@ -255,7 +273,7 @@ func TestCancellationQueryInProgress(t *testing.T) { require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id") additionalBlocks++ - case <-td.completedRequestChan: + case <-td.cancelledRequests: require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block") return } @@ -266,7 +284,7 @@ func TestCancellationViaCommand(t *testing.T) { td := newTestData(t) defer td.cancel() blks := td.blockChain.AllBlocks() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) td.requestHooks.Register(selectorvalidator.SelectorValidator(100)) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -298,7 +316,8 @@ func TestCancellationViaCommand(t *testing.T) { require.Equal(t, blks[blockIndex].RawData(), sentResponse.data, "sent incorrect data") require.Equal(t, td.requestID, sentResponse.requestID, "incorrect response id") additionalBlocks++ - case <-td.completedRequestChan: + case completed := <-td.completedRequestChan: + require.Equal(t, completed.result, graphsync.RequestCancelled) require.LessOrEqual(t, additionalBlocks, 1, "should send at most 1 additional block") return } @@ -309,7 +328,7 @@ func TestEarlyCancellation(t *testing.T) { td := newTestData(t) defer td.cancel() td.queryQueue.popWait.Add(1) - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) @@ -324,7 +343,7 @@ func TestEarlyCancellation(t *testing.T) { // unblock popping from queue td.queryQueue.popWait.Done() - timer := time.NewTimer(time.Second) + timer := time.NewTimer(200 * time.Millisecond) // verify no responses processed testutil.AssertDoesReceiveFirst(t, timer.C, "should not process more responses", td.sentResponses, td.completedRequestChan) } @@ -333,7 +352,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("on its own, should fail validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() responseManager.ProcessRequests(td.ctx, td.p, td.requests) var lastRequest completedRequest @@ -344,7 +363,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if non validating hook succeeds, does not pass validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.SendExtensionData(td.extensionResponse) @@ -361,7 +380,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if validating hook succeeds, should pass validation", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -379,7 +398,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("if any hook fails, should fail", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -400,7 +419,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("hooks can be unregistered", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() unregister := td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -430,7 +449,7 @@ func TestValidationAndExtensions(t *testing.T) { defer td.cancel() obs := make(map[ipld.Link][]byte) oloader, _ := testutil.NewTestStore(obs) - responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, oloader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() // add validating hook -- so the request SHOULD succeed td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { @@ -464,7 +483,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("hooks can alter the node builder chooser", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() customChooserCallCount := 0 @@ -506,7 +525,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("do-not-send-cids extension", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -536,11 +555,32 @@ func TestValidationAndExtensions(t *testing.T) { require.True(t, set.Has(link.(cidlink.Link).Cid)) } }) + t.Run("test pause/resume", func(t *testing.T) { + td := newTestData(t) + defer td.cancel() + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) + responseManager.Startup() + td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { + hookActions.ValidateRequest() + hookActions.PauseResponse() + }) + responseManager.ProcessRequests(td.ctx, td.p, td.requests) + var pauseRequest pausedRequest + testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pauseRequest, "should pause immediately") + timer := time.NewTimer(100 * time.Millisecond) + testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) + testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") + err := responseManager.UnpauseResponse(td.p, td.requestID) + require.NoError(t, err) + var lastRequest completedRequest + testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") + require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") + }) t.Run("test block hook processing", func(t *testing.T) { t.Run("can send extension data", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -562,7 +602,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can send errors", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -579,7 +619,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can pause/unpause", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -614,7 +654,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can pause/unpause externally", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -631,7 +671,7 @@ func TestValidationAndExtensions(t *testing.T) { responseManager.ProcessRequests(td.ctx, td.p, td.requests) timer := time.NewTimer(100 * time.Millisecond) testutil.AssertDoesReceiveFirst(t, timer.C, "should not complete request while paused", td.completedRequestChan) - for i := 0; i < blockCount; i++ { + for i := 0; i < blockCount+1; i++ { testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should sent block") } testutil.AssertChannelEmpty(t, td.sentResponses, "should not send more blocks") @@ -639,6 +679,9 @@ func TestValidationAndExtensions(t *testing.T) { testutil.AssertReceive(td.ctx, t, td.pausedRequests, &pausedRequest, "should pause request") err := responseManager.UnpauseResponse(td.p, td.requestID) require.NoError(t, err) + for i := blockCount + 1; i < td.blockChainLength; i++ { + testutil.AssertDoesReceive(td.ctx, t, td.sentResponses, "should send block") + } var lastRequest completedRequest testutil.AssertReceive(td.ctx, t, td.completedRequestChan, &lastRequest, "should complete request") require.True(t, gsmsg.IsTerminalSuccessCode(lastRequest.result), "request should succeed") @@ -650,7 +693,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("can pause/unpause", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -689,7 +732,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("when unpaused", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -726,7 +769,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("when paused", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -771,7 +814,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("when unpaused", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -805,7 +848,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("when paused", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -851,7 +894,7 @@ func TestValidationAndExtensions(t *testing.T) { t.Run("final response status listeners", func(t *testing.T) { td := newTestData(t) defer td.cancel() - responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners) + responseManager := New(td.ctx, td.loader, td.peerManager, td.queryQueue, td.requestHooks, td.blockHooks, td.updateHooks, td.completedListeners, td.cancelledListeners) responseManager.Startup() td.requestHooks.Register(func(p peer.ID, requestData graphsync.RequestData, hookActions graphsync.IncomingRequestHookActions) { hookActions.ValidateRequest() @@ -885,6 +928,7 @@ type testData struct { sentResponses chan sentResponse sentExtensions chan sentExtension pausedRequests chan pausedRequest + cancelledRequests chan cancelledRequest ignoredLinks chan []ipld.Link peerManager *fakePeerManager queryQueue *fakeQueryQueue @@ -904,6 +948,7 @@ type testData struct { blockHooks *hooks.OutgoingBlockHooks updateHooks *hooks.RequestUpdatedHooks completedListeners *hooks.CompletedResponseListeners + cancelledListeners *hooks.RequestorCancelledListeners } func newTestData(t *testing.T) testData { @@ -920,12 +965,14 @@ func newTestData(t *testing.T) testData { td.sentResponses = make(chan sentResponse, td.blockChainLength*2) td.sentExtensions = make(chan sentExtension, td.blockChainLength*2) td.pausedRequests = make(chan pausedRequest, 1) + td.cancelledRequests = make(chan cancelledRequest, 1) td.ignoredLinks = make(chan []ipld.Link, 1) fprs := &fakePeerResponseSender{ lastCompletedRequest: td.completedRequestChan, sentResponses: td.sentResponses, sentExtensions: td.sentExtensions, pausedRequests: td.pausedRequests, + cancelledRequests: td.cancelledRequests, ignoredLinks: td.ignoredLinks, } td.peerManager = &fakePeerManager{peerResponseSender: fprs} @@ -960,5 +1007,6 @@ func newTestData(t *testing.T) testData { td.blockHooks = hooks.NewBlockHooks() td.updateHooks = hooks.NewUpdateHooks() td.completedListeners = hooks.NewCompletedResponseListeners() + td.cancelledListeners = hooks.NewRequestorCancelledListeners() return td } diff --git a/responsemanager/runtraversal/runtraversal_test.go b/responsemanager/runtraversal/runtraversal_test.go index d4ca9f11..8ce42f46 100644 --- a/responsemanager/runtraversal/runtraversal_test.go +++ b/responsemanager/runtraversal/runtraversal_test.go @@ -2,6 +2,7 @@ package runtraversal import ( "bytes" + "context" "errors" "io" "testing" @@ -96,6 +97,9 @@ func (ft *fakeTraverser) Error(err error) { ft.receivedOutcomes = append(ft.receivedOutcomes, traverseOutcome{true, err, nil}) } +// Shutdown cancels the traversal if still in progress +func (ft *fakeTraverser) Shutdown(ctx context.Context) {} + func (ft *fakeTraverser) verifyExpectations(t *testing.T) { require.Equal(t, ft.expectedOutcomes, ft.receivedOutcomes) } diff --git a/testutil/testchain.go b/testutil/testchain.go index 992f3e34..07e0c952 100644 --- a/testutil/testchain.go +++ b/testutil/testchain.go @@ -206,6 +206,22 @@ func (tbc *TestBlockChain) VerifyResponseRange(ctx context.Context, responseChan tbc.checkResponses(responses, from, to, false) } +// VerifyWholeChainSync verifies the given set of read responses are the expected responses for the whole chain +func (tbc *TestBlockChain) VerifyWholeChainSync(responses []graphsync.ResponseProgress) { + tbc.VerifyRemainderSync(responses, 0) +} + +// VerifyRemainderSync verifies the given set of read responses are the remainder of the chain starting at the nth block from the tip +func (tbc *TestBlockChain) VerifyRemainderSync(responses []graphsync.ResponseProgress, from int) { + tbc.checkResponses(responses, from, tbc.blockChainLength, false) +} + +// VerifyResponseRangeSync verifies given set of read responses match responses for the given range of the blockchain, indexed from the tip +// (with possibly more data left in the channel) +func (tbc *TestBlockChain) VerifyResponseRangeSync(responses []graphsync.ResponseProgress, from int, to int) { + tbc.checkResponses(responses, from, to, false) +} + // VerifyWholeChainWithTypes verifies the given response channel returns the expected responses for the whole chain // and that the types in the response are the expected types for a block chain func (tbc *TestBlockChain) VerifyWholeChainWithTypes(ctx context.Context, responseChan <-chan graphsync.ResponseProgress) {