diff --git a/server/core/region_option.go b/server/core/region_option.go index 51fc9d051bc3..0c3f6a325a14 100644 --- a/server/core/region_option.go +++ b/server/core/region_option.go @@ -248,10 +248,10 @@ func WithPromoteLearner(peerID uint64) RegionCreateOption { } // WithReplacePeerStore replaces a peer's storeID with another ID. -func WithReplacePeerStore(peerID, newStoreID uint64) RegionCreateOption { +func WithReplacePeerStore(oldStoreID, newStoreID uint64) RegionCreateOption { return func(region *RegionInfo) { for _, p := range region.GetPeers() { - if p.GetId() == peerID { + if p.GetStoreId() == oldStoreID { p.StoreId = newStoreID } } diff --git a/server/schedule/filter/filters.go b/server/schedule/filter/filters.go index 766f941312ab..cb3ba7342143 100644 --- a/server/schedule/filter/filters.go +++ b/server/schedule/filter/filters.go @@ -531,23 +531,23 @@ type RegionFitter interface { } type ruleFitFilter struct { - scope string - fitter RegionFitter - region *core.RegionInfo - oldFit *placement.RegionFit - oldPeer uint64 + scope string + fitter RegionFitter + region *core.RegionInfo + oldFit *placement.RegionFit + oldStore uint64 } // NewRuleFitFilter creates a filter that ensures after replace a peer with new // one, the isolation level will not decrease. Its function is the same as // distinctScoreFilter but used when placement rules is enabled. -func NewRuleFitFilter(scope string, fitter RegionFitter, region *core.RegionInfo, oldPeerID uint64) Filter { +func NewRuleFitFilter(scope string, fitter RegionFitter, region *core.RegionInfo, oldStoreID uint64) Filter { return &ruleFitFilter{ - scope: scope, - fitter: fitter, - region: region, - oldFit: fitter.FitRegion(region), - oldPeer: oldPeerID, + scope: scope, + fitter: fitter, + region: region, + oldFit: fitter.FitRegion(region), + oldStore: oldStoreID, } } @@ -564,7 +564,7 @@ func (f *ruleFitFilter) Source(opt opt.Options, store *core.StoreInfo) bool { } func (f *ruleFitFilter) Target(opt opt.Options, store *core.StoreInfo) bool { - region := f.region.Clone(core.WithReplacePeerStore(f.oldPeer, store.GetID())) + region := f.region.Clone(core.WithReplacePeerStore(f.oldStore, store.GetID())) newFit := f.fitter.FitRegion(region) return placement.CompareRegionFit(f.oldFit, newFit) > 0 } diff --git a/server/schedule/region_scatterer.go b/server/schedule/region_scatterer.go index 90a84c69bce9..36aa19395e6c 100644 --- a/server/schedule/region_scatterer.go +++ b/server/schedule/region_scatterer.go @@ -142,7 +142,12 @@ func (r *RegionScatterer) selectPeerToReplace(stores map[uint64]*core.StoreInfo, if sourceStore == nil { log.Error("failed to get the store", zap.Uint64("store-id", storeID)) } - scoreGuard := filter.NewDistinctScoreFilter(r.name, r.cluster.GetLocationLabels(), regionStores, sourceStore) + var scoreGuard filter.Filter + if r.cluster.IsPlacementRulesEnabled() { + scoreGuard = filter.NewRuleFitFilter(r.name, r.cluster, region, oldPeer.GetStoreId()) + } else { + scoreGuard = filter.NewDistinctScoreFilter(r.name, r.cluster.GetLocationLabels(), regionStores, sourceStore) + } candidates := make([]*core.StoreInfo, 0, len(stores)) for _, store := range stores { diff --git a/server/schedulers/adjacent_region.go b/server/schedulers/adjacent_region.go index f46cc41c9a39..629ad567d592 100644 --- a/server/schedulers/adjacent_region.go +++ b/server/schedulers/adjacent_region.go @@ -317,8 +317,12 @@ func (l *balanceAdjacentRegionScheduler) dispersePeer(cluster opt.Cluster, regio log.Error("failed to get the source store", zap.Uint64("store-id", leaderStoreID)) return nil } - - scoreGuard := filter.NewDistinctScoreFilter(l.GetName(), cluster.GetLocationLabels(), stores, source) + var scoreGuard filter.Filter + if cluster.IsPlacementRulesEnabled() { + scoreGuard = filter.NewRuleFitFilter(l.GetName(), cluster, region, leaderStoreID) + } else { + scoreGuard = filter.NewDistinctScoreFilter(l.GetName(), cluster.GetLocationLabels(), stores, source) + } excludeStores := region.GetStoreIds() for _, storeID := range l.cacheRegions.assignedStoreIds { if _, ok := excludeStores[storeID]; !ok { diff --git a/server/schedulers/hot_region.go b/server/schedulers/hot_region.go index 06284f75b8a3..8a0437003a73 100644 --- a/server/schedulers/hot_region.go +++ b/server/schedulers/hot_region.go @@ -401,10 +401,24 @@ func (h *balanceHotRegionsScheduler) balanceByPeer(cluster opt.Cluster, storesSt if srcStore == nil { log.Error("failed to get the source store", zap.Uint64("store-id", srcStoreID)) } + + srcPeer := srcRegion.GetStorePeer(srcStoreID) + if srcPeer == nil { + log.Debug("region does not peer on source store, maybe stat out of date", zap.Uint64("region-id", rs.RegionID)) + continue + } + + var scoreGuard filter.Filter + if cluster.IsPlacementRulesEnabled() { + scoreGuard = filter.NewRuleFitFilter(h.GetName(), cluster, srcRegion, srcStoreID) + } else { + scoreGuard = filter.NewDistinctScoreFilter(h.GetName(), cluster.GetLocationLabels(), cluster.GetRegionStores(srcRegion), srcStore) + } + filters := []filter.Filter{ filter.StoreStateFilter{ActionScope: h.GetName(), MoveRegion: true}, filter.NewExcludedFilter(h.GetName(), srcRegion.GetStoreIds(), srcRegion.GetStoreIds()), - filter.NewDistinctScoreFilter(h.GetName(), cluster.GetLocationLabels(), cluster.GetRegionStores(srcRegion), srcStore), + scoreGuard, } candidateStoreIDs := make([]uint64, 0, len(stores)) for _, store := range stores { diff --git a/server/schedulers/hot_test.go b/server/schedulers/hot_test.go index 2e059c63ebbe..f9a051be02ad 100644 --- a/server/schedulers/hot_test.go +++ b/server/schedulers/hot_test.go @@ -42,6 +42,12 @@ func (s *testHotWriteRegionSchedulerSuite) TestSchedule(c *C) { c.Assert(err, IsNil) opt.HotRegionCacheHitsThreshold = 0 + s.checkSchedule(c, tc, opt, hb) + opt.EnablePlacementRules = true + s.checkSchedule(c, tc, opt, hb) +} + +func (s *testHotWriteRegionSchedulerSuite) checkSchedule(c *C, tc *mockcluster.Cluster, opt *mockoption.ScheduleOptions, hb schedule.Scheduler) { // Add stores 1, 2, 3, 4, 5, 6 with region counts 3, 2, 2, 2, 0, 0. tc.AddLabelsStore(1, 3, map[string]string{"zone": "z1", "host": "h1"}) diff --git a/server/schedulers/scheduler_test.go b/server/schedulers/scheduler_test.go index cb9237429691..3998229b199a 100644 --- a/server/schedulers/scheduler_test.go +++ b/server/schedulers/scheduler_test.go @@ -353,6 +353,12 @@ func (s *testShuffleHotRegionSchedulerSuite) TestBalance(c *C) { hb, err := schedule.CreateScheduler(ShuffleHotRegionType, schedule.NewOperatorController(ctx, nil, nil), core.NewStorage(kv.NewMemoryKV()), schedule.ConfigSliceDecoder("shuffle-hot-region", []string{"", ""})) c.Assert(err, IsNil) + s.checkBalance(c, tc, opt, hb) + opt.EnablePlacementRules = true + s.checkBalance(c, tc, opt, hb) +} + +func (s *testShuffleHotRegionSchedulerSuite) checkBalance(c *C, tc *mockcluster.Cluster, opt *mockoption.ScheduleOptions, hb schedule.Scheduler) { // Add stores 1, 2, 3, 4, 5, 6 with hot peer counts 3, 2, 2, 2, 0, 0. tc.AddLabelsStore(1, 3, map[string]string{"zone": "z1", "host": "h1"}) tc.AddLabelsStore(2, 2, map[string]string{"zone": "z2", "host": "h2"}) diff --git a/server/schedulers/shuffle_hot_region.go b/server/schedulers/shuffle_hot_region.go index 57b72076ac5e..0db31a39fe15 100644 --- a/server/schedulers/shuffle_hot_region.go +++ b/server/schedulers/shuffle_hot_region.go @@ -147,10 +147,18 @@ func (s *shuffleHotRegionScheduler) randomSchedule(cluster opt.Cluster, storeSta if srcStore == nil { log.Error("failed to get the source store", zap.Uint64("store-id", srcStoreID)) } + + var scoreGuard filter.Filter + if cluster.IsPlacementRulesEnabled() { + scoreGuard = filter.NewRuleFitFilter(s.GetName(), cluster, srcRegion, srcStoreID) + } else { + scoreGuard = filter.NewDistinctScoreFilter(s.GetName(), cluster.GetLocationLabels(), cluster.GetRegionStores(srcRegion), srcStore) + } + filters := []filter.Filter{ filter.StoreStateFilter{ActionScope: s.GetName(), MoveRegion: true}, filter.NewExcludedFilter(s.GetName(), srcRegion.GetStoreIds(), srcRegion.GetStoreIds()), - filter.NewDistinctScoreFilter(s.GetName(), cluster.GetLocationLabels(), cluster.GetRegionStores(srcRegion), srcStore), + scoreGuard, } stores := cluster.GetStores() destStoreIDs := make([]uint64, 0, len(stores))