Skip to content

Commit

Permalink
Improve cross cluster components shutdown logic (#4662)
Browse files Browse the repository at this point in the history
  • Loading branch information
yycptt authored Dec 1, 2021
1 parent 624a1fc commit c894177
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
34 changes: 29 additions & 5 deletions service/history/task/cross_cluster_task_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type (
crossClusterTaskProcessors []*crossClusterTaskProcessor

crossClusterTaskProcessor struct {
ctx context.Context
ctxCancel context.CancelFunc
shard shard.Context
taskProcessor Processor
taskExecutor Executor
Expand Down Expand Up @@ -117,6 +119,7 @@ func newCrossClusterTaskProcessor(
taskFetcher Fetcher,
options *CrossClusterTaskProcessorOptions,
) *crossClusterTaskProcessor {
ctx, cancel := context.WithCancel(context.Background())
sourceCluster := taskFetcher.GetSourceCluster()
logger := shard.GetLogger().WithTags(
tag.ComponentCrossClusterTaskProcessor,
Expand All @@ -130,6 +133,8 @@ func newCrossClusterTaskProcessor(
retryPolicy.SetMaximumInterval(time.Second)
retryPolicy.SetExpirationInterval(options.TaskWaitInterval())
return &crossClusterTaskProcessor{
ctx: ctx,
ctxCancel: cancel,
shard: shard,
taskProcessor: taskProcessor,
taskExecutor: NewCrossClusterTargetTaskExecutor(
Expand Down Expand Up @@ -190,6 +195,7 @@ func (p *crossClusterTaskProcessor) Stop() {
}

close(p.shutdownCh)
p.ctxCancel()
p.redispatcher.Stop()

if success := common.AwaitWaitGroup(&p.shutdownWG, time.Minute); !success {
Expand Down Expand Up @@ -219,9 +225,13 @@ func (p *crossClusterTaskProcessor) processLoop() {
sw := p.metricsScope.StartTimer(metrics.CrossClusterFetchLatency)

var taskRequests []*types.CrossClusterTaskRequest
err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(context.Background(), &taskRequests)
err := p.taskFetcher.Fetch(p.shard.GetShardID()).Get(p.ctx, &taskRequests)
sw.Stop()
if err != nil {
if err == errTaskFetcherShutdown {
return
}

p.logger.Error("Unable to fetch cross cluster tasks", tag.Error(err))
if common.IsServiceBusyError(err) {
p.metricsScope.IncCounter(metrics.CrossClusterFetchServiceBusyFailures)
Expand All @@ -231,6 +241,7 @@ func (p *crossClusterTaskProcessor) processLoop() {
))
} else {
p.metricsScope.IncCounter(metrics.CrossClusterFetchFailures)
// note we rely on the aggregation interval in task fetcher as the backoff
}
continue
}
Expand Down Expand Up @@ -273,7 +284,7 @@ func (p *crossClusterTaskProcessor) processTaskRequests(
TargetCluster: p.shard.GetClusterMetadata().GetCurrentClusterName(),
FetchNewTasks: p.numPendingTasks() < p.options.MaxPendingTasks(),
}
taskWaitContext, cancel := context.WithTimeout(context.Background(), p.options.TaskWaitInterval())
taskWaitContext, cancel := context.WithTimeout(p.ctx, p.options.TaskWaitInterval())
deadlineExceeded := false
for taskID, taskFuture := range taskFutures {
if deadlineExceeded && !taskFuture.IsReady() {
Expand All @@ -282,6 +293,13 @@ func (p *crossClusterTaskProcessor) processTaskRequests(

var taskResponse types.CrossClusterTaskResponse
if err := taskFuture.Get(taskWaitContext, &taskResponse); err != nil {
if p.ctx.Err() != nil {
// root context is no-longer valid, component is being shutdown,
// we can return directly
cancel()
return
}

if err == context.DeadlineExceeded {
// switch to a valid context here, otherwise Get() will always return an error.
// using context.Background() is fine since we will only be calling Get() with it
Expand Down Expand Up @@ -361,7 +379,13 @@ func (p *crossClusterTaskProcessor) respondPendingTaskLoop() {
for taskID, taskFuture := range p.pendingTasks {
if taskFuture.IsReady() {
var taskResponse types.CrossClusterTaskResponse
if err := taskFuture.Get(context.Background(), &taskResponse); err != nil {
if err := taskFuture.Get(p.ctx, &taskResponse); err != nil {
if p.ctx.Err() != nil {
// we are in shutdown logic
p.taskLock.Unlock()
return
}

// this case should not happen,
// task failure should be converted to FailCause in the response by the processing logic
taskResponse = types.CrossClusterTaskResponse{
Expand Down Expand Up @@ -433,7 +457,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry(

var response *types.RespondCrossClusterTasksCompletedResponse
op := func() error {
ctx, cancel := context.WithTimeout(context.Background(), respondCrossClusterTaskTimeout)
ctx, cancel := context.WithTimeout(p.ctx, respondCrossClusterTaskTimeout)
defer cancel()
var err error
response, err = p.sourceAdminClient.RespondCrossClusterTasksCompleted(ctx, request)
Expand All @@ -443,7 +467,7 @@ func (p *crossClusterTaskProcessor) respondTaskCompletedWithRetry(
}
return err
}
err := p.throttleRetry.Do(context.Background(), op)
err := p.throttleRetry.Do(p.ctx, op)

return response, err
}
Expand Down
26 changes: 20 additions & 6 deletions service/history/task/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type (
}

fetchTaskFunc func(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand All @@ -76,7 +77,9 @@ type (
shutdownCh chan struct{}
requestChan chan fetchRequest

fetchTaskFunc fetchTaskFunc
fetchCtx context.Context
fetchCtxCancel context.CancelFunc
fetchTaskFunc fetchTaskFunc
}
)

Expand Down Expand Up @@ -111,6 +114,7 @@ func NewCrossClusterTaskFetchers(
}

func crossClusterTaskFetchFn(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand All @@ -128,7 +132,7 @@ func crossClusterTaskFetchFn(
ShardIDs: shardIDs,
TargetCluster: currentCluster,
}
ctx, cancel := context.WithTimeout(context.Background(), defaultFetchTimeout)
ctx, cancel := context.WithTimeout(ctx, defaultFetchTimeout)
defer cancel()
resp, err := adminClient.GetCrossClusterTasks(ctx, request)
if err != nil {
Expand Down Expand Up @@ -199,6 +203,7 @@ func newTaskFetcher(
metricsClient metrics.Client,
logger log.Logger,
) *fetcherImpl {
fetchCtx, fetchCtxCancel := context.WithCancel(context.Background())
return &fetcherImpl{
status: common.DaemonStatusInitialized,
currentCluster: currentCluster,
Expand All @@ -213,9 +218,11 @@ func newTaskFetcher(
tag.ComponentCrossClusterTaskFetcher,
tag.SourceCluster(sourceCluster),
),
shutdownCh: make(chan struct{}),
requestChan: make(chan fetchRequest, defaultRequestChanBufferSize),
fetchTaskFunc: fetchTaskFunc,
shutdownCh: make(chan struct{}),
requestChan: make(chan fetchRequest, defaultRequestChanBufferSize),
fetchCtx: fetchCtx,
fetchCtxCancel: fetchCtxCancel,
fetchTaskFunc: fetchTaskFunc,
}
}

Expand All @@ -239,6 +246,7 @@ func (f *fetcherImpl) Stop() {
}

close(f.shutdownCh)
f.fetchCtxCancel()
if success := common.AwaitWaitGroup(&f.shutdownWG, time.Minute); !success {
f.logger.Warn("Task fetcher timedout on shutdown.", tag.LifeCycleStopTimedout)
}
Expand Down Expand Up @@ -338,7 +346,13 @@ func (f *fetcherImpl) fetch(
sw := f.metricsScope.StartTimer(metrics.CrossClusterFetchLatency)
defer sw.Stop()

responseByShard, err := f.fetchTaskFunc(f.clientBean, f.sourceCluster, f.currentCluster, outstandingRequests)
responseByShard, err := f.fetchTaskFunc(
f.fetchCtx,
f.clientBean,
f.sourceCluster,
f.currentCluster,
outstandingRequests,
)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions service/history/task/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func (s *fetcherSuite) TestAggregator() {
}

func (s *fetcherSuite) testFetchTaskFn(
ctx context.Context,
clientBean client.Bean,
sourceCluster string,
currentCluster string,
Expand Down

0 comments on commit c894177

Please sign in to comment.