Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: fix data race on the SetResourceGroupTagger #491

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions txnkv/txnsnapshot/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ func (s *Scanner) Next() error {
if !s.valid {
return errors.New("scanner iterator is invalid")
}
s.snapshot.interceptorMutex.RLock()
if s.snapshot.interceptor != nil {
s.snapshot.mu.RLock()
if s.snapshot.mu.interceptor != nil {
// User has called snapshot.SetRPCInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.snapshot.interceptor))
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.snapshot.mu.interceptor))
}
s.snapshot.interceptorMutex.RUnlock()
s.snapshot.mu.RUnlock()
var err error
for {
s.idx++
Expand Down Expand Up @@ -228,7 +228,7 @@ func (s *Scanner) getData(bo *retry.Backoffer) error {
Priority: s.snapshot.priority.ToPB(),
NotFillCache: s.snapshot.notFillCache,
IsolationLevel: s.snapshot.isolationLevel.ToPB(),
ResourceGroupTag: s.snapshot.resourceGroupTag,
ResourceGroupTag: s.snapshot.mu.resourceGroupTag,
},
StartKey: s.nextStartKey,
EndKey: reqEndKey,
Expand All @@ -247,11 +247,11 @@ func (s *Scanner) getData(bo *retry.Backoffer) error {
Priority: s.snapshot.priority.ToPB(),
NotFillCache: s.snapshot.notFillCache,
TaskId: s.snapshot.mu.taskID,
ResourceGroupTag: s.snapshot.resourceGroupTag,
ResourceGroupTag: s.snapshot.mu.resourceGroupTag,
IsolationLevel: s.snapshot.isolationLevel.ToPB(),
})
if s.snapshot.resourceGroupTag == nil && s.snapshot.resourceGroupTagger != nil {
s.snapshot.resourceGroupTagger(req)
if s.snapshot.mu.resourceGroupTag == nil && s.snapshot.mu.resourceGroupTagger != nil {
s.snapshot.mu.resourceGroupTagger(req)
}
s.snapshot.mu.RUnlock()
resp, err := sender.SendReq(bo, req, loc.Region, client.ReadTimeoutMedium)
Expand Down
64 changes: 33 additions & 31 deletions txnkv/txnsnapshot/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,14 @@ type KVSnapshot struct {
readReplicaScope string
// MatchStoreLabels indicates the labels the store should be matched
matchStoreLabels []*metapb.StoreLabel
// resourceGroupTag is use to set the kv request resource group tag.
resourceGroupTag []byte
// resourceGroupTagger is use to set the kv request resource group tag if resourceGroupTag is nil.
resourceGroupTagger tikvrpc.ResourceGroupTagger
// interceptor is used to decorate the RPC request logic related to the snapshot.
interceptor interceptor.RPCInterceptor
}
sampleStep uint32
// resourceGroupTag is use to set the kv request resource group tag.
resourceGroupTag []byte
// resourceGroupTagger is use to set the kv request resource group tag if resourceGroupTag is nil.
resourceGroupTagger tikvrpc.ResourceGroupTagger
// interceptorMutex is a lock for interceptor
interceptorMutex sync.RWMutex
// interceptor is used to decorate the RPC request logic related to the snapshot.
interceptor interceptor.RPCInterceptor
}

