diff --git a/.gitignore b/.gitignore index c71bd34b114bd..48acac8905ae4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ coverage.out *.iml *.swp *.log +*.test.bin tags profile.coverprofile explain_test diff --git a/Makefile b/Makefile index d49b0db744da4..30e01b2c1a1e1 100644 --- a/Makefile +++ b/Makefile @@ -126,6 +126,10 @@ devgotest: failpoint-enable $(GOTEST) -ldflags '$(TEST_LDFLAGS)' $(EXTRA_TEST_ARGS) -cover $(PACKAGES_TIDB_TESTS) -check.p true > gotest.log || { $(FAILPOINT_DISABLE); grep -v '^\([[]20\|PASS:\|ok \)' 'gotest.log'; exit 1; } @$(FAILPOINT_DISABLE) +ut: failpoint-enable tools/bin/ut + tools/bin/ut $(X); + @$(FAILPOINT_DISABLE) + gotest: failpoint-enable @echo "Running in native mode." @export log_level=info; export TZ='Asia/Shanghai'; \ @@ -220,6 +224,10 @@ failpoint-disable: tools/bin/failpoint-ctl # Restoring gofail failpoints... @$(FAILPOINT_DISABLE) +tools/bin/ut: tools/check/ut.go + cd tools/check; \ + $(GO) build -o ../bin/ut ut.go + tools/bin/megacheck: tools/check/go.mod cd tools/check; \ $(GO) build -o ../bin/megacheck honnef.co/go/tools/cmd/megacheck diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go index eb7ab37802e4e..b703acec49395 100644 --- a/br/pkg/lightning/backend/local/local.go +++ b/br/pkg/lightning/backend/local/local.go @@ -150,9 +150,9 @@ type local struct { duplicateDetection bool duplicateDB *pebble.DB errorMgr *errormanager.ErrorManager -} -var bufferPool = membuf.NewPool(1024, manual.Allocator{}) + bufferPool *membuf.Pool +} func openDuplicateDB(storeDir string) (*pebble.DB, error) { dbPath := filepath.Join(storeDir, duplicateDBName) @@ -244,6 +244,8 @@ func NewLocalBackend( checkTiKVAvaliable: cfg.App.CheckRequirements, duplicateDB: duplicateDB, errorMgr: errorMgr, + + bufferPool: membuf.NewPool(membuf.WithAllocator(manual.Allocator{})), } local.conns = common.NewGRPCConns() if err = local.checkMultiIngestSupport(ctx); err != nil { @@ -423,6 +425,7 @@ func (local *local) Close() { engine.unlock() } local.conns.Close() + local.bufferPool.Destroy() if local.duplicateDB != nil { // Check whether there are duplicates. @@ -776,7 +779,7 @@ func (local *local) WriteToTiKV( requests = append(requests, req) } - bytesBuf := bufferPool.NewBuffer() + bytesBuf := local.bufferPool.NewBuffer() defer bytesBuf.Destroy() pairs := make([]*sst.Pair, 0, local.batchWriteKVPairs) count := 0 @@ -1664,14 +1667,14 @@ func (local *local) LocalWriter(ctx context.Context, cfg *backend.LocalWriterCon return nil, errors.Errorf("could not find engine for %s", engineUUID.String()) } engine := e.(*Engine) - return openLocalWriter(cfg, engine, local.localWriterMemCacheSize) + return openLocalWriter(cfg, engine, local.localWriterMemCacheSize, local.bufferPool.NewBuffer()) } -func openLocalWriter(cfg *backend.LocalWriterConfig, engine *Engine, cacheSize int64) (*Writer, error) { +func openLocalWriter(cfg *backend.LocalWriterConfig, engine *Engine, cacheSize int64, kvBuffer *membuf.Buffer) (*Writer, error) { w := &Writer{ engine: engine, memtableSizeLimit: cacheSize, - kvBuffer: bufferPool.NewBuffer(), + kvBuffer: kvBuffer, isKVSorted: cfg.IsKVSorted, isWriteBatchSorted: true, } diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go index 747034068c463..35c13692dce3e 100644 --- a/br/pkg/lightning/backend/local/local_test.go +++ b/br/pkg/lightning/backend/local/local_test.go @@ -46,6 +46,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/backend/kv" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/mydump" + "github.com/pingcap/tidb/br/pkg/membuf" "github.com/pingcap/tidb/br/pkg/mock" "github.com/pingcap/tidb/br/pkg/pdutil" "github.com/pingcap/tidb/br/pkg/restore" @@ -357,7 +358,10 @@ func testLocalWriter(c *C, needSort bool, partitialSort bool) { f.wg.Add(1) go f.ingestSSTLoop() sorted := needSort && !partitialSort - w, err := openLocalWriter(&backend.LocalWriterConfig{IsKVSorted: sorted}, f, 1024) + pool := membuf.NewPool() + defer pool.Destroy() + kvBuffer := pool.NewBuffer() + w, err := openLocalWriter(&backend.LocalWriterConfig{IsKVSorted: sorted}, f, 1024, kvBuffer) c.Assert(err, IsNil) ctx := context.Background() diff --git a/br/pkg/lightning/backend/local/localhelper_test.go b/br/pkg/lightning/backend/local/localhelper_test.go index d901b3c2711e6..52a9b71286087 100644 --- a/br/pkg/lightning/backend/local/localhelper_test.go +++ b/br/pkg/lightning/backend/local/localhelper_test.go @@ -69,6 +69,11 @@ func newTestClient( } } +// ScatterRegions scatters regions in a batch. +func (c *testClient) ScatterRegions(ctx context.Context, regionInfo []*restore.RegionInfo) error { + return nil +} + func (c *testClient) GetAllRegions() map[uint64]*restore.RegionInfo { c.mu.RLock() defer c.mu.RUnlock() diff --git a/br/pkg/membuf/buffer.go b/br/pkg/membuf/buffer.go index 172d99baec9aa..49ffbae8afdf3 100644 --- a/br/pkg/membuf/buffer.go +++ b/br/pkg/membuf/buffer.go @@ -14,9 +14,11 @@ package membuf -const bigValueSize = 1 << 16 // 64K - -var allocBufLen = 1 << 20 // 1M +const ( + defaultPoolSize = 1024 + defaultBlockSize = 1 << 20 // 1M + defaultLargeAllocThreshold = 1 << 16 // 64K +) // Allocator is the abstract interface for allocating and freeing memory. type Allocator interface { @@ -38,30 +40,71 @@ func (stdAllocator) Free(_ []byte) {} // garbage collector which always release the memory so late. Use a fixed size chan to reuse // can decrease the memory usage to 1/3 compare with sync.Pool. type Pool struct { - allocator Allocator - recycleCh chan []byte + allocator Allocator + blockSize int + blockCache chan []byte + largeAllocThreshold int +} + +// Option configures a pool. +type Option func(p *Pool) + +// WithPoolSize configures how many blocks cached by this pool. +func WithPoolSize(size int) Option { + return func(p *Pool) { + p.blockCache = make(chan []byte, size) + } +} + +// WithBlockSize configures the size of each block. +func WithBlockSize(size int) Option { + return func(p *Pool) { + p.blockSize = size + } +} + +// WithAllocator specifies the allocator used by pool to allocate and free memory. +func WithAllocator(allocator Allocator) Option { + return func(p *Pool) { + p.allocator = allocator + } +} + +// WithLargeAllocThreshold configures the threshold for large allocation of a Buffer. +// If allocate size is larger than this threshold, bytes will be allocated directly +// by the make built-in function and won't be tracked by the pool. +func WithLargeAllocThreshold(threshold int) Option { + return func(p *Pool) { + p.largeAllocThreshold = threshold + } } // NewPool creates a new pool. -func NewPool(size int, allocator Allocator) *Pool { - return &Pool{ - allocator: allocator, - recycleCh: make(chan []byte, size), +func NewPool(opts ...Option) *Pool { + p := &Pool{ + allocator: stdAllocator{}, + blockSize: defaultBlockSize, + blockCache: make(chan []byte, defaultPoolSize), + largeAllocThreshold: defaultLargeAllocThreshold, + } + for _, opt := range opts { + opt(p) } + return p } func (p *Pool) acquire() []byte { select { - case b := <-p.recycleCh: + case b := <-p.blockCache: return b default: - return p.allocator.Alloc(allocBufLen) + return p.allocator.Alloc(p.blockSize) } } func (p *Pool) release(b []byte) { select { - case p.recycleCh <- b: + case p.blockCache <- b: default: p.allocator.Free(b) } @@ -72,10 +115,12 @@ func (p *Pool) NewBuffer() *Buffer { return &Buffer{pool: p, bufs: make([][]byte, 0, 128), curBufIdx: -1} } -var globalPool = NewPool(1024, stdAllocator{}) - -// NewBuffer creates a new buffer in global pool. -func NewBuffer() *Buffer { return globalPool.NewBuffer() } +func (p *Pool) Destroy() { + close(p.blockCache) + for b := range p.blockCache { + p.allocator.Free(b) + } +} // Buffer represents the reuse buffer. type Buffer struct { @@ -123,12 +168,12 @@ func (b *Buffer) Destroy() { // TotalSize represents the total memory size of this Buffer. func (b *Buffer) TotalSize() int64 { - return int64(len(b.bufs) * allocBufLen) + return int64(len(b.bufs) * b.pool.blockSize) } // AllocBytes allocates bytes with the given length. func (b *Buffer) AllocBytes(n int) []byte { - if n > bigValueSize { + if n > b.pool.largeAllocThreshold { return make([]byte, n) } if b.curIdx+n > b.curBufLen { diff --git a/br/pkg/membuf/buffer_test.go b/br/pkg/membuf/buffer_test.go index c5d095d299f9c..fa45c5c4e34b1 100644 --- a/br/pkg/membuf/buffer_test.go +++ b/br/pkg/membuf/buffer_test.go @@ -21,10 +21,6 @@ import ( "github.com/stretchr/testify/require" ) -func init() { - allocBufLen = 1024 -} - type testAllocator struct { allocs int frees int @@ -41,7 +37,13 @@ func (t *testAllocator) Free(_ []byte) { func TestBufferPool(t *testing.T) { allocator := &testAllocator{} - pool := NewPool(2, allocator) + pool := NewPool( + WithPoolSize(2), + WithAllocator(allocator), + WithBlockSize(1024), + WithLargeAllocThreshold(512), + ) + defer pool.Destroy() bytesBuf := pool.NewBuffer() bytesBuf.AllocBytes(256) @@ -53,6 +55,10 @@ func TestBufferPool(t *testing.T) { bytesBuf.AllocBytes(767) require.Equal(t, 2, allocator.allocs) + largeBytes := bytesBuf.AllocBytes(513) + require.Equal(t, 513, len(largeBytes)) + require.Equal(t, 2, allocator.allocs) + require.Equal(t, 0, allocator.frees) bytesBuf.Destroy() require.Equal(t, 0, allocator.frees) @@ -67,7 +73,9 @@ func TestBufferPool(t *testing.T) { } func TestBufferIsolation(t *testing.T) { - bytesBuf := NewBuffer() + pool := NewPool(WithBlockSize(1024)) + defer pool.Destroy() + bytesBuf := pool.NewBuffer() defer bytesBuf.Destroy() b1 := bytesBuf.AllocBytes(16) diff --git a/br/pkg/restore/split.go b/br/pkg/restore/split.go index c962a2109aac6..ada8662522c21 100644 --- a/br/pkg/restore/split.go +++ b/br/pkg/restore/split.go @@ -24,6 +24,8 @@ import ( "github.com/tikv/pd/pkg/codec" "go.uber.org/multierr" "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // Constants for split retry machinery. @@ -112,6 +114,7 @@ SplitRegions: regionMap[region.Region.GetId()] = region } for regionID, keys := range splitKeyMap { + log.Info("get split keys for region", zap.Int("len", len(keys)), zap.Uint64("region", regionID)) var newRegions []*RegionInfo region := regionMap[regionID] log.Info("split regions", @@ -142,6 +145,7 @@ SplitRegions: logutil.Keys(keys), rtree.ZapRanges(ranges)) continue SplitRegions } + log.Info("scattered regions", zap.Int("count", len(newRegions))) if len(newRegions) != len(keys) { log.Warn("split key count and new region count mismatch", zap.Int("new region count", len(newRegions)), @@ -294,8 +298,6 @@ func (rs *RegionSplitter) ScatterRegionsWithBackoffer(ctx context.Context, newRe log.Info("trying to scatter regions...", zap.Int("remain", len(newRegionSet))) var errs error for _, region := range newRegionSet { - // Wait for a while until the regions successfully split. - rs.waitForSplit(ctx, region.Region.Id) err := rs.client.ScatterRegion(ctx, region) if err == nil { // it is safe accroding to the Go language spec. @@ -328,15 +330,54 @@ func (rs *RegionSplitter) ScatterRegionsWithBackoffer(ctx context.Context, newRe } +// isUnsupportedError checks whether we should fallback to ScatterRegion API when meeting the error. +func isUnsupportedError(err error) bool { + s, ok := status.FromError(errors.Cause(err)) + if !ok { + // Not a gRPC error. Something other went wrong. + return false + } + // In two conditions, we fallback to ScatterRegion: + // (1) If the RPC endpoint returns UNIMPLEMENTED. (This is just for making test cases not be so magic.) + // (2) If the Message is "region 0 not found": + // In fact, PD reuses the gRPC endpoint `ScatterRegion` for the batch version of scattering. + // When the request contains the field `regionIDs`, it would use the batch version, + // Otherwise, it uses the old version and scatter the region with `regionID` in the request. + // When facing 4.x, BR(which uses v5.x PD clients and call `ScatterRegions`!) would set `regionIDs` + // which would be ignored by protocol buffers, and leave the `regionID` be zero. + // Then the older version of PD would try to search the region with ID 0. + // (Then it consistently fails, and returns "region 0 not found".) + return s.Code() == codes.Unimplemented || + strings.Contains(s.Message(), "region 0 not found") +} + // ScatterRegions scatter the regions. func (rs *RegionSplitter) ScatterRegions(ctx context.Context, newRegions []*RegionInfo) { - rs.ScatterRegionsWithBackoffer( - ctx, newRegions, - // backoff about 6s, or we give up scattering this region. - &exponentialBackoffer{ - attempt: 7, - baseBackoff: 100 * time.Millisecond, - }) + for _, region := range newRegions { + // Wait for a while until the regions successfully split. + rs.waitForSplit(ctx, region.Region.Id) + } + + err := utils.WithRetry(ctx, func() error { + err := rs.client.ScatterRegions(ctx, newRegions) + if isUnsupportedError(err) { + log.Warn("batch scatter isn't supported, rollback to old method", logutil.ShortError(err)) + rs.ScatterRegionsWithBackoffer( + ctx, newRegions, + // backoff about 6s, or we give up scattering this region. + &exponentialBackoffer{ + attempt: 7, + baseBackoff: 100 * time.Millisecond, + }) + return nil + } + return err + // the retry is for the temporary network errors during sending request. + }, &exponentialBackoffer{attempt: 3, baseBackoff: 500 * time.Millisecond}) + + if err != nil { + log.Warn("failed to batch scatter region", logutil.ShortError(err)) + } } func CheckRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { diff --git a/br/pkg/restore/split_client.go b/br/pkg/restore/split_client.go index 10a9913d8e683..ed24fc3984a52 100755 --- a/br/pkg/restore/split_client.go +++ b/br/pkg/restore/split_client.go @@ -60,6 +60,8 @@ type SplitClient interface { BatchSplitRegionsWithOrigin(ctx context.Context, regionInfo *RegionInfo, keys [][]byte) (*RegionInfo, []*RegionInfo, error) // ScatterRegion scatters a specified region. ScatterRegion(ctx context.Context, regionInfo *RegionInfo) error + // ScatterRegions scatters regions in a batch. + ScatterRegions(ctx context.Context, regionInfo []*RegionInfo) error // GetOperator gets the status of operator of the specified region. GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) // ScanRegion gets a list of regions, starts from the region that contains key. @@ -114,6 +116,24 @@ func (c *pdClient) needScatter(ctx context.Context) bool { return c.needScatterVal } +// ScatterRegions scatters regions in a batch. +func (c *pdClient) ScatterRegions(ctx context.Context, regionInfo []*RegionInfo) error { + c.mu.Lock() + defer c.mu.Unlock() + regionsID := make([]uint64, 0, len(regionInfo)) + for _, v := range regionInfo { + regionsID = append(regionsID, v.Region.Id) + } + resp, err := c.client.ScatterRegions(ctx, regionsID) + if err != nil { + return err + } + if pbErr := resp.GetHeader().GetError(); pbErr.GetType() != pdpb.ErrorType_OK { + return errors.Annotatef(berrors.ErrPDInvalidResponse, "pd returns error during batch scattering: %s", pbErr) + } + return nil +} + func (c *pdClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { c.mu.Lock() defer c.mu.Unlock() diff --git a/br/pkg/restore/split_test.go b/br/pkg/restore/split_test.go index 5e43d3378e579..fdfbba8df54d0 100644 --- a/br/pkg/restore/split_test.go +++ b/br/pkg/restore/split_test.go @@ -21,17 +21,19 @@ import ( "github.com/stretchr/testify/require" "github.com/tikv/pd/server/core" "github.com/tikv/pd/server/schedule/placement" + "go.uber.org/multierr" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type TestClient struct { - mu sync.RWMutex - stores map[uint64]*metapb.Store - regions map[uint64]*restore.RegionInfo - regionsInfo *core.RegionsInfo // For now it's only used in ScanRegions - nextRegionID uint64 - injectInScatter func(*restore.RegionInfo) error + mu sync.RWMutex + stores map[uint64]*metapb.Store + regions map[uint64]*restore.RegionInfo + regionsInfo *core.RegionsInfo // For now it's only used in ScanRegions + nextRegionID uint64 + injectInScatter func(*restore.RegionInfo) error + supportBatchScatter bool scattered map[uint64]bool } @@ -55,6 +57,36 @@ func NewTestClient( } } +func (c *TestClient) InstallBatchScatterSupport() { + c.supportBatchScatter = true +} + +// ScatterRegions scatters regions in a batch. +func (c *TestClient) ScatterRegions(ctx context.Context, regionInfo []*restore.RegionInfo) error { + if !c.supportBatchScatter { + return status.Error(codes.Unimplemented, "Ah, yep") + } + regions := map[uint64]*restore.RegionInfo{} + for _, region := range regionInfo { + regions[region.Region.Id] = region + } + var err error + for i := 0; i < 3; i++ { + if len(regions) == 0 { + return nil + } + for id, region := range regions { + splitErr := c.ScatterRegion(ctx, region) + if splitErr == nil { + delete(regions, id) + } + err = multierr.Append(err, splitErr) + + } + } + return nil +} + func (c *TestClient) GetAllRegions() map[uint64]*restore.RegionInfo { c.mu.RLock() defer c.mu.RUnlock() @@ -282,7 +314,18 @@ func TestScatterFinishInTime(t *testing.T) { // [, aay), [aay, bba), [bba, bbf), [bbf, bbh), [bbh, bbj), // [bbj, cca), [cca, xxe), [xxe, xxz), [xxz, ) func TestSplitAndScatter(t *testing.T) { - client := initTestClient() + t.Run("BatchScatter", func(t *testing.T) { + client := initTestClient() + client.InstallBatchScatterSupport() + runTestSplitAndScatterWith(t, client) + }) + t.Run("BackwardCompatibility", func(t *testing.T) { + client := initTestClient() + runTestSplitAndScatterWith(t, client) + }) +} + +func runTestSplitAndScatterWith(t *testing.T, client *TestClient) { ranges := initRanges() rewriteRules := initRewriteRules() regionSplitter := restore.NewRegionSplitter(client) @@ -320,7 +363,6 @@ func TestSplitAndScatter(t *testing.T) { t.Fatalf("region %d has not been scattered: %#v", key, regions[key]) } } - } // region: [, aay), [aay, bba), [bba, bbh), [bbh, cca), [cca, ) diff --git a/ddl/placement/errors.go b/ddl/placement/errors.go index 4f98e0fa4c2ec..b609827bd4ce7 100644 --- a/ddl/placement/errors.go +++ b/ddl/placement/errors.go @@ -43,4 +43,8 @@ var ( ErrNoRulesToDrop = errors.New("no rule of such role to drop") // ErrInvalidPlacementOptions is from bundle.go. ErrInvalidPlacementOptions = errors.New("invalid placement option") + // ErrInvalidConstraintsMappingWrongSeparator is wrong separator in mapping. + ErrInvalidConstraintsMappingWrongSeparator = errors.New("mappings use a colon and space (“: ”) to mark each key/value pair") + // ErrInvalidConstraintsMappingNoColonFound is no colon found in mapping. + ErrInvalidConstraintsMappingNoColonFound = errors.New("no colon found") ) diff --git a/ddl/placement/rule.go b/ddl/placement/rule.go index 88cd5067153f8..518dd369414b2 100644 --- a/ddl/placement/rule.go +++ b/ddl/placement/rule.go @@ -16,6 +16,7 @@ package placement import ( "fmt" + "regexp" "strings" "gopkg.in/yaml.v2" @@ -61,6 +62,18 @@ func NewRule(role PeerRoleType, replicas uint64, cnst Constraints) *Rule { } } +var wrongSeparatorRegexp = regexp.MustCompile(`[^"':]+:\d`) + +func getYamlMapFormatError(str string) error { + if !strings.Contains(str, ":") { + return ErrInvalidConstraintsMappingNoColonFound + } + if wrongSeparatorRegexp.MatchString(str) { + return ErrInvalidConstraintsMappingWrongSeparator + } + return nil +} + // NewRules constructs []*Rule from a yaml-compatible representation of // 'array' or 'dict' constraints. // Refer to https://github.com/pingcap/tidb/blob/master/docs/design/2020-06-24-placement-rules-in-sql.md. @@ -86,6 +99,9 @@ func NewRules(role PeerRoleType, replicas uint64, cnstr string) ([]*Rule, error) ruleCnt := 0 for labels, cnt := range constraints2 { if cnt <= 0 { + if err := getYamlMapFormatError(string(cnstbytes)); err != nil { + return rules, err + } return rules, fmt.Errorf("%w: count of labels '%s' should be positive, but got %d", ErrInvalidConstraintsMapcnt, labels, cnt) } ruleCnt += cnt diff --git a/ddl/placement/rule_test.go b/ddl/placement/rule_test.go index 9432448127a4a..f38819c278998 100644 --- a/ddl/placement/rule_test.go +++ b/ddl/placement/rule_test.go @@ -16,21 +16,20 @@ package placement import ( "errors" + "reflect" + "testing" . "github.com/pingcap/check" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testRuleSuite{}) - -type testRuleSuite struct{} - -func (t *testRuleSuite) TestClone(c *C) { +func TestClone(t *testing.T) { rule := &Rule{ID: "434"} newRule := rule.Clone() newRule.ID = "121" - c.Assert(rule, DeepEquals, &Rule{ID: "434"}) - c.Assert(newRule, DeepEquals, &Rule{ID: "121"}) + require.Equal(t, &Rule{ID: "434"}, rule) + require.Equal(t, &Rule{ID: "121"}, newRule) } func matchRules(t1, t2 []*Rule, prefix string, c *C) { @@ -50,7 +49,22 @@ func matchRules(t1, t2 []*Rule, prefix string, c *C) { } } -func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { +func matchRulesT(t1, t2 []*Rule, prefix string, t *testing.T) { + require.Equal(t, len(t2), len(t1), prefix) + for i := range t1 { + found := false + for j := range t2 { + ok := reflect.DeepEqual(t2[j], t1[i]) + if ok { + found = true + break + } + } + require.True(t, found, "%s\n\ncan not found %d rule\n%+v\n%+v", prefix, i, t1[i], t2) + } +} + +func TestNewRuleAndNewRules(t *testing.T) { type TestCase struct { name string input string @@ -58,7 +72,7 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { output []*Rule err error } - tests := []TestCase{} + var tests []TestCase tests = append(tests, TestCase{ name: "empty constraints", @@ -175,14 +189,21 @@ func (t *testRuleSuite) TestNewRuleAndNewRules(c *C) { err: ErrInvalidConstraintFormat, }) - for _, t := range tests { - comment := Commentf("[%s]", t.name) - output, err := NewRules(Voter, t.replicas, t.input) - if t.err == nil { - c.Assert(err, IsNil, comment) - matchRules(t.output, output, comment.CheckCommentString(), c) + tests = append(tests, TestCase{ + name: "invalid map separator", + input: `{+region=us-east-2:2}`, + replicas: 6, + err: ErrInvalidConstraintsMappingWrongSeparator, + }) + + for _, tt := range tests { + comment := Commentf("[%s]", tt.name) + output, err := NewRules(Voter, tt.replicas, tt.input) + if tt.err == nil { + require.NoError(t, err, comment) + matchRulesT(tt.output, output, comment.CheckCommentString(), t) } else { - c.Assert(errors.Is(err, t.err), IsTrue, Commentf("[%s]\n%s\n%s\n", t.name, err, t.err)) + require.True(t, errors.Is(err, tt.err), "[%s]\n%s\n%s\n", tt.name, err, tt.err) } } } diff --git a/ddl/table.go b/ddl/table.go index 83f7ad0b0e58a..625b4f39df759 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -955,6 +955,7 @@ func (w *worker) onSetTableFlashReplica(t *meta.Meta, job *model.Job) (ver int64 } if replicaInfo.Count > 0 && tableHasPlacementSettings(tblInfo) { + job.State = model.JobStateCancelled return ver, errors.Trace(ErrIncompatibleTiFlashAndPlacement) } @@ -1279,6 +1280,7 @@ func onAlterTablePartitionPlacement(t *meta.Meta, job *model.Job) (ver int64, er } if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + job.State = model.JobStateCancelled return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) } @@ -1350,6 +1352,7 @@ func onAlterTablePlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, } if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Count > 0 { + job.State = model.JobStateCancelled return 0, errors.Trace(ErrIncompatibleTiFlashAndPlacement) } diff --git a/executor/adapter.go b/executor/adapter.go index be83c42d99940..37998a3bfc7c4 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -315,7 +315,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, ret.InfoSchema) }) - a.InfoSchema = ret.InfoSchema + a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema() a.SnapshotTS = ret.LastSnapshotTS a.IsStaleness = ret.IsStaleness a.ReplicaReadScope = ret.ReadReplicaScope diff --git a/executor/compiler.go b/executor/compiler.go index baf49979572c2..5debec73f1590 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -73,7 +73,8 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (*ExecStm sessiontxn.AssertTxnManagerInfoSchema(c.Ctx, ret.InfoSchema) }) - finalPlan, names, err := planner.Optimize(ctx, c.Ctx, stmtNode, ret.InfoSchema) + is := sessiontxn.GetTxnManager(c.Ctx).GetTxnInfoSchema() + finalPlan, names, err := planner.Optimize(ctx, c.Ctx, stmtNode, is) if err != nil { return nil, err } @@ -96,7 +97,7 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (*ExecStm SnapshotTS: ret.LastSnapshotTS, IsStaleness: ret.IsStaleness, ReplicaReadScope: ret.ReadReplicaScope, - InfoSchema: ret.InfoSchema, + InfoSchema: is, Plan: finalPlan, LowerPriority: lowerPriority, Text: stmtNode.Text(), diff --git a/executor/prepared.go b/executor/prepared.go index 3b703c75a9cf5..4f63bce491ab9 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -317,7 +317,7 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error { delete(vars.PreparedStmtNameToID, e.Name) if plannercore.PreparedPlanCacheEnabled() { bindSQL := planner.GetBindSQL4PlanCache(e.ctx, prepared.Stmt) - e.ctx.PreparedPlanCache().Delete(plannercore.NewPSTMTPlanCacheKey( + e.ctx.PreparedPlanCache().Delete(plannercore.NewPlanCacheKey( vars, id, prepared.SchemaVersion, bindSQL, )) } diff --git a/executor/write_test.go b/executor/write_test.go index 11e402f446631..ec326d0b6d436 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -1096,6 +1096,20 @@ func TestReplace(t *testing.T) { tk.MustExec("drop table t1, t2") } +func TestReplaceWithCICollation(t *testing.T) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + tk.MustExec("create table t (a varchar(20) charset utf8mb4 collate utf8mb4_general_ci primary key);") + tk.MustExec("replace into t(a) values (_binary'A '),(_binary'A');") + tk.MustQuery("select a from t use index(primary);").Check(testkit.Rows("A")) + tk.MustQuery("select a from t ignore index(primary);").Check(testkit.Rows("A")) +} + func TestGeneratedColumnForInsert(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/expression/builtin_string.go b/expression/builtin_string.go index c494d9fcb5c10..acac019139708 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -1150,7 +1150,7 @@ func (b *builtinConvertSig) evalString(row chunk.Row) (string, bool, error) { return string(ret), false, err } enc := charset.FindEncoding(resultTp.Charset) - if !charset.IsValidString(enc, expr) { + if !enc.IsValid(hack.Slice(expr)) { replace, _ := enc.Transform(nil, hack.Slice(expr), charset.OpReplace) return string(replace), false, nil } diff --git a/expression/builtin_string_vec.go b/expression/builtin_string_vec.go index 3da555f9319ed..202a3d74ed3f1 100644 --- a/expression/builtin_string_vec.go +++ b/expression/builtin_string_vec.go @@ -689,7 +689,7 @@ func (b *builtinConvertSig) vecEvalString(input *chunk.Chunk, result *chunk.Colu continue } exprI := expr.GetBytes(i) - if !charset.IsValid(enc, exprI) { + if !enc.IsValid(exprI) { encBuf, _ = enc.Transform(encBuf, exprI, charset.OpReplace) result.AppendBytes(encBuf) } else { diff --git a/expression/collation.go b/expression/collation.go index 8dc5df02e55e0..813560775e2b4 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" ) @@ -315,7 +316,7 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) if isNull { continue } - if !charset.IsValidString(enc, str) { + if !enc.IsValid(hack.Slice(str)) { return false } } else { diff --git a/parser/charset/encoding.go b/parser/charset/encoding.go index 25257c44e440b..bf3d6b8ff269c 100644 --- a/parser/charset/encoding.go +++ b/parser/charset/encoding.go @@ -57,6 +57,8 @@ type Encoding interface { Tp() EncodingTp // Peek returns the next char. Peek(src []byte) []byte + // IsValid checks whether the utf-8 bytes can be convert to valid string in current encoding. + IsValid(src []byte) bool // Foreach iterates the characters in in current encoding. Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) // Transform map the bytes in src to dest according to Op. @@ -101,21 +103,6 @@ const ( OpDecodeReplace = opToUTF8 | opTruncateReplace | opCollectTo ) -// IsValid checks whether the bytes is valid in current encoding. -func IsValid(e Encoding, src []byte) bool { - isValid := true - e.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { - isValid = ok - return ok - }) - return isValid -} - -// IsValidString is a string version of IsValid. -func IsValidString(e Encoding, str string) bool { - return IsValid(e, Slice(str)) -} - // CountValidBytes counts the first valid bytes in src that // can be encode to the current encoding. func CountValidBytes(e Encoding, src []byte) int { diff --git a/parser/charset/encoding_ascii.go b/parser/charset/encoding_ascii.go index df5fed9c3bce2..34432d5b42e3c 100644 --- a/parser/charset/encoding_ascii.go +++ b/parser/charset/encoding_ascii.go @@ -49,8 +49,19 @@ func (e *encodingASCII) Peek(src []byte) []byte { return src[:1] } +// IsValid implements Encoding interface. +func (e *encodingASCII) IsValid(src []byte) bool { + srcLen := len(src) + for i := 0; i < srcLen; i++ { + if src[i] > go_unicode.MaxASCII { + return false + } + } + return true +} + func (e *encodingASCII) Transform(dest, src []byte, op Op) ([]byte, error) { - if IsValid(e, src) { + if e.IsValid(src) { return src, nil } return e.encodingBase.Transform(dest, src, op) diff --git a/parser/charset/encoding_base.go b/parser/charset/encoding_base.go index 275db24c5a3d6..213596c6aec55 100644 --- a/parser/charset/encoding_base.go +++ b/parser/charset/encoding_base.go @@ -42,6 +42,15 @@ func (b encodingBase) ToLower(src string) string { return strings.ToLower(src) } +func (b encodingBase) IsValid(src []byte) bool { + isValid := true + b.self.Foreach(src, opFromUTF8, func(from, to []byte, ok bool) bool { + isValid = ok + return ok + }) + return isValid +} + func (b encodingBase) Transform(dest, src []byte, op Op) (result []byte, err error) { if dest == nil { dest = make([]byte, len(src)) diff --git a/parser/charset/encoding_bin.go b/parser/charset/encoding_bin.go index 30fd87644c571..30b35ceb1d856 100644 --- a/parser/charset/encoding_bin.go +++ b/parser/charset/encoding_bin.go @@ -47,6 +47,11 @@ func (e *encodingBin) Peek(src []byte) []byte { return src[:1] } +// IsValid implements Encoding interface. +func (e *encodingBin) IsValid(src []byte) bool { + return true +} + // Foreach implements Encoding interface. func (e *encodingBin) Foreach(src []byte, op Op, fn func(from, to []byte, ok bool) bool) { for i := 0; i < len(src); i++ { diff --git a/parser/charset/encoding_latin1.go b/parser/charset/encoding_latin1.go index 1d2992b87642d..d627ed63ec419 100644 --- a/parser/charset/encoding_latin1.go +++ b/parser/charset/encoding_latin1.go @@ -41,6 +41,11 @@ func (e *encodingLatin1) Peek(src []byte) []byte { return src[:1] } +// IsValid implements Encoding interface. +func (e *encodingLatin1) IsValid(src []byte) bool { + return true +} + // Tp implements Encoding interface. func (e *encodingLatin1) Tp() EncodingTp { return EncodingTpLatin1 diff --git a/parser/charset/encoding_test.go b/parser/charset/encoding_test.go index a78aa640d8be5..27d41dbf5ebd2 100644 --- a/parser/charset/encoding_test.go +++ b/parser/charset/encoding_test.go @@ -133,8 +133,7 @@ func TestEncodingValidate(t *testing.T) { enc = charset.EncodingUTF8MB3StrictImpl } strBytes := []byte(tc.str) - ok := charset.IsValid(enc, strBytes) - require.Equal(t, tc.ok, ok, msg) + require.Equal(t, tc.ok, enc.IsValid(strBytes), msg) replace, _ := enc.Transform(nil, strBytes, charset.OpReplace) require.Equal(t, tc.expected, string(replace), msg) } diff --git a/parser/charset/encoding_utf8.go b/parser/charset/encoding_utf8.go index 871a5e5ec33c1..499ce5ea50de7 100644 --- a/parser/charset/encoding_utf8.go +++ b/parser/charset/encoding_utf8.go @@ -67,9 +67,17 @@ func (e *encodingUTF8) Peek(src []byte) []byte { return src[:nextLen] } +// IsValid implements Encoding interface. +func (e *encodingUTF8) IsValid(src []byte) bool { + if utf8.Valid(src) { + return true + } + return e.encodingBase.IsValid(src) +} + // Transform implements Encoding interface. func (e *encodingUTF8) Transform(dest, src []byte, op Op) ([]byte, error) { - if IsValid(e, src) { + if e.IsValid(src) { return src, nil } return e.encodingBase.Transform(dest, src, op) @@ -93,6 +101,11 @@ type encodingUTF8MB3Strict struct { encodingUTF8 } +// IsValid implements Encoding interface. +func (e *encodingUTF8MB3Strict) IsValid(src []byte) bool { + return e.encodingBase.IsValid(src) +} + // Foreach implements Encoding interface. func (e *encodingUTF8MB3Strict) Foreach(src []byte, op Op, fn func(srcCh, dstCh []byte, ok bool) bool) { for i, w := 0, 0; i < len(src); i += w { @@ -107,7 +120,7 @@ func (e *encodingUTF8MB3Strict) Foreach(src []byte, op Op, fn func(srcCh, dstCh // Transform implements Encoding interface. func (e *encodingUTF8MB3Strict) Transform(dest, src []byte, op Op) ([]byte, error) { - if IsValid(e, src) { + if e.IsValid(src) { return src, nil } return e.encodingBase.Transform(dest, src, op) diff --git a/planner/core/cache.go b/planner/core/cache.go index 4113f3e911e88..2ef974340c063 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -66,7 +66,11 @@ func PreparedPlanCacheEnabled() bool { return isEnabled == preparedPlanCacheEnabled } -type pstmtPlanCacheKey struct { +// planCacheKey is used to access Plan Cache. We put some variables that do not affect the plan into planCacheKey, such as the sql text. +// Put the parameters that may affect the plan in planCacheValue, such as bindSQL. +// However, due to some compatibility reasons, we will temporarily keep some system variable-related values in planCacheKey. +// At the same time, because these variables have a small impact on plan, we will move them to PlanCacheValue later if necessary. +type planCacheKey struct { database string connID uint64 pstmtID uint32 @@ -81,7 +85,7 @@ type pstmtPlanCacheKey struct { } // Hash implements Key interface. -func (key *pstmtPlanCacheKey) Hash() []byte { +func (key *planCacheKey) Hash() []byte { if len(key.hash) == 0 { var ( dbBytes = hack.Slice(key.database) @@ -114,7 +118,7 @@ func (key *pstmtPlanCacheKey) Hash() []byte { // SetPstmtIDSchemaVersion implements PstmtCacheKeyMutator interface to change pstmtID and schemaVersion of cacheKey. // so we can reuse Key instead of new every time. func SetPstmtIDSchemaVersion(key kvcache.Key, pstmtID uint32, schemaVersion int64, isolationReadEngines map[kv.StoreType]struct{}) { - psStmtKey, isPsStmtKey := key.(*pstmtPlanCacheKey) + psStmtKey, isPsStmtKey := key.(*planCacheKey) if !isPsStmtKey { return } @@ -127,13 +131,13 @@ func SetPstmtIDSchemaVersion(key kvcache.Key, pstmtID uint32, schemaVersion int6 psStmtKey.hash = psStmtKey.hash[:0] } -// NewPSTMTPlanCacheKey creates a new pstmtPlanCacheKey object. -func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, schemaVersion int64, bindSQL string) kvcache.Key { +// NewPlanCacheKey creates a new planCacheKey object. +func NewPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, schemaVersion int64, bindSQL string) kvcache.Key { timezoneOffset := 0 if sessionVars.TimeZone != nil { _, timezoneOffset = time.Now().In(sessionVars.TimeZone).Zone() } - key := &pstmtPlanCacheKey{ + key := &planCacheKey{ database: sessionVars.CurrentDB, connID: sessionVars.ConnectionID, pstmtID: pstmtID, @@ -175,16 +179,16 @@ func (s FieldSlice) Equal(tps []*types.FieldType) bool { return true } -// PSTMTPlanCacheValue stores the cached Statement and StmtNode. -type PSTMTPlanCacheValue struct { +// PlanCacheValue stores the cached Statement and StmtNode. +type PlanCacheValue struct { Plan Plan OutPutNames []*types.FieldName TblInfo2UnionScan map[*model.TableInfo]bool UserVarTypes FieldSlice } -// NewPSTMTPlanCacheValue creates a SQLCacheValue. -func NewPSTMTPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, userVarTps []*types.FieldType) *PSTMTPlanCacheValue { +// NewPlanCacheValue creates a SQLCacheValue. +func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, userVarTps []*types.FieldType) *PlanCacheValue { dstMap := make(map[*model.TableInfo]bool) for k, v := range srcMap { dstMap[k] = v @@ -193,7 +197,7 @@ func NewPSTMTPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*mod for i, tp := range userVarTps { userVarTypes[i] = *tp } - return &PSTMTPlanCacheValue{ + return &PlanCacheValue{ Plan: plan, OutPutNames: names, TblInfo2UnionScan: dstMap, diff --git a/planner/core/cache_test.go b/planner/core/cache_test.go index 074d1e4cf2828..c75a4b3963713 100644 --- a/planner/core/cache_test.go +++ b/planner/core/cache_test.go @@ -28,6 +28,6 @@ func TestCacheKey(t *testing.T) { ctx.GetSessionVars().SQLMode = mysql.ModeNone ctx.GetSessionVars().TimeZone = time.UTC ctx.GetSessionVars().ConnectionID = 0 - key := NewPSTMTPlanCacheKey(ctx.GetSessionVars(), 1, 1, "") + key := NewPlanCacheKey(ctx.GetSessionVars(), 1, 1, "") require.Equal(t, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, key.Hash()) } diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index d3e56600b2c25..54f30d7d998ea 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -404,7 +404,7 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, var bindSQL string if prepared.UseCache { bindSQL = GetBindSQL4PlanCache(sctx, prepared.Stmt) - cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion, bindSQL) + cacheKey = NewPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion, bindSQL) } tps := make([]*types.FieldType, len(e.UsingVars)) for i, param := range e.UsingVars { @@ -445,7 +445,7 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, if err := e.checkPreparedPriv(ctx, sctx, preparedStmt, is); err != nil { return err } - cachedVals := cacheValue.([]*PSTMTPlanCacheValue) + cachedVals := cacheValue.([]*PlanCacheValue) for _, cachedVal := range cachedVals { if !cachedVal.UserVarTypes.Equal(tps) { continue @@ -510,30 +510,30 @@ REBUILD: // rebuild key to exclude kv.TiFlash when stmt is not read only if _, isolationReadContainTiFlash := sessVars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmt, sessVars) { delete(sessVars.IsolationReadEngines, kv.TiFlash) - cacheKey = NewPSTMTPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) + cacheKey = NewPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} } else { // We need to reconstruct the plan cache key based on the bindSQL. - cacheKey = NewPSTMTPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) + cacheKey = NewPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) } - cached := NewPSTMTPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, tps) + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, tps) preparedStmt.NormalizedPlan, preparedStmt.PlanDigest = NormalizePlan(p) stmtCtx.SetPlanDigest(preparedStmt.NormalizedPlan, preparedStmt.PlanDigest) if cacheVals, exists := sctx.PreparedPlanCache().Get(cacheKey); exists { hitVal := false - for i, cacheVal := range cacheVals.([]*PSTMTPlanCacheValue) { + for i, cacheVal := range cacheVals.([]*PlanCacheValue) { if cacheVal.UserVarTypes.Equal(tps) { hitVal = true - cacheVals.([]*PSTMTPlanCacheValue)[i] = cached + cacheVals.([]*PlanCacheValue)[i] = cached break } } if !hitVal { - cacheVals = append(cacheVals.([]*PSTMTPlanCacheValue), cached) + cacheVals = append(cacheVals.([]*PlanCacheValue), cached) } sctx.PreparedPlanCache().Put(cacheKey, cacheVals) } else { - sctx.PreparedPlanCache().Put(cacheKey, []*PSTMTPlanCacheValue{cached}) + sctx.PreparedPlanCache().Put(cacheKey, []*PlanCacheValue{cached}) } } err = e.setFoundInPlanCache(sctx, false) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index bd7ae44d36cdf..9721cc68d730c 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -58,6 +58,8 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/set" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" ) const ( @@ -4207,10 +4209,14 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as if r := recover(); r != nil { } }() - err := cachedTable.UpdateLockForRead(ctx, store, startTS) - if err != nil { - log.Warn("Update Lock Info Error") - } + _, err, _ := sf.Do(fmt.Sprintf("%d", tableInfo.ID), func() (interface{}, error) { + err := cachedTable.UpdateLockForRead(ctx, store, startTS) + if err != nil { + log.Warn("Update Lock Info Error", zap.Error(err)) + } + return nil, nil + }) + terror.Log(err) }() } } @@ -4238,6 +4244,8 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as return result, nil } +var sf singleflight.Group + func (b *PlanBuilder) timeRangeForSummaryTable() QueryTimeRange { const defaultSummaryDuration = 30 * time.Minute hints := b.TableHints() diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 9a13eea632962..dfe88d5364114 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -167,7 +167,7 @@ func (ts *TiDBStatement) Close() error { } preparedAst := preparedObj.PreparedAst bindSQL := planner.GetBindSQL4PlanCache(ts.ctx, preparedAst.Stmt) - ts.ctx.PreparedPlanCache().Delete(core.NewPSTMTPlanCacheKey( + ts.ctx.PreparedPlanCache().Delete(core.NewPlanCacheKey( ts.ctx.GetSessionVars(), ts.id, preparedObj.PreparedAst.SchemaVersion, bindSQL)) } ts.ctx.GetSessionVars().RemovePreparedStmt(ts.id) diff --git a/session/session.go b/session/session.go index cd1ab05da5498..f685b44d0be43 100644 --- a/session/session.go +++ b/session/session.go @@ -314,7 +314,7 @@ func (s *session) cleanRetryInfo() { if ok { preparedAst = preparedObj.PreparedAst bindSQL := planner.GetBindSQL4PlanCache(s, preparedAst.Stmt) - cacheKey = plannercore.NewPSTMTPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion, bindSQL) + cacheKey = plannercore.NewPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion, bindSQL) } } } @@ -2130,7 +2130,9 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [ txnCtxProvider := &sessiontxn.SimpleTxnContextProvider{ InfoSchema: is, } - if err = sessiontxn.GetTxnManager(s).SetContextProvider(txnCtxProvider); err != nil { + + txnManager := sessiontxn.GetTxnManager(s) + if err = txnManager.SetContextProvider(txnCtxProvider); err != nil { return nil, err } @@ -2143,9 +2145,9 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [ defer s.txn.onStmtEnd() if ok { - return s.cachedPlanExec(ctx, is, snapshotTS, stmtID, preparedStmt, args) + return s.cachedPlanExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, stmtID, preparedStmt, args) } - return s.preparedStmtExec(ctx, is, snapshotTS, stmtID, preparedStmt, args) + return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, stmtID, preparedStmt, args) } func (s *session) DropPreparedStmt(stmtID uint32) error { diff --git a/table/column.go b/table/column.go index d7e9a9ec5dadb..90404f08cbb69 100644 --- a/table/column.go +++ b/table/column.go @@ -363,18 +363,18 @@ func validateStringDatum(ctx sessionctx.Context, origin, casted *types.Datum, co src := casted.GetBytes() encBytes, err := enc.Transform(nil, src, charset.OpDecode) if err != nil { - casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) + casted.SetBytesAsString(encBytes, col.Collate, 0) nSrc := charset.CountValidBytesDecode(enc, src) return handleWrongCharsetValue(ctx, col, src, nSrc) } - casted.SetBytesAsString(encBytes, charset.CollationUTF8MB4, 0) + casted.SetBytesAsString(encBytes, col.Collate, 0) return nil } // Check if the string is valid in the given column charset. str := casted.GetBytes() - if !charset.IsValid(enc, str) { + if !enc.IsValid(str) { replace, _ := enc.Transform(nil, str, charset.OpReplace) - casted.SetBytesAsString(replace, charset.CollationUTF8MB4, 0) + casted.SetBytesAsString(replace, col.Collate, 0) nSrc := charset.CountValidBytes(enc, str) return handleWrongCharsetValue(ctx, col, str, nSrc) } diff --git a/table/column_test.go b/table/column_test.go index 02cbb12237afc..27e35f94757ba 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -303,6 +303,18 @@ func TestCastValue(t *testing.T) { colInfoS.Charset = charset.CharsetASCII _, err = CastValue(ctx, types.NewDatum([]byte{0x32, 0xf0}), &colInfoS, false, true) require.NoError(t, err) + + colInfoS.Charset = charset.CharsetUTF8MB4 + colInfoS.Collate = "utf8mb4_general_ci" + val, err = CastValue(ctx, types.NewBinaryLiteralDatum([]byte{0xE5, 0xA5, 0xBD}), &colInfoS, false, false) + require.NoError(t, err) + require.Equal(t, "utf8mb4_general_ci", val.Collation()) + val, err = CastValue(ctx, types.NewBinaryLiteralDatum([]byte{0xE5, 0xA5, 0xBD, 0x81}), &colInfoS, false, false) + require.Error(t, err, "[table:1366]Incorrect string value '\\x81' for column ''") + require.Equal(t, "utf8mb4_general_ci", val.Collation()) + val, err = CastValue(ctx, types.NewDatum([]byte{0xE5, 0xA5, 0xBD, 0x81}), &colInfoS, false, false) + require.Error(t, err, "[table:1366]Incorrect string value '\\x81' for column ''") + require.Equal(t, "utf8mb4_general_ci", val.Collation()) } func TestGetDefaultValue(t *testing.T) { diff --git a/table/tables/cache_test.go b/table/tables/cache_test.go index 62e48ccd24c94..a4dc5b4d43d68 100644 --- a/table/tables/cache_test.go +++ b/table/tables/cache_test.go @@ -221,6 +221,8 @@ func TestCacheTableBasicReadAndWrite(t *testing.T) { if tk.HasPlan("select * from write_tmp1", "UnionScan") { break } + // Wait for the cache to be loaded. + time.Sleep(50 * time.Millisecond) } require.True(t, i < 10) diff --git a/tools/check/go.mod b/tools/check/go.mod index 9c9c2b8d3da55..81ee48b2242cd 100644 --- a/tools/check/go.mod +++ b/tools/check/go.mod @@ -16,6 +16,7 @@ require ( github.com/pingcap/failpoint v0.0.0-20200702092429-9f69995143ce // indirect github.com/securego/gosec v0.0.0-20181211171558-12400f9a1ca7 github.com/shurcooL/vfsgen v0.0.0-20181202132449-6a9ea43bcacd + go.uber.org/automaxprocs v1.4.0 // indirect gopkg.in/alecthomas/gometalinter.v2 v2.0.12 // indirect gopkg.in/alecthomas/gometalinter.v3 v3.0.0 // indirect gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect diff --git a/tools/check/go.sum b/tools/check/go.sum index ca6214c124823..776ad3f913a32 100644 --- a/tools/check/go.sum +++ b/tools/check/go.sum @@ -96,6 +96,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= +go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/tools/check/ut.go b/tools/check/ut.go new file mode 100644 index 0000000000000..6fad0f471fe2c --- /dev/null +++ b/tools/check/ut.go @@ -0,0 +1,536 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "fmt" + "math/rand" + "os" + "os/exec" + "path" + "runtime" + "strings" + "sync" + "time" + + // Set the correct when it runs inside docker. + _ "go.uber.org/automaxprocs" +) + +func usage() { + msg := `// run all tests +ut + +// show usage +ut -h + +// list all packages +ut list + +// list test cases of a single package +ut list $package + +// run all tests +ut run + +// run test all cases of a single package +ut run $package + +// run test cases of a single package +ut run $package $test + +// build all test package +ut build + +// build a test package +ut build xxx` + fmt.Println(msg) +} + +const modulePath = "github.com/pingcap/tidb" + +type task struct { + pkg string + test string + old bool +} + +var P int +var workDir string + +func cmdList(args ...string) { + pkgs, err := listPackages() + if err != nil { + fmt.Println("list package error", err) + return + } + + // list all packages + if len(args) == 0 { + for _, pkg := range pkgs { + fmt.Println(pkg) + } + return + } + + // list test case of a single package + if len(args) == 1 { + pkg := args[0] + pkgs = filter(pkgs, func(s string) bool { return s == pkg }) + if len(pkgs) != 1 { + fmt.Println("package not exist", pkg) + return + } + + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existance error", err) + return + } + if !exist { + fmt.Println("no test case in ", pkg) + return + } + + res, err := listTestCases(pkg, nil) + if err != nil { + fmt.Println("list test cases for package error", err) + return + } + for _, x := range res { + fmt.Println(x.test) + } + } +} + +func cmdBuild(args ...string) { + pkgs, err := listPackages() + if err != nil { + fmt.Println("list package error", err) + return + } + + // build all packages + if len(args) == 0 { + for _, pkg := range pkgs { + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + } + return + } + + // build test binary of a single package + if len(args) >= 1 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + } +} + +func cmdRun(args ...string) { + var err error + pkgs, err := listPackages() + if err != nil { + fmt.Println("list packages error", err) + return + } + tasks := make([]task, 0, 5000) + // run all tests + if len(args) == 0 { + for _, pkg := range pkgs { + fmt.Println("handling package", pkg) + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existance error", err) + return + } + if !exist { + fmt.Println("no test case in ", pkg) + continue + } + + tasks, err = listTestCases(pkg, tasks) + if err != nil { + fmt.Println("list test cases error", err) + return + } + } + } + + // run tests for a single package + if len(args) == 1 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existance error", err) + return + } + + if !exist { + fmt.Println("no test case in ", pkg) + return + } + tasks, err = listTestCases(pkg, tasks) + if err != nil { + fmt.Println("list test cases error", err) + return + } + } + + // run a single test + if len(args) == 2 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + fmt.Println("build package error", pkg, err) + return + } + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existance error", err) + return + } + if !exist { + fmt.Println("no test case in ", pkg) + return + } + + tasks, err = listTestCases(pkg, tasks) + if err != nil { + fmt.Println("list test cases error", err) + return + } + // filter the test case to run + tmp := tasks[:0] + for _, task := range tasks { + if strings.Contains(task.test, args[1]) { + tmp = append(tmp, task) + } + } + tasks = tmp + } + fmt.Println("building task finish...", len(tasks)) + + numactl := numactlExist() + taskCh := make(chan task, 100) + var wg sync.WaitGroup + for i := 0; i < P; i++ { + n := numa{fmt.Sprintf("%d", i), numactl} + wg.Add(1) + go n.worker(&wg, taskCh) + } + + shuffle(tasks) + for _, task := range tasks { + taskCh <- task + } + close(taskCh) + wg.Wait() +} + +func main() { + // Get the correct count of CPU if it's in docker. + P = runtime.GOMAXPROCS(0) + rand.Seed(time.Now().Unix()) + var err error + workDir, err = os.Getwd() + if err != nil { + fmt.Println("os.Getwd() error", err) + } + + if len(os.Args) == 1 { + // run all tests + cmdRun() + return + } + + if len(os.Args) >= 2 { + switch os.Args[1] { + case "list": + cmdList(os.Args[2:]...) + case "build": + cmdBuild(os.Args[2:]...) + case "run": + cmdRun(os.Args[2:]...) + default: + usage() + } + } +} + +func listTestCases(pkg string, tasks []task) ([]task, error) { + newCases, err := listNewTestCases(pkg) + if err != nil { + fmt.Println("list test case error", pkg, err) + return nil, withTrace(err) + } + for _, c := range newCases { + tasks = append(tasks, task{pkg, c, false}) + } + + oldCases, err := listOldTestCases(pkg) + if err != nil { + fmt.Println("list old test case error", pkg, err) + return nil, withTrace(err) + } + for _, c := range oldCases { + tasks = append(tasks, task{pkg, c, true}) + } + return tasks, nil +} + +func listPackages() ([]string, error) { + cmd := exec.Command("go", "list", "./...") + ss, err := cmdToLines(cmd) + if err != nil { + return nil, withTrace(err) + } + + ret := ss[:0] + for _, s := range ss { + if !strings.HasPrefix(s, modulePath) { + continue + } + pkg := s[len(modulePath)+1:] + if skipDIR(pkg) { + continue + } + ret = append(ret, pkg) + } + return ret, nil +} + +type numa struct { + cpu string + numactl bool +} + +func (n *numa) worker(wg *sync.WaitGroup, ch chan task) { + defer wg.Done() + for t := range ch { + start := time.Now() + if err := n.runTestCase(t.pkg, t.test, t.old); err != nil { + fmt.Println("run test case error", t.pkg, t.test, t.old, time.Since(start), err) + } + } +} + +func (n *numa) runTestCase(pkg string, fn string, old bool) error { + exe := "./" + testFileName(pkg) + var cmd *exec.Cmd + if n.numactl { + cmd = n.testCommandWithNumaCtl(exe, fn, old) + } else { + cmd = n.testCommand(exe, fn, old) + } + cmd.Dir = path.Join(workDir, pkg) + _, err := cmd.CombinedOutput() + if err != nil { + // fmt.Println("run test case error", pkg, fn, string(output)) + return err + } + return nil +} + +func (n *numa) testCommandWithNumaCtl(exe string, fn string, old bool) *exec.Cmd { + if old { + // numactl --physcpubind 3 -- session.test -test.run '^TestT$' -check.f testTxnStateSerialSuite.TestTxnInfoWithPSProtoco + return exec.Command( + "numactl", "--physcpubind", n.cpu, "--", + exe, + "-test.timeout", "20s", + "-test.cpu", "1", "-test.run", "^TestT$", "-check.f", fn) + } + + // numactl --physcpubind 3 -- session.test -test.run TestClusteredPrefixColum + return exec.Command( + "numactl", "--physcpubind", n.cpu, "--", + exe, + "-test.timeout", "20s", + "-test.cpu", "1", "-test.run", fn) +} + +func (n *numa) testCommand(exe string, fn string, old bool) *exec.Cmd { + if old { + // session.test -test.run '^TestT$' -check.f testTxnStateSerialSuite.TestTxnInfoWithPSProtoco + return exec.Command( + exe, + "-test.timeout", "20s", + "-test.cpu", "1", "-test.run", "^TestT$", "-check.f", fn) + } + + // session.test -test.run TestClusteredPrefixColum + return exec.Command( + exe, + "-test.timeout", "20s", + "-test.cpu", "1", "-test.run", fn) +} + +func skipDIR(pkg string) bool { + skipDir := []string{"br", "cmd", "dumpling"} + for _, ignore := range skipDir { + if strings.HasPrefix(pkg, ignore) { + return true + } + } + return false +} + +func buildTestBinary(pkg string) error { + // go test -c + cmd := exec.Command("go", "test", "-c", "-vet", "off", "-o", testFileName(pkg)) + cmd.Dir = path.Join(workDir, pkg) + err := cmd.Run() + return withTrace(err) +} + +func testBinaryExist(pkg string) (bool, error) { + _, err := os.Stat(testFileFullPath(pkg)) + if err != nil { + if _, ok := err.(*os.PathError); ok { + return false, nil + } + } + return true, withTrace(err) +} +func numactlExist() bool { + find, err := exec.Command("which", "numactl").Output() + if err == nil && len(find) > 0 { + return true + } + return false +} + +func testFileName(pkg string) string { + _, file := path.Split(pkg) + return file+".test.bin" +} + +func testFileFullPath(pkg string) string { + return path.Join(workDir, pkg, testFileName(pkg)) +} + +func listNewTestCases(pkg string) ([]string, error) { + exe := "./" + testFileName(pkg) + + // session.test -test.list Test + cmd := exec.Command(exe, "-test.list", "Test") + cmd.Dir = path.Join(workDir, pkg) + res, err := cmdToLines(cmd) + if err != nil { + return nil, withTrace(err) + } + return filter(res, func(s string) bool { + return strings.HasPrefix(s, "Test") && s != "TestT" && s != "TestBenchDaily" + }), nil +} + +func listOldTestCases(pkg string) (res []string, err error) { + exe := "./" + testFileName(pkg) + + // Maybe the restructure is finish on this package. + cmd := exec.Command(exe, "-h") + cmd.Dir = path.Join(workDir, pkg) + buf, err := cmd.CombinedOutput() + if err != nil { + err = withTrace(err) + return + } + if !bytes.Contains(buf, []byte("check.list")) { + // there is no old test case in pkg + return + } + + // session.test -test.run TestT -check.list Test + cmd = exec.Command(exe, "-test.run", "^TestT$", "-check.list", "Test") + cmd.Dir = path.Join(workDir, pkg) + res, err = cmdToLines(cmd) + res = filter(res, func(s string) bool { return strings.Contains(s, "Test") }) + return res, withTrace(err) +} + +func cmdToLines(cmd *exec.Cmd) ([]string, error) { + res, err := cmd.Output() + if err != nil { + return nil, withTrace(err) + } + ss := bytes.Split(res, []byte{'\n'}) + ret := make([]string, len(ss)) + for i, s := range ss { + ret[i] = string(s) + } + return ret, nil +} + +func filter(input []string, f func(string) bool) []string { + ret := input[:0] + for _, s := range input { + if f(s) { + ret = append(ret, s) + } + } + return ret +} + +func shuffle(tasks []task) { + for i := 0; i < len(tasks); i++ { + pos := rand.Intn(len(tasks)) + tasks[i], tasks[pos] = tasks[pos], tasks[i] + } +} + +type errWithStack struct { + err error + buf []byte +} + +func (e *errWithStack) Error() string { + return e.err.Error() + "\n" + string(e.buf) +} + +func withTrace(err error) error { + if err == nil { + return err + } + if _, ok := err.(*errWithStack); ok { + return err + } + var stack [4096]byte + sz := runtime.Stack(stack[:], false) + return &errWithStack{err, stack[:sz]} +}