Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cross cluster components shutdown logic #4662

Merged
merged 3 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions service/history/task/cross_cluster_task_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type (
crossClusterTaskProcessors []*crossClusterTaskProcessor

crossClusterTaskProcessor struct {
ctx context.Context
ctxCancel context.CancelFunc
shard shard.Context
taskProcessor Processor
taskExecutor Executor
Expand Down Expand Up @@ -117,6 +119,7 @@ func newCrossClusterTaskProcessor(
taskFetcher Fetcher,
options *CrossClusterTaskProcessorOptions,
) *crossClusterTaskProcessor {
ctx, cancel := context.WithCancel(context.Background())
sourceCluster := taskFetcher.GetSourceCluster()
logger := shard.GetLogger().WithTags(
tag.ComponentCrossClusterTaskProcessor,
Expand All @@ -130,6 +133,8 @@ func newCrossClusterTaskProcessor(
retryPolicy.SetMaximumInterval(time.Second)
retryPolicy.SetExpirationInterval(options.TaskWaitInterval())
return &crossClusterTaskProcessor{
ctx: ctx,
ctxCancel: cancel,
shard: shard,
taskProcessor: taskProcessor,
taskExecutor: NewCrossClusterTargetTaskExecutor(
Expand Down Expand Up @@ -190,6 +195,7 @@ func (p *crossClusterTaskProcessor) Stop() {
}

close(p.shutdownCh)
p.ctxCancel()
p.redispatcher.Stop()

if success := common.AwaitWaitGroup(&p.shutdownWG, time.Minute); !success {
Expand Down Expand Up @@ -219,9 +225,13 @@ func (p *crossClusterTaskProcessor) processLoop() {
sw := p.metricsScope.StartTimer(metrics.CrossClusterFetchLatency)

var taskRequests []*types.CrossClusterTaskRequest
err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(context.Background(), &taskRequests)
yycptt marked this conversation as resolved.
Show resolved Hide resolved
err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(p.ctx, &taskRequests)
sw.Stop()
if err != nil {
if err == errTaskFetcherShutdown {
return
}

p.logger.Error("Unable to fetch cross cluster tasks", tag.Error(err))
if common.IsServiceBusyError(err) {
p.metricsScope.IncCounter(metrics.CrossClusterFetchServiceBusyFailures)
Expand All @@ -231,6 +241,7 @@ func (p *crossClusterTaskProcessor) processLoop() {
))
} else {
p.metricsScope.IncCounter(metrics.CrossClusterFetchFailures)
// note we rely on the aggregation interval in task fetcher as the backoff
}
continue
}
Expand Down Expand Up @@ -273,7 +284,7 @@ func (p *crossClusterTaskProcessor) processTaskRequests(
TargetCluster: p.shard.GetClusterMetadata().GetCurrentClusterName(),
FetchNewTasks: p.numPendingTasks() < p.options.MaxPendingTasks(),
}
taskWaitContext, cancel := context.WithTimeout(context.Background(), p.options.TaskWaitInterval())
taskWaitContext, cancel := context.WithTimeout(p.ctx, p.options.TaskWaitInterval())
deadlineExceeded := false
for taskID, taskFuture := range taskFutures {
if deadlineExceeded && !taskFuture.IsReady() {
Expand Down Expand Up @@ -303,6 +314,10 @@ func (p *crossClusterTaskProcessor) processTaskRequests(
}
cancel()

if p.ctx.Err() != nil {
return
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be better above, in the loop? or is that unsafe to interrupt? (though if anything under it checks the context, it may still be interrupted)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I don't have a strong opinion on this. Moving it into the for loop also works and probably can save us some iterations during the shutdown.

I was thinking the for loop is already non-blocking once the context is cancelled on shutdown, and we just need one check before sending the response back instead of checking it in each iteration.


successfullyRespondedTaskIDs := make(map[int64]struct{})
var respondResponse *types.RespondCrossClusterTasksCompletedResponse
var respondErr error
Expand Down Expand Up @@ -361,7 +376,13 @@ func (p *crossClusterTaskProcessor) respondPendingTaskLoop() {
for taskID, taskFuture := range p.pendingTasks {
if taskFuture.IsReady() {
var taskResponse types.CrossClusterTaskResponse
if err := taskFuture.Get(context.Background(), &taskResponse); err != nil {
if err := taskFuture.Get(p.ctx, &taskResponse); err != nil {
if p.ctx.Err() != nil {
// we are in shutdown logic
p.taskLock.Unlock()
return
}

// this case should not happen,
// task failure should be converted to FailCause in the response by the processing logic
taskResponse = types.CrossClusterTaskResponse{
Expand Down Expand Up @@ -433,7 +454,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry(

var response *types.RespondCrossClusterTasksCompletedResponse
op := func() error {
ctx, cancel := context.WithTimeout(context.Background(), respondCrossClusterTaskTimeout)
ctx, cancel := context.WithTimeout(p.ctx, respondCrossClusterTaskTimeout)
defer cancel()
var err error
response, err = p.sourceAdminClient.RespondCrossClusterTasksCompleted(ctx, request)
Expand All @@ -443,7 +464,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry(
}
return err
}
err := p.throttleRetry.Do(context.Background(), op)
err := p.throttleRetry.Do(p.ctx, op)

return response, err
}
Expand Down
26 changes: 20 additions & 6 deletions service/history/task/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type (
}

fetchTaskFunc func(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand All @@ -76,7 +77,9 @@ type (
shutdownCh chan struct{}
requestChan chan fetchRequest

fetchTaskFunc fetchTaskFunc
fetchCtx context.Context
fetchCtxCancel context.CancelFunc
fetchTaskFunc fetchTaskFunc
}
)

Expand Down Expand Up @@ -111,6 +114,7 @@ func NewCrossClusterTaskFetchers(
}

func crossClusterTaskFetchFn(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand All @@ -128,7 +132,7 @@ func crossClusterTaskFetchFn(
ShardIDs: shardIDs,
TargetCluster: currentCluster,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultFetchTimeout)
ctx, cancel := context.WithTimeout(ctx, defaultFetchTimeout)
defer cancel()
resp, err := adminClient.GetCrossClusterTasks(ctx, request)
if err != nil {
Expand Down Expand Up @@ -199,6 +203,7 @@ func newTaskFetcher(
metricsClient metrics.Client,
logger log.Logger,
) *fetcherImpl {
fetchCtx, fetchCtxCancel := context.WithCancel(context.Background())
return &fetcherImpl{
status: common.DaemonStatusInitialized,
currentCluster: currentCluster,
Expand All @@ -213,9 +218,11 @@ func newTaskFetcher(
tag.ComponentCrossClusterTaskFetcher,
tag.SourceCluster(sourceCluster),
),
shutdownCh: make(chan struct{}),
requestChan: make(chan fetchRequest, defaultRequestChanBufferSize),
fetchTaskFunc: fetchTaskFunc,
shutdownCh: make(chan struct{}),
requestChan: make(chan fetchRequest, defaultRequestChanBufferSize),
fetchCtx: fetchCtx,
fetchCtxCancel: fetchCtxCancel,
fetchTaskFunc: fetchTaskFunc,
}
}

Expand All @@ -239,6 +246,7 @@ func (f *fetcherImpl) Stop() {
}

close(f.shutdownCh)
f.fetchCtxCancel()
if success := common.AwaitWaitGroup(&f.shutdownWG, time.Minute); !success {
f.logger.Warn("Task fetcher timedout on shutdown.", tag.LifeCycleStopTimedout)
}
Expand Down Expand Up @@ -338,7 +346,13 @@ func (f *fetcherImpl) fetch(
sw := f.metricsScope.StartTimer(metrics.CrossClusterFetchLatency)
defer sw.Stop()

responseByShard, err := f.fetchTaskFunc(f.clientBean, f.sourceCluster, f.currentCluster, outstandingRequests)
responseByShard, err := f.fetchTaskFunc(
f.fetchCtx,
f.clientBean,
f.sourceCluster,
f.currentCluster,
outstandingRequests,
)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions service/history/task/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func (s *fetcherSuite) TestAggregator() {
}

func (s *fetcherSuite) testFetchTaskFn(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand Down