From 7a35af8a4cfdabce87fe858d502f17f5472d3b5b Mon Sep 17 00:00:00 2001 From: disksing Date: Fri, 26 Mar 2021 19:21:23 +0800 Subject: [PATCH] mocktikv: split rpcHandler to kvHandler and coprHandler (#22857) --- store/mockstore/mocktikv/analyze.go | 6 +- store/mockstore/mocktikv/checksum.go | 2 +- store/mockstore/mocktikv/cop_handler_dag.go | 38 +-- store/mockstore/mocktikv/rpc.go | 323 +++++++------------- store/mockstore/mocktikv/session.go | 146 +++++++++ 5 files changed, 272 insertions(+), 243 deletions(-) create mode 100644 store/mockstore/mocktikv/session.go diff --git a/store/mockstore/mocktikv/analyze.go b/store/mockstore/mocktikv/analyze.go index a575f5536015d..2a013d63313a2 100644 --- a/store/mockstore/mocktikv/analyze.go +++ b/store/mockstore/mocktikv/analyze.go @@ -33,7 +33,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coprocessor.Response { resp := &coprocessor.Response{} if len(req.Ranges) == 0 { return resp @@ -62,7 +62,7 @@ func (h *rpcHandler) handleCopAnalyzeRequest(req *coprocessor.Request) *coproces return resp } -func (h *rpcHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) { +func (h coprHandler) handleAnalyzeIndexReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (*coprocessor.Response, error) { ranges, err := h.extractKVRanges(req.Ranges, false) if err != nil { return nil, errors.Trace(err) @@ -125,7 +125,7 @@ type analyzeColumnsExec struct { fields []*ast.ResultField } -func (h *rpcHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { +func (h coprHandler) handleAnalyzeColumnsReq(req *coprocessor.Request, analyzeReq *tipb.AnalyzeReq) (_ *coprocessor.Response, err error) { sc := flagsToStatementContext(analyzeReq.Flags) sc.TimeZone, err = constructTimeZone("", int(analyzeReq.TimeZoneOffset)) if err != nil { diff --git a/store/mockstore/mocktikv/checksum.go b/store/mockstore/mocktikv/checksum.go index 13f54d26ab5a6..5c99a55ee70bf 100644 --- a/store/mockstore/mocktikv/checksum.go +++ b/store/mockstore/mocktikv/checksum.go @@ -20,7 +20,7 @@ import ( "github.com/pingcap/tipb/go-tipb" ) -func (h *rpcHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopChecksumRequest(req *coprocessor.Request) *coprocessor.Response { resp := &tipb.ChecksumResponse{ Checksum: 1, TotalKvs: 1, diff --git a/store/mockstore/mocktikv/cop_handler_dag.go b/store/mockstore/mocktikv/cop_handler_dag.go index d020d058467ee..82c75e99bb69f 100644 --- a/store/mockstore/mocktikv/cop_handler_dag.go +++ b/store/mockstore/mocktikv/cop_handler_dag.go @@ -54,7 +54,7 @@ type dagContext struct { evalCtx *evalContext } -func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response { +func (h coprHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor.Response { resp := &coprocessor.Response{} dagCtx, e, dagReq, err := h.buildDAGExecutor(req) if err != nil { @@ -88,7 +88,7 @@ func (h *rpcHandler) handleCopDAGRequest(req *coprocessor.Request) *coprocessor. return buildResp(selResp, execDetails, err) } -func (h *rpcHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) { +func (h coprHandler) buildDAGExecutor(req *coprocessor.Request) (*dagContext, executor, *tipb.DAGRequest, error) { if len(req.Ranges) == 0 { return nil, nil, nil, errors.New("request range is null") } @@ -133,7 +133,7 @@ func constructTimeZone(name string, offset int) (*time.Location, error) { return timeutil.ConstructTimeZone(name, offset) } -func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) { +func (h coprHandler) handleCopStream(ctx context.Context, req *coprocessor.Request) (tikvpb.Tikv_CoprocessorStreamClient, error) { dagCtx, e, dagReq, err := h.buildDAGExecutor(req) if err != nil { return nil, errors.Trace(err) @@ -147,7 +147,7 @@ func (h *rpcHandler) handleCopStream(ctx context.Context, req *coprocessor.Reque }, nil } -func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) { +func (h coprHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, *tipb.Executor, error) { var currExec executor var err error var childExec *tipb.Executor @@ -179,7 +179,7 @@ func (h *rpcHandler) buildExec(ctx *dagContext, curr *tipb.Executor) (executor, return currExec, childExec, errors.Trace(err) } -func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) { +func (h coprHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) (executor, error) { curr, child, err := h.buildExec(ctx, farther) if err != nil { return nil, errors.Trace(err) @@ -194,7 +194,7 @@ func (h *rpcHandler) buildDAGForTiFlash(ctx *dagContext, farther *tipb.Executor) return curr, nil } -func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) { +func (h coprHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (executor, error) { var src executor for i := 0; i < len(executors); i++ { curr, _, err := h.buildExec(ctx, executors[i]) @@ -207,7 +207,7 @@ func (h *rpcHandler) buildDAG(ctx *dagContext, executors []*tipb.Executor) (exec return src, nil } -func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) { +func (h coprHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (*tableScanExec, error) { columns := executor.TblScan.Columns ctx.evalCtx.setColumnInfo(columns) ranges, err := h.extractKVRanges(ctx.keyRanges, executor.TblScan.Desc) @@ -258,7 +258,7 @@ func (h *rpcHandler) buildTableScan(ctx *dagContext, executor *tipb.Executor) (* return e, nil } -func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) { +func (h coprHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (*indexScanExec, error) { var err error columns := executor.IdxScan.Columns ctx.evalCtx.setColumnInfo(columns) @@ -311,7 +311,7 @@ func (h *rpcHandler) buildIndexScan(ctx *dagContext, executor *tipb.Executor) (* return e, nil } -func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) { +func (h coprHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (*selectionExec, error) { var err error var relatedColOffsets []int pbConds := executor.Selection.Conditions @@ -335,7 +335,7 @@ func (h *rpcHandler) buildSelection(ctx *dagContext, executor *tipb.Executor) (* }, nil } -func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) { +func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]aggregation.Aggregation, []expression.Expression, []int, error) { length := len(executor.Aggregation.AggFunc) aggs := make([]aggregation.Aggregation, 0, length) var err error @@ -366,7 +366,7 @@ func (h *rpcHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg return aggs, groupBys, relatedColOffsets, nil } -func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) { +func (h coprHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*hashAggExec, error) { aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor) if err != nil { return nil, errors.Trace(err) @@ -384,7 +384,7 @@ func (h *rpcHandler) buildHashAgg(ctx *dagContext, executor *tipb.Executor) (*ha }, nil } -func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) { +func (h coprHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (*streamAggExec, error) { aggs, groupBys, relatedColOffsets, err := h.getAggInfo(ctx, executor) if err != nil { return nil, errors.Trace(err) @@ -406,7 +406,7 @@ func (h *rpcHandler) buildStreamAgg(ctx *dagContext, executor *tipb.Executor) (* }, nil } -func (h *rpcHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) { +func (h coprHandler) buildTopN(ctx *dagContext, executor *tipb.Executor) (*topNExec, error) { topN := executor.TopN var err error var relatedColOffsets []int @@ -664,7 +664,7 @@ func (mock *mockCopStreamClient) readBlockFromExecutor() (tipb.Chunk, bool, *cop return chunk, finish, &ran, mock.exec.Counts(), warnings, nil } -func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse { +func (h coprHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, counts []int64) *tipb.SelectResponse { selResp := &tipb.SelectResponse{ Error: toPBError(err), OutputCounts: counts, @@ -675,7 +675,7 @@ func (h *rpcHandler) initSelectResponse(err error, warnings []stmtctx.SQLWarn, c return selResp } -func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error { +func (h coprHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dagReq *tipb.DAGRequest, dagCtx *dagContext, rows [][][]byte) error { switch dagReq.EncodeType { case tipb.EncodeType_TypeDefault: h.encodeDefault(selResp, rows, dagReq.OutputOffsets) @@ -690,7 +690,7 @@ func (h *rpcHandler) fillUpData4SelectResponse(selResp *tipb.SelectResponse, dag return nil } -func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType { +func (h coprHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType { var root *tipb.Executor if len(dagCtx.dagReq.Executors) == 0 { root = dagCtx.dagReq.RootExecutor @@ -717,7 +717,7 @@ func (h *rpcHandler) constructRespSchema(dagCtx *dagContext) []*types.FieldType return schema } -func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) { +func (h coprHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte, colOrdinal []uint32) { var chunks []tipb.Chunk for i := range rows { requestedRow := dummySlice @@ -730,7 +730,7 @@ func (h *rpcHandler) encodeDefault(selResp *tipb.SelectResponse, rows [][][]byte selResp.EncodeType = tipb.EncodeType_TypeDefault } -func (h *rpcHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error { +func (h coprHandler) encodeChunk(selResp *tipb.SelectResponse, rows [][][]byte, colTypes []*types.FieldType, colOrdinal []uint32, loc *time.Location) error { var chunks []tipb.Chunk respColTypes := make([]*types.FieldType, 0, len(colOrdinal)) for _, ordinal := range colOrdinal { @@ -826,7 +826,7 @@ func toPBError(err error) *tipb.Error { } // extractKVRanges extracts kv.KeyRanges slice from a SelectRequest. -func (h *rpcHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { +func (h coprHandler) extractKVRanges(keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { for _, kran := range keyRanges { if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 { err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd()) diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go index 2ef026b408249..320f545c550ca 100644 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/golang/protobuf/proto" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -32,7 +31,6 @@ import ( "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/terror" - "github.com/pingcap/tidb/ddl/placement" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tipb/go-tipb" @@ -141,132 +139,13 @@ func convertToPbPairs(pairs []Pair) []*kvrpcpb.KvPair { return kvPairs } -// rpcHandler mocks tikv's side handler behavior. In general, you may assume +// kvHandler mocks tikv's side handler behavior. In general, you may assume // TiKV just translate the logic from Go to Rust. -type rpcHandler struct { - cluster *Cluster - mvccStore MVCCStore - - // storeID stores id for current request - storeID uint64 - // startKey is used for handling normal request. - startKey []byte - endKey []byte - // rawStartKey is used for handling coprocessor request. - rawStartKey []byte - rawEndKey []byte - // isolationLevel is used for current request. - isolationLevel kvrpcpb.IsolationLevel - resolvedLocks []uint64 +type kvHandler struct { + *Session } -func isTiFlashStore(store *metapb.Store) bool { - for _, l := range store.GetLabels() { - if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash { - return true - } - } - return false -} - -func (h *rpcHandler) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error { - ctxPeer := ctx.GetPeer() - if ctxPeer != nil && ctxPeer.GetStoreId() != h.storeID { - return &errorpb.Error{ - Message: *proto.String("store not match"), - StoreNotMatch: &errorpb.StoreNotMatch{}, - } - } - region, leaderID := h.cluster.GetRegion(ctx.GetRegionId()) - // No region found. - if region == nil { - return &errorpb.Error{ - Message: *proto.String("region not found"), - RegionNotFound: &errorpb.RegionNotFound{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - var storePeer, leaderPeer *metapb.Peer - for _, p := range region.Peers { - if p.GetStoreId() == h.storeID { - storePeer = p - } - if p.GetId() == leaderID { - leaderPeer = p - } - } - // The Store does not contain a Peer of the Region. - if storePeer == nil { - return &errorpb.Error{ - Message: *proto.String("region not found"), - RegionNotFound: &errorpb.RegionNotFound{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - // No leader. - if leaderPeer == nil { - return &errorpb.Error{ - Message: *proto.String("no leader"), - NotLeader: &errorpb.NotLeader{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - }, - } - } - // The Peer on the Store is not leader. If it's tiflash store , we pass this check. - if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(h.cluster.GetStore(storePeer.GetStoreId())) { - return &errorpb.Error{ - Message: *proto.String("not leader"), - NotLeader: &errorpb.NotLeader{ - RegionId: *proto.Uint64(ctx.GetRegionId()), - Leader: leaderPeer, - }, - } - } - // Region epoch does not match. - if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) { - nextRegion, _ := h.cluster.GetRegionByKey(region.GetEndKey()) - currentRegions := []*metapb.Region{region} - if nextRegion != nil { - currentRegions = append(currentRegions, nextRegion) - } - return &errorpb.Error{ - Message: *proto.String("epoch not match"), - EpochNotMatch: &errorpb.EpochNotMatch{ - CurrentRegions: currentRegions, - }, - } - } - h.startKey, h.endKey = region.StartKey, region.EndKey - h.isolationLevel = ctx.IsolationLevel - h.resolvedLocks = ctx.ResolvedLocks - return nil -} - -func (h *rpcHandler) checkRequestSize(size int) *errorpb.Error { - // TiKV has a limitation on raft log size. - // mocktikv has no raft inside, so we check the request's size instead. - if size >= requestMaxSize { - return &errorpb.Error{ - RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{}, - } - } - return nil -} - -func (h *rpcHandler) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error { - if err := h.checkRequestContext(ctx); err != nil { - return err - } - return h.checkRequestSize(size) -} - -func (h *rpcHandler) checkKeyInRegion(key []byte) bool { - return regionContains(h.startKey, h.endKey, NewMvccKey(key)) -} - -func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { +func (h kvHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { if !h.checkKeyInRegion(req.Key) { panic("KvGet: key not in region") } @@ -282,7 +161,7 @@ func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { } } -func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse { +func (h kvHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse { endKey := MvccKey(h.endKey).Raw() var pairs []Pair if !req.Reverse { @@ -314,7 +193,7 @@ func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanRespons } } -func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse { +func (h kvHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.PrewriteResponse { regionID := req.Context.RegionId h.cluster.handleDelay(req.StartVersion, regionID) @@ -329,7 +208,7 @@ func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.Pre } } -func (h *rpcHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse { +func (h kvHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse { for _, m := range req.Mutations { if !h.checkKeyInRegion(m.Key) { panic("KvPessimisticLock: key not in region") @@ -350,7 +229,7 @@ func simulateServerSideWaitLock(errs []error) { } } -func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse { +func (h kvHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse { for _, key := range req.Keys { if !h.checkKeyInRegion(key) { panic("KvPessimisticRollback: key not in region") @@ -362,7 +241,7 @@ func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbac } } -func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse { +func (h kvHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { panic("KvCommit: key not in region") @@ -376,7 +255,7 @@ func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitR return &resp } -func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse { +func (h kvHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.CleanupResponse { if !h.checkKeyInRegion(req.Key) { panic("KvCleanup: key not in region") } @@ -392,7 +271,7 @@ func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.Clean return &resp } -func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse { +func (h kvHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) *kvrpcpb.CheckTxnStatusResponse { if !h.checkKeyInRegion(req.PrimaryKey) { panic("KvCheckTxnStatus: key not in region") } @@ -406,7 +285,7 @@ func (h *rpcHandler) handleKvCheckTxnStatus(req *kvrpcpb.CheckTxnStatusRequest) return &resp } -func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse { +func (h kvHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse { if !h.checkKeyInRegion(req.PrimaryLock) { panic("KvTxnHeartBeat: key not in region") } @@ -419,7 +298,7 @@ func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpc return &resp } -func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse { +func (h kvHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { panic("KvBatchGet: key not in region") @@ -431,7 +310,7 @@ func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.Bat } } -func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse { +func (h kvHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpcpb.MvccGetByKeyResponse { debugger, ok := h.mvccStore.(MVCCDebugger) if !ok { return &kvrpcpb.MvccGetByKeyResponse{ @@ -447,7 +326,7 @@ func (h *rpcHandler) handleMvccGetByKey(req *kvrpcpb.MvccGetByKeyRequest) *kvrpc return &resp } -func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse { +func (h kvHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest) *kvrpcpb.MvccGetByStartTsResponse { debugger, ok := h.mvccStore.(MVCCDebugger) if !ok { return &kvrpcpb.MvccGetByStartTsResponse{ @@ -459,7 +338,7 @@ func (h *rpcHandler) handleMvccGetByStartTS(req *kvrpcpb.MvccGetByStartTsRequest return &resp } -func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse { +func (h kvHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *kvrpcpb.BatchRollbackResponse { err := h.mvccStore.Rollback(req.Keys, req.StartVersion) if err != nil { return &kvrpcpb.BatchRollbackResponse{ @@ -469,7 +348,7 @@ func (h *rpcHandler) handleKvBatchRollback(req *kvrpcpb.BatchRollbackRequest) *k return &kvrpcpb.BatchRollbackResponse{} } -func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse { +func (h kvHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.ScanLockResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() locks, err := h.mvccStore.ScanLock(startKey, endKey, req.GetMaxVersion()) @@ -483,7 +362,7 @@ func (h *rpcHandler) handleKvScanLock(req *kvrpcpb.ScanLockRequest) *kvrpcpb.Sca } } -func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse { +func (h kvHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpcpb.ResolveLockResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() err := h.mvccStore.ResolveLock(startKey, endKey, req.GetStartVersion(), req.GetCommitVersion()) @@ -495,7 +374,7 @@ func (h *rpcHandler) handleKvResolveLock(req *kvrpcpb.ResolveLockRequest) *kvrpc return &kvrpcpb.ResolveLockResponse{} } -func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { +func (h kvHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { startKey := MvccKey(h.startKey).Raw() endKey := MvccKey(h.endKey).Raw() err := h.mvccStore.GC(startKey, endKey, req.GetSafePoint()) @@ -507,7 +386,7 @@ func (h *rpcHandler) handleKvGC(req *kvrpcpb.GCRequest) *kvrpcpb.GCResponse { return &kvrpcpb.GCResponse{} } -func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse { +func (h kvHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpcpb.DeleteRangeResponse { if !h.checkKeyInRegion(req.StartKey) { panic("KvDeleteRange: key not in region") } @@ -519,7 +398,7 @@ func (h *rpcHandler) handleKvDeleteRange(req *kvrpcpb.DeleteRangeRequest) *kvrpc return &resp } -func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse { +func (h kvHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawGetResponse{ @@ -531,7 +410,7 @@ func (h *rpcHandler) handleKvRawGet(req *kvrpcpb.RawGetRequest) *kvrpcpb.RawGetR } } -func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse { +func (h kvHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpcpb.RawBatchGetResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { // TODO should we add error ? @@ -554,7 +433,7 @@ func (h *rpcHandler) handleKvRawBatchGet(req *kvrpcpb.RawBatchGetRequest) *kvrpc } } -func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse { +func (h kvHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawPutResponse{ @@ -565,7 +444,7 @@ func (h *rpcHandler) handleKvRawPut(req *kvrpcpb.RawPutRequest) *kvrpcpb.RawPutR return &kvrpcpb.RawPutResponse{} } -func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse { +func (h kvHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpcpb.RawBatchPutResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawBatchPutResponse{ @@ -582,7 +461,7 @@ func (h *rpcHandler) handleKvRawBatchPut(req *kvrpcpb.RawBatchPutRequest) *kvrpc return &kvrpcpb.RawBatchPutResponse{} } -func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse { +func (h kvHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.RawDeleteResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawDeleteResponse{ @@ -593,7 +472,7 @@ func (h *rpcHandler) handleKvRawDelete(req *kvrpcpb.RawDeleteRequest) *kvrpcpb.R return &kvrpcpb.RawDeleteResponse{} } -func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse { +func (h kvHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) *kvrpcpb.RawBatchDeleteResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawBatchDeleteResponse{ @@ -604,7 +483,7 @@ func (h *rpcHandler) handleKvRawBatchDelete(req *kvrpcpb.RawBatchDeleteRequest) return &kvrpcpb.RawBatchDeleteResponse{} } -func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse { +func (h kvHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) *kvrpcpb.RawDeleteRangeResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { return &kvrpcpb.RawDeleteRangeResponse{ @@ -615,7 +494,7 @@ func (h *rpcHandler) handleKvRawDeleteRange(req *kvrpcpb.RawDeleteRangeRequest) return &kvrpcpb.RawDeleteRangeResponse{} } -func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse { +func (h kvHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawScanResponse { rawKV, ok := h.mvccStore.(RawKV) if !ok { errStr := "not implemented" @@ -654,7 +533,7 @@ func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawSc } } -func (h *rpcHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse { +func (h kvHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse { keys := req.GetSplitKeys() resp := &kvrpcpb.SplitRegionResponse{Regions: make([]*metapb.Region, 0, len(keys)+1)} for i, key := range keys { @@ -690,7 +569,11 @@ func drainRowsFromExecutor(ctx context.Context, e executor, req *tipb.DAGRequest } } -func (h *rpcHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) { +type coprHandler struct { + *Session +} + +func (h coprHandler) handleBatchCopRequest(ctx context.Context, req *coprocessor.BatchRequest) (*mockBatchCopDataClient, error) { client := &mockBatchCopDataClient{} for _, ri := range req.Regions { cop := coprocessor.Request{ @@ -766,7 +649,7 @@ func (c *RPCClient) getAndCheckStoreByAddr(addr string) (*metapb.Store, error) { return nil, errors.New("connection refused") } -func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, error) { +func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*Session, error) { if err := checkGoContext(ctx); err != nil { return nil, err } @@ -775,13 +658,13 @@ func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, er if err != nil { return nil, err } - handler := &rpcHandler{ + session := &Session{ cluster: c.Cluster, mvccStore: c.MvccStore, // set store id for current request storeID: store.GetId(), } - return handler, nil + return session, nil } // GRPCClientFactory is the GRPC client factory. @@ -828,25 +711,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return c.redirectRequestToRPCServer(ctx, addr, req, timeout) } - handler, err := c.checkArgs(ctx, addr) + session, err := c.checkArgs(ctx, addr) if err != nil { return nil, err } switch req.Type { case tikvrpc.CmdGet: r := req.Get() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.GetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvGet(r) + resp.Resp = kvHandler{session}.handleKvGet(r) case tikvrpc.CmdScan: r := req.Scan() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ScanResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvScan(r) + resp.Resp = kvHandler{session}.handleKvScan(r) case tikvrpc.CmdPrewrite: failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { @@ -859,25 +742,25 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) r := req.Prewrite() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PrewriteResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPrewrite(r) + resp.Resp = kvHandler{session}.handleKvPrewrite(r) case tikvrpc.CmdPessimisticLock: r := req.PessimisticLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PessimisticLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPessimisticLock(r) + resp.Resp = kvHandler{session}.handleKvPessimisticLock(r) case tikvrpc.CmdPessimisticRollback: r := req.PessimisticRollback() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.PessimisticRollbackResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvPessimisticRollback(r) + resp.Resp = kvHandler{session}.handleKvPessimisticRollback(r) case tikvrpc.CmdCommit: failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { switch val.(string) { @@ -895,11 +778,11 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) r := req.Commit() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CommitResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCommit(r) + resp.Resp = kvHandler{session}.handleKvCommit(r) failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, undeterminedErr) @@ -907,122 +790,122 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R }) case tikvrpc.CmdCleanup: r := req.Cleanup() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CleanupResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCleanup(r) + resp.Resp = kvHandler{session}.handleKvCleanup(r) case tikvrpc.CmdCheckTxnStatus: r := req.CheckTxnStatus() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.CheckTxnStatusResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvCheckTxnStatus(r) + resp.Resp = kvHandler{session}.handleKvCheckTxnStatus(r) case tikvrpc.CmdTxnHeartBeat: r := req.TxnHeartBeat() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.TxnHeartBeatResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleTxnHeartBeat(r) + resp.Resp = kvHandler{session}.handleTxnHeartBeat(r) case tikvrpc.CmdBatchGet: r := req.BatchGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.BatchGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvBatchGet(r) + resp.Resp = kvHandler{session}.handleKvBatchGet(r) case tikvrpc.CmdBatchRollback: r := req.BatchRollback() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.BatchRollbackResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvBatchRollback(r) + resp.Resp = kvHandler{session}.handleKvBatchRollback(r) case tikvrpc.CmdScanLock: r := req.ScanLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ScanLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvScanLock(r) + resp.Resp = kvHandler{session}.handleKvScanLock(r) case tikvrpc.CmdResolveLock: r := req.ResolveLock() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.ResolveLockResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvResolveLock(r) + resp.Resp = kvHandler{session}.handleKvResolveLock(r) case tikvrpc.CmdGC: r := req.GC() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.GCResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvGC(r) + resp.Resp = kvHandler{session}.handleKvGC(r) case tikvrpc.CmdDeleteRange: r := req.DeleteRange() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.DeleteRangeResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvDeleteRange(r) + resp.Resp = kvHandler{session}.handleKvDeleteRange(r) case tikvrpc.CmdRawGet: r := req.RawGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawGet(r) + resp.Resp = kvHandler{session}.handleKvRawGet(r) case tikvrpc.CmdRawBatchGet: r := req.RawBatchGet() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchGetResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawBatchGet(r) + resp.Resp = kvHandler{session}.handleKvRawBatchGet(r) case tikvrpc.CmdRawPut: r := req.RawPut() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawPutResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawPut(r) + resp.Resp = kvHandler{session}.handleKvRawPut(r) case tikvrpc.CmdRawBatchPut: r := req.RawBatchPut() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchPutResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawBatchPut(r) + resp.Resp = kvHandler{session}.handleKvRawBatchPut(r) case tikvrpc.CmdRawDelete: r := req.RawDelete() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawDeleteResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawDelete(r) + resp.Resp = kvHandler{session}.handleKvRawDelete(r) case tikvrpc.CmdRawBatchDelete: r := req.RawBatchDelete() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawBatchDeleteResponse{RegionError: err} } - resp.Resp = handler.handleKvRawBatchDelete(r) + resp.Resp = kvHandler{session}.handleKvRawBatchDelete(r) case tikvrpc.CmdRawDeleteRange: r := req.RawDeleteRange() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawDeleteRangeResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawDeleteRange(r) + resp.Resp = kvHandler{session}.handleKvRawDeleteRange(r) case tikvrpc.CmdRawScan: r := req.RawScan() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.RawScanResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleKvRawScan(r) + resp.Resp = kvHandler{session}.handleKvRawScan(r) case tikvrpc.CmdUnsafeDestroyRange: panic("unimplemented") case tikvrpc.CmdRegisterLockObserver: @@ -1035,20 +918,20 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return nil, errors.New("unimplemented") case tikvrpc.CmdCop: r := req.Cop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &coprocessor.Response{RegionError: err} return resp, nil } - handler.rawStartKey = MvccKey(handler.startKey).Raw() - handler.rawEndKey = MvccKey(handler.endKey).Raw() + session.rawStartKey = MvccKey(session.startKey).Raw() + session.rawEndKey = MvccKey(session.endKey).Raw() var res *coprocessor.Response switch r.GetTp() { case kv.ReqTypeDAG: - res = handler.handleCopDAGRequest(r) + res = coprHandler{session}.handleCopDAGRequest(r) case kv.ReqTypeAnalyze: - res = handler.handleCopAnalyzeRequest(r) + res = coprHandler{session}.handleCopAnalyzeRequest(r) case kv.ReqTypeChecksum: - res = handler.handleCopChecksumRequest(r) + res = coprHandler{session}.handleCopChecksumRequest(r) default: panic(fmt.Sprintf("unknown coprocessor request type: %v", r.GetTp())) } @@ -1066,7 +949,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } }) r := req.BatchCop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &tikvrpc.BatchCopStreamResponse{ Tikv_BatchCoprocessorClient: &mockBathCopErrClient{Error: err}, BatchResponse: &coprocessor.BatchResponse{ @@ -1076,7 +959,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return resp, nil } ctx1, cancel := context.WithCancel(ctx) - batchCopStream, err := handler.handleBatchCopRequest(ctx1, r) + batchCopStream, err := coprHandler{session}.handleBatchCopRequest(ctx1, r) if err != nil { cancel() return nil, errors.Trace(err) @@ -1094,7 +977,7 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp = batchResp case tikvrpc.CmdCopStream: r := req.Cop() - if err := handler.checkRequestContext(reqCtx); err != nil { + if err := session.checkRequestContext(reqCtx); err != nil { resp.Resp = &tikvrpc.CopStreamResponse{ Tikv_CoprocessorStreamClient: &mockCopStreamErrClient{Error: err}, Response: &coprocessor.Response{ @@ -1103,10 +986,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R } return resp, nil } - handler.rawStartKey = MvccKey(handler.startKey).Raw() - handler.rawEndKey = MvccKey(handler.endKey).Raw() + session.rawStartKey = MvccKey(session.startKey).Raw() + session.rawEndKey = MvccKey(session.endKey).Raw() ctx1, cancel := context.WithCancel(ctx) - copStream, err := handler.handleCopStream(ctx1, r) + copStream, err := coprHandler{session}.handleCopStream(ctx1, r) if err != nil { cancel() return nil, errors.Trace(err) @@ -1127,31 +1010,31 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp = streamResp case tikvrpc.CmdMvccGetByKey: r := req.MvccGetByKey() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.MvccGetByKeyResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleMvccGetByKey(r) + resp.Resp = kvHandler{session}.handleMvccGetByKey(r) case tikvrpc.CmdMvccGetByStartTs: r := req.MvccGetByStartTs() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.MvccGetByStartTsResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleMvccGetByStartTS(r) + resp.Resp = kvHandler{session}.handleMvccGetByStartTS(r) case tikvrpc.CmdSplitRegion: r := req.SplitRegion() - if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + if err := session.checkRequest(reqCtx, r.Size()); err != nil { resp.Resp = &kvrpcpb.SplitRegionResponse{RegionError: err} return resp, nil } - resp.Resp = handler.handleSplitRegion(r) + resp.Resp = kvHandler{session}.handleSplitRegion(r) // DebugGetRegionProperties is for fast analyze in mock tikv. case tikvrpc.CmdDebugGetRegionProperties: r := req.DebugGetRegionProperties() region, _ := c.Cluster.GetRegion(r.RegionId) var reqCtx kvrpcpb.Context - scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{ + scanResp := kvHandler{session}.handleKvScan(&kvrpcpb.ScanRequest{ Context: &reqCtx, StartKey: MvccKey(region.StartKey).Raw(), EndKey: MvccKey(region.EndKey).Raw(), diff --git a/store/mockstore/mocktikv/session.go b/store/mockstore/mocktikv/session.go new file mode 100644 index 0000000000000..4d5e8b61678d8 --- /dev/null +++ b/store/mockstore/mocktikv/session.go @@ -0,0 +1,146 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mocktikv + +import ( + "github.com/gogo/protobuf/proto" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/ddl/placement" +) + +// Session stores session scope rpc data. +type Session struct { + cluster *Cluster + mvccStore MVCCStore + + // storeID stores id for current request + storeID uint64 + // startKey is used for handling normal request. + startKey []byte + endKey []byte + // rawStartKey is used for handling coprocessor request. + rawStartKey []byte + rawEndKey []byte + // isolationLevel is used for current request. + isolationLevel kvrpcpb.IsolationLevel + resolvedLocks []uint64 +} + +func (s *Session) checkRequestContext(ctx *kvrpcpb.Context) *errorpb.Error { + ctxPeer := ctx.GetPeer() + if ctxPeer != nil && ctxPeer.GetStoreId() != s.storeID { + return &errorpb.Error{ + Message: *proto.String("store not match"), + StoreNotMatch: &errorpb.StoreNotMatch{}, + } + } + region, leaderID := s.cluster.GetRegion(ctx.GetRegionId()) + // No region found. + if region == nil { + return &errorpb.Error{ + Message: *proto.String("region not found"), + RegionNotFound: &errorpb.RegionNotFound{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + var storePeer, leaderPeer *metapb.Peer + for _, p := range region.Peers { + if p.GetStoreId() == s.storeID { + storePeer = p + } + if p.GetId() == leaderID { + leaderPeer = p + } + } + // The Store does not contain a Peer of the Region. + if storePeer == nil { + return &errorpb.Error{ + Message: *proto.String("region not found"), + RegionNotFound: &errorpb.RegionNotFound{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + // No leader. + if leaderPeer == nil { + return &errorpb.Error{ + Message: *proto.String("no leader"), + NotLeader: &errorpb.NotLeader{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + }, + } + } + // The Peer on the Store is not leader. If it's tiflash store , we pass this check. + if storePeer.GetId() != leaderPeer.GetId() && !isTiFlashStore(s.cluster.GetStore(storePeer.GetStoreId())) { + return &errorpb.Error{ + Message: *proto.String("not leader"), + NotLeader: &errorpb.NotLeader{ + RegionId: *proto.Uint64(ctx.GetRegionId()), + Leader: leaderPeer, + }, + } + } + // Region epoch does not match. + if !proto.Equal(region.GetRegionEpoch(), ctx.GetRegionEpoch()) { + nextRegion, _ := s.cluster.GetRegionByKey(region.GetEndKey()) + currentRegions := []*metapb.Region{region} + if nextRegion != nil { + currentRegions = append(currentRegions, nextRegion) + } + return &errorpb.Error{ + Message: *proto.String("epoch not match"), + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: currentRegions, + }, + } + } + s.startKey, s.endKey = region.StartKey, region.EndKey + s.isolationLevel = ctx.IsolationLevel + s.resolvedLocks = ctx.ResolvedLocks + return nil +} + +func (s *Session) checkRequestSize(size int) *errorpb.Error { + // TiKV has a limitation on raft log size. + // mocktikv has no raft inside, so we check the request's size instead. + if size >= requestMaxSize { + return &errorpb.Error{ + RaftEntryTooLarge: &errorpb.RaftEntryTooLarge{}, + } + } + return nil +} + +func (s *Session) checkRequest(ctx *kvrpcpb.Context, size int) *errorpb.Error { + if err := s.checkRequestContext(ctx); err != nil { + return err + } + return s.checkRequestSize(size) +} + +func (s *Session) checkKeyInRegion(key []byte) bool { + return regionContains(s.startKey, s.endKey, NewMvccKey(key)) +} + +func isTiFlashStore(store *metapb.Store) bool { + for _, l := range store.GetLabels() { + if l.GetKey() == placement.EngineLabelKey && l.GetValue() == placement.EngineLabelTiFlash { + return true + } + } + return false +}