// NewTiKVSnapshot creates a snapshot of an TiKV store.
Expand Down Expand Up @@ -209,14 +207,14 @@ func (s *KVSnapshot) BatchGet(ctx context.Context, keys [][]byte) (map[string][]

ctx = context.WithValue(ctx, retry.TxnStartKey, s.version)
bo := retry.NewBackofferWithVars(ctx, batchGetMaxBackoff, s.vars)
s.interceptorMutex.RLock()
if s.interceptor != nil {
s.mu.RLock()
if s.mu.interceptor != nil {
// User has called snapshot.SetRPCInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.interceptor))
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.mu.interceptor))
}
s.interceptorMutex.RUnlock()
s.mu.RUnlock()
// Create a map to collect key-values from region servers.
var mu sync.Mutex
err := s.batchGetKeysByRegions(bo, keys, func(k, v []byte) {
Expand Down Expand Up @@ -367,11 +365,11 @@ func (s *KVSnapshot) batchGetSingleRegion(bo *retry.Backoffer, batch batchKeys,
Priority: s.priority.ToPB(),
NotFillCache: s.notFillCache,
TaskId: s.mu.taskID,
ResourceGroupTag: s.resourceGroupTag,
ResourceGroupTag: s.mu.resourceGroupTag,
IsolationLevel: s.isolationLevel.ToPB(),
})
if s.resourceGroupTag == nil && s.resourceGroupTagger != nil {
s.resourceGroupTagger(req)
if s.mu.resourceGroupTag == nil && s.mu.resourceGroupTagger != nil {
s.mu.resourceGroupTagger(req)
}
scope := s.mu.readReplicaScope
isStaleness := s.mu.isStaleness
Expand Down Expand Up @@ -483,14 +481,14 @@ func (s *KVSnapshot) Get(ctx context.Context, k []byte) ([]byte, error) {

ctx = context.WithValue(ctx, retry.TxnStartKey, s.version)
bo := retry.NewBackofferWithVars(ctx, getMaxBackoff, s.vars)
s.interceptorMutex.RLock()
if s.interceptor != nil {
s.mu.RLock()
if s.mu.interceptor != nil {
// User has called snapshot.SetRPCInterceptor() to explicitly set an interceptor, we
// need to bind it to ctx so that the internal client can perceive and execute
// it before initiating an RPC request.
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.interceptor))
bo.SetCtx(interceptor.WithRPCInterceptor(bo.GetCtx(), s.mu.interceptor))
}
s.interceptorMutex.RUnlock()
s.mu.RUnlock()
val, err := s.get(ctx, bo, k)
s.recordBackoffInfo(bo)
if err != nil {
Expand Down Expand Up @@ -546,11 +544,11 @@ func (s *KVSnapshot) get(ctx context.Context, bo *retry.Backoffer, k []byte) ([]
Priority: s.priority.ToPB(),
NotFillCache: s.notFillCache,
TaskId: s.mu.taskID,
ResourceGroupTag: s.resourceGroupTag,
ResourceGroupTag: s.mu.resourceGroupTag,
IsolationLevel: s.isolationLevel.ToPB(),
})
if s.resourceGroupTag == nil && s.resourceGroupTagger != nil {
s.resourceGroupTagger(req)
if s.mu.resourceGroupTag == nil && s.mu.resourceGroupTagger != nil {
s.mu.resourceGroupTagger(req)
}
isStaleness := s.mu.isStaleness
matchStoreLabels := s.mu.matchStoreLabels
Expand Down Expand Up @@ -748,34 +746,38 @@ func (s *KVSnapshot) SetMatchStoreLabels(labels []*metapb.StoreLabel) {

// SetResourceGroupTag sets resource group tag of the kv request.
func (s *KVSnapshot) SetResourceGroupTag(tag []byte) {
s.resourceGroupTag = tag
s.mu.Lock()
defer s.mu.Unlock()
s.mu.resourceGroupTag = tag
hawkingrei marked this conversation as resolved.
Show resolved Hide resolved
}

// SetResourceGroupTagger sets resource group tagger of the kv request.
// Before sending the request, if resourceGroupTag is not nil, use
// resourceGroupTag directly, otherwise use resourceGroupTagger.
func (s *KVSnapshot) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) {
s.resourceGroupTagger = tagger
s.mu.Lock()
defer s.mu.Unlock()
s.mu.resourceGroupTagger = tagger
}

// SetRPCInterceptor sets interceptor.RPCInterceptor for the snapshot.
// interceptor.RPCInterceptor will be executed before each RPC request is initiated.
// Note that SetRPCInterceptor will replace the previously set interceptor.
func (s *KVSnapshot) SetRPCInterceptor(it interceptor.RPCInterceptor) {
s.interceptorMutex.Lock()
defer s.interceptorMutex.Unlock()
s.interceptor = it
s.mu.Lock()
defer s.mu.Unlock()
s.mu.interceptor = it
}

// AddRPCInterceptor adds an interceptor, the order of addition is the order of execution.
func (s *KVSnapshot) AddRPCInterceptor(it interceptor.RPCInterceptor) {
s.interceptorMutex.Lock()
defer s.interceptorMutex.Unlock()
if s.interceptor == nil {
s.mu.Lock()
defer s.mu.Unlock()
if s.mu.interceptor == nil {
s.SetRPCInterceptor(it)
return
}
s.interceptor = interceptor.ChainRPCInterceptors(s.interceptor, it)
s.mu.interceptor = interceptor.ChainRPCInterceptors(s.mu.interceptor, it)
}

// SnapCacheHitCount gets the snapshot cache hit count. Only for test.
Expand Down