From 58e9b2088b069a2c6a1e58be9dac752e46593e44 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 27 Nov 2023 12:07:14 +0800 Subject: [PATCH] client/http: implement more rule and batch related interfaces (#7430) ref tikv/pd#7300 - Implement more rule and batch related interfaces. - Add more types and methods. - Refine the tests. Signed-off-by: JmPotato --- client/http/api.go | 35 ++-- client/http/client.go | 101 +++++++++-- client/http/types.go | 61 ++++++- pkg/schedule/checker/checker_controller.go | 5 + pkg/utils/tsoutil/tsoutil.go | 5 + server/cluster/cluster.go | 6 +- server/cluster/scheduling_controller.go | 7 + tests/integrations/client/http_client_test.go | 160 ++++++++++++++---- 8 files changed, 319 insertions(+), 61 deletions(-) diff --git a/client/http/api.go b/client/http/api.go index 1826e2231ee..c6d4f2dfb74 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -23,19 +23,20 @@ import ( // The following constants are the paths of PD HTTP APIs. const ( // Metadata - HotRead = "/pd/api/v1/hotspot/regions/read" - HotWrite = "/pd/api/v1/hotspot/regions/write" - HotHistory = "/pd/api/v1/hotspot/regions/history" - RegionByIDPrefix = "/pd/api/v1/region/id" - regionByKey = "/pd/api/v1/region/key" - Regions = "/pd/api/v1/regions" - regionsByKey = "/pd/api/v1/regions/key" - RegionsByStoreIDPrefix = "/pd/api/v1/regions/store" - EmptyRegions = "/pd/api/v1/regions/check/empty-region" - AccelerateSchedule = "/pd/api/v1/regions/accelerate-schedule" - store = "/pd/api/v1/store" - Stores = "/pd/api/v1/stores" - StatsRegion = "/pd/api/v1/stats/region" + HotRead = "/pd/api/v1/hotspot/regions/read" + HotWrite = "/pd/api/v1/hotspot/regions/write" + HotHistory = "/pd/api/v1/hotspot/regions/history" + RegionByIDPrefix = "/pd/api/v1/region/id" + regionByKey = "/pd/api/v1/region/key" + Regions = "/pd/api/v1/regions" + regionsByKey = "/pd/api/v1/regions/key" + RegionsByStoreIDPrefix = "/pd/api/v1/regions/store" + EmptyRegions = "/pd/api/v1/regions/check/empty-region" + AccelerateSchedule = "/pd/api/v1/regions/accelerate-schedule" + AccelerateScheduleInBatch = "/pd/api/v1/regions/accelerate-schedule/batch" + store = "/pd/api/v1/store" + Stores = "/pd/api/v1/stores" + StatsRegion = "/pd/api/v1/stats/region" // Config Config = "/pd/api/v1/config" ClusterVersion = "/pd/api/v1/config/cluster-version" @@ -44,8 +45,11 @@ const ( // Rule PlacementRule = "/pd/api/v1/config/rule" PlacementRules = "/pd/api/v1/config/rules" + PlacementRulesInBatch = "/pd/api/v1/config/rules/batch" placementRulesByGroup = "/pd/api/v1/config/rules/group" PlacementRuleBundle = "/pd/api/v1/config/placement-rule" + placementRuleGroup = "/pd/api/v1/config/rule_group" + placementRuleGroups = "/pd/api/v1/config/rule_groups" RegionLabelRule = "/pd/api/v1/config/region-label/rule" RegionLabelRules = "/pd/api/v1/config/region-label/rules" RegionLabelRulesByIDs = "/pd/api/v1/config/region-label/rules/ids" @@ -136,6 +140,11 @@ func PlacementRuleBundleWithPartialParameter(partial bool) string { return fmt.Sprintf("%s?partial=%t", PlacementRuleBundle, partial) } +// PlacementRuleGroupByID returns the path of PD HTTP API to get placement rule group by ID. +func PlacementRuleGroupByID(id string) string { + return fmt.Sprintf("%s/%s", placementRuleGroup, id) +} + // SchedulerByName returns the scheduler API with the given scheduler name. func SchedulerByName(name string) string { return fmt.Sprintf("%s/%s", Schedulers, name) diff --git a/client/http/client.go b/client/http/client.go index 880489aa85c..1d8c2d5c427 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/hex" "encoding/json" "fmt" "io" @@ -46,25 +47,31 @@ type Client interface { GetRegionByID(context.Context, uint64) (*RegionInfo, error) GetRegionByKey(context.Context, []byte) (*RegionInfo, error) GetRegions(context.Context) (*RegionsInfo, error) - GetRegionsByKeyRange(context.Context, []byte, []byte, int) (*RegionsInfo, error) + GetRegionsByKeyRange(context.Context, *KeyRange, int) (*RegionsInfo, error) GetRegionsByStoreID(context.Context, uint64) (*RegionsInfo, error) GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) - GetRegionStatusByKeyRange(context.Context, []byte, []byte) (*RegionStats, error) + GetRegionStatusByKeyRange(context.Context, *KeyRange) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) /* Rule-related interfaces */ GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) GetPlacementRulesByGroup(context.Context, string) ([]*Rule, error) SetPlacementRule(context.Context, *Rule) error + SetPlacementRuleInBatch(context.Context, []*RuleOp) error SetPlacementRuleBundles(context.Context, []*GroupBundle, bool) error DeletePlacementRule(context.Context, string, string) error + GetAllPlacementRuleGroups(context.Context) ([]*RuleGroup, error) + GetPlacementRuleGroupByID(context.Context, string) (*RuleGroup, error) + SetPlacementRuleGroup(context.Context, *RuleGroup) error + DeletePlacementRuleGroupByID(context.Context, string) error GetAllRegionLabelRules(context.Context) ([]*LabelRule, error) GetRegionLabelRulesByIDs(context.Context, []string) ([]*LabelRule, error) SetRegionLabelRule(context.Context, *LabelRule) error PatchRegionLabelRules(context.Context, *LabelRulePatch) error /* Scheduling-related interfaces */ - AccelerateSchedule(context.Context, []byte, []byte) error + AccelerateSchedule(context.Context, *KeyRange) error + AccelerateScheduleInBatch(context.Context, []*KeyRange) error /* Other interfaces */ GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) @@ -308,10 +315,10 @@ func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { } // GetRegionsByKeyRange gets the regions info by key range. If the limit is -1, it will return all regions within the range. -func (c *client) GetRegionsByKeyRange(ctx context.Context, startKey, endKey []byte, limit int) (*RegionsInfo, error) { +func (c *client) GetRegionsByKeyRange(ctx context.Context, keyRange *KeyRange, limit int) (*RegionsInfo, error) { var regions RegionsInfo err := c.requestWithRetry(ctx, - "GetRegionsByKeyRange", RegionsByKey(startKey, endKey, limit), + "GetRegionsByKeyRange", RegionsByKey(keyRange.StartKey, keyRange.EndKey, limit), http.MethodGet, http.NoBody, ®ions) if err != nil { return nil, err @@ -356,10 +363,10 @@ func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, e } // GetRegionStatusByKeyRange gets the region status by key range. -func (c *client) GetRegionStatusByKeyRange(ctx context.Context, startKey, endKey []byte) (*RegionStats, error) { +func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange) (*RegionStats, error) { var regionStats RegionStats err := c.requestWithRetry(ctx, - "GetRegionStatusByKeyRange", RegionStatsByKeyRange(startKey, endKey), + "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange.StartKey, keyRange.StartKey), http.MethodGet, http.NoBody, ®ionStats, ) if err != nil { @@ -427,6 +434,17 @@ func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { http.MethodPost, bytes.NewBuffer(ruleJSON), nil) } +// SetPlacementRuleInBatch sets the placement rules in batch. +func (c *client) SetPlacementRuleInBatch(ctx context.Context, ruleOps []*RuleOp) error { + ruleOpsJSON, err := json.Marshal(ruleOps) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetPlacementRuleInBatch", PlacementRulesInBatch, + http.MethodPost, bytes.NewBuffer(ruleOpsJSON), nil) +} + // SetPlacementRuleBundles sets the placement rule bundles. // If `partial` is false, all old configurations will be over-written and dropped. func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBundle, partial bool) error { @@ -446,6 +464,48 @@ func (c *client) DeletePlacementRule(ctx context.Context, group, id string) erro http.MethodDelete, http.NoBody, nil) } +// GetAllPlacementRuleGroups gets all placement rule groups. +func (c *client) GetAllPlacementRuleGroups(ctx context.Context) ([]*RuleGroup, error) { + var ruleGroups []*RuleGroup + err := c.requestWithRetry(ctx, + "GetAllPlacementRuleGroups", placementRuleGroups, + http.MethodGet, http.NoBody, &ruleGroups) + if err != nil { + return nil, err + } + return ruleGroups, nil +} + +// GetPlacementRuleGroupByID gets the placement rule group by ID. +func (c *client) GetPlacementRuleGroupByID(ctx context.Context, id string) (*RuleGroup, error) { + var ruleGroup RuleGroup + err := c.requestWithRetry(ctx, + "GetPlacementRuleGroupByID", PlacementRuleGroupByID(id), + http.MethodGet, http.NoBody, &ruleGroup) + if err != nil { + return nil, err + } + return &ruleGroup, nil +} + +// SetPlacementRuleGroup sets the placement rule group. +func (c *client) SetPlacementRuleGroup(ctx context.Context, ruleGroup *RuleGroup) error { + ruleGroupJSON, err := json.Marshal(ruleGroup) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetPlacementRuleGroup", placementRuleGroup, + http.MethodPost, bytes.NewBuffer(ruleGroupJSON), nil) +} + +// DeletePlacementRuleGroupByID deletes the placement rule group by ID. +func (c *client) DeletePlacementRuleGroupByID(ctx context.Context, id string) error { + return c.requestWithRetry(ctx, + "DeletePlacementRuleGroupByID", PlacementRuleGroupByID(id), + http.MethodDelete, http.NoBody, nil) +} + // GetAllRegionLabelRules gets all region label rules. func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, error) { var labelRules []*LabelRule @@ -497,17 +557,34 @@ func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *Labe } // AccelerateSchedule accelerates the scheduling of the regions within the given key range. -func (c *client) AccelerateSchedule(ctx context.Context, startKey, endKey []byte) error { - input := map[string]string{ - "start_key": url.QueryEscape(string(startKey)), - "end_key": url.QueryEscape(string(endKey)), +func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) error { + inputJSON, err := json.Marshal(map[string]string{ + "start_key": url.QueryEscape(hex.EncodeToString(keyRange.StartKey)), + "end_key": url.QueryEscape(hex.EncodeToString(keyRange.EndKey)), + }) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "AccelerateSchedule", AccelerateSchedule, + http.MethodPost, bytes.NewBuffer(inputJSON), nil) +} + +// AccelerateScheduleInBatch accelerates the scheduling of the regions within the given key ranges in batch. +func (c *client) AccelerateScheduleInBatch(ctx context.Context, keyRanges []*KeyRange) error { + input := make([]map[string]string, 0, len(keyRanges)) + for _, keyRange := range keyRanges { + input = append(input, map[string]string{ + "start_key": url.QueryEscape(hex.EncodeToString(keyRange.StartKey)), + "end_key": url.QueryEscape(hex.EncodeToString(keyRange.EndKey)), + }) } inputJSON, err := json.Marshal(input) if err != nil { return errors.Trace(err) } return c.requestWithRetry(ctx, - "AccelerateSchedule", AccelerateSchedule, + "AccelerateScheduleInBatch", AccelerateScheduleInBatch, http.MethodPost, bytes.NewBuffer(inputJSON), nil) } diff --git a/client/http/types.go b/client/http/types.go index f948286c2b5..56d59bafa58 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -14,7 +14,16 @@ package http -import "time" +import ( + "encoding/json" + "time" +) + +// KeyRange defines a range of keys. +type KeyRange struct { + StartKey []byte `json:"start_key"` + EndKey []byte `json:"end_key"` +} // NOTICE: the structures below are copied from the PD API definitions. // Please make sure the consistency if any change happens to the PD API. @@ -247,6 +256,56 @@ type Rule struct { CreateTimestamp uint64 `json:"create_timestamp,omitempty"` // only set at runtime, recorded rule create timestamp } +// String returns the string representation of this rule. +func (r *Rule) String() string { + b, _ := json.Marshal(r) + return string(b) +} + +// Clone returns a copy of Rule. +func (r *Rule) Clone() *Rule { + var clone Rule + json.Unmarshal([]byte(r.String()), &clone) + clone.StartKey = append(r.StartKey[:0:0], r.StartKey...) + clone.EndKey = append(r.EndKey[:0:0], r.EndKey...) + return &clone +} + +// RuleOpType indicates the operation type +type RuleOpType string + +const ( + // RuleOpAdd a placement rule, only need to specify the field *Rule + RuleOpAdd RuleOpType = "add" + // RuleOpDel a placement rule, only need to specify the field `GroupID`, `ID`, `MatchID` + RuleOpDel RuleOpType = "del" +) + +// RuleOp is for batching placement rule actions. +// The action type is distinguished by the field `Action`. +type RuleOp struct { + *Rule // information of the placement rule to add/delete the operation type + Action RuleOpType `json:"action"` + DeleteByIDPrefix bool `json:"delete_by_id_prefix"` // if action == delete, delete by the prefix of id +} + +func (r RuleOp) String() string { + b, _ := json.Marshal(r) + return string(b) +} + +// RuleGroup defines properties of a rule group. +type RuleGroup struct { + ID string `json:"id,omitempty"` + Index int `json:"index,omitempty"` + Override bool `json:"override,omitempty"` +} + +func (g *RuleGroup) String() string { + b, _ := json.Marshal(g) + return string(b) +} + // GroupBundle represents a rule group and all rules belong to the group. type GroupBundle struct { ID string `json:"group_id"` diff --git a/pkg/schedule/checker/checker_controller.go b/pkg/schedule/checker/checker_controller.go index 68b794f417a..355226cd2d8 100644 --- a/pkg/schedule/checker/checker_controller.go +++ b/pkg/schedule/checker/checker_controller.go @@ -221,6 +221,11 @@ func (c *Controller) ClearSuspectKeyRanges() { c.suspectKeyRanges.Clear() } +// ClearSuspectRegions clears the suspect regions, only for unit test +func (c *Controller) ClearSuspectRegions() { + c.suspectRegions.Clear() +} + // IsPendingRegion returns true if the given region is in the pending list. func (c *Controller) IsPendingRegion(regionID uint64) bool { _, exist := c.ruleChecker.pendingList.Get(regionID) diff --git a/pkg/utils/tsoutil/tsoutil.go b/pkg/utils/tsoutil/tsoutil.go index 796012ae031..43d8b09aa49 100644 --- a/pkg/utils/tsoutil/tsoutil.go +++ b/pkg/utils/tsoutil/tsoutil.go @@ -25,6 +25,11 @@ const ( logicalBits = (1 << physicalShiftBits) - 1 ) +// TimeToTS converts a `time.Time` to an `uint64` TS. +func TimeToTS(t time.Time) uint64 { + return ComposeTS(t.UnixNano()/int64(time.Millisecond), 0) +} + // ParseTS parses the ts to (physical,logical). func ParseTS(ts uint64) (time.Time, uint64) { physical, logical := ParseTSUint64(ts) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 1c3d8a03a98..ec8ca3a0d65 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -2241,7 +2241,9 @@ func (c *RaftCluster) SetMinResolvedTS(storeID, minResolvedTS uint64) error { return nil } -func (c *RaftCluster) checkAndUpdateMinResolvedTS() (uint64, bool) { +// CheckAndUpdateMinResolvedTS checks and updates the min resolved ts of the cluster. +// This is exported for testing purpose. +func (c *RaftCluster) CheckAndUpdateMinResolvedTS() (uint64, bool) { c.Lock() defer c.Unlock() @@ -2284,7 +2286,7 @@ func (c *RaftCluster) runMinResolvedTSJob() { case <-ticker.C: interval = c.opt.GetMinResolvedTSPersistenceInterval() if interval != 0 { - if current, needPersist := c.checkAndUpdateMinResolvedTS(); needPersist { + if current, needPersist := c.CheckAndUpdateMinResolvedTS(); needPersist { c.storage.SaveMinResolvedTS(current) } } else { diff --git a/server/cluster/scheduling_controller.go b/server/cluster/scheduling_controller.go index 5e8cb8462df..a36e7159cfd 100644 --- a/server/cluster/scheduling_controller.go +++ b/server/cluster/scheduling_controller.go @@ -438,6 +438,13 @@ func (sc *schedulingController) ClearSuspectKeyRanges() { sc.coordinator.GetCheckerController().ClearSuspectKeyRanges() } +// ClearSuspectRegions clears the suspect regions, only for unit test +func (sc *schedulingController) ClearSuspectRegions() { + sc.mu.RLock() + defer sc.mu.RUnlock() + sc.coordinator.GetCheckerController().ClearSuspectRegions() +} + // AddSuspectKeyRange adds the key range with the its ruleID as the key // The instance of each keyRange is like following format: // [2][]byte: start key/end key diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 213aa57de46..dc901b1d290 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -17,13 +17,19 @@ package client_test import ( "context" "math" + "net/http" "sort" "testing" + "time" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client/http" + "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/pkg/utils/tsoutil" "github.com/tikv/pd/tests" ) @@ -69,21 +75,30 @@ func (suite *httpClientTestSuite) TearDownSuite() { func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { re := suite.Require() - // Get the cluster-level min resolved TS. + testMinResolvedTS := tsoutil.TimeToTS(time.Now()) + raftCluster := suite.cluster.GetLeaderServer().GetRaftCluster() + err := raftCluster.SetMinResolvedTS(1, testMinResolvedTS) + re.NoError(err) + // Make sure the min resolved TS is updated. + testutil.Eventually(re, func() bool { + minResolvedTS, _ := raftCluster.CheckAndUpdateMinResolvedTS() + return minResolvedTS == testMinResolvedTS + }) + // Wait for the cluster-level min resolved TS to be initialized. minResolvedTS, storeMinResolvedTSMap, err := suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, nil) re.NoError(err) - re.Greater(minResolvedTS, uint64(0)) + re.Equal(testMinResolvedTS, minResolvedTS) re.Empty(storeMinResolvedTSMap) // Get the store-level min resolved TS. minResolvedTS, storeMinResolvedTSMap, err = suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, []uint64{1}) re.NoError(err) - re.Greater(minResolvedTS, uint64(0)) + re.Equal(testMinResolvedTS, minResolvedTS) re.Len(storeMinResolvedTSMap, 1) re.Equal(minResolvedTS, storeMinResolvedTSMap[1]) // Get the store-level min resolved TS with an invalid store ID. minResolvedTS, storeMinResolvedTSMap, err = suite.client.GetMinResolvedTSByStoresIDs(suite.ctx, []uint64{1, 2}) re.NoError(err) - re.Greater(minResolvedTS, uint64(0)) + re.Equal(testMinResolvedTS, minResolvedTS) re.Len(storeMinResolvedTSMap, 2) re.Equal(minResolvedTS, storeMinResolvedTSMap[1]) re.Equal(uint64(math.MaxUint64), storeMinResolvedTSMap[2]) @@ -98,15 +113,15 @@ func (suite *httpClientTestSuite) TestRule() { bundle, err := suite.client.GetPlacementRuleBundleByGroup(suite.ctx, placement.DefaultGroupID) re.NoError(err) re.Equal(bundles[0], bundle) - rules, err := suite.client.GetPlacementRulesByGroup(suite.ctx, placement.DefaultGroupID) - re.NoError(err) - re.Len(rules, 1) - re.Equal(placement.DefaultGroupID, rules[0].GroupID) - re.Equal(placement.DefaultRuleID, rules[0].ID) - re.Equal(pd.Voter, rules[0].Role) - re.Equal(3, rules[0].Count) + // Check if we have the default rule. + suite.checkRule(re, &pd.Rule{ + GroupID: placement.DefaultGroupID, + ID: placement.DefaultRuleID, + Role: pd.Voter, + Count: 3, + }, 1, true) // Should be the same as the rules in the bundle. - re.Equal(bundle.Rules, rules) + suite.checkRule(re, bundle.Rules[0], 1, true) testRule := &pd.Rule{ GroupID: placement.DefaultGroupID, ID: "test", @@ -115,20 +130,24 @@ func (suite *httpClientTestSuite) TestRule() { } err = suite.client.SetPlacementRule(suite.ctx, testRule) re.NoError(err) - rules, err = suite.client.GetPlacementRulesByGroup(suite.ctx, placement.DefaultGroupID) - re.NoError(err) - re.Len(rules, 2) - re.Equal(placement.DefaultGroupID, rules[1].GroupID) - re.Equal("test", rules[1].ID) - re.Equal(pd.Voter, rules[1].Role) - re.Equal(3, rules[1].Count) + suite.checkRule(re, testRule, 2, true) err = suite.client.DeletePlacementRule(suite.ctx, placement.DefaultGroupID, "test") re.NoError(err) - rules, err = suite.client.GetPlacementRulesByGroup(suite.ctx, placement.DefaultGroupID) + suite.checkRule(re, testRule, 1, false) + testRuleOp := &pd.RuleOp{ + Rule: testRule, + Action: pd.RuleOpAdd, + } + err = suite.client.SetPlacementRuleInBatch(suite.ctx, []*pd.RuleOp{testRuleOp}) + re.NoError(err) + suite.checkRule(re, testRule, 2, true) + testRuleOp = &pd.RuleOp{ + Rule: testRule, + Action: pd.RuleOpDel, + } + err = suite.client.SetPlacementRuleInBatch(suite.ctx, []*pd.RuleOp{testRuleOp}) re.NoError(err) - re.Len(rules, 1) - re.Equal(placement.DefaultGroupID, rules[0].GroupID) - re.Equal(placement.DefaultRuleID, rules[0].ID) + suite.checkRule(re, testRule, 1, false) err = suite.client.SetPlacementRuleBundles(suite.ctx, []*pd.GroupBundle{ { ID: placement.DefaultGroupID, @@ -136,14 +155,63 @@ func (suite *httpClientTestSuite) TestRule() { }, }, true) re.NoError(err) - bundles, err = suite.client.GetAllPlacementRuleBundles(suite.ctx) + suite.checkRule(re, testRule, 1, true) + ruleGroups, err := suite.client.GetAllPlacementRuleGroups(suite.ctx) re.NoError(err) - re.Len(bundles, 1) - re.Equal(placement.DefaultGroupID, bundles[0].ID) - re.Len(bundles[0].Rules, 1) - // Make sure the create timestamp is not zero to pass the later assertion. - testRule.CreateTimestamp = bundles[0].Rules[0].CreateTimestamp - re.Equal(testRule, bundles[0].Rules[0]) + re.Len(ruleGroups, 1) + re.Equal(placement.DefaultGroupID, ruleGroups[0].ID) + ruleGroup, err := suite.client.GetPlacementRuleGroupByID(suite.ctx, placement.DefaultGroupID) + re.NoError(err) + re.Equal(ruleGroups[0], ruleGroup) + testRuleGroup := &pd.RuleGroup{ + ID: "test-group", + Index: 1, + Override: true, + } + err = suite.client.SetPlacementRuleGroup(suite.ctx, testRuleGroup) + re.NoError(err) + ruleGroup, err = suite.client.GetPlacementRuleGroupByID(suite.ctx, testRuleGroup.ID) + re.NoError(err) + re.Equal(testRuleGroup, ruleGroup) + err = suite.client.DeletePlacementRuleGroupByID(suite.ctx, testRuleGroup.ID) + re.NoError(err) + ruleGroup, err = suite.client.GetPlacementRuleGroupByID(suite.ctx, testRuleGroup.ID) + re.ErrorContains(err, http.StatusText(http.StatusNotFound)) + re.Empty(ruleGroup) +} + +func (suite *httpClientTestSuite) checkRule( + re *require.Assertions, + rule *pd.Rule, totalRuleCount int, exist bool, +) { + // Check through the `GetPlacementRulesByGroup` API. + rules, err := suite.client.GetPlacementRulesByGroup(suite.ctx, rule.GroupID) + re.NoError(err) + checkRuleFunc(re, rules, rule, totalRuleCount, exist) + // Check through the `GetPlacementRuleBundleByGroup` API. + bundle, err := suite.client.GetPlacementRuleBundleByGroup(suite.ctx, rule.GroupID) + re.NoError(err) + checkRuleFunc(re, bundle.Rules, rule, totalRuleCount, exist) +} + +func checkRuleFunc( + re *require.Assertions, + rules []*pd.Rule, rule *pd.Rule, totalRuleCount int, exist bool, +) { + re.Len(rules, totalRuleCount) + for _, r := range rules { + if r.ID != rule.ID { + continue + } + re.Equal(rule.GroupID, r.GroupID) + re.Equal(rule.ID, r.ID) + re.Equal(rule.Role, r.Role) + re.Equal(rule.Count, r.Count) + return + } + if exist { + re.Failf("Failed to check the rule", "rule %+v not found", rule) + } } func (suite *httpClientTestSuite) TestRegionLabel() { @@ -202,10 +270,36 @@ func (suite *httpClientTestSuite) TestRegionLabel() { func (suite *httpClientTestSuite) TestAccelerateSchedule() { re := suite.Require() - suspectRegions := suite.cluster.GetLeaderServer().GetRaftCluster().GetSuspectRegions() + raftCluster := suite.cluster.GetLeaderServer().GetRaftCluster() + for _, region := range []*core.RegionInfo{ + core.NewTestRegionInfo(10, 1, []byte("a1"), []byte("a2")), + core.NewTestRegionInfo(11, 1, []byte("a2"), []byte("a3")), + } { + err := raftCluster.HandleRegionHeartbeat(region) + re.NoError(err) + } + suspectRegions := raftCluster.GetSuspectRegions() re.Len(suspectRegions, 0) - err := suite.client.AccelerateSchedule(suite.ctx, []byte("a1"), []byte("a2")) + err := suite.client.AccelerateSchedule(suite.ctx, &pd.KeyRange{ + StartKey: []byte("a1"), + EndKey: []byte("a2")}) re.NoError(err) - suspectRegions = suite.cluster.GetLeaderServer().GetRaftCluster().GetSuspectRegions() + suspectRegions = raftCluster.GetSuspectRegions() re.Len(suspectRegions, 1) + raftCluster.ClearSuspectRegions() + suspectRegions = raftCluster.GetSuspectRegions() + re.Len(suspectRegions, 0) + err = suite.client.AccelerateScheduleInBatch(suite.ctx, []*pd.KeyRange{ + { + StartKey: []byte("a1"), + EndKey: []byte("a2"), + }, + { + StartKey: []byte("a2"), + EndKey: []byte("a3"), + }, + }) + re.NoError(err) + suspectRegions = raftCluster.GetSuspectRegions() + re.Len(suspectRegions, 2) }