diff --git a/pkg/schedule/scatter/region_scatterer.go b/pkg/schedule/scatter/region_scatterer.go index 34f6a8a1d65..5e5246ee348 100644 --- a/pkg/schedule/scatter/region_scatterer.go +++ b/pkg/schedule/scatter/region_scatterer.go @@ -333,6 +333,11 @@ func (r *RegionScatterer) scatterRegion(region *core.RegionInfo, group string, s selectedStores := make(map[uint64]struct{}, len(region.GetPeers())) // selected StoreID set leaderCandidateStores := make([]uint64, 0, len(region.GetPeers())) // StoreID allowed to become Leader scatterWithSameEngine := func(peers map[uint64]*metapb.Peer, context engineContext) { // peers: StoreID -> Peer + filterLen := len(context.filterFuncs) + 2 + filters := make([]filter.Filter, filterLen) + for i, filterFunc := range context.filterFuncs { + filters[i] = filterFunc() + } for _, peer := range peers { if _, ok := selectedStores[peer.GetStoreId()]; ok { if allowLeader(oldFit, peer) { @@ -342,8 +347,14 @@ func (r *RegionScatterer) scatterRegion(region *core.RegionInfo, group string, s continue } for { - candidates := r.selectCandidates(group, region, oldFit, peer.GetStoreId(), selectedStores, context) - newPeer := r.selectStore(group, peer, peer.GetStoreId(), candidates, context) + sourceStore := r.cluster.GetStore(peer.GetStoreId()) + if sourceStore == nil { + log.Error("failed to get the store", zap.Uint64("store-id", peer.GetStoreId()), errs.ZapError(errs.ErrGetSourceStore)) + continue + } + filters[filterLen-2] = filter.NewExcludedFilter(r.name, nil, selectedStores) + filters[filterLen-1] = filter.NewPlacementSafeguard(r.name, r.cluster.GetSharedConfig(), r.cluster.GetBasicCluster(), r.cluster.GetRuleManager(), region, sourceStore, oldFit) + newPeer := r.selectCandidates(context, group, peer, filters) targetPeers[newPeer.GetStoreId()] = newPeer selectedStores[newPeer.GetStoreId()] = struct{}{} // If the selected peer is a peer other than origin peer in this region, @@ -435,23 +446,8 @@ func isSameDistribution(region *core.RegionInfo, targetPeers map[uint64]*metapb. return region.GetLeader().GetStoreId() == targetLeader } -func (r *RegionScatterer) selectCandidates(group string, region *core.RegionInfo, oldFit *placement.RegionFit, - sourceStoreID uint64, selectedStores map[uint64]struct{}, context engineContext) []uint64 { - sourceStore := r.cluster.GetStore(sourceStoreID) - if sourceStore == nil { - log.Error("failed to get the store", zap.Uint64("store-id", sourceStoreID), errs.ZapError(errs.ErrGetSourceStore)) - return nil - } - filters := []filter.Filter{ - filter.NewExcludedFilter(r.name, nil, selectedStores), - } - scoreGuard := filter.NewPlacementSafeguard(r.name, r.cluster.GetSharedConfig(), r.cluster.GetBasicCluster(), r.cluster.GetRuleManager(), region, sourceStore, oldFit) - for _, filterFunc := range context.filterFuncs { - filters = append(filters, filterFunc()) - } - filters = append(filters, scoreGuard) +func (r *RegionScatterer) selectCandidates(context engineContext, group string, peer *metapb.Peer, filters []filter.Filter) *metapb.Peer { stores := r.cluster.GetStores() - candidates := make([]uint64, 0) maxStoreTotalCount := uint64(0) minStoreTotalCount := uint64(math.MaxUint64) for _, store := range stores { @@ -463,37 +459,27 @@ func (r *RegionScatterer) selectCandidates(group string, region *core.RegionInfo minStoreTotalCount = count } } + + var newPeer *metapb.Peer + minCount := uint64(math.MaxUint64) + sourceHit := uint64(math.MaxUint64) for _, store := range stores { storeCount := context.selectedPeer.Get(store.GetID(), group) + if store.GetID() == peer.GetId() { + sourceHit = storeCount + } // If storeCount is equal to the maxStoreTotalCount, we should skip this store as candidate. // If the storeCount are all the same for the whole cluster(maxStoreTotalCount == minStoreTotalCount), any store // could be selected as candidate. if storeCount < maxStoreTotalCount || maxStoreTotalCount == minStoreTotalCount { if filter.Target(r.cluster.GetSharedConfig(), store, filters) { - candidates = append(candidates, store.GetID()) - } - } - } - return candidates -} - -func (r *RegionScatterer) selectStore(group string, peer *metapb.Peer, sourceStoreID uint64, candidates []uint64, context engineContext) *metapb.Peer { - if len(candidates) < 1 { - return peer - } - var newPeer *metapb.Peer - minCount := uint64(math.MaxUint64) - sourceHit := uint64(math.MaxUint64) - for _, storeID := range candidates { - count := context.selectedPeer.Get(storeID, group) - if storeID == storeID { - sourceHit = count - } - if count < minCount { - minCount = count - newPeer = &metapb.Peer{ - StoreId: storeID, - Role: peer.GetRole(), + if storeCount < minCount { + minCount = storeCount + newPeer = &metapb.Peer{ + StoreId: store.GetID(), + Role: peer.GetRole(), + } + } } } } diff --git a/pkg/schedule/scatter/region_scatterer_test.go b/pkg/schedule/scatter/region_scatterer_test.go index c0724e481f6..d531f62c661 100644 --- a/pkg/schedule/scatter/region_scatterer_test.go +++ b/pkg/schedule/scatter/region_scatterer_test.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "math" - "math/rand" "strconv" "sync" "testing" @@ -533,48 +532,11 @@ func TestSelectedStoreGC(t *testing.T) { re.False(ok) } -// TestRegionFromDifferentGroups test the multi regions. each region have its own group. -// After scatter, the distribution for the whole cluster should be well. -func TestRegionFromDifferentGroups(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - opt := mockconfig.NewTestOptions() - tc := mockcluster.NewCluster(ctx, opt) - stream := hbstream.NewTestHeartbeatStreams(ctx, tc.ID, tc, false) - oc := operator.NewController(ctx, tc.GetBasicCluster(), tc.GetSharedConfig(), stream) - // Add 6 stores. - storeCount := 6 - for i := uint64(1); i <= uint64(storeCount); i++ { - tc.AddRegionStore(i, 0) - } - scatterer := NewRegionScatterer(ctx, tc, oc, tc.AddSuspectRegions) - regionCount := 50 - for i := 1; i <= regionCount; i++ { - p := rand.Perm(storeCount) - scatterer.scatterRegion(tc.AddLeaderRegion(uint64(i), uint64(p[0])+1, uint64(p[1])+1, uint64(p[2])+1), fmt.Sprintf("t%d", i), false) - } - check := func(ss *selectedStores) { - max := uint64(0) - min := uint64(math.MaxUint64) - for i := uint64(1); i <= uint64(storeCount); i++ { - count := ss.TotalCountByStore(i) - if count > max { - max = count - } - if count < min { - min = count - } - } - re.LessOrEqual(max-min, uint64(2)) - } - check(scatterer.ordinaryEngine.selectedPeer) -} - func TestRegionHasLearner(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + group := "group" opt := mockconfig.NewTestOptions() tc := mockcluster.NewCluster(ctx, opt) stream := hbstream.NewTestHeartbeatStreams(ctx, tc.ID, tc, false) @@ -617,14 +579,14 @@ func TestRegionHasLearner(t *testing.T) { scatterer := NewRegionScatterer(ctx, tc, oc, tc.AddSuspectRegions) regionCount := 50 for i := 1; i <= regionCount; i++ { - _, err := scatterer.Scatter(tc.AddRegionWithLearner(uint64(i), uint64(1), []uint64{uint64(2), uint64(3)}, []uint64{7}), "group", false) + _, err := scatterer.Scatter(tc.AddRegionWithLearner(uint64(i), uint64(1), []uint64{uint64(2), uint64(3)}, []uint64{7}), group, false) re.NoError(err) } check := func(ss *selectedStores) { max := uint64(0) min := uint64(math.MaxUint64) for i := uint64(1); i <= max; i++ { - count := ss.TotalCountByStore(i) + count := ss.Get(i, group) if count > max { max = count } @@ -639,7 +601,7 @@ func TestRegionHasLearner(t *testing.T) { max := uint64(0) min := uint64(math.MaxUint64) for i := uint64(1); i <= voterCount; i++ { - count := ss.TotalCountByStore(i) + count := ss.Get(i, group) if count > max { max = count } @@ -691,6 +653,7 @@ func TestSelectedStoresTooFewPeers(t *testing.T) { region := tc.AddLeaderRegion(i+200, i%3+2, (i+1)%3+2, (i+2)%3+2) op := scatterer.scatterRegion(region, group, false) re.False(isPeerCountChanged(op)) + re.Equal(group, op.AdditionalInfos["group"]) } }