diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel index 01c1bd2c42001..dbd3ad17df790 100644 --- a/br/pkg/lightning/backend/local/BUILD.bazel +++ b/br/pkg/lightning/backend/local/BUILD.bazel @@ -167,6 +167,7 @@ go_test( "@com_github_tikv_pd_client//errs", "@org_golang_google_grpc//:grpc", "@org_golang_google_grpc//codes", + "@org_golang_google_grpc//encoding", "@org_golang_google_grpc//status", "@org_uber_go_atomic//:atomic", ], diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go index 902dd906fa040..6c2aa6fb435d1 100644 --- a/br/pkg/lightning/backend/local/local_test.go +++ b/br/pkg/lightning/backend/local/local_test.go @@ -29,6 +29,7 @@ import ( "sync/atomic" "testing" "time" + _ "unsafe" "github.com/cockroachdb/pebble" "github.com/docker/go-units" @@ -59,6 +60,7 @@ import ( pd "github.com/tikv/pd/client" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/status" ) @@ -742,6 +744,33 @@ func (m mockWriteClient) CloseAndRecv() (*sst.WriteResponse, error) { return m.writeResp, nil } +type baseCodec interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error +} + +//go:linkname newContextWithRPCInfo google.golang.org/grpc.newContextWithRPCInfo +func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp grpc.Compressor, comp encoding.Compressor) context.Context + +type mockCodec struct{} + +func (m mockCodec) Marshal(v interface{}) ([]byte, error) { + return nil, nil +} + +func (m mockCodec) Unmarshal(data []byte, v interface{}) error { + return nil +} + +func (m mockWriteClient) Context() context.Context { + ctx := context.Background() + return newContextWithRPCInfo(ctx, false, mockCodec{}, nil, nil) +} + +func (m mockWriteClient) SendMsg(_ interface{}) error { + return nil +} + func (c *mockImportClient) Write(ctx context.Context, opts ...grpc.CallOption) (sst.ImportSST_WriteClient, error) { if c.apiInvokeRecorder != nil { c.apiInvokeRecorder["Write"] = append(c.apiInvokeRecorder["Write"], c.store.GetId()) diff --git a/br/pkg/lightning/backend/local/region_job.go b/br/pkg/lightning/backend/local/region_job.go index a75b0ac908662..f4d898bb40a20 100644 --- a/br/pkg/lightning/backend/local/region_job.go +++ b/br/pkg/lightning/backend/local/region_job.go @@ -39,6 +39,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/mathutil" "go.uber.org/zap" + "google.golang.org/grpc" ) type jobStageTp string @@ -225,7 +226,11 @@ func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { leaderID := j.region.Leader.GetId() clients := make([]sst.ImportSST_WriteClient, 0, len(region.GetPeers())) allPeers := make([]*metapb.Peer, 0, len(region.GetPeers())) - requests := make([]*sst.WriteRequest, 0, len(region.GetPeers())) + req := &sst.WriteRequest{ + Chunk: &sst.WriteRequest_Meta{ + Meta: meta, + }, + } for _, peer := range region.GetPeers() { cli, err := clientFactory.Create(ctx, peer.StoreId) if err != nil { @@ -238,23 +243,17 @@ func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { } // Bind uuid for this write request - req := &sst.WriteRequest{ - Chunk: &sst.WriteRequest_Meta{ - Meta: meta, - }, - } if err = wstream.Send(req); err != nil { return annotateErr(err, peer) } - req.Chunk = &sst.WriteRequest_Batch{ - Batch: &sst.WriteBatch{ - CommitTs: j.engine.TS, - }, - } clients = append(clients, wstream) - requests = append(requests, req) allPeers = append(allPeers, peer) } + req.Chunk = &sst.WriteRequest_Batch{ + Batch: &sst.WriteBatch{ + CommitTs: j.engine.TS, + }, + } bytesBuf := bufferPool.NewBuffer() defer bytesBuf.Destroy() @@ -271,12 +270,19 @@ func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { } flushKVs := func() error { + req.Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] + preparedMsg := &grpc.PreparedMsg{} + // by reading the source code, Encode need to find codec and compression from the stream + // because all stream has the same codec and compression, we can use any one of them + if err := preparedMsg.Encode(clients[0], req); err != nil { + return err + } + for i := range clients { if err := writeLimiter.WaitN(ctx, allPeers[i].StoreId, int(size)); err != nil { return errors.Trace(err) } - requests[i].Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] - if err := clients[i].Send(requests[i]); err != nil { + if err := clients[i].SendMsg(preparedMsg); err != nil { return annotateErr(err, allPeers[i]) } }