From b27cb07173a21c276783ec9595c5db057f623f14 Mon Sep 17 00:00:00 2001 From: ShuNing Date: Fri, 21 Jul 2023 15:53:26 +0800 Subject: [PATCH] resourcecontrol: fix nil pointer (#900) --- internal/resourcecontrol/resource_control.go | 5 +++-- internal/resourcecontrol/resource_control_test.go | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/internal/resourcecontrol/resource_control.go b/internal/resourcecontrol/resource_control.go index 15d631058..23a48b3c2 100644 --- a/internal/resourcecontrol/resource_control.go +++ b/internal/resourcecontrol/resource_control.go @@ -50,8 +50,9 @@ func MakeRequestInfo(req *tikvrpc.Request) *RequestInfo { bypass = true } } + storeID := req.Context.GetPeer().GetStoreId() if !req.IsTxnWriteRequest() && !req.IsRawWriteRequest() { - return &RequestInfo{writeBytes: -1, storeID: req.Context.Peer.StoreId, bypass: bypass} + return &RequestInfo{writeBytes: -1, storeID: storeID, bypass: bypass} } var writeBytes int64 @@ -69,7 +70,7 @@ func MakeRequestInfo(req *tikvrpc.Request) *RequestInfo { writeBytes += int64(len(k)) } } - return &RequestInfo{writeBytes: writeBytes, storeID: req.Context.Peer.StoreId, replicaNumber: req.ReplicaNumber, bypass: bypass} + return &RequestInfo{writeBytes: writeBytes, storeID: storeID, replicaNumber: req.ReplicaNumber, bypass: bypass} } // IsWrite returns whether the request is a write request. diff --git a/internal/resourcecontrol/resource_control_test.go b/internal/resourcecontrol/resource_control_test.go index dba077e26..25f6f72aa 100644 --- a/internal/resourcecontrol/resource_control_test.go +++ b/internal/resourcecontrol/resource_control_test.go @@ -37,4 +37,12 @@ func TestMakeRequestInfo(t *testing.T) { assert.Equal(t, uint64(3), info.WriteBytes()) assert.False(t, info.Bypass()) assert.Equal(t, uint64(3), info.StoreID()) + + // Test Nil Peer in Context + req = &tikvrpc.Request{Type: tikvrpc.CmdCommit, Req: commitReq, ReplicaNumber: 2, Context: kvrpcpb.Context{}} + info = MakeRequestInfo(req) + assert.True(t, info.IsWrite()) + assert.Equal(t, uint64(3), info.WriteBytes()) + assert.False(t, info.Bypass()) + assert.Equal(t, uint64(0), info.StoreID()) }