diff --git a/.github/workflows/pd-tests.yaml b/.github/workflows/pd-tests.yaml index 9084c7545a8..223187737e0 100644 --- a/.github/workflows/pd-tests.yaml +++ b/.github/workflows/pd-tests.yaml @@ -25,24 +25,33 @@ jobs: strategy: fail-fast: true matrix: - worker_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + include: + - worker_id: 1 + name: 'Unit Test(1)' + - worker_id: 2 + name: 'Unit Test(2)' + - worker_id: 3 + name: 'Tools Test' + - worker_id: 4 + name: 'Client Integration Test' + - worker_id: 5 + name: 'TSO Integration Test' + - worker_id: 6 + name: 'MicroService Integration Test' outputs: - job-total: 13 + job-total: 6 steps: - name: Checkout code uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: '1.21' - - name: Make Test + - name: ${{ matrix.name }} env: WORKER_ID: ${{ matrix.worker_id }} - WORKER_COUNT: 13 - JOB_COUNT: 9 # 10 is tools test, 11, 12, 13 are for other integrations jobs run: | - make ci-test-job JOB_COUNT=$(($JOB_COUNT)) JOB_INDEX=$WORKER_ID + make ci-test-job JOB_INDEX=$WORKER_ID mv covprofile covprofile_$WORKER_ID - sed -i "/failpoint_binding/d" covprofile_$WORKER_ID - name: Upload coverage result ${{ matrix.worker_id }} uses: actions/upload-artifact@v4 with: @@ -62,7 +71,11 @@ jobs: - name: Merge env: TOTAL_JOBS: ${{needs.chunks.outputs.job-total}} - run: for i in $(seq 1 $TOTAL_JOBS); do cat covprofile_$i >> covprofile; done + run: | + for i in $(seq 1 $TOTAL_JOBS); do cat covprofile_$i >> covprofile; done + sed -i "/failpoint_binding/d" covprofile + # only keep the first line(`mode: aomic`) of the coverage profile + sed -i '2,${/mode: atomic/d;}' covprofile - name: Send coverage uses: codecov/codecov-action@v4.2.0 with: diff --git a/Makefile b/Makefile index 205896c377a..dca00012114 100644 --- a/Makefile +++ b/Makefile @@ -127,7 +127,7 @@ regions-dump: stores-dump: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/stores-dump stores-dump/main.go pd-ut: pd-xprog - cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ut pd-ut/ut.go + cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ut pd-ut/ut.go pd-ut/coverProfile.go pd-xprog: cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -tags xprog -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/xprog pd-ut/xprog.go @@ -227,7 +227,8 @@ failpoint-disable: install-tools ut: pd-ut @$(FAILPOINT_ENABLE) - ./bin/pd-ut run --race + # only run unit tests + ./bin/pd-ut run --ignore tests --race @$(CLEAN_UT_BINARY) @$(FAILPOINT_DISABLE) @@ -251,7 +252,7 @@ basic-test: install-tools go test $(BASIC_TEST_PKGS) || { $(FAILPOINT_DISABLE); exit 1; } @$(FAILPOINT_DISABLE) -ci-test-job: install-tools dashboard-ui +ci-test-job: install-tools dashboard-ui pd-ut @$(FAILPOINT_ENABLE) ./scripts/ci-subtask.sh $(JOB_COUNT) $(JOB_INDEX) || { $(FAILPOINT_DISABLE); exit 1; } @$(FAILPOINT_DISABLE) diff --git a/codecov.yml b/codecov.yml index bb439917e78..936eb3bbb11 100644 --- a/codecov.yml +++ b/codecov.yml @@ -24,9 +24,3 @@ flag_management: target: 74% # increase it if you want to enforce higher coverage for project, current setting as 74% is for do not let the error be reported and lose the meaning of warning. - type: patch target: 74% # increase it if you want to enforce higher coverage for project, current setting as 74% is for do not let the error be reported and lose the meaning of warning. - -ignore: - # Ignore the tool tests - - tests/dashboard - - tests/pdbackup - - tests/pdctl diff --git a/pkg/core/region.go b/pkg/core/region.go index be8fcddc179..c9a8455d4de 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -914,6 +914,8 @@ type RegionsInfo struct { learners map[uint64]*regionTree // storeID -> sub regionTree witnesses map[uint64]*regionTree // storeID -> sub regionTree pendingPeers map[uint64]*regionTree // storeID -> sub regionTree + // This tree is used to check the overlaps among all the subtrees. + overlapTree *regionTree } // NewRegionsInfo creates RegionsInfo with tree, regions, leaders and followers @@ -927,6 +929,7 @@ func NewRegionsInfo() *RegionsInfo { learners: make(map[uint64]*regionTree), witnesses: make(map[uint64]*regionTree), pendingPeers: make(map[uint64]*regionTree), + overlapTree: newRegionTreeWithCountRef(), } } @@ -1041,10 +1044,10 @@ func (r *RegionsInfo) CheckAndPutRootTree(ctx *MetaProcessContext, region *Regio // Usually used with CheckAndPutRootTree together. func (r *RegionsInfo) CheckAndPutSubTree(region *RegionInfo) { // new region get from root tree again - var newRegion *RegionInfo - newRegion = r.GetRegion(region.GetID()) + newRegion := r.GetRegion(region.GetID()) if newRegion == nil { - newRegion = region + // Make sure there is this region in the root tree, so as to ensure the correctness of reference count + return } r.UpdateSubTreeOrderInsensitive(newRegion) } @@ -1066,110 +1069,98 @@ func (r *RegionsInfo) UpdateSubTreeOrderInsensitive(region *RegionInfo) { origin = originItem.RegionInfo } rangeChanged := true - if origin != nil { + rangeChanged = !origin.rangeEqualsTo(region) + if r.preUpdateSubTreeLocked(rangeChanged, !origin.peersEqualTo(region), true, origin, region) { + return + } + } + r.updateSubTreeLocked(rangeChanged, nil, region) +} + +func (r *RegionsInfo) preUpdateSubTreeLocked( + rangeChanged, peerChanged, orderInsensitive bool, + origin, region *RegionInfo, +) (done bool) { + if orderInsensitive { re := region.GetRegionEpoch() oe := origin.GetRegionEpoch() isTermBehind := region.GetTerm() > 0 && region.GetTerm() < origin.GetTerm() if (isTermBehind || re.GetVersion() < oe.GetVersion() || re.GetConfVer() < oe.GetConfVer()) && !region.isRegionRecreated() { // Region meta is stale, skip. - return + return true } - rangeChanged = !origin.rangeEqualsTo(region) - - if rangeChanged || !origin.peersEqualTo(region) { - // If the range or peers have changed, the sub regionTree needs to be cleaned up. - // TODO: Improve performance by deleting only the different peers. - r.removeRegionFromSubTreeLocked(origin) - } else { - // The region tree and the subtree update is not atomic and the region tree is updated first. - // If there are two thread needs to update region tree, - // t1: thread-A update region tree - // t2: thread-B: update region tree again - // t3: thread-B: update subtree - // t4: thread-A: update region subtree - // to keep region tree consistent with subtree, we need to drop this update. - if tree, ok := r.subRegions[region.GetID()]; ok { - r.updateSubTreeStat(origin, region) - tree.RegionInfo = region - } - return + } + if rangeChanged || peerChanged { + // If the range or peers have changed, clean up the subtrees before updating them. + // TODO: improve performance by deleting only the different peers. + r.removeRegionFromSubTreeLocked(origin) + } else { + // The region tree and the subtree update is not atomic and the region tree is updated first. + // If there are two thread needs to update region tree, + // t1: thread-A update region tree + // t2: thread-B: update region tree again + // t3: thread-B: update subtree + // t4: thread-A: update region subtree + // to keep region tree consistent with subtree, we need to drop this update. + if tree, ok := r.subRegions[region.GetID()]; ok { + r.updateSubTreeStat(origin, region) + tree.RegionInfo = region } + return true } + return false +} +func (r *RegionsInfo) updateSubTreeLocked(rangeChanged bool, overlaps []*RegionInfo, region *RegionInfo) { if rangeChanged { - overlaps := r.getOverlapRegionFromSubTreeLocked(region) - for _, re := range overlaps { - r.removeRegionFromSubTreeLocked(re) + // TODO: only perform the remove operation on the overlapped peer. + if len(overlaps) == 0 { + // If the range has changed but the overlapped regions are not provided, collect them by `[]*regionItem`. + for _, item := range r.getOverlapRegionFromOverlapTreeLocked(region) { + r.removeRegionFromSubTreeLocked(item.RegionInfo) + } + } else { + // Remove all provided overlapped regions from the subtrees. + for _, overlap := range overlaps { + r.removeRegionFromSubTreeLocked(overlap) + } } } - + // Reinsert the region into all subtrees. item := ®ionItem{region} r.subRegions[region.GetID()] = item - // It has been removed and all information needs to be updated again. - // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { + r.overlapTree.update(item, false) + // Add leaders and followers. + setPeer := func(peersMap map[uint64]*regionTree, storeID uint64) { store, ok := peersMap[storeID] if !ok { - if !countRef { - store = newRegionTree() - } else { - store = newRegionTreeWithCountRef() - } + store = newRegionTree() peersMap[storeID] = store } store.update(item, false) } - - // Add to leaders and followers. for _, peer := range region.GetVoters() { storeID := peer.GetStoreId() if peer.GetId() == region.leader.GetId() { - // Add leader peer to leaders. - setPeer(r.leaders, storeID, item, true) + setPeer(r.leaders, storeID) } else { - // Add follower peer to followers. - setPeer(r.followers, storeID, item, false) + setPeer(r.followers, storeID) } } - + // Add other peers. setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { for _, peer := range peers { - storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item, false) + setPeer(peersMap, peer.GetStoreId()) } } - // Add to learners. setPeers(r.learners, region.GetLearners()) - // Add to witnesses. setPeers(r.witnesses, region.GetWitnesses()) - // Add to PendingPeers setPeers(r.pendingPeers, region.GetPendingPeers()) } -func (r *RegionsInfo) getOverlapRegionFromSubTreeLocked(region *RegionInfo) []*RegionInfo { - it := ®ionItem{RegionInfo: region} - overlaps := make([]*RegionInfo, 0) - overlapsMap := make(map[uint64]struct{}) - collectFromItemSlice := func(peersMap map[uint64]*regionTree, storeID uint64) { - if tree, ok := peersMap[storeID]; ok { - items := tree.overlaps(it) - for _, item := range items { - if _, ok := overlapsMap[item.GetID()]; !ok { - overlapsMap[item.GetID()] = struct{}{} - overlaps = append(overlaps, item.RegionInfo) - } - } - } - } - for _, peer := range region.GetMeta().GetPeers() { - storeID := peer.GetStoreId() - collectFromItemSlice(r.leaders, storeID) - collectFromItemSlice(r.followers, storeID) - collectFromItemSlice(r.learners, storeID) - collectFromItemSlice(r.witnesses, storeID) - } - return overlaps +func (r *RegionsInfo) getOverlapRegionFromOverlapTreeLocked(region *RegionInfo) []*regionItem { + return r.overlapTree.overlaps(®ionItem{RegionInfo: region}) } // GetRelevantRegions returns the relevant regions for a given region. @@ -1275,72 +1266,11 @@ func (r *RegionsInfo) UpdateSubTree(region, origin *RegionInfo, overlaps []*Regi r.st.Lock() defer r.st.Unlock() if origin != nil { - if rangeChanged || !origin.peersEqualTo(region) { - // If the range or peers have changed, the sub regionTree needs to be cleaned up. - // TODO: Improve performance by deleting only the different peers. - r.removeRegionFromSubTreeLocked(origin) - } else { - // The region tree and the subtree update is not atomic and the region tree is updated first. - // If there are two thread needs to update region tree, - // t1: thread-A update region tree - // t2: thread-B: update region tree again - // t3: thread-B: update subtree - // t4: thread-A: update region subtree - // to keep region tree consistent with subtree, we need to drop this update. - if tree, ok := r.subRegions[region.GetID()]; ok { - r.updateSubTreeStat(origin, region) - tree.RegionInfo = region - } + if r.preUpdateSubTreeLocked(rangeChanged, !origin.peersEqualTo(region), false, origin, region) { return } } - if rangeChanged { - for _, re := range overlaps { - r.removeRegionFromSubTreeLocked(re) - } - } - - item := ®ionItem{region} - r.subRegions[region.GetID()] = item - // It has been removed and all information needs to be updated again. - // Set peers then. - setPeer := func(peersMap map[uint64]*regionTree, storeID uint64, item *regionItem, countRef bool) { - store, ok := peersMap[storeID] - if !ok { - if !countRef { - store = newRegionTree() - } else { - store = newRegionTreeWithCountRef() - } - peersMap[storeID] = store - } - store.update(item, false) - } - - // Add to leaders and followers. - for _, peer := range region.GetVoters() { - storeID := peer.GetStoreId() - if peer.GetId() == region.leader.GetId() { - // Add leader peer to leaders. - setPeer(r.leaders, storeID, item, true) - } else { - // Add follower peer to followers. - setPeer(r.followers, storeID, item, false) - } - } - - setPeers := func(peersMap map[uint64]*regionTree, peers []*metapb.Peer) { - for _, peer := range peers { - storeID := peer.GetStoreId() - setPeer(peersMap, storeID, item, false) - } - } - // Add to learners. - setPeers(r.learners, region.GetLearners()) - // Add to witnesses. - setPeers(r.witnesses, region.GetWitnesses()) - // Add to PendingPeers - setPeers(r.pendingPeers, region.GetPendingPeers()) + r.updateSubTreeLocked(rangeChanged, overlaps, region) } func (r *RegionsInfo) updateSubTreeStat(origin *RegionInfo, region *RegionInfo) { @@ -1394,7 +1324,7 @@ func (r *RegionsInfo) RemoveRegion(region *RegionInfo) { // ResetRegionCache resets the regions info. func (r *RegionsInfo) ResetRegionCache() { r.t.Lock() - r.tree = newRegionTree() + r.tree = newRegionTreeWithCountRef() r.regions = make(map[uint64]*regionItem) r.t.Unlock() r.st.Lock() @@ -1404,6 +1334,7 @@ func (r *RegionsInfo) ResetRegionCache() { r.learners = make(map[uint64]*regionTree) r.witnesses = make(map[uint64]*regionTree) r.pendingPeers = make(map[uint64]*regionTree) + r.overlapTree = newRegionTreeWithCountRef() } // RemoveRegionFromSubTree removes RegionInfo from regionSubTrees @@ -1416,7 +1347,6 @@ func (r *RegionsInfo) RemoveRegionFromSubTree(region *RegionInfo) { // removeRegionFromSubTreeLocked removes RegionInfo from regionSubTrees func (r *RegionsInfo) removeRegionFromSubTreeLocked(region *RegionInfo) { - // Remove from leaders and followers. for _, peer := range region.GetMeta().GetPeers() { storeID := peer.GetStoreId() r.leaders[storeID].remove(region) @@ -1425,6 +1355,7 @@ func (r *RegionsInfo) removeRegionFromSubTreeLocked(region *RegionInfo) { r.witnesses[storeID].remove(region) r.pendingPeers[storeID].remove(region) } + r.overlapTree.remove(region) delete(r.subRegions, region.GetMeta().GetId()) } diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 43629fccda0..1b8f20cf9b2 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -778,27 +778,24 @@ func BenchmarkRandomSetRegionWithGetRegionSizeByRangeParallel(b *testing.B) { ) } -const keyLength = 100 - -func randomBytes(n int) []byte { - bytes := make([]byte, n) - _, err := rand.Read(bytes) - if err != nil { - panic(err) - } - return bytes -} +const ( + peerNum = 3 + storeNum = 10 + keyLength = 100 +) func newRegionInfoIDRandom(idAllocator id.Allocator) *RegionInfo { var ( peers []*metapb.Peer leader *metapb.Peer ) - storeNum := 10 - for i := 0; i < 3; i++ { + // Randomly select a peer as the leader. + leaderIdx := mrand.Intn(peerNum) + for i := 0; i < peerNum; i++ { id, _ := idAllocator.Alloc() - p := &metapb.Peer{Id: id, StoreId: uint64(i%storeNum + 1)} - if i == 0 { + // Randomly distribute the peers to different stores. + p := &metapb.Peer{Id: id, StoreId: uint64(mrand.Intn(storeNum) + 1)} + if i == leaderIdx { leader = p } peers = append(peers, p) @@ -817,13 +814,19 @@ func newRegionInfoIDRandom(idAllocator id.Allocator) *RegionInfo { ) } +func randomBytes(n int) []byte { + bytes := make([]byte, n) + _, err := rand.Read(bytes) + if err != nil { + panic(err) + } + return bytes +} + func BenchmarkAddRegion(b *testing.B) { regions := NewRegionsInfo() idAllocator := mockid.NewIDAllocator() - var items []*RegionInfo - for i := 0; i < 10000000; i++ { - items = append(items, newRegionInfoIDRandom(idAllocator)) - } + items := generateRegionItems(idAllocator, 10000000) b.ResetTimer() for i := 0; i < b.N; i++ { origin, overlaps, rangeChanged := regions.SetRegion(items[i]) @@ -831,6 +834,54 @@ func BenchmarkAddRegion(b *testing.B) { } } +func BenchmarkUpdateSubTreeOrderInsensitive(b *testing.B) { + idAllocator := mockid.NewIDAllocator() + for _, size := range []int{10, 100, 1000, 10000, 100000, 1000000, 10000000} { + regions := NewRegionsInfo() + items := generateRegionItems(idAllocator, size) + // Update the subtrees from an empty `*RegionsInfo`. + b.Run(fmt.Sprintf("from empty with size %d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + + // Update the subtrees from a non-empty `*RegionsInfo` with the same regions, + // which means the regions are completely non-overlapped. + b.Run(fmt.Sprintf("from non-overlapped regions with size %d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + + // Update the subtrees from a non-empty `*RegionsInfo` with different regions, + // which means the regions are most likely overlapped. + b.Run(fmt.Sprintf("from overlapped regions with size %d", size), func(b *testing.B) { + items = generateRegionItems(idAllocator, size) + b.ResetTimer() + for i := 0; i < b.N; i++ { + for idx := range items { + regions.UpdateSubTreeOrderInsensitive(items[idx]) + } + } + }) + } +} + +func generateRegionItems(idAllocator *mockid.IDAllocator, size int) []*RegionInfo { + items := make([]*RegionInfo, size) + for i := 0; i < size; i++ { + items[i] = newRegionInfoIDRandom(idAllocator) + } + return items +} + func BenchmarkRegionFromHeartbeat(b *testing.B) { peers := make([]*metapb.Peer, 0, 3) for i := uint64(1); i <= 3; i++ { @@ -1021,3 +1072,27 @@ func TestUpdateRegionEventualConsistency(t *testing.T) { re.Equal(int32(2), item.GetRef()) } } + +func TestCheckAndPutSubTree(t *testing.T) { + re := require.New(t) + regions := NewRegionsInfo() + region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + regions.CheckAndPutSubTree(region) + // should failed to put because the root tree is missing + re.Equal(0, regions.tree.length()) +} + +func TestCntRefAfterResetRegionCache(t *testing.T) { + re := require.New(t) + regions := NewRegionsInfo() + // Put the region first. + region := NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + regions.CheckAndPutRegion(region) + re.Equal(int32(2), region.GetRef()) + regions.ResetRegionCache() + // Put the region after reset. + region = NewTestRegionInfo(1, 1, []byte("a"), []byte("b")) + re.Zero(region.GetRef()) + regions.CheckAndPutRegion(region) + re.Equal(int32(2), region.GetRef()) +} diff --git a/scripts/ci-subtask.sh b/scripts/ci-subtask.sh index c00cba9c0a4..9bdce420d4a 100755 --- a/scripts/ci-subtask.sh +++ b/scripts/ci-subtask.sh @@ -3,63 +3,34 @@ # ./ci-subtask.sh ROOT_PATH_COV=$(pwd)/covprofile - -if [[ $2 -gt 9 ]]; then - # run tools tests - if [[ $2 -eq 10 ]]; then +# Currently, we only have 3 integration tests, so we can hardcode the task index. +integrations_dir=$(pwd)/tests/integrations + +case $1 in + 1) + # unit tests ignore `tests` + ./bin/pd-ut run --race --ignore tests --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 2) + # unit tests only in `tests` + ./bin/pd-ut run tests --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 3) + # tools tests cd ./tools && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 - exit - fi - - # Currently, we only have 3 integration tests, so we can hardcode the task index. - integrations_dir=$(pwd)/tests/integrations - integrations_tasks=($(find "$integrations_dir" -mindepth 1 -maxdepth 1 -type d)) - for t in "${integrations_tasks[@]}"; do - if [[ "$t" = "$integrations_dir/client" && $2 -eq 11 ]]; then - cd ./client && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 - cd $integrations_dir && make ci-test-job test_name=client && cat ./client/covprofile >> $ROOT_PATH_COV || exit 1 - elif [[ "$t" = "$integrations_dir/tso" && $2 -eq 12 ]]; then - cd $integrations_dir && make ci-test-job test_name=tso && cat ./tso/covprofile >> $ROOT_PATH_COV || exit 1 - elif [[ "$t" = "$integrations_dir/mcs" && $2 -eq 13 ]]; then - cd $integrations_dir && make ci-test-job test_name=mcs && cat ./mcs/covprofile >> $ROOT_PATH_COV || exit 1 - fi - done -else - # Get package test list. - packages=($(go list ./...)) - dirs=($(find . -iname "*_test.go" -exec dirname {} \; | sort -u | sed -e "s/^\./github.com\/tikv\/pd/")) - tasks=($(comm -12 <(printf "%s\n" "${packages[@]}") <(printf "%s\n" "${dirs[@]}"))) - - weight() { - [[ $1 == "github.com/tikv/pd/server/api" ]] && return 30 - [[ $1 == "github.com/tikv/pd/pkg/schedule" ]] && return 30 - [[ $1 == "github.com/tikv/pd/pkg/core" ]] && return 30 - [[ $1 == "github.com/tikv/pd/tests/server/api" ]] && return 30 - [[ $1 =~ "pd/tests" ]] && return 5 - return 1 - } - - # Create an associative array to store the weight of each task. - declare -A task_weights - for t in ${tasks[@]}; do - weight $t - task_weights[$t]=$? - done - - # Sort tasks by weight in descending order. - tasks=($(printf "%s\n" "${tasks[@]}" | sort -rn)) - - scores=($(seq "$1" | xargs -I{} echo 0)) - - res=() - for t in ${tasks[@]}; do - min_i=0 - for i in ${!scores[@]}; do - [[ ${scores[i]} -lt ${scores[$min_i]} ]] && min_i=$i - done - scores[$min_i]=$((${scores[$min_i]} + ${task_weights[$t]})) - [[ $(($min_i + 1)) -eq $2 ]] && res+=($t) - done - - CGO_ENABLED=1 go test -timeout=15m -tags deadlock -race -cover -covermode=atomic -coverprofile=$ROOT_PATH_COV -coverpkg=./... ${res[@]} -fi + ;; + 4) + # integration test client + ./bin/pd-ut it run client --race --coverprofile $ROOT_PATH_COV || exit 1 + # client tests + cd ./client && make ci-test-job && cat covprofile >> $ROOT_PATH_COV || exit 1 + ;; + 5) + # integration test tso + ./bin/pd-ut it run tso --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; + 6) + # integration test mcs + ./bin/pd-ut it run mcs --race --coverprofile $ROOT_PATH_COV || exit 1 + ;; +esac diff --git a/server/api/diagnostic_test.go b/server/api/diagnostic_test.go index c98717902c5..8c4089a8710 100644 --- a/server/api/diagnostic_test.go +++ b/server/api/diagnostic_test.go @@ -36,7 +36,7 @@ type diagnosticTestSuite struct { cleanup tu.CleanupFunc urlPrefix string configPrefix string - schedulerPrifex string + schedulerPrefix string } func TestDiagnosticTestSuite(t *testing.T) { @@ -50,7 +50,7 @@ func (suite *diagnosticTestSuite) SetupSuite() { addr := suite.svr.GetAddr() suite.urlPrefix = fmt.Sprintf("%s%s/api/v1/schedulers/diagnostic", addr, apiPrefix) - suite.schedulerPrifex = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) + suite.schedulerPrefix = fmt.Sprintf("%s%s/api/v1/schedulers", addr, apiPrefix) suite.configPrefix = fmt.Sprintf("%s%s/api/v1/config", addr, apiPrefix) mustBootstrapCluster(re, suite.svr) @@ -108,7 +108,7 @@ func (suite *diagnosticTestSuite) TestSchedulerDiagnosticAPI() { input["name"] = schedulers.BalanceRegionName body, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix, body, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("pending", balanceRegionURL) @@ -116,21 +116,23 @@ func (suite *diagnosticTestSuite) TestSchedulerDiagnosticAPI() { input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("paused", balanceRegionURL) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, suite.schedulerPrifex+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(testDialClient, suite.schedulerPrefix+"/"+schedulers.BalanceRegionName, pauseArgs, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("pending", balanceRegionURL) + fmt.Println("before put region") mustPutRegion(re, suite.svr, 1000, 1, []byte("a"), []byte("b"), core.SetApproximateSize(60)) + fmt.Println("after put region") suite.checkStatus("normal", balanceRegionURL) - deleteURL := fmt.Sprintf("%s/%s", suite.schedulerPrifex, schedulers.BalanceRegionName) + deleteURL := fmt.Sprintf("%s/%s", suite.schedulerPrefix, schedulers.BalanceRegionName) err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) suite.checkStatus("disabled", balanceRegionURL) diff --git a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go index 78431eb72c6..0c7683b569c 100644 --- a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go +++ b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go @@ -47,7 +47,6 @@ type keyspaceGroupTestSuite struct { cluster *tests.TestCluster server *tests.TestServer backendEndpoints string - dialClient *http.Client } func TestKeyspaceGroupTestSuite(t *testing.T) { @@ -67,11 +66,6 @@ func (suite *keyspaceGroupTestSuite) SetupTest() { suite.server = cluster.GetLeaderServer() re.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() - suite.dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, - } suite.cleanupFunc = func() { cancel() } @@ -81,7 +75,6 @@ func (suite *keyspaceGroupTestSuite) TearDownTest() { re := suite.Require() suite.cleanupFunc() suite.cluster.Destroy() - suite.dialClient.CloseIdleConnections() re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/acceleratedAllocNodes")) } @@ -347,7 +340,7 @@ func (suite *keyspaceGroupTestSuite) tryAllocNodesForKeyspaceGroup(re *require.A re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/alloc", id), bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() nodes := make([]endpoint.KeyspaceGroupMember, 0) @@ -364,7 +357,7 @@ func (suite *keyspaceGroupTestSuite) tryCreateKeyspaceGroup(re *require.Assertio re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, suite.server.GetAddr()+keyspaceGroupsPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() return resp.StatusCode @@ -373,7 +366,7 @@ func (suite *keyspaceGroupTestSuite) tryCreateKeyspaceGroup(re *require.Assertio func (suite *keyspaceGroupTestSuite) tryGetKeyspaceGroup(re *require.Assertions, id uint32) (*endpoint.KeyspaceGroup, int) { httpReq, err := http.NewRequest(http.MethodGet, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() kg := &endpoint.KeyspaceGroup{} @@ -390,7 +383,7 @@ func (suite *keyspaceGroupTestSuite) trySetNodesForKeyspaceGroup(re *require.Ass re.NoError(err) httpReq, err := http.NewRequest(http.MethodPatch, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), bytes.NewBuffer(data)) re.NoError(err) - resp, err := suite.dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { diff --git a/tests/integrations/mcs/members/member_test.go b/tests/integrations/mcs/members/member_test.go index 87a667e5344..d650d1ded4f 100644 --- a/tests/integrations/mcs/members/member_test.go +++ b/tests/integrations/mcs/members/member_test.go @@ -34,7 +34,7 @@ type memberTestSuite struct { cluster *tests.TestCluster server *tests.TestServer backendEndpoints string - dialClient pdClient.Client + pdClient pdClient.Client } func TestMemberTestSuite(t *testing.T) { @@ -53,7 +53,7 @@ func (suite *memberTestSuite) SetupTest() { suite.server = cluster.GetLeaderServer() re.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() - suite.dialClient = pdClient.NewClient("mcs-member-test", []string{suite.server.GetAddr()}) + suite.pdClient = pdClient.NewClient("mcs-member-test", []string{suite.server.GetAddr()}) // TSO nodes := make(map[string]bs.Server) @@ -86,30 +86,30 @@ func (suite *memberTestSuite) TearDownTest() { for _, cleanup := range suite.cleanupFunc { cleanup() } - if suite.dialClient != nil { - suite.dialClient.Close() + if suite.pdClient != nil { + suite.pdClient.Close() } suite.cluster.Destroy() } func (suite *memberTestSuite) TestMembers() { re := suite.Require() - members, err := suite.dialClient.GetMicroServiceMembers(suite.ctx, "tso") + members, err := suite.pdClient.GetMicroServiceMembers(suite.ctx, "tso") re.NoError(err) re.Len(members, utils.DefaultKeyspaceGroupReplicaCount) - members, err = suite.dialClient.GetMicroServiceMembers(suite.ctx, "scheduling") + members, err = suite.pdClient.GetMicroServiceMembers(suite.ctx, "scheduling") re.NoError(err) re.Len(members, 3) } func (suite *memberTestSuite) TestPrimary() { re := suite.Require() - primary, err := suite.dialClient.GetMicroServicePrimary(suite.ctx, "tso") + primary, err := suite.pdClient.GetMicroServicePrimary(suite.ctx, "tso") re.NoError(err) re.NotEmpty(primary) - primary, err = suite.dialClient.GetMicroServicePrimary(suite.ctx, "scheduling") + primary, err = suite.pdClient.GetMicroServicePrimary(suite.ctx, "scheduling") re.NoError(err) re.NotEmpty(primary) } diff --git a/tests/integrations/mcs/resourcemanager/resource_manager_test.go b/tests/integrations/mcs/resourcemanager/resource_manager_test.go index 17673213a97..ab7cd5321ad 100644 --- a/tests/integrations/mcs/resourcemanager/resource_manager_test.go +++ b/tests/integrations/mcs/resourcemanager/resource_manager_test.go @@ -957,7 +957,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } createJSON, err := json.Marshal(group) re.NoError(err) - resp, err := http.Post(getAddr(i)+"/resource-manager/api/v1/config/group", "application/json", strings.NewReader(string(createJSON))) + resp, err := tests.TestDialClient.Post(getAddr(i)+"/resource-manager/api/v1/config/group", "application/json", strings.NewReader(string(createJSON))) re.NoError(err) resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -982,7 +982,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } // Get Resource Group - resp, err = http.Get(getAddr(i) + "/resource-manager/api/v1/config/group/" + tcase.name) + resp, err = tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/group/" + tcase.name) re.NoError(err) re.Equal(http.StatusOK, resp.StatusCode) respString, err := io.ReadAll(resp.Body) @@ -995,7 +995,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { // Last one, Check list and delete all resource groups if i == len(testCasesSet1)-1 { - resp, err := http.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") + resp, err := tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") re.NoError(err) re.Equal(http.StatusOK, resp.StatusCode) respString, err := io.ReadAll(resp.Body) @@ -1023,7 +1023,7 @@ func (suite *resourceManagerClientTestSuite) TestBasicResourceGroupCURD() { } // verify again - resp1, err := http.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") + resp1, err := tests.TestDialClient.Get(getAddr(i) + "/resource-manager/api/v1/config/groups") re.NoError(err) re.Equal(http.StatusOK, resp1.StatusCode) respString1, err := io.ReadAll(resp1.Body) diff --git a/tests/integrations/mcs/resourcemanager/server_test.go b/tests/integrations/mcs/resourcemanager/server_test.go index 4e1fb018d56..24de29db3a6 100644 --- a/tests/integrations/mcs/resourcemanager/server_test.go +++ b/tests/integrations/mcs/resourcemanager/server_test.go @@ -63,7 +63,7 @@ func TestResourceManagerServer(t *testing.T) { // Test registered REST HTTP Handler url := addr + "/resource-manager/api/v1/config" { - resp, err := http.Get(url + "/groups") + resp, err := tests.TestDialClient.Get(url + "/groups") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -78,13 +78,13 @@ func TestResourceManagerServer(t *testing.T) { } createJSON, err := json.Marshal(group) re.NoError(err) - resp, err := http.Post(url+"/group", "application/json", strings.NewReader(string(createJSON))) + resp, err := tests.TestDialClient.Post(url+"/group", "application/json", strings.NewReader(string(createJSON))) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) } { - resp, err := http.Get(url + "/group/pingcap") + resp, err := tests.TestDialClient.Get(url + "/group/pingcap") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -95,7 +95,7 @@ func TestResourceManagerServer(t *testing.T) { // Test metrics handler { - resp, err := http.Get(addr + "/metrics") + resp, err := tests.TestDialClient.Get(addr + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -106,7 +106,7 @@ func TestResourceManagerServer(t *testing.T) { // Test status handler { - resp, err := http.Get(addr + "/status") + resp, err := tests.TestDialClient.Get(addr + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index e9033e5016a..cf2c6dd2508 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -29,12 +29,6 @@ import ( "github.com/tikv/pd/tests" ) -var testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - type apiTestSuite struct { suite.Suite env *tests.SchedulingTestEnvironment @@ -56,7 +50,6 @@ func (suite *apiTestSuite) TearDownSuite() { re := suite.Require() re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/changeCoordinatorTicker")) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/mcs/scheduling/server/changeRunCollectWaitTime")) - testDialClient.CloseIdleConnections() } func (suite *apiTestSuite) TestGetCheckerByName() { @@ -84,14 +77,14 @@ func (suite *apiTestSuite) checkGetCheckerByName(cluster *tests.TestCluster) { name := testCase.name // normal run resp := make(map[string]any) - err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) // paused err = co.PauseOrResumeChecker(name, 30) re.NoError(err) resp = make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) // resumed @@ -99,7 +92,7 @@ func (suite *apiTestSuite) checkGetCheckerByName(cluster *tests.TestCluster) { re.NoError(err) time.Sleep(time.Second) resp = make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } @@ -121,29 +114,29 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { }) // Test operators - err := testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &respSlice, + err := testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &respSlice, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.Empty(respSlice) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), []byte(``), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), []byte(``), testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/2"), testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/records"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators/records"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test checker - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.False(resp["paused"].(bool)) @@ -154,7 +147,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["delay"] = delay pauseArgs, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "checker/merge"), pauseArgs, testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } @@ -173,7 +166,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { // "/schedulers", http.MethodPost // "/schedulers/{name}", http.MethodDelete testutil.Eventually(re, func() bool { - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &respSlice, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), &respSlice, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) return slice.Contains(respSlice, "balance-leader-scheduler") @@ -184,18 +177,18 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["delay"] = delay pauseArgs, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), pauseArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), pauseArgs, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } postScheduler(30) postScheduler(0) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/diagnostic/balance-leader-scheduler"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/diagnostic/balance-leader-scheduler"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "scheduler-config"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "scheduler-config"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) re.Contains(resp, "balance-leader-scheduler") @@ -206,16 +199,16 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { "balance-hot-region-scheduler", } for _, schedulerName := range schedulers { - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s/%s/%s", urlPrefix, "scheduler-config", schedulerName, "list"), &resp, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s/%s/%s", urlPrefix, "scheduler-config", schedulerName, "list"), &resp, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), nil, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), nil, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers/balance-leader-scheduler"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) @@ -223,74 +216,74 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { input["name"] = "balance-leader-scheduler" b, err := json.Marshal(input) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), b, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "schedulers"), b, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) // Test hotspot var hotRegions statistics.StoreHotPeersInfos - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/write"), &hotRegions, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/write"), &hotRegions, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/read"), &hotRegions, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/read"), &hotRegions, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var stores handler.HotStoreStats - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/stores"), &stores, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/stores"), &stores, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var buckets handler.HotBucketsResponse - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/buckets"), &buckets, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/buckets"), &buckets, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var history storage.HistoryHotRegions - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "hotspot/regions/history"), &history, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test region label var labelRules []*labeler.LabelRule - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules"), &labelRules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules"), &labelRules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSONWithBody(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules/ids"), []byte(`["rule1", "rule3"]`), + err = testutil.ReadGetJSONWithBody(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rules/ids"), []byte(`["rule1", "rule3"]`), &labelRules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rule/rule1"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/region-label/rule/rule1"), nil, testutil.StatusNotOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1"), nil, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/label/key"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/label/key"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/labels"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "region/id/1/labels"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test Region body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`[{"start_key":"%s", "end_key": "%s"}, {"start_key":"%s", "end_key": "%s"}]`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3")), hex.EncodeToString([]byte("a4")), hex.EncodeToString([]byte("a6"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule/batch"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/accelerate-schedule/batch"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("b1")), hex.EncodeToString([]byte("b3"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/scatter"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/scatter"), []byte(body), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) body = fmt.Sprintf(`{"retry_limit":%v, "split_keys": ["%s","%s","%s"]}`, 3, hex.EncodeToString([]byte("bbb")), hex.EncodeToString([]byte("ccc")), hex.EncodeToString([]byte("ddd"))) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/split"), []byte(body), + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "regions/split"), []byte(body), testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a2"))), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a2"))), nil, testutil.StatusOK(re), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test rules: only forward `GET` request @@ -308,73 +301,73 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { rulesArgs, err := json.Marshal(rules) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/batch"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/group/pd"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) var fit placement.RegionFit - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/region/2/detail"), &fit, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules/key/0000000000000001"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule/pd/2"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group/pd"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_group"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rule_groups"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) - err = testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), + err = testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), rulesArgs, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) // test redirect is disabled - err = testutil.CheckGetJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, + err = testutil.CheckGetJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), nil, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", urlPrefix, "config/placement-rule/pd"), http.NoBody) re.NoError(err) req.Header.Set(apiutil.XForbiddenForwardToMicroServiceHeader, "true") - httpResp, err := testDialClient.Do(req) + httpResp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(http.StatusOK, httpResp.StatusCode) defer httpResp.Body.Close() @@ -395,7 +388,7 @@ func (suite *apiTestSuite) checkConfig(cluster *tests.TestCluster) { urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/config", addr) var cfg config.Config - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(cfg.GetListenAddr(), s.GetConfig().GetListenAddr()) re.Equal(cfg.Schedule.LeaderScheduleLimit, s.GetConfig().Schedule.LeaderScheduleLimit) re.Equal(cfg.Schedule.EnableCrossTableMerge, s.GetConfig().Schedule.EnableCrossTableMerge) @@ -427,7 +420,7 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { // Test config forward // Expect to get same config in scheduling server and api server testutil.Eventually(re, func() bool { - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(cfg["schedule"].(map[string]any)["leader-schedule-limit"], float64(opts.GetLeaderScheduleLimit())) re.Equal(cfg["replication"].(map[string]any)["max-replicas"], @@ -442,10 +435,10 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { "max-replicas": 4, }) re.NoError(err) - err = testutil.CheckPostJSON(testDialClient, urlPrefix, reqData, testutil.StatusOK(re)) + err = testutil.CheckPostJSON(tests.TestDialClient, urlPrefix, reqData, testutil.StatusOK(re)) re.NoError(err) testutil.Eventually(re, func() bool { - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) return cfg["replication"].(map[string]any)["max-replicas"] == 4. && opts.GetReplicationConfig().MaxReplicas == 4. }) @@ -454,11 +447,11 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { // Expect to get new config in scheduling server but not old config in api server opts.GetScheduleConfig().LeaderScheduleLimit = 100 re.Equal(100, int(opts.GetLeaderScheduleLimit())) - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(100., cfg["schedule"].(map[string]any)["leader-schedule-limit"]) opts.GetReplicationConfig().MaxReplicas = 5 re.Equal(5, int(opts.GetReplicationConfig().MaxReplicas)) - testutil.ReadGetJSON(re, testDialClient, urlPrefix, &cfg) + testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &cfg) re.Equal(5., cfg["replication"].(map[string]any)["max-replicas"]) } @@ -480,11 +473,11 @@ func (suite *apiTestSuite) checkAdminRegionCache(cluster *tests.TestCluster) { addr := schedulingServer.GetAddr() urlPrefix := fmt.Sprintf("%s/scheduling/api/v1/admin/cache/regions", addr) - err := testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) + err := testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) re.NoError(err) re.Equal(2, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) - err = testutil.CheckDelete(testDialClient, urlPrefix, testutil.StatusOK(re)) + err = testutil.CheckDelete(tests.TestDialClient, urlPrefix, testutil.StatusOK(re)) re.NoError(err) re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) } @@ -509,12 +502,12 @@ func (suite *apiTestSuite) checkAdminRegionCacheForward(cluster *tests.TestClust addr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/admin/cache/region", addr) - err := testutil.CheckDelete(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) + err := testutil.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "30"), testutil.StatusOK(re)) re.NoError(err) re.Equal(2, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(2, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) - err = testutil.CheckDelete(testDialClient, urlPrefix+"s", testutil.StatusOK(re)) + err = testutil.CheckDelete(tests.TestDialClient, urlPrefix+"s", testutil.StatusOK(re)) re.NoError(err) re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(0, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) @@ -544,14 +537,14 @@ func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // follower will forward to scheduling server directly re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true"), ) re.NoError(err) } else { // follower will forward to leader server re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader), ) re.NoError(err) @@ -560,7 +553,7 @@ func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { // follower will forward to leader server re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) results := make(map[string]any) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader), ) re.NoError(err) @@ -576,7 +569,7 @@ func (suite *apiTestSuite) checkMetrics(cluster *tests.TestCluster) { testutil.Eventually(re, func() bool { return s.IsServing() }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - resp, err := http.Get(s.GetConfig().GetAdvertiseListenAddr() + "/metrics") + resp, err := tests.TestDialClient.Get(s.GetConfig().GetAdvertiseListenAddr() + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -595,7 +588,7 @@ func (suite *apiTestSuite) checkStatus(cluster *tests.TestCluster) { testutil.Eventually(re, func() bool { return s.IsServing() }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) - resp, err := http.Get(s.GetConfig().GetAdvertiseListenAddr() + "/status") + resp, err := tests.TestDialClient.Get(s.GetConfig().GetAdvertiseListenAddr() + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -659,34 +652,34 @@ func (suite *apiTestSuite) checkStores(cluster *tests.TestCluster) { apiServerAddr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/stores", apiServerAddr) var resp map[string]any - err := testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["stores"].([]any), 3) scheServerAddr := cluster.GetSchedulingPrimaryServer().GetAddr() urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["stores"].([]any), 3) // Test /stores/{id} urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/1", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv1", resp["store"].(map[string]any)["address"]) re.Equal("Up", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/6", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv6", resp["store"].(map[string]any)["address"]) re.Equal("Offline", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/7", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal("tikv7", resp["store"].(map[string]any)["address"]) re.Equal("Tombstone", resp["store"].(map[string]any)["state_name"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/stores/233", scheServerAddr) - testutil.CheckGetJSON(testDialClient, urlPrefix, nil, + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found")) } @@ -703,27 +696,27 @@ func (suite *apiTestSuite) checkRegions(cluster *tests.TestCluster) { apiServerAddr := cluster.GetLeaderServer().GetAddr() urlPrefix := fmt.Sprintf("%s/pd/api/v1/regions", apiServerAddr) var resp map[string]any - err := testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err := testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["regions"].([]any), 3) scheServerAddr := cluster.GetSchedulingPrimaryServer().GetAddr() urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3, int(resp["count"].(float64))) re.Len(resp["regions"].([]any), 3) // Test /regions/{id} and /regions/count urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/1", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) key := fmt.Sprintf("%x", "a") re.Equal(key, resp["start_key"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/count", scheServerAddr) - err = testutil.ReadGetJSON(re, testDialClient, urlPrefix, &resp) + err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3., resp["count"]) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/233", scheServerAddr) - testutil.CheckGetJSON(testDialClient, urlPrefix, nil, + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found")) } diff --git a/tests/integrations/mcs/tso/api_test.go b/tests/integrations/mcs/tso/api_test.go index dc9bfa1e291..4d6f9b33e3b 100644 --- a/tests/integrations/mcs/tso/api_test.go +++ b/tests/integrations/mcs/tso/api_test.go @@ -42,13 +42,6 @@ const ( tsoKeyspaceGroupsPrefix = "/tso/api/v1/keyspace-groups" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - type tsoAPITestSuite struct { suite.Suite ctx context.Context @@ -110,13 +103,13 @@ func (suite *tsoAPITestSuite) TestForwardResetTS() { // Test reset ts input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err := testutil.CheckPostJSON(dialClient, url, input, + err := testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // Test reset ts with invalid tso input = []byte(`{}`) - err = testutil.CheckPostJSON(dialClient, url, input, + err = testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) } @@ -124,7 +117,7 @@ func (suite *tsoAPITestSuite) TestForwardResetTS() { func mustGetKeyspaceGroupMembers(re *require.Assertions, server *tso.Server) map[uint32]*apis.KeyspaceGroupMember { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+tsoKeyspaceGroupsPrefix+"/members", http.NoBody) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() data, err := io.ReadAll(httpResp.Body) @@ -177,14 +170,14 @@ func TestTSOServerStartFirst(t *testing.T) { re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, addr+"/pd/api/v2/tso/keyspace-groups/0/split", bytes.NewBuffer(jsonBody)) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() re.Equal(http.StatusOK, httpResp.StatusCode) httpReq, err = http.NewRequest(http.MethodGet, addr+"/pd/api/v2/tso/keyspace-groups/0", http.NoBody) re.NoError(err) - httpResp, err = dialClient.Do(httpReq) + httpResp, err = tests.TestDialClient.Do(httpReq) re.NoError(err) data, err := io.ReadAll(httpResp.Body) re.NoError(err) @@ -219,20 +212,20 @@ func TestForwardOnlyTSONoScheduling(t *testing.T) { // Test /operators, it should not forward when there is no scheduling server. var slice []string - err = testutil.ReadGetJSON(re, dialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, + err = testutil.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "operators"), &slice, testutil.WithoutHeader(re, apiutil.XForwardedToMicroServiceHeader)) re.NoError(err) re.Empty(slice) // Test admin/reset-ts, it should forward to tso server. input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err = testutil.CheckPostJSON(dialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully"), testutil.WithHeader(re, apiutil.XForwardedToMicroServiceHeader, "true")) re.NoError(err) // If close tso server, it should try forward to tso server, but return error in api mode. ttc.Destroy() - err = testutil.CheckPostJSON(dialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, + err = testutil.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, "admin/reset-ts"), input, testutil.Status(re, http.StatusInternalServerError), testutil.StringContain(re, "[PD:apiutil:ErrRedirect]redirect failed")) re.NoError(err) } @@ -241,7 +234,7 @@ func (suite *tsoAPITestSuite) TestMetrics() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/metrics") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/metrics") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -254,7 +247,7 @@ func (suite *tsoAPITestSuite) TestStatus() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/status") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/status") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -271,7 +264,7 @@ func (suite *tsoAPITestSuite) TestConfig() { re := suite.Require() primary := suite.tsoCluster.WaitForDefaultPrimaryServing(re) - resp, err := http.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/tso/api/v1/config") + resp, err := tests.TestDialClient.Get(primary.GetConfig().GetAdvertiseListenAddr() + "/tso/api/v1/config") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/integrations/mcs/tso/server_test.go b/tests/integrations/mcs/tso/server_test.go index 108740e46f9..260395e4209 100644 --- a/tests/integrations/mcs/tso/server_test.go +++ b/tests/integrations/mcs/tso/server_test.go @@ -111,13 +111,13 @@ func (suite *tsoServerTestSuite) TestTSOServerStartAndStopNormally() { url := s.GetAddr() + tsoapi.APIPathPrefix + "/admin/reset-ts" // Test reset ts input := []byte(`{"tso":"121312", "force-use-larger":true}`) - err = testutil.CheckPostJSON(dialClient, url, input, + err = testutil.CheckPostJSON(tests.TestDialClient, url, input, testutil.StatusOK(re), testutil.StringContain(re, "Reset ts successfully")) re.NoError(err) // Test reset ts with invalid tso input = []byte(`{}`) - err = testutil.CheckPostJSON(dialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, + err = testutil.CheckPostJSON(tests.TestDialClient, suite.backendEndpoints+"/pd/api/v1/admin/reset-ts", input, testutil.StatusNotOK(re), testutil.StringContain(re, "invalid tso value")) re.NoError(err) } @@ -583,7 +583,7 @@ func (suite *CommonTestSuite) TestBootstrapDefaultKeyspaceGroup() { // check the default keyspace group check := func() { - resp, err := http.Get(suite.pdLeader.GetServer().GetConfig().AdvertiseClientUrls + "/pd/api/v2/tso/keyspace-groups") + resp, err := tests.TestDialClient.Get(suite.pdLeader.GetServer().GetConfig().AdvertiseClientUrls + "/pd/api/v2/tso/keyspace-groups") re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) diff --git a/tests/scheduling_cluster.go b/tests/scheduling_cluster.go index 1768c4128cc..434a6bd9a48 100644 --- a/tests/scheduling_cluster.go +++ b/tests/scheduling_cluster.go @@ -113,7 +113,7 @@ func (tc *TestSchedulingCluster) WaitForPrimaryServing(re *require.Assertions) * } } return false - }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) + }, testutil.WithWaitFor(10*time.Second), testutil.WithTickInterval(50*time.Millisecond)) return primary } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index b32a3afdf25..091d1488177 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -66,7 +66,7 @@ func TestReconnect(t *testing.T) { re.NotEmpty(leader) for name, s := range cluster.GetServers() { if name != leader { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) res.Body.Close() re.Equal(http.StatusOK, res.StatusCode) @@ -83,7 +83,7 @@ func TestReconnect(t *testing.T) { for name, s := range cluster.GetServers() { if name != leader { testutil.Eventually(re, func() bool { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusOK @@ -98,7 +98,7 @@ func TestReconnect(t *testing.T) { for name, s := range cluster.GetServers() { if name != leader && name != newLeader { testutil.Eventually(re, func() bool { - res, err := http.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") + res, err := tests.TestDialClient.Get(s.GetConfig().AdvertiseClientUrls + "/pd/api/v1/version") re.NoError(err) defer res.Body.Close() return res.StatusCode == http.StatusServiceUnavailable @@ -148,7 +148,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) @@ -156,7 +156,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { labels := make(map[string]any) labels["testkey"] = "testvalue" data, _ = json.Marshal(labels) - resp, err = dialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?seconds=1", "application/json", bytes.NewBuffer(data)) + resp, err = tests.TestDialClient.Post(leader.GetAddr()+"/pd/api/v1/debug/pprof/profile?seconds=1", "application/json", bytes.NewBuffer(data)) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -176,7 +176,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) @@ -199,7 +199,7 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -219,14 +219,14 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) // returns StatusOK when no rate-limit config req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -240,7 +240,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { jsonBody, err := json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config/rate-limit", bytes.NewBuffer(jsonBody)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -249,7 +249,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -266,7 +266,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second * 2) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -283,7 +283,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -310,7 +310,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -327,7 +327,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second * 2) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -344,7 +344,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { time.Sleep(time.Second) for i := 0; i < 2; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) data, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -359,14 +359,14 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -381,7 +381,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) @@ -392,7 +392,7 @@ func (suite *middlewareTestSuite) TestSwaggerUrl() { leader := suite.cluster.GetLeaderServer() re.NotNil(leader) req, _ := http.NewRequest(http.MethodGet, leader.GetAddr()+"/swagger/ui/index", http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(http.StatusNotFound, resp.StatusCode) resp.Body.Close() @@ -408,20 +408,20 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) timeUnix := time.Now().Unix() - 20 req, _ = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() re.NoError(err) req, _ = http.NewRequest(http.MethodGet, leader.GetAddr()+"/metrics", http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) @@ -440,14 +440,14 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { timeUnix = time.Now().Unix() - 20 req, _ = http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() re.NoError(err) req, _ = http.NewRequest(http.MethodGet, leader.GetAddr()+"/metrics", http.NoBody) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) @@ -460,7 +460,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { data, err = json.Marshal(input) re.NoError(err) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) @@ -478,13 +478,13 @@ func (suite *middlewareTestSuite) TestAuditLocalLogBackend() { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled()) req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) - resp, err = dialClient.Do(req) + resp, err = tests.TestDialClient.Do(req) re.NoError(err) _, err = io.ReadAll(resp.Body) resp.Body.Close() @@ -506,7 +506,7 @@ func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -528,7 +528,7 @@ func BenchmarkDoRequestWithPrometheusAudit(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -550,7 +550,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { } data, _ := json.Marshal(input) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() b.StartTimer() for i := 0; i < b.N; i++ { @@ -563,7 +563,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { func doTestRequestWithLogAudit(srv *tests.TestServer) { req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody) req.Header.Set(apiutil.XCallerIDHeader, "test") - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() } @@ -571,7 +571,7 @@ func doTestRequestWithPrometheus(srv *tests.TestServer) { timeUnix := time.Now().Unix() - 20 req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody) req.Header.Set(apiutil.XCallerIDHeader, "test") - resp, _ := dialClient.Do(req) + resp, _ := tests.TestDialClient.Do(req) resp.Body.Close() } @@ -635,7 +635,7 @@ func (suite *redirectorTestSuite) TestAllowFollowerHandle() { request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) request.Header.Add(apiutil.PDAllowFollowerHandleHeader, "true") - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) re.Equal("", resp.Header.Get(apiutil.PDRedirectorHeader)) defer resp.Body.Close() @@ -660,7 +660,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower without redirectorHeader is OK. request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -670,7 +670,7 @@ func (suite *redirectorTestSuite) TestNotLeader() { // Request to follower with redirectorHeader will fail. request.RequestURI = "" request.Header.Set(apiutil.PDRedirectorHeader, "pd") - resp1, err := dialClient.Do(request) + resp1, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp1.Body.Close() re.NotEqual(http.StatusOK, resp1.StatusCode) @@ -689,7 +689,7 @@ func (suite *redirectorTestSuite) TestXForwardedFor() { addr := follower.GetAddr() + "/pd/api/v1/regions" request, err := http.NewRequest(http.MethodGet, addr, http.NoBody) re.NoError(err) - resp, err := dialClient.Do(request) + resp, err := tests.TestDialClient.Do(request) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -701,7 +701,7 @@ func (suite *redirectorTestSuite) TestXForwardedFor() { } func mustRequestSuccess(re *require.Assertions, s *server.Server) http.Header { - resp, err := dialClient.Get(s.GetAddr() + "/pd/api/v1/version") + resp, err := tests.TestDialClient.Get(s.GetAddr() + "/pd/api/v1/version") re.NoError(err) defer resp.Body.Close() _, err = io.ReadAll(resp.Body) @@ -795,7 +795,7 @@ func TestRemovingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=removing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -819,7 +819,7 @@ func TestRemovingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=removing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -996,7 +996,7 @@ func TestPreparingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=preparing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { @@ -1021,7 +1021,7 @@ func TestPreparingProgress(t *testing.T) { } url := leader.GetAddr() + "/pd/api/v1/stores/progress?action=preparing" req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) defer resp.Body.Close() output, err := io.ReadAll(resp.Body) @@ -1078,7 +1078,7 @@ func TestPreparingProgress(t *testing.T) { func sendRequest(re *require.Assertions, url string, method string, statusCode int) []byte { req, _ := http.NewRequest(method, url, http.NoBody) - resp, err := dialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) re.Equal(statusCode, resp.StatusCode) output, err := io.ReadAll(resp.Body) diff --git a/tests/server/api/checker_test.go b/tests/server/api/checker_test.go index 0304d7fd369..54298b405f1 100644 --- a/tests/server/api/checker_test.go +++ b/tests/server/api/checker_test.go @@ -73,14 +73,14 @@ func testErrCases(re *require.Assertions, cluster *tests.TestCluster) { input := make(map[string]any) pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) re.NoError(err) // negative delay input["delay"] = -10 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/merge", pauseArgs, tu.StatusNotOK(re)) re.NoError(err) // wrong name @@ -88,12 +88,12 @@ func testErrCases(re *require.Assertions, cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusNotOK(re)) re.NoError(err) } @@ -102,28 +102,28 @@ func testGetStatus(re *require.Assertions, cluster *tests.TestCluster, name stri urlPrefix := fmt.Sprintf("%s/pd/api/v1/checker", cluster.GetLeaderServer().GetAddr()) // normal run resp := make(map[string]any) - err := tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err := tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) // paused input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) resp = make(map[string]any) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) // resumed input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) resp = make(map[string]any) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } @@ -137,18 +137,18 @@ func testPauseOrResume(re *require.Assertions, cluster *tests.TestCluster, name input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.True(resp["paused"].(bool)) input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) @@ -157,14 +157,14 @@ func testPauseOrResume(re *require.Assertions, cluster *tests.TestCluster, name input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+name, pauseArgs, tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s/%s", urlPrefix, name), &resp) re.NoError(err) re.False(resp["paused"].(bool)) } diff --git a/tests/server/api/operator_test.go b/tests/server/api/operator_test.go index 752325ea3da..c3b86f9fde0 100644 --- a/tests/server/api/operator_test.go +++ b/tests/server/api/operator_test.go @@ -18,7 +18,6 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "sort" "strconv" "strings" @@ -35,15 +34,6 @@ import ( "github.com/tikv/pd/tests" ) -var ( - // testDialClient used to dial http request. only used for test. - testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, - } -) - type operatorTestSuite struct { suite.Suite env *tests.SchedulingTestEnvironment @@ -112,35 +102,35 @@ func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) regionURL := fmt.Sprintf("%s/operators/%d", urlPrefix, region.GetId()) - err := tu.CheckGetJSON(testDialClient, regionURL, nil, + err := tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) recordURL := fmt.Sprintf("%s/operators/records?from=%s", urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "add learner peer 1 on store 3"), tu.StringContain(re, "RUNNING")) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusOK(re), tu.StringContain(re, "admin-add-peer {add peer: store [3]}")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"remove-peer", "region_id": 1, "store_id": 2}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "remove peer on store 2"), tu.StringContain(re, "RUNNING")) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, recordURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, recordURL, nil, tu.StatusOK(re), tu.StringContain(re, "admin-remove-peer {rm peer: store [2]}")) re.NoError(err) @@ -150,26 +140,26 @@ func (suite *operatorTestSuite) checkAddRemovePeer(cluster *tests.TestCluster) { NodeState: metapb.NodeState_Serving, LastHeartbeat: time.Now().UnixNano(), }) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-learner", "region_id": 1, "store_id": 4}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, "add learner peer 2 on store 4")) re.NoError(err) // Fail to add peer to tombstone store. err = cluster.GetLeaderServer().GetRaftCluster().RemoveStore(3, true) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 1, "store_id": 3}`), tu.StatusNotOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-peer", "region_id": 1, "from_store_id": 1, "to_store_id": 3}`), tu.StatusNotOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"transfer-region", "region_id": 1, "to_store_ids": [1, 2, 3]}`), tu.StatusNotOK(re)) re.NoError(err) // Fail to get operator if from is latest. time.Sleep(time.Second) url := fmt.Sprintf("%s/operators/records?from=%s", urlPrefix, strconv.FormatInt(time.Now().Unix(), 10)) - err = tu.CheckGetJSON(testDialClient, url, nil, + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) } @@ -214,17 +204,17 @@ func (suite *operatorTestSuite) checkMergeRegionOperator(cluster *tests.TestClus tests.MustPutRegionInfo(re, cluster, r3) urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - tu.CheckDelete(testDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(re)) + tu.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 20, "target_region_id": 10}`), tu.StatusOK(re)) re.NoError(err) - tu.CheckDelete(testDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), + tu.CheckDelete(tests.TestDialClient, fmt.Sprintf("%s/operators/%d", urlPrefix, 10), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 30}`), tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 30, "target_region_id": 10}`), tu.StatusNotOK(re), tu.StringContain(re, "not adjacent")) re.NoError(err) } @@ -287,7 +277,7 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) regionURL := fmt.Sprintf("%s/operators/%d", urlPrefix, region.GetId()) - err := tu.CheckGetJSON(testDialClient, regionURL, nil, + err := tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusNotOK(re), tu.StringContain(re, "operator not found")) re.NoError(err) convertStepsToStr := func(steps []string) string { @@ -462,7 +452,7 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te } reqData, e := json.Marshal(data) re.NoError(e) - err := tu.CheckPostJSON(testDialClient, url, reqData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, url, reqData, tu.StatusOK(re)) re.NoError(err) if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // wait for the scheduling server to update the config @@ -491,19 +481,19 @@ func (suite *operatorTestSuite) checkTransferRegionWithPlacementRule(cluster *te re.NoError(err) } if testCase.expectedError == nil { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusOK(re)) } else { - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), testCase.input, tu.StatusNotOK(re), tu.StringContain(re, testCase.expectedError.Error())) } re.NoError(err) if len(testCase.expectSteps) > 0 { - err = tu.CheckGetJSON(testDialClient, regionURL, nil, + err = tu.CheckGetJSON(tests.TestDialClient, regionURL, nil, tu.StatusOK(re), tu.StringContain(re, testCase.expectSteps)) re.NoError(err) - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusOK(re)) } else { - err = tu.CheckDelete(testDialClient, regionURL, tu.StatusNotOK(re)) + err = tu.CheckDelete(tests.TestDialClient, regionURL, tu.StatusNotOK(re)) } re.NoError(err) } @@ -552,7 +542,7 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu resp := make([]operator.OpObject, 0) // No operator. - err := tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err := tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Empty(resp) @@ -564,9 +554,9 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu r3 := core.NewTestRegionInfo(30, 1, []byte("c"), []byte("d"), core.SetWrittenBytes(500), core.SetReadBytes(800), core.SetRegionConfVer(3), core.SetRegionVersion(2)) tests.MustPutRegionInfo(re, cluster, r3) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Len(resp, 2) less := func(i, j int) bool { @@ -601,9 +591,9 @@ func (suite *operatorTestSuite) checkGetOperatorsAsObject(cluster *tests.TestClu } regionInfo := core.NewRegionInfo(region, peer1) tests.MustPutRegionInfo(re, cluster, regionInfo) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 40, "store_id": 3}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 40, "store_id": 3}`), tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, objURL, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, objURL, &resp) re.NoError(err) re.Len(resp, 3) sort.Slice(resp, less) @@ -651,15 +641,15 @@ func (suite *operatorTestSuite) checkRemoveOperators(cluster *tests.TestCluster) tests.MustPutRegionInfo(re, cluster, r3) urlPrefix := fmt.Sprintf("%s/pd/api/v1", cluster.GetLeaderServer().GetAddr()) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"merge-region", "source_region_id": 10, "target_region_id": 20}`), tu.StatusOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 30, "store_id": 4}`), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/operators", urlPrefix), []byte(`{"name":"add-peer", "region_id": 30, "store_id": 4}`), tu.StatusOK(re)) re.NoError(err) url := fmt.Sprintf("%s/operators", urlPrefix) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "merge: region 10 to 20"), tu.StringContain(re, "add peer: store [4]")) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "merge: region 10 to 20"), tu.StringContain(re, "add peer: store [4]")) re.NoError(err) - err = tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringNotContain(re, "merge: region 10 to 20"), tu.StringNotContain(re, "add peer: store [4]")) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re), tu.StringNotContain(re, "merge: region 10 to 20"), tu.StringNotContain(re, "add peer: store [4]")) re.NoError(err) } diff --git a/tests/server/api/region_test.go b/tests/server/api/region_test.go index f36d9cbccf7..2ff0b5d4b86 100644 --- a/tests/server/api/region_test.go +++ b/tests/server/api/region_test.go @@ -57,7 +57,7 @@ func (suite *regionTestSuite) TearDownTest() { pdAddr := cluster.GetConfig().GetClientURL() for _, region := range leader.GetRegions() { url := fmt.Sprintf("%s/pd/api/v1/admin/cache/region/%d", pdAddr, region.GetID()) - err := tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err := tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) } re.Empty(leader.GetRegions()) @@ -71,7 +71,7 @@ func (suite *regionTestSuite) TearDownTest() { data, err := json.Marshal([]placement.GroupBundle{def}) re.NoError(err) urlPrefix := cluster.GetLeaderServer().GetAddr() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) // clean stores for _, store := range leader.GetStores() { @@ -132,7 +132,7 @@ func (suite *regionTestSuite) checkSplitRegions(cluster *tests.TestCluster) { re.Equal([]uint64{newRegionID}, s.NewRegionsID) } re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/handler/splitResponses", fmt.Sprintf("return(%v)", newRegionID))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/split", urlPrefix), []byte(body), checkOpt) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/split", urlPrefix), []byte(body), checkOpt) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/handler/splitResponses")) re.NoError(err) } @@ -162,7 +162,7 @@ func (suite *regionTestSuite) checkAccelerateRegionsScheduleInRange(cluster *tes checkRegionCount(re, cluster, regionCount) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", urlPrefix), []byte(body), + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/accelerate-schedule", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) idList := leader.GetRaftCluster().GetSuspectRegions() @@ -198,7 +198,7 @@ func (suite *regionTestSuite) checkAccelerateRegionsScheduleInRanges(cluster *te body := fmt.Sprintf(`[{"start_key":"%s", "end_key": "%s"}, {"start_key":"%s", "end_key": "%s"}]`, hex.EncodeToString([]byte("a1")), hex.EncodeToString([]byte("a3")), hex.EncodeToString([]byte("a4")), hex.EncodeToString([]byte("a6"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/accelerate-schedule/batch", urlPrefix), []byte(body), + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/accelerate-schedule/batch", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) idList := leader.GetRaftCluster().GetSuspectRegions() @@ -239,7 +239,7 @@ func (suite *regionTestSuite) checkScatterRegions(cluster *tests.TestCluster) { checkRegionCount(re, cluster, 3) body := fmt.Sprintf(`{"start_key":"%s", "end_key": "%s"}`, hex.EncodeToString([]byte("b1")), hex.EncodeToString([]byte("b3"))) - err := tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) oc := leader.GetRaftCluster().GetOperatorController() if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { @@ -253,7 +253,7 @@ func (suite *regionTestSuite) checkScatterRegions(cluster *tests.TestCluster) { re.True(op1 != nil || op2 != nil || op3 != nil) body = `{"regions_id": [701, 702, 703]}` - err = tu.CheckPostJSON(testDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, fmt.Sprintf("%s/regions/scatter", urlPrefix), []byte(body), tu.StatusOK(re)) re.NoError(err) } @@ -295,40 +295,40 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) // invalid url url := fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, "_", "t") - err := tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + err := tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) re.NoError(err) url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString(r1.GetStartKey()), "_") - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, http.StatusBadRequest)) re.NoError(err) // correct test url = fmt.Sprintf(`%s/regions/replicated?startKey=%s&endKey=%s`, urlPrefix, hex.EncodeToString(r1.GetStartKey()), hex.EncodeToString(r1.GetEndKey())) - err = tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.StatusOK(re)) re.NoError(err) // test one rule data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) return len(respBundle) == 1 && respBundle[0].ID == "5" }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/handler/mockPending", "return(true)")) - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) re.Equal("PENDING", status) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/handler/mockPending")) @@ -342,19 +342,19 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) data, err = json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) return len(respBundle) == 1 && len(respBundle[0].Rules) == 2 }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) @@ -371,12 +371,12 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) data, err = json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) if len(respBundle) != 2 { @@ -388,7 +388,7 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) }) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "INPROGRESS" }) @@ -398,7 +398,7 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) tests.MustPutRegionInfo(re, cluster, r1) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &status) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &status) re.NoError(err) return status == "REPLICATED" }) @@ -422,9 +422,9 @@ func pauseAllCheckers(re *require.Assertions, cluster *tests.TestCluster) { for _, checkerName := range checkerNames { resp := make(map[string]any) url := fmt.Sprintf("%s/pd/api/v1/checker/%s", addr, checkerName) - err := tu.CheckPostJSON(testDialClient, url, []byte(`{"delay":1000}`), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, url, []byte(`{"delay":1000}`), tu.StatusOK(re)) re.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) re.NoError(err) re.True(resp["paused"].(bool)) } diff --git a/tests/server/api/rule_test.go b/tests/server/api/rule_test.go index 4f60b5cfb28..16077a308f6 100644 --- a/tests/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -71,7 +71,7 @@ func (suite *ruleTestSuite) TearDownTest() { data, err := json.Marshal([]placement.GroupBundle{def}) re.NoError(err) urlPrefix := cluster.GetLeaderServer().GetAddr() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/pd/api/v1/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) } suite.env.RunFuncInTwoModes(cleanFunc) @@ -171,7 +171,7 @@ func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { // clear suspect keyRanges to prevent test case from others leaderServer.GetRaftCluster().ClearSuspectKeyRanges() if testCase.success { - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusOK(re)) popKeyRangeMap := map[string]struct{}{} for i := 0; i < len(testCase.popKeyRange)/2; i++ { v, got := leaderServer.GetRaftCluster().PopOneSuspectKeyRange() @@ -185,7 +185,7 @@ func (suite *ruleTestSuite) checkSet(cluster *tests.TestCluster) { re.True(ok) } } else { - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", testCase.rawData, + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) } @@ -206,7 +206,7 @@ func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "a", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -234,11 +234,11 @@ func (suite *ruleTestSuite) checkGet(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.rule.GroupID, testCase.rule.ID) if testCase.found { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) return compareRule(&resp, &testCase.rule) }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -257,11 +257,11 @@ func (suite *ruleTestSuite) checkGetAll(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "b", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) var resp2 []*placement.Rule - err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/rules", &resp2) + err = tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix+"/rules", &resp2) re.NoError(err) re.NotEmpty(resp2) } @@ -369,13 +369,13 @@ func (suite *ruleTestSuite) checkSetAll(cluster *tests.TestCluster) { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules", testCase.rawData, tu.StatusOK(re)) re.NoError(err) if testCase.isDefaultRule { re.Equal(int(leaderServer.GetPersistOptions().GetReplicationConfig().MaxReplicas), testCase.count) } } else { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules", testCase.rawData, + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules", testCase.rawData, tu.StringEqual(re, testCase.response)) re.NoError(err) } @@ -395,13 +395,13 @@ func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "c", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) rule1 := placement.Rule{GroupID: "c", ID: "30", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err = json.Marshal(rule1) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -426,7 +426,7 @@ func (suite *ruleTestSuite) checkGetAllByGroup(cluster *tests.TestCluster) { var resp []*placement.Rule url := fmt.Sprintf("%s/rules/group/%s", urlPrefix, testCase.groupID) tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) re.NoError(err) if len(resp) != testCase.count { return false @@ -452,7 +452,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "e", ID: "20", StartKeyHex: "1111", EndKeyHex: "3333", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) r := core.NewTestRegionInfo(4, 1, []byte{0x22, 0x22}, []byte{0x33, 0x33}) @@ -489,7 +489,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { if testCase.success { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) for _, r := range resp { if r.GroupID == "e" { return compareRule(r, &rule) @@ -498,7 +498,7 @@ func (suite *ruleTestSuite) checkGetAllByRegion(cluster *tests.TestCluster) { return true }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -517,7 +517,7 @@ func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "f", ID: "40", StartKeyHex: "8888", EndKeyHex: "9111", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) testCases := []struct { @@ -553,11 +553,11 @@ func (suite *ruleTestSuite) checkGetAllByKey(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rules/key/%s", urlPrefix, testCase.key) if testCase.success { tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, url, &resp) + err = tu.ReadGetJSON(re, tests.TestDialClient, url, &resp) return len(resp) == testCase.respSize }) } else { - err = tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, testCase.code)) + err = tu.CheckGetJSON(tests.TestDialClient, url, nil, tu.Status(re, testCase.code)) } re.NoError(err) } @@ -576,7 +576,7 @@ func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { rule := placement.Rule{GroupID: "g", ID: "10", StartKeyHex: "8888", EndKeyHex: "9111", Role: placement.Voter, Count: 1} data, err := json.Marshal(rule) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rule", data, tu.StatusOK(re)) re.NoError(err) oldStartKey, err := hex.DecodeString(rule.StartKeyHex) re.NoError(err) @@ -610,7 +610,7 @@ func (suite *ruleTestSuite) checkDelete(cluster *tests.TestCluster) { url := fmt.Sprintf("%s/rule/%s/%s", urlPrefix, testCase.groupID, testCase.id) // clear suspect keyRanges to prevent test case from others leaderServer.GetRaftCluster().ClearSuspectKeyRanges() - err = tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, url, tu.StatusOK(re)) re.NoError(err) if len(testCase.popKeyRange) > 0 { popKeyRangeMap := map[string]struct{}{} @@ -747,10 +747,10 @@ func (suite *ruleTestSuite) checkBatch(cluster *tests.TestCluster) { for _, testCase := range testCases { suite.T().Log(testCase.name) if testCase.success { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusOK(re)) re.NoError(err) } else { - err := tu.CheckPostJSON(testDialClient, urlPrefix+"/rules/batch", testCase.rawData, + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/rules/batch", testCase.rawData, tu.StatusNotOK(re), tu.StringEqual(re, testCase.response)) re.NoError(err) @@ -793,7 +793,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err := json.Marshal(b2) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule/foo", data, tu.StatusOK(re)) re.NoError(err) // Get @@ -803,7 +803,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { assertBundlesEqual(re, urlPrefix+"/placement-rule", []placement.GroupBundle{b1, b2}, 2) // Delete - err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(re)) re.NoError(err) // GetAll again @@ -815,14 +815,14 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { b3 := placement.GroupBundle{ID: "foobar", Index: 100} data, err = json.Marshal([]placement.GroupBundle{b1, b2, b3}) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) re.NoError(err) // GetAll again assertBundlesEqual(re, urlPrefix+"/placement-rule", []placement.GroupBundle{b1, b2, b3}, 3) // Delete using regexp - err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(re)) re.NoError(err) // GetAll again @@ -838,7 +838,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err = json.Marshal(b4) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule/"+id, data, tu.StatusOK(re)) re.NoError(err) b4.ID = id @@ -859,7 +859,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { } data, err = json.Marshal([]placement.GroupBundle{b1, b4, b5}) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/placement-rule", data, tu.StatusOK(re)) re.NoError(err) b5.Rules[0].GroupID = b5.ID @@ -891,7 +891,7 @@ func (suite *ruleTestSuite) checkBundleBadRequest(cluster *tests.TestCluster) { {"/placement-rule", `[{"group_id":"foo", "rules": [{"group_id":"bar", "id":"baz", "role":"voter", "count":1}]}]`, false}, } for _, testCase := range testCases { - err := tu.CheckPostJSON(testDialClient, urlPrefix+testCase.uri, []byte(testCase.data), + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix+testCase.uri, []byte(testCase.data), func(_ []byte, code int, _ http.Header) { re.Equal(testCase.ok, code == http.StatusOK) }) @@ -976,12 +976,12 @@ func (suite *ruleTestSuite) checkLeaderAndVoter(cluster *tests.TestCluster) { for _, bundle := range bundles { data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err := tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) re.Len(respBundle, 1) @@ -1144,7 +1144,7 @@ func (suite *ruleTestSuite) checkConcurrencyWith(cluster *tests.TestCluster, re.NoError(err) for j := 0; j < 10; j++ { expectResult.Lock() - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) expectResult.val = i expectResult.Unlock() @@ -1158,7 +1158,7 @@ func (suite *ruleTestSuite) checkConcurrencyWith(cluster *tests.TestCluster, re.NotZero(expectResult.val) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err := tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) re.Len(respBundle, 1) @@ -1197,7 +1197,7 @@ func (suite *ruleTestSuite) checkLargeRules(cluster *tests.TestCluster) { func assertBundleEqual(re *require.Assertions, url string, expectedBundle placement.GroupBundle) { var bundle placement.GroupBundle tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, url, &bundle) + err := tu.ReadGetJSON(re, tests.TestDialClient, url, &bundle) if err != nil { return false } @@ -1208,7 +1208,7 @@ func assertBundleEqual(re *require.Assertions, url string, expectedBundle placem func assertBundlesEqual(re *require.Assertions, url string, expectedBundles []placement.GroupBundle, expectedLen int) { var bundles []placement.GroupBundle tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, url, &bundles) + err := tu.ReadGetJSON(re, tests.TestDialClient, url, &bundles) if err != nil { return false } @@ -1253,12 +1253,12 @@ func (suite *ruleTestSuite) postAndCheckRuleBundle(urlPrefix string, bundle []pl re := suite.Require() data, err := json.Marshal(bundle) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { respBundle := make([]placement.GroupBundle, 0) - err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + err = tu.CheckGetJSON(tests.TestDialClient, urlPrefix+"/config/placement-rule", nil, tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) re.NoError(err) if len(respBundle) != len(bundle) { @@ -1364,19 +1364,19 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl fit := &placement.RegionFit{} u := fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) - err := tu.ReadGetJSON(re, testDialClient, u, fit) + err := tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Len(fit.RuleFits, 1) re.Len(fit.OrphanPeers, 1) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 2) fit = &placement.RegionFit{} - err = tu.ReadGetJSON(re, testDialClient, u, fit) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Len(fit.RuleFits, 2) re.Empty(fit.OrphanPeers) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 3) fit = &placement.RegionFit{} - err = tu.ReadGetJSON(re, testDialClient, u, fit) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, fit) re.NoError(err) re.Empty(fit.RuleFits) re.Len(fit.OrphanPeers, 2) @@ -1384,26 +1384,26 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl var label labeler.LabelRule escapedID := url.PathEscape("keyspaces/0") u = fmt.Sprintf("%s/config/region-label/rule/%s", urlPrefix, escapedID) - err = tu.ReadGetJSON(re, testDialClient, u, &label) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &label) re.NoError(err) re.Equal("keyspaces/0", label.ID) var labels []labeler.LabelRule u = fmt.Sprintf("%s/config/region-label/rules", urlPrefix) - err = tu.ReadGetJSON(re, testDialClient, u, &labels) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &labels) re.NoError(err) re.Len(labels, 1) re.Equal("keyspaces/0", labels[0].ID) u = fmt.Sprintf("%s/config/region-label/rules/ids", urlPrefix) - err = tu.CheckGetJSON(testDialClient, u, []byte(`["rule1", "rule3"]`), func(resp []byte, _ int, _ http.Header) { + err = tu.CheckGetJSON(tests.TestDialClient, u, []byte(`["rule1", "rule3"]`), func(resp []byte, _ int, _ http.Header) { err := json.Unmarshal(resp, &labels) re.NoError(err) re.Empty(labels) }) re.NoError(err) - err = tu.CheckGetJSON(testDialClient, u, []byte(`["keyspaces/0"]`), func(resp []byte, _ int, _ http.Header) { + err = tu.CheckGetJSON(tests.TestDialClient, u, []byte(`["keyspaces/0"]`), func(resp []byte, _ int, _ http.Header) { err := json.Unmarshal(resp, &labels) re.NoError(err) re.Len(labels, 1) @@ -1412,12 +1412,12 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl re.NoError(err) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 4) - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusNotFound), tu.StringContain( re, "region 4 not found")) re.NoError(err) u = fmt.Sprintf("%s/config/rules/region/%s/detail", urlPrefix, "id") - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusBadRequest), tu.StringContain( re, errs.ErrRegionInvalidID.Error())) re.NoError(err) @@ -1426,7 +1426,7 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl reqData, e := json.Marshal(data) re.NoError(e) u = fmt.Sprintf("%s/config", urlPrefix) - err = tu.CheckPostJSON(testDialClient, u, reqData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, reqData, tu.StatusOK(re)) re.NoError(err) if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { // wait for the scheduling server to update the config @@ -1435,7 +1435,7 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl }) } u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) - err = tu.CheckGetJSON(testDialClient, u, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( + err = tu.CheckGetJSON(tests.TestDialClient, u, nil, tu.Status(re, http.StatusPreconditionFailed), tu.StringContain( re, "placement rules feature is disabled")) re.NoError(err) } diff --git a/tests/server/api/scheduler_test.go b/tests/server/api/scheduler_test.go index 4f71315803a..10631dab158 100644 --- a/tests/server/api/scheduler_test.go +++ b/tests/server/api/scheduler_test.go @@ -84,12 +84,12 @@ func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { input["store_id"] = 1 body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, "evict-leader-scheduler") - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 1) input1 := make(map[string]any) input1["name"] = "evict-leader-scheduler" @@ -97,35 +97,35 @@ func (suite *scheduleTestSuite) checkOriginAPI(cluster *tests.TestCluster) { body, err = json.Marshal(input1) re.NoError(err) re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/schedulers/persistFail", "return(true)")) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusNotOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusNotOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp = make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 1) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/schedulers/persistFail")) - re.NoError(tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re))) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp = make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Len(resp["store-id-ranges"], 2) deleteURL := fmt.Sprintf("%s/%s", urlPrefix, "evict-leader-scheduler-1") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") resp1 := make(map[string]any) - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp1)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp1)) re.Len(resp1["store-id-ranges"], 1) deleteURL = fmt.Sprintf("%s/%s", urlPrefix, "evict-leader-scheduler-2") re.NoError(failpoint.Enable("github.com/tikv/pd/server/config/persistFail", "return(true)")) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusInternalServerError)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusInternalServerError)) re.NoError(err) suite.assertSchedulerExists(urlPrefix, "evict-leader-scheduler") re.NoError(failpoint.Disable("github.com/tikv/pd/server/config/persistFail")) - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) assertNoScheduler(re, urlPrefix, "evict-leader-scheduler") - re.NoError(tu.CheckGetJSON(testDialClient, listURL, nil, tu.Status(re, http.StatusNotFound))) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + re.NoError(tu.CheckGetJSON(tests.TestDialClient, listURL, nil, tu.Status(re, http.StatusNotFound))) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) } @@ -164,7 +164,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) resp := make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 4.0 }) dataMap := make(map[string]any) @@ -172,15 +172,15 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { // wait for scheduling server to be synced. - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "\"Config is the same with origin, so do nothing.\"\n")) re.NoError(err) @@ -189,17 +189,17 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["batch"] = 100 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid batch size which should be an integer between 1 and 10\"\n")) re.NoError(err) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -208,7 +208,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"Config item is not found.\"\n")) re.NoError(err) @@ -245,7 +245,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { "history-sample-interval": "30s", } tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) re.Equal(len(expectMap), len(resp), "expect %v, got %v", expectMap, resp) for key := range expectMap { if !reflect.DeepEqual(resp[key], expectMap[key]) { @@ -260,10 +260,10 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) for key := range expectMap { if !reflect.DeepEqual(resp[key], expectMap[key]) { return false @@ -273,7 +273,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "Config is the same with origin, so do nothing.")) re.NoError(err) @@ -282,7 +282,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "Config item is not found.")) re.NoError(err) @@ -295,7 +295,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) resp := make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["degree"] == 3.0 && resp["split-limit"] == 0.0 }) dataMap := make(map[string]any) @@ -303,19 +303,19 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["degree"] == 4.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "Config is the same with origin, so do nothing.")) re.NoError(err) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -324,7 +324,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "Config item is not found.")) re.NoError(err) @@ -353,7 +353,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 4.0 }) dataMap := make(map[string]any) @@ -361,14 +361,14 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(dataMap) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // update again - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re), tu.StringEqual(re, "\"Config is the same with origin, so do nothing.\"\n")) re.NoError(err) @@ -377,17 +377,17 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["batch"] = 100 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid batch size which should be an integer between 1 and 10\"\n")) re.NoError(err) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["batch"] == 3.0 }) // empty body - err = tu.CheckPostJSON(testDialClient, updateURL, nil, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, nil, tu.Status(re, http.StatusInternalServerError), tu.StringEqual(re, "\"unexpected end of JSON input\"\n")) re.NoError(err) @@ -396,7 +396,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { dataMap["error"] = 3 body, err = json.Marshal(dataMap) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, updateURL, body, + err = tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"Config item is not found.\"\n")) re.NoError(err) @@ -412,7 +412,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { expectedMap := make(map[string]any) expectedMap["1"] = []any{map[string]any{"end-key": "", "start-key": ""}} tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) @@ -423,25 +423,25 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) expectedMap["2"] = []any{map[string]any{"end-key": "", "start-key": ""}} resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) // using /pd/v1/schedule-config/grant-leader-scheduler/config to delete exists store from grant-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name, "2") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) delete(expectedMap, "2") resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) }, }, @@ -454,7 +454,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { resp := make(map[string]any) listURL := fmt.Sprintf("%s%s%s/%s/list", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["start-key"] == "" && resp["end-key"] == "" && resp["range-name"] == "test" }) resp["start-key"] = "a_00" @@ -462,10 +462,10 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(resp) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return resp["start-key"] == "a_00" && resp["end-key"] == "a_99" && resp["range-name"] == "test" }) }, @@ -481,7 +481,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { expectedMap := make(map[string]any) expectedMap["3"] = []any{map[string]any{"end-key": "", "start-key": ""}} tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) @@ -492,25 +492,25 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { updateURL := fmt.Sprintf("%s%s%s/%s/config", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name) body, err := json.Marshal(input) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, updateURL, body, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, updateURL, body, tu.StatusOK(re))) expectedMap["4"] = []any{map[string]any{"end-key": "", "start-key": ""}} resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) // using /pd/v1/schedule-config/evict-leader-scheduler/config to delete exist store from evict-leader-scheduler deleteURL := fmt.Sprintf("%s%s%s/%s/delete/%s", leaderAddr, apiPrefix, server.SchedulerConfigHandlerPath, name, "4") - err = tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) delete(expectedMap, "4") resp = make(map[string]any) tu.Eventually(re, func() bool { - re.NoError(tu.ReadGetJSON(re, testDialClient, listURL, &resp)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, listURL, &resp)) return reflect.DeepEqual(expectedMap, resp["store-id-ranges"]) }) - err = tu.CheckDelete(testDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) + err = tu.CheckDelete(tests.TestDialClient, deleteURL, tu.Status(re, http.StatusNotFound)) re.NoError(err) }, }, @@ -558,7 +558,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) for _, testCase := range testCases { @@ -572,7 +572,7 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second) for _, testCase := range testCases { @@ -588,12 +588,12 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/all", pauseArgs, tu.StatusOK(re)) re.NoError(err) for _, testCase := range testCases { createdName := testCase.createdName @@ -642,14 +642,14 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { u := fmt.Sprintf("%s%s/api/v1/config/schedule", leaderAddr, apiPrefix) var scheduleConfig sc.ScheduleConfig - err = tu.ReadGetJSON(re, testDialClient, u, &scheduleConfig) + err = tu.ReadGetJSON(re, tests.TestDialClient, u, &scheduleConfig) re.NoError(err) originSchedulers := scheduleConfig.Schedulers scheduleConfig.Schedulers = sc.SchedulerConfigs{sc.SchedulerConfig{Type: "shuffle-leader", Disable: true}} body, err = json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, body, tu.StatusOK(re)) re.NoError(err) assertNoScheduler(re, urlPrefix, name) @@ -659,7 +659,7 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { scheduleConfig.Schedulers = originSchedulers body, err = json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, u, body, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, u, body, tu.StatusOK(re)) re.NoError(err) deleteScheduler(re, urlPrefix, name) @@ -667,13 +667,13 @@ func (suite *scheduleTestSuite) checkDisable(cluster *tests.TestCluster) { } func addScheduler(re *require.Assertions, urlPrefix string, body []byte) { - err := tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re)) re.NoError(err) } func deleteScheduler(re *require.Assertions, urlPrefix string, createdName string) { deleteURL := fmt.Sprintf("%s/%s", urlPrefix, createdName) - err := tu.CheckDelete(testDialClient, deleteURL, tu.StatusOK(re)) + err := tu.CheckDelete(tests.TestDialClient, deleteURL, tu.StatusOK(re)) re.NoError(err) } @@ -682,9 +682,9 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre createdName = name } var schedulers []string - tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers) + tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers) if !slice.Contains(schedulers, createdName) { - err := tu.CheckPostJSON(testDialClient, urlPrefix, body, tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, urlPrefix, body, tu.StatusOK(re)) re.NoError(err) } suite.assertSchedulerExists(urlPrefix, createdName) // wait for scheduler to be synced. @@ -694,14 +694,14 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre input["delay"] = 30 pauseArgs, err := json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) isPaused := isSchedulerPaused(re, urlPrefix, createdName) re.True(isPaused) input["delay"] = 1 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) time.Sleep(time.Second * 2) isPaused = isSchedulerPaused(re, urlPrefix, createdName) @@ -712,12 +712,12 @@ func (suite *scheduleTestSuite) testPauseOrResume(re *require.Assertions, urlPre input["delay"] = 30 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) input["delay"] = 0 pauseArgs, err = json.Marshal(input) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, urlPrefix+"/"+createdName, pauseArgs, tu.StatusOK(re)) re.NoError(err) isPaused = isSchedulerPaused(re, urlPrefix, createdName) re.False(isPaused) @@ -742,7 +742,7 @@ func (suite *scheduleTestSuite) checkEmptySchedulers(cluster *tests.TestCluster) } for _, query := range []string{"", "?status=paused", "?status=disabled"} { schedulers := make([]string, 0) - re.NoError(tu.ReadGetJSON(re, testDialClient, urlPrefix+query, &schedulers)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix+query, &schedulers)) for _, scheduler := range schedulers { if strings.Contains(query, "disable") { input := make(map[string]any) @@ -755,7 +755,7 @@ func (suite *scheduleTestSuite) checkEmptySchedulers(cluster *tests.TestCluster) } } tu.Eventually(re, func() bool { - resp, err := apiutil.GetJSON(testDialClient, urlPrefix+query, nil) + resp, err := apiutil.GetJSON(tests.TestDialClient, urlPrefix+query, nil) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -770,7 +770,7 @@ func (suite *scheduleTestSuite) assertSchedulerExists(urlPrefix string, schedule var schedulers []string re := suite.Require() tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers, tu.StatusOK(re)) re.NoError(err) return slice.Contains(schedulers, scheduler) @@ -780,7 +780,7 @@ func (suite *scheduleTestSuite) assertSchedulerExists(urlPrefix string, schedule func assertNoScheduler(re *require.Assertions, urlPrefix string, scheduler string) { var schedulers []string tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, urlPrefix, &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &schedulers, tu.StatusOK(re)) re.NoError(err) return !slice.Contains(schedulers, scheduler) @@ -789,7 +789,7 @@ func assertNoScheduler(re *require.Assertions, urlPrefix string, scheduler strin func isSchedulerPaused(re *require.Assertions, urlPrefix, name string) bool { var schedulers []string - err := tu.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s?status=paused", urlPrefix), &schedulers, + err := tu.ReadGetJSON(re, tests.TestDialClient, fmt.Sprintf("%s?status=paused", urlPrefix), &schedulers, tu.StatusOK(re)) re.NoError(err) for _, scheduler := range schedulers { diff --git a/tests/server/api/testutil.go b/tests/server/api/testutil.go index 1b2f3d09e3d..163a25c9bbb 100644 --- a/tests/server/api/testutil.go +++ b/tests/server/api/testutil.go @@ -23,6 +23,7 @@ import ( "path" "github.com/stretchr/testify/require" + "github.com/tikv/pd/tests" ) const ( @@ -30,13 +31,6 @@ const ( schedulerConfigPrefix = "/pd/api/v1/scheduler-config" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - // MustAddScheduler adds a scheduler with HTTP API. func MustAddScheduler( re *require.Assertions, serverAddr string, @@ -53,7 +47,7 @@ func MustAddScheduler( httpReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s", serverAddr, schedulersPrefix), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -65,7 +59,7 @@ func MustAddScheduler( func MustDeleteScheduler(re *require.Assertions, serverAddr, schedulerName string) { httpReq, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s%s/%s", serverAddr, schedulersPrefix, schedulerName), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -84,7 +78,7 @@ func MustCallSchedulerConfigAPI( args = append([]string{schedulerConfigPrefix, schedulerName}, args...) httpReq, err := http.NewRequest(method, fmt.Sprintf("%s%s", serverAddr, path.Join(args...)), bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) diff --git a/tests/server/apiv2/handlers/testutil.go b/tests/server/apiv2/handlers/testutil.go index c5682aafbce..1a40e8d1ac7 100644 --- a/tests/server/apiv2/handlers/testutil.go +++ b/tests/server/apiv2/handlers/testutil.go @@ -34,13 +34,6 @@ const ( keyspaceGroupsPrefix = "/pd/api/v2/tso/keyspace-groups" ) -// dialClient used to dial http request. -var dialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - func sendLoadRangeRequest(re *require.Assertions, server *tests.TestServer, token, limit string) *handlers.LoadAllKeyspacesResponse { // Construct load range request. httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+keyspacesPrefix, http.NoBody) @@ -50,7 +43,7 @@ func sendLoadRangeRequest(re *require.Assertions, server *tests.TestServer, toke query.Add("limit", limit) httpReq.URL.RawQuery = query.Encode() // Send request. - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() re.Equal(http.StatusOK, httpResp.StatusCode) @@ -67,7 +60,7 @@ func sendUpdateStateRequest(re *require.Assertions, server *tests.TestServer, na re.NoError(err) httpReq, err := http.NewRequest(http.MethodPut, server.GetAddr()+keyspacesPrefix+"/"+name+"/state", bytes.NewBuffer(data)) re.NoError(err) - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { @@ -86,7 +79,7 @@ func MustCreateKeyspace(re *require.Assertions, server *tests.TestServer, reques re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspacesPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -110,7 +103,7 @@ func mustUpdateKeyspaceConfig(re *require.Assertions, server *tests.TestServer, re.NoError(err) httpReq, err := http.NewRequest(http.MethodPatch, server.GetAddr()+keyspacesPrefix+"/"+name+"/config", bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -122,7 +115,7 @@ func mustUpdateKeyspaceConfig(re *require.Assertions, server *tests.TestServer, } func mustLoadKeyspaces(re *require.Assertions, server *tests.TestServer, name string) *keyspacepb.KeyspaceMeta { - resp, err := dialClient.Get(server.GetAddr() + keyspacesPrefix + "/" + name) + resp, err := tests.TestDialClient.Get(server.GetAddr() + keyspacesPrefix + "/" + name) re.NoError(err) defer resp.Body.Close() re.Equal(http.StatusOK, resp.StatusCode) @@ -143,7 +136,7 @@ func MustLoadKeyspaceGroups(re *require.Assertions, server *tests.TestServer, to query.Add("limit", limit) httpReq.URL.RawQuery = query.Encode() // Send request. - httpResp, err := dialClient.Do(httpReq) + httpResp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer httpResp.Body.Close() data, err := io.ReadAll(httpResp.Body) @@ -159,7 +152,7 @@ func tryCreateKeyspaceGroup(re *require.Assertions, server *tests.TestServer, re re.NoError(err) httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix, bytes.NewBuffer(data)) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -184,7 +177,7 @@ func MustLoadKeyspaceGroupByID(re *require.Assertions, server *tests.TestServer, func TryLoadKeyspaceGroupByID(re *require.Assertions, server *tests.TestServer, id uint32) (*endpoint.KeyspaceGroup, int) { httpReq, err := http.NewRequest(http.MethodGet, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -214,7 +207,7 @@ func FailCreateKeyspaceGroupWithCode(re *require.Assertions, server *tests.TestS func MustDeleteKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id uint32) { httpReq, err := http.NewRequest(http.MethodDelete, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), http.NoBody) re.NoError(err) - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err := io.ReadAll(resp.Body) @@ -229,7 +222,7 @@ func MustSplitKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/split", id), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) @@ -245,7 +238,7 @@ func MustFinishSplitKeyspaceGroup(re *require.Assertions, server *tests.TestServ return false } // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) if err != nil { return false } @@ -270,7 +263,7 @@ func MustMergeKeyspaceGroup(re *require.Assertions, server *tests.TestServer, id httpReq, err := http.NewRequest(http.MethodPost, server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/merge", id), bytes.NewBuffer(data)) re.NoError(err) // Send request. - resp, err := dialClient.Do(httpReq) + resp, err := tests.TestDialClient.Do(httpReq) re.NoError(err) defer resp.Body.Close() data, err = io.ReadAll(resp.Body) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index aea5ff73968..61a4561c55a 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -753,20 +753,19 @@ func TestConcurrentHandleRegion(t *testing.T) { re.NoError(err) peerID, err := id.Alloc() re.NoError(err) - regionID, err := id.Alloc() - re.NoError(err) peer := &metapb.Peer{Id: peerID, StoreId: store.GetId()} regionReq := &pdpb.RegionHeartbeatRequest{ Header: testutil.NewRequestHeader(clusterID), Region: &metapb.Region{ - Id: regionID, + // mock error msg to trigger stream.Recv() + Id: 0, Peers: []*metapb.Peer{peer}, }, Leader: peer, } err = stream.Send(regionReq) re.NoError(err) - // make sure the first store can receive one response + // make sure the first store can receive one response(error msg) if i == 0 { wg.Add(1) } diff --git a/tests/server/config/config_test.go b/tests/server/config/config_test.go index 57e4272f7ea..67d7478caa0 100644 --- a/tests/server/config/config_test.go +++ b/tests/server/config/config_test.go @@ -36,13 +36,6 @@ import ( "github.com/tikv/pd/tests" ) -// testDialClient used to dial http request. -var testDialClient = &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - }, -} - func TestRateLimitConfigReload(t *testing.T) { re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) @@ -65,7 +58,7 @@ func TestRateLimitConfigReload(t *testing.T) { data, err := json.Marshal(input) re.NoError(err) req, _ := http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) - resp, err := testDialClient.Do(req) + resp, err := tests.TestDialClient.Do(req) re.NoError(err) resp.Body.Close() re.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) @@ -109,7 +102,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { addr := fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) cfg := &config.Config{} tu.Eventually(re, func() bool { - err := tu.ReadGetJSON(re, testDialClient, addr, cfg) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, cfg) re.NoError(err) return cfg.PDServerCfg.DashboardAddress != "auto" }) @@ -118,7 +111,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l := map[string]any{ "location-labels": "zone,rack", @@ -126,7 +119,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l = map[string]any{ @@ -134,7 +127,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) cfg.Replication.MaxReplicas = 5 cfg.Replication.LocationLabels = []string{"zone", "rack"} @@ -143,7 +136,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { tu.Eventually(re, func() bool { newCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, newCfg) re.NoError(err) return suite.Equal(newCfg, cfg) }) @@ -160,7 +153,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) cfg.Schedule.EnableTiKVSplitRegion = false cfg.Schedule.TolerantSizeRatio = 2.5 @@ -174,7 +167,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { cfg.ClusterVersion = *v tu.Eventually(re, func() bool { newCfg1 := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, newCfg1) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, newCfg1) re.NoError(err) return suite.Equal(cfg, newCfg1) }) @@ -183,7 +176,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { l["schedule.enable-tikv-split-region"] = "true" postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) // illegal prefix @@ -192,7 +185,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) re.NoError(err) @@ -203,7 +196,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "cannot update config prefix")) re.NoError(err) @@ -214,7 +207,7 @@ func (suite *configTestSuite) checkConfigAll(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringContain(re, "not found")) re.NoError(err) } @@ -230,16 +223,16 @@ func (suite *configTestSuite) checkConfigSchedule(cluster *tests.TestCluster) { addr := fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) scheduleConfig := &sc.ScheduleConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addr, scheduleConfig)) scheduleConfig.MaxStoreDownTime.Duration = time.Second postData, err := json.Marshal(scheduleConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) tu.Eventually(re, func() bool { scheduleConfig1 := &sc.ScheduleConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addr, scheduleConfig1)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addr, scheduleConfig1)) return reflect.DeepEqual(*scheduleConfig1, *scheduleConfig) }) } @@ -255,33 +248,33 @@ func (suite *configTestSuite) checkConfigReplication(cluster *tests.TestCluster) addr := fmt.Sprintf("%s/pd/api/v1/config/replicate", urlPrefix) rc := &sc.ReplicationConfig{} - err := tu.ReadGetJSON(re, testDialClient, addr, rc) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, rc) re.NoError(err) rc.MaxReplicas = 5 rc1 := map[string]int{"max-replicas": 5} postData, err := json.Marshal(rc1) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc.LocationLabels = []string{"zone", "rack"} rc2 := map[string]string{"location-labels": "zone,rack"} postData, err = json.Marshal(rc2) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc.IsolationLevel = "zone" rc3 := map[string]string{"isolation-level": "zone"} postData, err = json.Marshal(rc3) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) rc4 := &sc.ReplicationConfig{} tu.Eventually(re, func() bool { - err = tu.ReadGetJSON(re, testDialClient, addr, rc4) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, rc4) re.NoError(err) return reflect.DeepEqual(*rc4, *rc) }) @@ -299,7 +292,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste addr := urlPrefix + "/pd/api/v1/config/label-property" loadProperties := func() config.LabelPropertyConfig { var cfg config.LabelPropertyConfig - err := tu.ReadGetJSON(re, testDialClient, addr, &cfg) + err := tu.ReadGetJSON(re, tests.TestDialClient, addr, &cfg) re.NoError(err) return cfg } @@ -313,7 +306,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste `{"type": "bar", "action": "set", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, addr, []byte(cmd), tu.StatusOK(re)) re.NoError(err) } @@ -330,7 +323,7 @@ func (suite *configTestSuite) checkConfigLabelProperty(cluster *tests.TestCluste `{"type": "bar", "action": "delete", "label-key": "host", "label-value": "h1"}`, } for _, cmd := range cmds { - err := tu.CheckPostJSON(testDialClient, addr, []byte(cmd), tu.StatusOK(re)) + err := tu.CheckPostJSON(tests.TestDialClient, addr, []byte(cmd), tu.StatusOK(re)) re.NoError(err) } @@ -353,7 +346,7 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { r := map[string]int{"max-replicas": 5} postData, err := json.Marshal(r) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l := map[string]any{ "location-labels": "zone,rack", @@ -361,7 +354,7 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) l = map[string]any{ @@ -369,12 +362,12 @@ func (suite *configTestSuite) checkConfigDefault(cluster *tests.TestCluster) { } postData, err = json.Marshal(l) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config/default", urlPrefix) defaultCfg := &config.Config{} - err = tu.ReadGetJSON(re, testDialClient, addr, defaultCfg) + err = tu.ReadGetJSON(re, tests.TestDialClient, addr, defaultCfg) re.NoError(err) re.Equal(uint64(3), defaultCfg.Replication.MaxReplicas) @@ -398,10 +391,10 @@ func (suite *configTestSuite) checkConfigPDServer(cluster *tests.TestCluster) { } postData, err := json.Marshal(ms) re.NoError(err) - re.NoError(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(re))) + re.NoError(tu.CheckPostJSON(tests.TestDialClient, addrPost, postData, tu.StatusOK(re))) addrGet := fmt.Sprintf("%s/pd/api/v1/config/pd-server", urlPrefix) sc := &config.PDServerConfig{} - re.NoError(tu.ReadGetJSON(re, testDialClient, addrGet, sc)) + re.NoError(tu.ReadGetJSON(re, tests.TestDialClient, addrGet, sc)) re.Equal(bool(true), sc.UseRegionStorage) re.Equal("table", sc.KeyType) re.Equal(typeutil.StringSlice([]string{}), sc.RuntimeServices) @@ -525,28 +518,28 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { re.NoError(err) // test no config and cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, false) // test time goes by - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) time.Sleep(5 * time.Second) assertTTLConfig(re, cluster, false) // test cleaning up - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 5), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, false) postData, err = json.Marshal(invalidTTLConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"unsupported ttl config schedule.invalid-ttl-config\"\n")) re.NoError(err) @@ -557,7 +550,7 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { postData, err = json.Marshal(mergeConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 1), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfigItemEqual(re, cluster, "max-merge-region-size", uint64(999)) // max-merge-region-keys should keep consistence with max-merge-region-size. @@ -569,7 +562,7 @@ func (suite *configTestSuite) checkConfigTTL(cluster *tests.TestCluster) { } postData, err = json.Marshal(mergeConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 10), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 10), postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfigItemEqual(re, cluster, "enable-tikv-split-region", true) } @@ -585,7 +578,7 @@ func (suite *configTestSuite) checkTTLConflict(cluster *tests.TestCluster) { addr := createTTLUrl(urlPrefix, 1) postData, err := json.Marshal(ttlConfig) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) assertTTLConfig(re, cluster, true) @@ -593,16 +586,16 @@ func (suite *configTestSuite) checkTTLConflict(cluster *tests.TestCluster) { postData, err = json.Marshal(cfg) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config", urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) re.NoError(err) addr = fmt.Sprintf("%s/pd/api/v1/config/schedule", urlPrefix) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusNotOK(re), tu.StringEqual(re, "\"need to clean up TTL first for schedule.max-snapshot-count\"\n")) re.NoError(err) cfg = map[string]any{"schedule.max-snapshot-count": 30} postData, err = json.Marshal(cfg) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, createTTLUrl(urlPrefix, 0), postData, tu.StatusOK(re)) re.NoError(err) - err = tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(re)) + err = tu.CheckPostJSON(tests.TestDialClient, addr, postData, tu.StatusOK(re)) re.NoError(err) } diff --git a/tests/server/member/member_test.go b/tests/server/member/member_test.go index 92ed11a75ce..c581eb39390 100644 --- a/tests/server/member/member_test.go +++ b/tests/server/member/member_test.go @@ -84,15 +84,13 @@ func TestMemberDelete(t *testing.T) { {path: fmt.Sprintf("id/%d", members[1].GetServerID()), members: []*config.Config{leader.GetConfig()}}, } - httpClient := &http.Client{Timeout: 15 * time.Second, Transport: &http.Transport{DisableKeepAlives: true}} - defer httpClient.CloseIdleConnections() for _, table := range tables { t.Log(time.Now(), "try to delete:", table.path) testutil.Eventually(re, func() bool { addr := leader.GetConfig().ClientUrls + "/pd/api/v1/members/" + table.path req, err := http.NewRequest(http.MethodDelete, addr, http.NoBody) re.NoError(err) - res, err := httpClient.Do(req) + res, err := tests.TestDialClient.Do(req) re.NoError(err) defer res.Body.Close() // Check by status. @@ -105,7 +103,7 @@ func TestMemberDelete(t *testing.T) { } // Check by member list. cluster.WaitLeader() - if err = checkMemberList(re, *httpClient, leader.GetConfig().ClientUrls, table.members); err != nil { + if err = checkMemberList(re, leader.GetConfig().ClientUrls, table.members); err != nil { t.Logf("check member fail: %v", err) time.Sleep(time.Second) return false @@ -122,9 +120,9 @@ func TestMemberDelete(t *testing.T) { } } -func checkMemberList(re *require.Assertions, httpClient http.Client, clientURL string, configs []*config.Config) error { +func checkMemberList(re *require.Assertions, clientURL string, configs []*config.Config) error { addr := clientURL + "/pd/api/v1/members" - res, err := httpClient.Get(addr) + res, err := tests.TestDialClient.Get(addr) re.NoError(err) defer res.Body.Close() buf, err := io.ReadAll(res.Body) @@ -183,7 +181,7 @@ func TestLeaderPriority(t *testing.T) { func post(t *testing.T, re *require.Assertions, url string, body string) { testutil.Eventually(re, func() bool { - res, err := http.Post(url, "", bytes.NewBufferString(body)) // #nosec + res, err := tests.TestDialClient.Post(url, "", bytes.NewBufferString(body)) // #nosec re.NoError(err) b, err := io.ReadAll(res.Body) res.Body.Close() diff --git a/tests/testutil.go b/tests/testutil.go index 5d9905af64c..ea52bce310e 100644 --- a/tests/testutil.go +++ b/tests/testutil.go @@ -17,8 +17,12 @@ package tests import ( "context" "fmt" + "math/rand" + "net" + "net/http" "os" "runtime" + "strconv" "strings" "sync" "testing" @@ -45,6 +49,45 @@ import ( "go.uber.org/zap" ) +var ( + TestDialClient = &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } + + testPortMutex sync.Mutex + testPortMap = make(map[string]struct{}) +) + +// SetRangePort sets the range of ports for test. +func SetRangePort(start, end int) { + portRange := []int{start, end} + dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{} + randomPort := strconv.Itoa(rand.Intn(portRange[1]-portRange[0]) + portRange[0]) + testPortMutex.Lock() + for i := 0; i < 10; i++ { + if _, ok := testPortMap[randomPort]; !ok { + break + } + randomPort = strconv.Itoa(rand.Intn(portRange[1]-portRange[0]) + portRange[0]) + } + testPortMutex.Unlock() + localAddr, err := net.ResolveTCPAddr(network, "0.0.0.0:"+randomPort) + if err != nil { + return nil, err + } + dialer.LocalAddr = localAddr + return dialer.DialContext(ctx, network, addr) + } + + TestDialClient.Transport = &http.Transport{ + DisableKeepAlives: true, + DialContext: dialContext, + } +} + var once sync.Once // InitLogger initializes the logger for test. @@ -157,7 +200,7 @@ func WaitForPrimaryServing(re *require.Assertions, serverMap map[string]bs.Serve } } return false - }, testutil.WithWaitFor(5*time.Second), testutil.WithTickInterval(50*time.Millisecond)) + }, testutil.WithWaitFor(10*time.Second), testutil.WithTickInterval(50*time.Millisecond)) return primary } diff --git a/tools/go.mod b/tools/go.mod index 8d0f0d4ec35..2febbe1ad68 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -35,6 +35,7 @@ require ( go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/text v0.14.0 + golang.org/x/tools v0.14.0 google.golang.org/grpc v1.62.1 ) @@ -172,7 +173,6 @@ require ( golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.14.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect diff --git a/tools/pd-ut/README.md b/tools/pd-ut/README.md index 77b59bea4f7..805ee5cf322 100644 --- a/tools/pd-ut/README.md +++ b/tools/pd-ut/README.md @@ -63,4 +63,8 @@ pd-ut run --junitfile xxx // test with race flag pd-ut run --race + +// test with coverprofile +pd-ut run --coverprofile xxx +go tool cover --func=xxx ``` diff --git a/tools/pd-ut/coverProfile.go b/tools/pd-ut/coverProfile.go new file mode 100644 index 00000000000..0ed1c3f3c61 --- /dev/null +++ b/tools/pd-ut/coverProfile.go @@ -0,0 +1,176 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "fmt" + "os" + "path" + "sort" + + "golang.org/x/tools/cover" +) + +func collectCoverProfileFile() { + // Combine all the cover file of single test function into a whole. + files, err := os.ReadDir(coverFileTempDir) + if err != nil { + fmt.Println("collect cover file error:", err) + os.Exit(-1) + } + + w, err := os.Create(coverProfile) + if err != nil { + fmt.Println("create cover file error:", err) + os.Exit(-1) + } + //nolint: errcheck + defer w.Close() + w.WriteString("mode: atomic\n") + + result := make(map[string]*cover.Profile) + for _, file := range files { + if file.IsDir() { + continue + } + collectOneCoverProfileFile(result, file) + } + + w1 := bufio.NewWriter(w) + for _, prof := range result { + for _, block := range prof.Blocks { + fmt.Fprintf(w1, "%s:%d.%d,%d.%d %d %d\n", + prof.FileName, + block.StartLine, + block.StartCol, + block.EndLine, + block.EndCol, + block.NumStmt, + block.Count, + ) + } + if err := w1.Flush(); err != nil { + fmt.Println("flush data to cover profile file error:", err) + os.Exit(-1) + } + } +} + +func collectOneCoverProfileFile(result map[string]*cover.Profile, file os.DirEntry) { + f, err := os.Open(path.Join(coverFileTempDir, file.Name())) + if err != nil { + fmt.Println("open temp cover file error:", err) + os.Exit(-1) + } + //nolint: errcheck + defer f.Close() + + profs, err := cover.ParseProfilesFromReader(f) + if err != nil { + fmt.Println("parse cover profile file error:", err) + os.Exit(-1) + } + mergeProfile(result, profs) +} + +func mergeProfile(m map[string]*cover.Profile, profs []*cover.Profile) { + for _, prof := range profs { + sort.Sort(blocksByStart(prof.Blocks)) + old, ok := m[prof.FileName] + if !ok { + m[prof.FileName] = prof + continue + } + + // Merge samples from the same location. + // The data has already been sorted. + tmp := old.Blocks[:0] + var i, j int + for i < len(old.Blocks) && j < len(prof.Blocks) { + v1 := old.Blocks[i] + v2 := prof.Blocks[j] + + switch compareProfileBlock(v1, v2) { + case -1: + tmp = appendWithReduce(tmp, v1) + i++ + case 1: + tmp = appendWithReduce(tmp, v2) + j++ + default: + tmp = appendWithReduce(tmp, v1) + tmp = appendWithReduce(tmp, v2) + i++ + j++ + } + } + for ; i < len(old.Blocks); i++ { + tmp = appendWithReduce(tmp, old.Blocks[i]) + } + for ; j < len(prof.Blocks); j++ { + tmp = appendWithReduce(tmp, prof.Blocks[j]) + } + + m[prof.FileName] = old + } +} + +// appendWithReduce works like append(), but it merge the duplicated values. +func appendWithReduce(input []cover.ProfileBlock, b cover.ProfileBlock) []cover.ProfileBlock { + if len(input) >= 1 { + last := &input[len(input)-1] + if b.StartLine == last.StartLine && + b.StartCol == last.StartCol && + b.EndLine == last.EndLine && + b.EndCol == last.EndCol { + if b.NumStmt != last.NumStmt { + panic(fmt.Errorf("inconsistent NumStmt: changed from %d to %d", last.NumStmt, b.NumStmt)) + } + // Merge the data with the last one of the slice. + last.Count |= b.Count + return input + } + } + return append(input, b) +} + +type blocksByStart []cover.ProfileBlock + +func compareProfileBlock(x, y cover.ProfileBlock) int { + if x.StartLine < y.StartLine { + return -1 + } + if x.StartLine > y.StartLine { + return 1 + } + + // Now x.StartLine == y.StartLine + if x.StartCol < y.StartCol { + return -1 + } + if x.StartCol > y.StartCol { + return 1 + } + + return 0 +} + +func (b blocksByStart) Len() int { return len(b) } +func (b blocksByStart) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b blocksByStart) Less(i, j int) bool { + bi, bj := b[i], b[j] + return bi.StartLine < bj.StartLine || bi.StartLine == bj.StartLine && bi.StartCol < bj.StartCol +} diff --git a/tools/pd-ut/ut.go b/tools/pd-ut/ut.go index 7fc96ee11cf..9419363c152 100644 --- a/tools/pd-ut/ut.go +++ b/tools/pd-ut/ut.go @@ -74,27 +74,49 @@ pd-ut build xxx pd-ut run --junitfile xxx // test with race flag -pd-ut run --race` +pd-ut run --race + +// test with coverprofile +pd-ut run --coverprofile xxx +go tool cover --func=xxx` fmt.Println(msg) return true } -const modulePath = "github.com/tikv/pd" +var ( + modulePath = "github.com/tikv/pd" + integrationsTestPath = "tests/integrations" +) var ( // runtime - p int - buildParallel int - workDir string + p int + buildParallel int + workDir string + coverFileTempDir string // arguments - race bool - junitFile string + race bool + junitFile string + coverProfile string + ignoreDir string ) func main() { race = handleFlag("--race") junitFile = stripFlag("--junitfile") + coverProfile = stripFlag("--coverprofile") + ignoreDir = stripFlag("--ignore") + + if coverProfile != "" { + var err error + coverFileTempDir, err = os.MkdirTemp(os.TempDir(), "cov") + if err != nil { + fmt.Println("create temp dir fail", coverFileTempDir) + os.Exit(1) + } + defer os.RemoveAll(coverFileTempDir) + } // Get the correct count of CPU if it's in docker. p = runtime.GOMAXPROCS(0) @@ -120,6 +142,18 @@ func main() { isSucceed = cmdBuild(os.Args[2:]...) case "run": isSucceed = cmdRun(os.Args[2:]...) + case "it": + // run integration tests + if len(os.Args) >= 3 { + modulePath = path.Join(modulePath, integrationsTestPath) + workDir = path.Join(workDir, integrationsTestPath) + switch os.Args[2] { + case "run": + isSucceed = cmdRun(os.Args[3:]...) + default: + isSucceed = usage() + } + } default: isSucceed = usage() } @@ -204,10 +238,16 @@ func cmdBuild(args ...string) bool { // build test binary of a single package if len(args) >= 1 { - pkg := args[0] - err := buildTestBinary(pkg) + var dirPkgs []string + for _, pkg := range pkgs { + if strings.Contains(pkg, args[0]) { + dirPkgs = append(dirPkgs, pkg) + } + } + + err := buildTestBinaryMulti(dirPkgs) if err != nil { - log.Println("build package error", pkg, err) + log.Println("build package error", dirPkgs, err) return false } } @@ -248,23 +288,32 @@ func cmdRun(args ...string) bool { // run tests for a single package if len(args) == 1 { - pkg := args[0] - err := buildTestBinary(pkg) - if err != nil { - log.Println("build package error", pkg, err) - return false + var dirPkgs []string + for _, pkg := range pkgs { + if strings.Contains(pkg, args[0]) { + dirPkgs = append(dirPkgs, pkg) + } } - exist, err := testBinaryExist(pkg) + + err := buildTestBinaryMulti(dirPkgs) if err != nil { - log.Println("check test binary existence error", err) + log.Println("build package error", dirPkgs, err) return false } - if !exist { - fmt.Println("no test case in ", pkg) - return false + for _, pkg := range dirPkgs { + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existence error", err) + return false + } + if !exist { + fmt.Println("no test case in ", pkg) + continue + } + + tasks = listTestCases(pkg, tasks) } - tasks = listTestCases(pkg, tasks) } // run a single test @@ -326,6 +375,10 @@ func cmdRun(args ...string) bool { } } + if coverProfile != "" { + collectCoverProfileFile() + } + for _, work := range works { if work.Fail { return false @@ -336,7 +389,7 @@ func cmdRun(args ...string) bool { // stripFlag strip the '--flag xxx' from the command line os.Args // Example of the os.Args changes -// Before: ut run pkg TestXXX --junitfile yyy +// Before: ut run pkg TestXXX --coverprofile xxx --junitfile yyy // After: ut run pkg TestXXX // The value of the flag is returned. func stripFlag(flag string) string { @@ -421,6 +474,7 @@ func filterTestCases(tasks []task, arg1 string) ([]task, error) { func listPackages() ([]string, error) { cmd := exec.Command("go", "list", "./...") + cmd.Dir = workDir ss, err := cmdToLines(cmd) if err != nil { return nil, withTrace(err) @@ -565,7 +619,16 @@ func failureCases(input []JUnitTestCase) int { func (*numa) testCommand(pkg string, fn string) *exec.Cmd { args := make([]string, 0, 10) exe := "./" + testFileName(pkg) - args = append(args, "-test.cpu", "1") + if coverProfile != "" { + fileName := strings.ReplaceAll(pkg, "/", "_") + "." + fn + tmpFile := path.Join(coverFileTempDir, fileName) + args = append(args, "-test.coverprofile", tmpFile) + } + if strings.Contains(fn, "Suite") { + args = append(args, "-test.cpu", fmt.Sprint(p/2)) + } else { + args = append(args, "-test.cpu", "1") + } if !race { args = append(args, []string{"-test.timeout", "2m"}...) } else { @@ -580,7 +643,10 @@ func (*numa) testCommand(pkg string, fn string) *exec.Cmd { } func skipDIR(pkg string) bool { - skipDir := []string{"tests", "bin", "cmd", "tools"} + skipDir := []string{"bin", "cmd", "realcluster"} + if ignoreDir != "" { + skipDir = append(skipDir, ignoreDir) + } for _, ignore := range skipDir { if strings.HasPrefix(pkg, ignore) { return true @@ -593,8 +659,14 @@ func generateBuildCache() error { // cd cmd/pd-server && go test -tags=tso_function_test,deadlock -exec-=true -vet=off -toolexec=go-compile-without-link cmd := exec.Command("go", "test", "-exec=true", "-vet", "off", "--tags=tso_function_test,deadlock") goCompileWithoutLink := fmt.Sprintf("-toolexec=%s/tools/pd-ut/go-compile-without-link.sh", workDir) - cmd.Args = append(cmd.Args, goCompileWithoutLink) cmd.Dir = fmt.Sprintf("%s/cmd/pd-server", workDir) + if strings.Contains(workDir, integrationsTestPath) { + cmd.Dir = fmt.Sprintf("%s/cmd/pd-server", workDir[:strings.LastIndex(workDir, integrationsTestPath)]) + goCompileWithoutLink = fmt.Sprintf("-toolexec=%s/tools/pd-ut/go-compile-without-link.sh", + workDir[:strings.LastIndex(workDir, integrationsTestPath)]) + } + cmd.Args = append(cmd.Args, goCompileWithoutLink) + cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { @@ -612,7 +684,11 @@ func buildTestBinaryMulti(pkgs []string) error { } // go test --exec=xprog --tags=tso_function_test,deadlock -vet=off --count=0 $(pkgs) + // workPath just like `/data/nvme0n1/husharp/proj/pd/tests/integrations` xprogPath := path.Join(workDir, "bin/xprog") + if strings.Contains(workDir, integrationsTestPath) { + xprogPath = path.Join(workDir[:strings.LastIndex(workDir, integrationsTestPath)], "bin/xprog") + } packages := make([]string, 0, len(pkgs)) for _, pkg := range pkgs { packages = append(packages, path.Join(modulePath, pkg)) @@ -620,6 +696,13 @@ func buildTestBinaryMulti(pkgs []string) error { p := strconv.Itoa(buildParallel) cmd := exec.Command("go", "test", "-p", p, "--exec", xprogPath, "-vet", "off", "--tags=tso_function_test,deadlock") + if coverProfile != "" { + coverpkg := "./..." + if strings.Contains(workDir, integrationsTestPath) { + coverpkg = "../../..." + } + cmd.Args = append(cmd.Args, "-cover", fmt.Sprintf("-coverpkg=%s", coverpkg)) + } cmd.Args = append(cmd.Args, packages...) cmd.Dir = workDir cmd.Stdout = os.Stdout @@ -633,6 +716,9 @@ func buildTestBinaryMulti(pkgs []string) error { func buildTestBinary(pkg string) error { //nolint:gosec cmd := exec.Command("go", "test", "-c", "-vet", "off", "--tags=tso_function_test,deadlock", "-o", testFileName(pkg), "-v") + if coverProfile != "" { + cmd.Args = append(cmd.Args, "-cover", "-coverpkg=./...") + } if race { cmd.Args = append(cmd.Args, "-race") }