diff --git a/query/worker.go b/query/worker.go index dc15a18c..5c2d2517 100644 --- a/query/worker.go +++ b/query/worker.go @@ -24,11 +24,12 @@ var ( // queryJob is the internal struct that wraps the Query to work on, in // addition to some information about the query. type queryJob struct { - tries uint8 - index uint64 - timeout time.Duration - encoding wire.MessageEncoding - cancelChan <-chan struct{} + tries uint8 + index uint64 + timeout time.Duration + encoding wire.MessageEncoding + cancelChan <-chan struct{} + internalCancelChan <-chan struct{} *Request } @@ -128,6 +129,16 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { // result will be sent back. break + case <-job.internalCancelChan: + log.Tracef("Worker %v found job with index %v "+ + "already internally canceled (batch timed out)", + peer.Addr(), job.Index()) + + // We break to the below loop, where we'll check the + // internal cancel channel again and the ErrJobCanceled + // result will be sent back. + break + // We received a non-canceled query job, send it to the peer. default: log.Tracef("Worker %v queuing job %T with index %v", @@ -214,6 +225,13 @@ func (w *worker) Run(results chan<- *jobResult, quit <-chan struct{}) { jobErr = ErrJobCanceled break Loop + case <-job.internalCancelChan: + log.Tracef("Worker %v job %v internally "+ + "canceled", peer.Addr(), job.Index()) + + jobErr = ErrJobCanceled + break Loop + case <-quit: return } diff --git a/query/workmanager.go b/query/workmanager.go index 9217f49a..88e91e02 100644 --- a/query/workmanager.go +++ b/query/workmanager.go @@ -183,6 +183,7 @@ func (w *peerWorkManager) workDispatcher() { timeout <-chan time.Time rem int errChan chan error + cancelChan chan struct{} } // We set up a batch index counter to keep track of batches that still @@ -309,7 +310,20 @@ Loop: // turns out to be an error. batchNum := currentQueries[result.job.index] delete(currentQueries, result.job.index) - batch := currentBatches[batchNum] + + // In case the batch is already canceled we return + // early. + batch, ok := currentBatches[batchNum] + if !ok { + log.Warnf("Query(%d) result from peer %v "+ + "discarded with retries %d, because "+ + "batch already canceled: %v", + result.job.index, + result.peer.Addr(), + result.job.tries, result.err) + + continue Loop + } switch { // If the query ended because it was canceled, drop it. @@ -322,30 +336,34 @@ Loop: // was canceled, forward the error on the // batch's error channel. We do this since a // cancellation applies to the whole batch. - if batch != nil { - batch.errChan <- result.err - delete(currentBatches, batchNum) + batch.errChan <- result.err + delete(currentBatches, batchNum) - log.Debugf("Canceled batch %v", - batchNum) - continue Loop - } + log.Debugf("Canceled batch %v", batchNum) + continue Loop // If the query ended with any other error, put it back // into the work queue if it has not reached the // maximum number of retries. case result.err != nil: - // Punish the peer for the failed query. - w.cfg.Ranking.Punish(result.peer.Addr()) + // Refresh peer rank on disconnect. + if result.err == ErrPeerDisconnected { + w.cfg.Ranking.ResetRanking( + result.peer.Addr(), + ) + } else { + // Punish the peer for the failed query. + w.cfg.Ranking.Punish(result.peer.Addr()) + } - if batch != nil && !batch.noRetryMax { + if !batch.noRetryMax { result.job.tries++ } // Check if this query has reached its maximum // number of retries. If so, remove it from the // batch and don't reschedule it. - if batch != nil && !batch.noRetryMax && + if !batch.noRetryMax && result.job.tries >= batch.maxRetries { log.Warnf("Query(%d) from peer %v "+ @@ -380,11 +398,6 @@ Loop: result.job.timeout = newTimeout } - // Refresh peer rank on disconnect. - if result.err == ErrPeerDisconnected { - w.cfg.Ranking.ResetRanking(result.peer.Addr()) - } - heap.Push(work, result.job) currentQueries[result.job.index] = batchNum @@ -396,42 +409,47 @@ Loop: // Decrement the number of queries remaining in // the batch. - if batch != nil { - batch.rem-- - log.Tracef("Remaining jobs for batch "+ - "%v: %v ", batchNum, batch.rem) - - // If this was the last query in flight - // for this batch, we can notify that - // it finished, and delete it. - if batch.rem == 0 { - batch.errChan <- nil - delete(currentBatches, batchNum) - - log.Tracef("Batch %v done", - batchNum) - continue Loop - } + batch.rem-- + log.Tracef("Remaining jobs for batch "+ + "%v: %v ", batchNum, batch.rem) + + // If this was the last query in flight + // for this batch, we can notify that + // it finished, and delete it. + if batch.rem == 0 { + batch.errChan <- nil + delete(currentBatches, batchNum) + + log.Tracef("Batch %v done", + batchNum) + continue Loop } } // If the total timeout for this batch has passed, // return an error. - if batch != nil { - select { - case <-batch.timeout: - batch.errChan <- ErrQueryTimeout - delete(currentBatches, batchNum) + select { + case <-batch.timeout: + batch.errChan <- ErrQueryTimeout + delete(currentBatches, batchNum) + + // When deleting the particular batch + // number we need to make sure to cancel + // all queued and ongoing queryJobs + // to not waste resources when the batch + // call is already canceled. + if batch.cancelChan != nil { + close(batch.cancelChan) + } - log.Warnf("Query(%d) failed with "+ - "error: %v. Timing out.", - result.job.index, result.err) + log.Warnf("Query(%d) failed with "+ + "error: %v. Timing out.", + result.job.index, result.err) - log.Debugf("Batch %v timed out", - batchNum) + log.Warnf("Batch %v timed out", + batchNum) - default: - } + default: } // A new batch of queries where scheduled. @@ -442,13 +460,17 @@ Loop: log.Debugf("Adding new batch(%d) of %d queries to "+ "work queue", batchIndex, len(batch.requests)) + // Internal cancel channel of a batch request. + cancelChan := make(chan struct{}) + for _, q := range batch.requests { heap.Push(work, &queryJob{ - index: queryIndex, - timeout: minQueryTimeout, - encoding: batch.options.encoding, - cancelChan: batch.options.cancelChan, - Request: q, + index: queryIndex, + timeout: minQueryTimeout, + encoding: batch.options.encoding, + cancelChan: batch.options.cancelChan, + internalCancelChan: cancelChan, + Request: q, }) currentQueries[queryIndex] = batchIndex queryIndex++ @@ -457,9 +479,12 @@ Loop: currentBatches[batchIndex] = &batchProgress{ noRetryMax: batch.options.noRetryMax, maxRetries: batch.options.numRetries, - timeout: time.After(batch.options.timeout), + timeout: time.After( + batch.options.timeout, + ), rem: len(batch.requests), errChan: batch.errChan, + cancelChan: cancelChan, } batchIndex++ diff --git a/query/workmanager_test.go b/query/workmanager_test.go index 13e81c76..c88b2f18 100644 --- a/query/workmanager_test.go +++ b/query/workmanager_test.go @@ -3,6 +3,7 @@ package query import ( "fmt" "sort" + "sync" "testing" "time" @@ -479,3 +480,148 @@ func TestWorkManagerWorkRankingScheduling(t *testing.T) { } } } + +// queryJobWithWorkerIndex is used to know which worker was used for the +// corresponding job request to signal the result back to the result channel. +type queryJobWithWorkerIndex struct { + worker int + job *queryJob +} + +// mergeWorkChannels is used to merge the channels of all the workers into a one +// single one for better control of the concurrency during testing. +func mergeWorkChannels(workers []*mockWorker) <-chan queryJobWithWorkerIndex { + var wg sync.WaitGroup + merged := make(chan queryJobWithWorkerIndex) + + // Function to copy data from each worker channel to the merged channel + readFromWorker := func(input <-chan *queryJob, worker int) { + defer wg.Done() + for { + value, ok := <-input + if !ok { + // Channel is closed, exit the loop + return + } + merged <- queryJobWithWorkerIndex{ + worker: worker, + job: value, + } + } + } + + // Start a goroutine for each worker channel. + wg.Add(len(workers)) + for i, work := range workers { + go readFromWorker(work.nextJob, i) + } + + // Wait for all copying to be done, then close the merged channel + go func() { + wg.Wait() + close(merged) + }() + + return merged +} + +// TestWorkManagerTimeOutBatch tests that as soon as a batch times-out all the +// ongoing queries already registered with workers and also the queued up ones +// are canceled. +func TestWorkManagerTimeOutBatch(t *testing.T) { + const numQueries = 100 + const numWorkers = 10 + + // Start the workDispatcher goroutine. + wm, workers := startWorkManager(t, numWorkers) + + // mergeChan is the channel which receives all the jobQueries + // sequentially which are sent to the registered workers. + mergeChan := mergeWorkChannels(workers) + + // activeQueries are the queries currently registered with the workers. + var activeQueries []queryJobWithWorkerIndex + + // Schedule a batch of queries. + var queries []*Request + for i := 0; i < numQueries; i++ { + q := &Request{} + queries = append(queries, q) + } + + // Send the batch query (including numQueries), and include a channel + // to cancel the batch. + // + // NOTE: We will timeout the batch to simulate a slow peer connection + // and make sure we cancel all ongoing queries including the ones which + // are still queued up. + errChan := wm.Query(queries, Timeout(1*time.Second)) + + // Send a query to every active worker. + for i := 0; i < numWorkers; i++ { + select { + case jobQuery := <-mergeChan: + activeQueries = append(activeQueries, jobQuery) + case <-errChan: + t.Fatalf("did not expect on errChan") + case <-time.After(5 * time.Second): + t.Fatalf("next job not received") + } + } + + // We wait before we send the result for one query to exceed the timeout + // of the batch. + time.Sleep(2 * time.Second) + + // We need to signal a result for one of the active workers so that + // the batch timeout is triggered. + workerIndex := activeQueries[0].worker + workers[workerIndex].results <- &jobResult{ + job: activeQueries[0].job, + err: nil, + } + + // As soon as the batch times-out an error is sent via the errChan. + select { + case err := <-errChan: + require.ErrorIs(t, err, ErrQueryTimeout) + case <-time.After(time.Second): + t.Fatalf("expected for the errChan to signal") + } + + // The cancelChan got closed, this happens when the batch times-out. + // So all the ongoing queries are canceled as well. + for i := 1; i < numWorkers; i++ { + job := activeQueries[i].job + select { + case <-job.internalCancelChan: + workers[i].results <- &jobResult{ + job: job, + err: nil, + } + case <-time.After(time.Second): + t.Fatalf("expected for the cancelChan to close") + } + } + + // Make also sure that all the queued queries for this batch are + // canceled as well. + for i := numWorkers; i < numQueries; i++ { + select { + case res := <-mergeChan: + job := res.job + workerIndex := res.worker + select { + case <-job.internalCancelChan: + workers[workerIndex].results <- &jobResult{ + job: job, + err: nil, + } + case <-time.After(time.Second): + t.Fatalf("expected for the cancelChan to close") + } + case <-time.After(time.Second): + t.Fatalf("next job not received") + } + } +}