Skip to content

Commit

Permalink
Update replication fetcher lifecycle (#2421)
Browse files Browse the repository at this point in the history
* Update replication fetcher lifecycle
  • Loading branch information
yux0 authored Jan 27, 2022
1 parent 735e057 commit b042650
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 30 deletions.
49 changes: 38 additions & 11 deletions client/clientBean.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ import (
"go.temporal.io/server/common/cluster"
)

const clientBeanCallbackID = "clientBean"

type (
// Bean is a collection of clients
Bean interface {
Expand All @@ -57,10 +59,10 @@ type (

clientBeanImpl struct {
sync.Mutex
historyClient historyservice.HistoryServiceClient
matchingClient atomic.Value
clusterMetdata cluster.Metadata
factory Factory
historyClient historyservice.HistoryServiceClient
matchingClient atomic.Value
clusterMetadata cluster.Metadata
factory Factory

remoteAdminClientsLock sync.RWMutex
remoteAdminClients map[string]adminservice.AdminServiceClient
Expand Down Expand Up @@ -101,13 +103,38 @@ func NewClientBean(factory Factory, clusterMetadata cluster.Metadata) (Bean, err
remoteFrontendClients[clusterName] = remoteFrontendClient
}

return &clientBeanImpl{
bean := &clientBeanImpl{
factory: factory,
historyClient: historyClient,
clusterMetdata: clusterMetadata,
clusterMetadata: clusterMetadata,
remoteAdminClients: remoteAdminClients,
remoteFrontendClients: remoteFrontendClients,
}, nil
}
bean.registerClientEviction()
return bean, nil
}

func (h *clientBeanImpl) registerClientEviction() {
currentCluster := h.clusterMetadata.GetCurrentClusterName()
h.clusterMetadata.RegisterMetadataChangeCallback(
clientBeanCallbackID,
func(oldClusterMetadata map[string]*cluster.ClusterInformation, newClusterMetadata map[string]*cluster.ClusterInformation) {
for clusterName := range newClusterMetadata {
if clusterName == currentCluster {
continue
}
h.remoteAdminClientsLock.Lock()
if _, ok := h.remoteAdminClients[clusterName]; ok {
delete(h.remoteAdminClients, clusterName)
}
h.remoteAdminClientsLock.Unlock()
h.remoteFrontendClientsLock.Lock()
if _, ok := h.remoteFrontendClients[clusterName]; ok {
delete(h.remoteFrontendClients, clusterName)
}
h.remoteFrontendClientsLock.Unlock()
}
})
}

func (h *clientBeanImpl) GetHistoryClient() historyservice.HistoryServiceClient {
Expand All @@ -134,13 +161,13 @@ func (h *clientBeanImpl) SetMatchingClient(
}

func (h *clientBeanImpl) GetFrontendClient() workflowservice.WorkflowServiceClient {
return h.remoteFrontendClients[h.clusterMetdata.GetCurrentClusterName()]
return h.remoteFrontendClients[h.clusterMetadata.GetCurrentClusterName()]
}

func (h *clientBeanImpl) SetFrontendClient(
client workflowservice.WorkflowServiceClient,
) {
h.remoteFrontendClients[h.clusterMetdata.GetCurrentClusterName()] = client
h.remoteFrontendClients[h.clusterMetadata.GetCurrentClusterName()] = client
}

func (h *clientBeanImpl) GetRemoteAdminClient(cluster string) adminservice.AdminServiceClient {
Expand All @@ -149,7 +176,7 @@ func (h *clientBeanImpl) GetRemoteAdminClient(cluster string) adminservice.Admin
h.remoteAdminClientsLock.RUnlock()

if !ok {
clusterInfo, clusterFound := h.clusterMetdata.GetAllClusterInfo()[cluster]
clusterInfo, clusterFound := h.clusterMetadata.GetAllClusterInfo()[cluster]
if !clusterFound {
panic(fmt.Sprintf(
"Unknown cluster name: %v with given cluster information map: %v.",
Expand Down Expand Up @@ -189,7 +216,7 @@ func (h *clientBeanImpl) GetRemoteFrontendClient(cluster string) workflowservice
h.remoteFrontendClientsLock.RUnlock()

if !ok {
clusterInfo, clusterFound := h.clusterMetdata.GetAllClusterInfo()[cluster]
clusterInfo, clusterFound := h.clusterMetadata.GetAllClusterInfo()[cluster]
if !clusterFound {
panic(fmt.Sprintf(
"Unknown cluster name: %v with given cluster information map: %v.",
Expand Down
30 changes: 14 additions & 16 deletions service/history/replicationTaskFetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ type (
config *configs.Config
numWorker int
logger log.Logger
remotePeer adminservice.AdminServiceClient
rateLimiter quotas.RateLimiter
requestChan chan *replicationTaskRequest
shutdownChan chan struct{}
Expand All @@ -101,7 +100,7 @@ type (
sourceCluster string
config *configs.Config
logger log.Logger
remotePeer adminservice.AdminServiceClient
clientBean client.Bean
rateLimiter quotas.RateLimiter
requestChan chan *replicationTaskRequest
shutdownChan chan struct{}
Expand Down Expand Up @@ -175,13 +174,12 @@ func (f *ReplicationTaskFetchersImpl) GetOrCreateFetcher(clusterName string) Rep

func (f *ReplicationTaskFetchersImpl) createReplicationFetcherLocked(clusterName string) ReplicationTaskFetcher {
currentCluster := f.clusterMetadata.GetCurrentClusterName()
remoteAdminClient := f.clientBean.GetRemoteAdminClient(clusterName)
fetcher := newReplicationTaskFetcher(
f.logger,
clusterName,
currentCluster,
f.config,
remoteAdminClient,
f.clientBean,
)
fetcher.Start()
f.fetchers[clusterName] = fetcher
Expand All @@ -196,16 +194,16 @@ func (f *ReplicationTaskFetchersImpl) listenClusterMetadataChange() {
f.fetchersLock.Lock()
defer f.fetchersLock.Unlock()

for clusterName := range newClusterMetadata {
// Fetcher is lazy init. The callback only need to handle remove case.
for clusterName, newClusterInfo := range newClusterMetadata {
if clusterName == currentCluster {
continue
}
if fetcher, ok := f.fetchers[clusterName]; ok {
fetcher.Stop()
delete(f.fetchers, clusterName)
}
if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled {
f.createReplicationFetcherLocked(clusterName)
if newClusterInfo == nil || !newClusterInfo.Enabled {
fetcher.Stop()
delete(f.fetchers, clusterName)
}
}
}
},
Expand All @@ -218,7 +216,7 @@ func newReplicationTaskFetcher(
sourceCluster string,
currentCluster string,
config *configs.Config,
sourceFrontend adminservice.AdminServiceClient,
clientBean client.Bean,
) *ReplicationTaskFetcherImpl {
numWorker := config.ReplicationTaskFetcherParallelism()
requestChan := make(chan *replicationTaskRequest, requestChanBufferSize)
Expand All @@ -234,7 +232,7 @@ func newReplicationTaskFetcher(
sourceCluster,
currentCluster,
config,
sourceFrontend,
clientBean,
rateLimiter,
requestChan,
shutdownChan,
Expand All @@ -246,7 +244,6 @@ func newReplicationTaskFetcher(
config: config,
numWorker: numWorker,
logger: log.With(logger, tag.ClusterName(sourceCluster)),
remotePeer: sourceFrontend,
currentCluster: currentCluster,
sourceCluster: sourceCluster,
rateLimiter: rateLimiter,
Expand Down Expand Up @@ -309,7 +306,7 @@ func newReplicationTaskFetcherWorker(
sourceCluster string,
currentCluster string,
config *configs.Config,
sourceFrontend adminservice.AdminServiceClient,
clientBean client.Bean,
rateLimiter quotas.RateLimiter,
requestChan chan *replicationTaskRequest,
shutdownChan chan struct{},
Expand All @@ -320,7 +317,7 @@ func newReplicationTaskFetcherWorker(
sourceCluster: sourceCluster,
config: config,
logger: logger,
remotePeer: sourceFrontend,
clientBean: clientBean,
rateLimiter: rateLimiter,
requestChan: requestChan,
shutdownChan: shutdownChan,
Expand Down Expand Up @@ -426,7 +423,8 @@ func (f *replicationTaskFetcherWorker) getMessages() error {
Tokens: tokens,
ClusterName: f.currentCluster,
}
response, err := f.remotePeer.GetReplicationMessages(ctx, request)
remoteClient := f.clientBean.GetRemoteAdminClient(f.sourceCluster)
response, err := remoteClient.GetReplicationMessages(ctx, request)
if err != nil {
f.logger.Error("Failed to get replication tasks", tag.Error(err))
for _, req := range requestByShard {
Expand Down
6 changes: 3 additions & 3 deletions service/history/replicationTaskFetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (s *replicationTaskFetcherSuite) SetupTest() {
cluster.TestAlternativeClusterName,
cluster.TestCurrentClusterName,
s.config,
s.frontendClient,
s.mockResource.ClientBean,
)
}

Expand Down Expand Up @@ -306,7 +306,7 @@ func (s *replicationTaskFetcherSuite) TestConcurrentFetchAndProcess_Success() {
cluster.TestAlternativeClusterName,
cluster.TestCurrentClusterName,
s.config,
s.frontendClient,
s.mockResource.ClientBean,
)

s.frontendClient.EXPECT().GetReplicationMessages(
Expand Down Expand Up @@ -350,7 +350,7 @@ func (s *replicationTaskFetcherSuite) TestConcurrentFetchAndProcess_Error() {
cluster.TestAlternativeClusterName,
cluster.TestCurrentClusterName,
s.config,
s.frontendClient,
s.mockResource.ClientBean,
)

s.frontendClient.EXPECT().GetReplicationMessages(
Expand Down

0 comments on commit b042650

Please sign in to comment.