From 47e10f6acda8b7a551c6c5905327a3028cfa99d9 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Wed, 13 Jan 2021 14:20:57 +0100 Subject: [PATCH 1/4] les: refactored server handler --- les/handler_test.go | 36 +- les/peer.go | 4 +- les/protocol.go | 8 - les/server_handler.go | 792 ++++++++--------------------------------- les/server_requests.go | 538 ++++++++++++++++++++++++++++ 5 files changed, 710 insertions(+), 668 deletions(-) create mode 100644 les/server_requests.go diff --git a/les/handler_test.go b/les/handler_test.go index f5cbeb8efc09..aca078d9c5df 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -64,27 +64,27 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Create a batch of tests for various scenarios limit := uint64(MaxHeaderFetch) tests := []struct { - query *getBlockHeadersData // The query to execute for header retrieval - expect []common.Hash // The hashes of the block whose headers are expected + query *GetBlockHeadersReq // The query to execute for header retrieval + expect []common.Hash // The hashes of the block whose headers are expected }{ // A single random block should be retrievable by hash and number too { - &getBlockHeadersData{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, }, // Multiple headers should be retrievable in both directions { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 + 1).Hash(), bc.GetBlockByNumber(limit/2 + 2).Hash(), }, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 - 1).Hash(), @@ -93,14 +93,14 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Multiple headers with skip lists should be retrievable { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 + 4).Hash(), bc.GetBlockByNumber(limit/2 + 8).Hash(), }, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 - 4).Hash(), @@ -109,26 +109,26 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // The chain endpoints should be retrievable { - &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: 0}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(0).Hash()}, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1}, []common.Hash{bc.CurrentBlock().Hash()}, }, // Ensure protocol limits are honored //{ - // &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + // &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, // []common.Hash{}, //}, // Check that requesting more than available is handled gracefully { - &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), bc.GetBlockByNumber(bc.CurrentBlock().NumberU64()).Hash(), }, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(4).Hash(), bc.GetBlockByNumber(0).Hash(), @@ -136,13 +136,13 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Check that requesting more than available is handled gracefully, even if mid skip { - &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 1).Hash(), }, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(4).Hash(), bc.GetBlockByNumber(1).Hash(), @@ -150,10 +150,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Check that non existing headers aren't returned { - &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, []common.Hash{}, }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1}, + &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1}, []common.Hash{}, }, } @@ -609,7 +609,7 @@ func TestStopResumeLes3(t *testing.T) { header := server.handler.blockchain.CurrentHeader() req := func() { reqID++ - sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, &getBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) + sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) } for i := 1; i <= 5; i++ { // send requests while we still have enough buffer and expect a response diff --git a/les/peer.go b/les/peer.go index 0e2ed52c124b..a43c3bc6a927 100644 --- a/les/peer.go +++ b/les/peer.go @@ -432,14 +432,14 @@ func (p *serverPeer) sendRequest(msgcode, reqID uint64, data interface{}, amount // specified header query, based on the hash of an origin block. func (p *serverPeer) requestHeadersByHash(reqID uint64, origin common.Hash, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) - return p.sendRequest(GetBlockHeadersMsg, reqID, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) + return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) } // requestHeadersByNumber fetches a batch of blocks' headers corresponding to the // specified header query, based on the number of an origin block. func (p *serverPeer) requestHeadersByNumber(reqID, origin uint64, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) - return p.sendRequest(GetBlockHeadersMsg, reqID, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) + return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) } // requestBodies fetches a batch of blocks' bodies corresponding to the hashes diff --git a/les/protocol.go b/les/protocol.go index 39d9f5152fd8..a201bb2d9522 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -229,14 +229,6 @@ type blockInfo struct { Td *big.Int // Total difficulty of one particular block being announced } -// getBlockHeadersData represents a block header query. -type getBlockHeadersData struct { - Origin hashOrNumber // Block from which to retrieve headers - Amount uint64 // Maximum number of headers to retrieve - Skip uint64 // Blocks to skip between consecutive headers - Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) -} - // hashOrNumber is a combined field for specifying an origin block. type hashOrNumber struct { Hash common.Hash // Block hash from which to retrieve headers (excludes Number) diff --git a/les/server_handler.go b/les/server_handler.go index bec4206e2b85..820b688b4b2e 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -18,8 +18,6 @@ package les import ( "crypto/ecdsa" - "encoding/binary" - "encoding/json" "errors" "sync" "sync/atomic" @@ -223,648 +221,165 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { } defer msg.Discard() - var ( - maxCost uint64 - task *servingTask - ) p.responseCount++ responseCount := p.responseCount - // accept returns an indicator whether the request can be served. - // If so, deduct the max cost from the flow control buffer. - accept := func(reqID, reqCnt, maxCnt uint64) bool { - // Short circuit if the peer is already frozen or the request is invalid. - inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) - if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt { - p.fcClient.OneTimeCost(inSizeCost) - return false - } - // Prepaid max cost units before request been serving. - maxCost = p.fcCosts.getMaxCost(msg.Code, reqCnt) - accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost) - if !accepted { - p.freeze() - p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) - p.fcClient.OneTimeCost(inSizeCost) - return false - } - // Create a multi-stage task, estimate the time it takes for the task to - // execute, and cache it in the request service queue. - factor := h.server.costTracker.globalFactor() - if factor < 0.001 { - factor = 1 - p.Log().Error("Invalid global cost factor", "factor", factor) - } - maxTime := uint64(float64(maxCost) / factor) - task = h.server.servingQueue.newTask(p, maxTime, priority) - if task.start() { - return true - } - p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost) - return false - } - // sendResponse sends back the response and updates the flow control statistic. - sendResponse := func(reqID, amount uint64, reply *reply, servingTime uint64) { - p.responseLock.Lock() - defer p.responseLock.Unlock() - // Short circuit if the client is already frozen. - if p.isFrozen() { - realCost := h.server.costTracker.realCost(servingTime, msg.Size, 0) - p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) - return - } - // Positive correction buffer value with real cost. - var replySize uint32 - if reply != nil { - replySize = reply.size() - } - var realCost uint64 - if h.server.costTracker.testing { - realCost = maxCost // Assign a fake cost for testing purpose - } else { - realCost = h.server.costTracker.realCost(servingTime, msg.Size, replySize) - if realCost > maxCost { - realCost = maxCost - } - } - bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) - if amount != 0 { - // Feed cost tracker request serving statistic. - h.server.costTracker.updateStats(msg.Code, amount, servingTime, realCost) - // Reduce priority "balance" for the specific peer. - p.balance.RequestServed(realCost) - } - if reply != nil { - p.queueSend(func() { - if err := reply.send(bv); err != nil { - select { - case p.errCh <- err: - default: - } - } - }) + var ( + req struct { + ID uint64 + Data rlp.RawValue } + hreq handlerRequest + decodeErr error + ) + if err := msg.Decode(&req); err != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "%v: %v", msg, err) } + switch msg.Code { case GetBlockHeadersMsg: p.Log().Trace("Received block header request") - if metrics.EnabledExpensive { - miscInHeaderPacketsMeter.Mark(1) - miscInHeaderTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Query getBlockHeadersData - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "%v: %v", msg, err) - } - query := req.Query - if accept(req.ReqID, query.Amount, MaxHeaderFetch) { - wg.Add(1) - go func() { - defer wg.Done() - hashMode := query.Origin.Hash != (common.Hash{}) - first := true - maxNonCanonical := uint64(100) - - // Gather headers until the fetch or network limits is reached - var ( - bytes common.StorageSize - headers []*types.Header - unknown bool - ) - for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit { - if !first && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Retrieve the next header satisfying the query - var origin *types.Header - if hashMode { - if first { - origin = h.blockchain.GetHeaderByHash(query.Origin.Hash) - if origin != nil { - query.Origin.Number = origin.Number.Uint64() - } - } else { - origin = h.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) - } - } else { - origin = h.blockchain.GetHeaderByNumber(query.Origin.Number) - } - if origin == nil { - break - } - headers = append(headers, origin) - bytes += estHeaderRlpSize - - // Advance to the next header of the query - switch { - case hashMode && query.Reverse: - // Hash based traversal towards the genesis block - ancestor := query.Skip + 1 - if ancestor == 0 { - unknown = true - } else { - query.Origin.Hash, query.Origin.Number = h.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) - unknown = query.Origin.Hash == common.Hash{} - } - case hashMode && !query.Reverse: - // Hash based traversal towards the leaf block - var ( - current = origin.Number.Uint64() - next = current + query.Skip + 1 - ) - if next <= current { - infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") - p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) - unknown = true - } else { - if header := h.blockchain.GetHeaderByNumber(next); header != nil { - nextHash := header.Hash() - expOldHash, _ := h.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) - if expOldHash == query.Origin.Hash { - query.Origin.Hash, query.Origin.Number = nextHash, next - } else { - unknown = true - } - } else { - unknown = true - } - } - case query.Reverse: - // Number based traversal towards the genesis block - if query.Origin.Number >= query.Skip+1 { - query.Origin.Number -= query.Skip + 1 - } else { - unknown = true - } - - case !query.Reverse: - // Number based traversal towards the leaf block - query.Origin.Number += query.Skip + 1 - } - first = false - } - reply := p.replyBlockHeaders(req.ReqID, headers) - sendResponse(req.ReqID, query.Amount, reply, task.done()) - if metrics.EnabledExpensive { - miscOutHeaderPacketsMeter.Mark(1) - miscOutHeaderTrafficMeter.Mark(int64(reply.size())) - miscServingTimeHeaderTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := &GetBlockHeadersReq{} + decodeErr = rlp.DecodeBytes(req.Data, r) + hreq = r case GetBlockBodiesMsg: p.Log().Trace("Received block bodies request") - if metrics.EnabledExpensive { - miscInBodyPacketsMeter.Mark(1) - miscInBodyTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - var ( - bytes int - bodies []rlp.RawValue - ) - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxBodyFetch) { - wg.Add(1) - go func() { - defer wg.Done() - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if bytes >= softResponseLimit { - break - } - body := h.blockchain.GetBodyRLP(hash) - if body == nil { - p.bumpInvalid() - continue - } - bodies = append(bodies, body) - bytes += len(body) - } - reply := p.replyBlockBodiesRLP(req.ReqID, bodies) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutBodyPacketsMeter.Mark(1) - miscOutBodyTrafficMeter.Mark(int64(reply.size())) - miscServingTimeBodyTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetBlockBodiesReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case GetCodeMsg: p.Log().Trace("Received code request") - if metrics.EnabledExpensive { - miscInCodePacketsMeter.Mark(1) - miscInCodeTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Reqs []CodeReq - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - var ( - bytes int - data [][]byte - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxCodeFetch) { - wg.Add(1) - go func() { - defer wg.Done() - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Look up the root hash belonging to the request - header := h.blockchain.GetHeaderByHash(request.BHash) - if header == nil { - p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash) - p.bumpInvalid() - continue - } - // Refuse to search stale state data in the database since looking for - // a non-exist key is kind of expensive. - local := h.blockchain.CurrentHeader().Number.Uint64() - if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) - p.bumpInvalid() - continue - } - triedb := h.blockchain.StateCache().TrieDB() - - account, err := h.getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) - p.bumpInvalid() - continue - } - code, err := h.blockchain.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) - if err != nil { - p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) - continue - } - // Accumulate the code and abort if enough data was retrieved - data = append(data, code) - if bytes += len(code); bytes >= softResponseLimit { - break - } - } - reply := p.replyCode(req.ReqID, data) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutCodePacketsMeter.Mark(1) - miscOutCodeTrafficMeter.Mark(int64(reply.size())) - miscServingTimeCodeTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetCodeReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case GetReceiptsMsg: p.Log().Trace("Received receipts request") - if metrics.EnabledExpensive { - miscInReceiptPacketsMeter.Mark(1) - miscInReceiptTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - var ( - bytes int - receipts []rlp.RawValue - ) - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxReceiptFetch) { - wg.Add(1) - go func() { - defer wg.Done() - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if bytes >= softResponseLimit { - break - } - // Retrieve the requested block's receipts, skipping if unknown to us - results := h.blockchain.GetReceiptsByHash(hash) - if results == nil { - if header := h.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { - p.bumpInvalid() - continue - } - } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(results); err != nil { - log.Error("Failed to encode receipt", "err", err) - } else { - receipts = append(receipts, encoded) - bytes += len(encoded) - } - } - reply := p.replyReceiptsRLP(req.ReqID, receipts) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutReceiptPacketsMeter.Mark(1) - miscOutReceiptTrafficMeter.Mark(int64(reply.size())) - miscServingTimeReceiptTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetReceiptsReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case GetProofsV2Msg: p.Log().Trace("Received les/2 proofs request") - if metrics.EnabledExpensive { - miscInTrieProofPacketsMeter.Mark(1) - miscInTrieProofTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Reqs []ProofReq - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - lastBHash common.Hash - root common.Hash - header *types.Header - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { - wg.Add(1) - go func() { - defer wg.Done() - nodes := light.NewNodeSet() - - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // Look up the root hash belonging to the request - if request.BHash != lastBHash { - root, lastBHash = common.Hash{}, request.BHash - - if header = h.blockchain.GetHeaderByHash(request.BHash); header == nil { - p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash) - p.bumpInvalid() - continue - } - // Refuse to search stale state data in the database since looking for - // a non-exist key is kind of expensive. - local := h.blockchain.CurrentHeader().Number.Uint64() - if !h.server.archiveMode && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) - p.bumpInvalid() - continue - } - root = header.Root - } - // If a header lookup failed (non existent), ignore subsequent requests for the same header - if root == (common.Hash{}) { - p.bumpInvalid() - continue - } - // Open the account or storage trie for the request - statedb := h.blockchain.StateCache() - - var trie state.Trie - switch len(request.AccKey) { - case 0: - // No account key specified, open an account trie - trie, err = statedb.OpenTrie(root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) - continue - } - default: - // Account key specified, open a storage trie - account, err := h.getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) - p.bumpInvalid() - continue - } - trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) - continue - } - } - // Prove the user's request from the account or stroage trie - if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { - p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) - continue - } - if nodes.DataSize() >= softResponseLimit { - break - } - } - reply := p.replyProofsV2(req.ReqID, nodes.NodeList()) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutTrieProofPacketsMeter.Mark(1) - miscOutTrieProofTrafficMeter.Mark(int64(reply.size())) - miscServingTimeTrieProofTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetProofsReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case GetHelperTrieProofsMsg: p.Log().Trace("Received helper trie proof request") - if metrics.EnabledExpensive { - miscInHelperTriePacketsMeter.Mark(1) - miscInHelperTrieTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Reqs []HelperTrieReq - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Gather state data until the fetch or network limits is reached - var ( - auxBytes int - auxData [][]byte - ) - reqCnt := len(req.Reqs) - if accept(req.ReqID, uint64(reqCnt), MaxHelperTrieProofsFetch) { - wg.Add(1) - go func() { - defer wg.Done() - var ( - lastIdx uint64 - lastType uint - root common.Hash - auxTrie *trie.Trie - ) - nodes := light.NewNodeSet() - for i, request := range req.Reqs { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { - auxTrie, lastType, lastIdx = nil, request.Type, request.TrieIdx - - var prefix string - if root, prefix = h.getHelperTrie(request.Type, request.TrieIdx); root != (common.Hash{}) { - auxTrie, _ = trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix))) - } - } - if auxTrie == nil { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - // TODO(rjl493456442) short circuit if the proving is failed. - // The original client side code has a dirty hack to retrieve - // the headers with no valid proof. Keep the compatibility for - // legacy les protocol and drop this hack when the les2/3 are - // not supported. - err := auxTrie.Prove(request.Key, request.FromLevel, nodes) - if p.version >= lpv4 && err != nil { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - if request.AuxReq == htAuxHeader { - data := h.getAuxiliaryHeaders(request) - auxData = append(auxData, data) - auxBytes += len(data) - } - if nodes.DataSize()+auxBytes >= softResponseLimit { - break - } - } - reply := p.replyHelperTrieProofs(req.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutHelperTriePacketsMeter.Mark(1) - miscOutHelperTrieTrafficMeter.Mark(int64(reply.size())) - miscServingTimeHelperTrieTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetHelperTrieProofsReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case SendTxV2Msg: p.Log().Trace("Received new transactions") - if metrics.EnabledExpensive { - miscInTxsPacketsMeter.Mark(1) - miscInTxsTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Txs []*types.Transaction - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - reqCnt := len(req.Txs) - if accept(req.ReqID, uint64(reqCnt), MaxTxSend) { - wg.Add(1) - go func() { - defer wg.Done() - stats := make([]light.TxStatus, len(req.Txs)) - for i, tx := range req.Txs { - if i != 0 && !task.waitOrStop() { - return - } - hash := tx.Hash() - stats[i] = h.txStatus(hash) - if stats[i].Status == core.TxStatusUnknown { - addFn := h.txpool.AddRemotes - // Add txs synchronously for testing purpose - if h.addTxsSync { - addFn = h.txpool.AddRemotesSync - } - if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { - stats[i].Error = errs[0].Error() - continue - } - stats[i] = h.txStatus(hash) - } - } - reply := p.replyTxStatus(req.ReqID, stats) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutTxsPacketsMeter.Mark(1) - miscOutTxsTrafficMeter.Mark(int64(reply.size())) - miscServingTimeTxTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := SendTxReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r case GetTxStatusMsg: p.Log().Trace("Received transaction status query request") - if metrics.EnabledExpensive { - miscInTxStatusPacketsMeter.Mark(1) - miscInTxStatusTrafficMeter.Mark(int64(msg.Size)) - } - var req struct { - ReqID uint64 - Hashes []common.Hash - } - if err := msg.Decode(&req); err != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - reqCnt := len(req.Hashes) - if accept(req.ReqID, uint64(reqCnt), MaxTxStatus) { - wg.Add(1) - go func() { - defer wg.Done() - stats := make([]light.TxStatus, len(req.Hashes)) - for i, hash := range req.Hashes { - if i != 0 && !task.waitOrStop() { - sendResponse(req.ReqID, 0, nil, task.servingTime) - return - } - stats[i] = h.txStatus(hash) - } - reply := p.replyTxStatus(req.ReqID, stats) - sendResponse(req.ReqID, uint64(reqCnt), reply, task.done()) - if metrics.EnabledExpensive { - miscOutTxStatusPacketsMeter.Mark(1) - miscOutTxStatusTrafficMeter.Mark(int64(reply.size())) - miscServingTimeTxStatusTimer.Update(time.Duration(task.servingTime)) - } - }() - } + r := GetTxStatusReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) + hreq = r default: p.Log().Trace("Received invalid message", "code", msg.Code) clientErrorMeter.Mark(1) return errResp(ErrInvalidMsgCode, "%v", msg.Code) } + + if metrics.EnabledExpensive { + hreq.InMetrics(int64(msg.Size)) + } + if decodeErr != nil { + clientErrorMeter.Mark(1) + return errResp(ErrDecode, "%v: %v", msg, decodeErr) + } + reqCnt, maxCnt := hreq.ReqAmount() + + // Short circuit if the peer is already frozen or the request is invalid. + inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) + if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt { + p.fcClient.OneTimeCost(inSizeCost) + return nil + } + // Prepaid max cost units before request been serving. + maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt) + accepted, bufShort, priority := p.fcClient.AcceptRequest(req.ID, responseCount, maxCost) + if !accepted { + p.freeze() + p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) + p.fcClient.OneTimeCost(inSizeCost) + return nil + } + // Create a multi-stage task, estimate the time it takes for the task to + // execute, and cache it in the request service queue. + factor := h.server.costTracker.globalFactor() + if factor < 0.001 { + factor = 1 + p.Log().Error("Invalid global cost factor", "factor", factor) + } + maxTime := uint64(float64(maxCost) / factor) + task := h.server.servingQueue.newTask(p, maxTime, priority) + if task.start() { + wg.Add(1) + go func() { + defer wg.Done() + reply := hreq.Serve(h, req.ID, p, task.waitOrStop) + if reply != nil { + task.done() + } + + p.responseLock.Lock() + defer p.responseLock.Unlock() + + // Short circuit if the client is already frozen. + if p.isFrozen() { + realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0) + p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, realCost) + return + } + // Positive correction buffer value with real cost. + var replySize uint32 + if reply != nil { + replySize = reply.size() + } + var realCost uint64 + if h.server.costTracker.testing { + realCost = maxCost // Assign a fake cost for testing purpose + } else { + realCost = h.server.costTracker.realCost(task.servingTime, msg.Size, replySize) + if realCost > maxCost { + realCost = maxCost + } + } + bv := p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, realCost) + if reply != nil { + // Feed cost tracker request serving statistic. + h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost) + // Reduce priority "balance" for the specific peer. + p.balance.RequestServed(realCost) + p.queueSend(func() { + if err := reply.send(bv); err != nil { + select { + case p.errCh <- err: + default: + } + } + }) + if metrics.EnabledExpensive { + hreq.OutMetrics(int64(replySize), time.Duration(task.servingTime)) + } + } + }() + } else { + p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, inSizeCost) + } + // If the client has made too much invalid request(e.g. request a non-existent data), // reject them to prevent SPAM attack. if p.getInvalid() > maxRequestErrors { @@ -874,8 +389,24 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { return nil } +func (h *serverHandler) BlockChain() *core.BlockChain { + return h.blockchain +} + +func (h *serverHandler) TxPool() *core.TxPool { + return h.txpool +} + +func (h *serverHandler) ArchiveMode() bool { + return h.server.archiveMode +} + +func (h *serverHandler) AddTxsSync() bool { + return h.addTxsSync +} + // getAccount retrieves an account from the state based on root. -func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) { +func getAccount(triedb *trie.Database, root, hash common.Hash) (state.Account, error) { trie, err := trie.New(root, triedb) if err != nil { return state.Account{}, err @@ -892,43 +423,24 @@ func (h *serverHandler) getAccount(triedb *trie.Database, root, hash common.Hash } // getHelperTrie returns the post-processed trie root for the given trie ID and section index -func (h *serverHandler) getHelperTrie(typ uint, index uint64) (common.Hash, string) { +func (h *serverHandler) GetHelperTrie(typ uint, index uint64) *trie.Trie { + var ( + root common.Hash + prefix string + ) switch typ { case htCanonical: sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.ChtSize-1) - return light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix + root, prefix = light.GetChtRoot(h.chainDb, index, sectionHead), light.ChtTablePrefix case htBloomBits: sectionHead := rawdb.ReadCanonicalHash(h.chainDb, (index+1)*h.server.iConfig.BloomTrieSize-1) - return light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix - } - return common.Hash{}, "" -} - -// getAuxiliaryHeaders returns requested auxiliary headers for the CHT request. -func (h *serverHandler) getAuxiliaryHeaders(req HelperTrieReq) []byte { - if req.Type == htCanonical && req.AuxReq == htAuxHeader && len(req.Key) == 8 { - blockNum := binary.BigEndian.Uint64(req.Key) - hash := rawdb.ReadCanonicalHash(h.chainDb, blockNum) - return rawdb.ReadHeaderRLP(h.chainDb, hash, blockNum) + root, prefix = light.GetBloomTrieRoot(h.chainDb, index, sectionHead), light.BloomTrieTablePrefix } - return nil -} - -// txStatus returns the status of a specified transaction. -func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus { - var stat light.TxStatus - // Looking the transaction in txpool first. - stat.Status = h.txpool.Status([]common.Hash{hash})[0] - - // If the transaction is unknown to the pool, try looking it up locally. - if stat.Status == core.TxStatusUnknown { - lookup := h.blockchain.GetTransactionLookup(hash) - if lookup != nil { - stat.Status = core.TxStatusIncluded - stat.Lookup = lookup - } + if root == (common.Hash{}) { + return nil } - return stat + trie, _ := trie.New(root, trie.NewDatabase(rawdb.NewTable(h.chainDb, prefix))) + return trie } // broadcastLoop broadcasts new block information to all connected light diff --git a/les/server_requests.go b/les/server_requests.go new file mode 100644 index 000000000000..3da36c801a6b --- /dev/null +++ b/les/server_requests.go @@ -0,0 +1,538 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package les + +import ( + "encoding/binary" + "encoding/json" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +type serverBackend interface { + ArchiveMode() bool + AddTxsSync() bool + BlockChain() *core.BlockChain + TxPool() *core.TxPool + GetHelperTrie(typ uint, index uint64) *trie.Trie +} + +type handlerRequest interface { + ReqAmount() (uint64, uint64) + InMetrics(size int64) + OutMetrics(size int64, servingTime time.Duration) + Serve(b serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply +} + +type GetBlockHeadersReq struct { + Origin hashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +func (r *GetBlockHeadersReq) ReqAmount() (uint64, uint64) { return r.Amount, MaxHeaderFetch } + +func (r *GetBlockHeadersReq) InMetrics(size int64) { + miscInHeaderPacketsMeter.Mark(1) + miscInHeaderTrafficMeter.Mark(size) +} + +func (r *GetBlockHeadersReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutHeaderPacketsMeter.Mark(1) + miscOutHeaderTrafficMeter.Mark(size) + miscServingTimeHeaderTimer.Update(servingTime) +} + +func (r *GetBlockHeadersReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + // Gather headers until the fetch or network limits is reached + var ( + bc = backend.BlockChain() + hashMode = r.Origin.Hash != (common.Hash{}) + first = true + maxNonCanonical = uint64(100) + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(r.Amount) && bytes < softResponseLimit { + if !first && !waitOrStop() { + return nil + } + // Retrieve the next header satisfying the r + var origin *types.Header + if hashMode { + if first { + origin = bc.GetHeaderByHash(r.Origin.Hash) + if origin != nil { + r.Origin.Number = origin.Number.Uint64() + } + } else { + origin = bc.GetHeader(r.Origin.Hash, r.Origin.Number) + } + } else { + origin = bc.GetHeaderByNumber(r.Origin.Number) + } + if origin == nil { + break + } + headers = append(headers, origin) + bytes += estHeaderRlpSize + + // Advance to the next header of the r + switch { + case hashMode && r.Reverse: + // Hash based traversal towards the genesis block + ancestor := r.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + r.Origin.Hash, r.Origin.Number = bc.GetAncestor(r.Origin.Hash, r.Origin.Number, ancestor, &maxNonCanonical) + unknown = r.Origin.Hash == common.Hash{} + } + case hashMode && !r.Reverse: + // Hash based traversal towards the leaf block + var ( + current = origin.Number.Uint64() + next = current + r.Skip + 1 + ) + if next <= current { + infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") + p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", r.Skip, "next", next, "attacker", infos) + unknown = true + } else { + if header := bc.GetHeaderByNumber(next); header != nil { + nextHash := header.Hash() + expOldHash, _ := bc.GetAncestor(nextHash, next, r.Skip+1, &maxNonCanonical) + if expOldHash == r.Origin.Hash { + r.Origin.Hash, r.Origin.Number = nextHash, next + } else { + unknown = true + } + } else { + unknown = true + } + } + case r.Reverse: + // Number based traversal towards the genesis block + if r.Origin.Number >= r.Skip+1 { + r.Origin.Number -= r.Skip + 1 + } else { + unknown = true + } + + case !r.Reverse: + // Number based traversal towards the leaf block + r.Origin.Number += r.Skip + 1 + } + first = false + } + return p.replyBlockHeaders(reqID, headers) +} + +type GetBlockBodiesReq []common.Hash + +func (r GetBlockBodiesReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxBodyFetch } + +func (r GetBlockBodiesReq) InMetrics(size int64) { + miscInBodyPacketsMeter.Mark(1) + miscInBodyTrafficMeter.Mark(size) +} + +func (r GetBlockBodiesReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutBodyPacketsMeter.Mark(1) + miscOutBodyTrafficMeter.Mark(size) + miscServingTimeBodyTimer.Update(servingTime) +} + +func (r GetBlockBodiesReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + bodies []rlp.RawValue + ) + bc := backend.BlockChain() + for i, hash := range r { + if i != 0 && !waitOrStop() { + return nil + } + if bytes >= softResponseLimit { + break + } + body := bc.GetBodyRLP(hash) + if body == nil { + p.bumpInvalid() + continue + } + bodies = append(bodies, body) + bytes += len(body) + } + return p.replyBlockBodiesRLP(reqID, bodies) +} + +type GetCodeReq []CodeReq + +func (r GetCodeReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxCodeFetch } + +func (r GetCodeReq) InMetrics(size int64) { + miscInCodePacketsMeter.Mark(1) + miscInCodeTrafficMeter.Mark(size) +} + +func (r GetCodeReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutCodePacketsMeter.Mark(1) + miscOutCodeTrafficMeter.Mark(size) + miscServingTimeCodeTimer.Update(servingTime) +} + +func (r GetCodeReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + data [][]byte + ) + bc := backend.BlockChain() + for i, request := range r { + if i != 0 && !waitOrStop() { + return nil + } + // Look up the root hash belonging to the request + header := bc.GetHeaderByHash(request.BHash) + if header == nil { + p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash) + p.bumpInvalid() + continue + } + // Refuse to search stale state data in the database since looking for + // a non-exist key is kind of expensive. + local := bc.CurrentHeader().Number.Uint64() + if !backend.ArchiveMode() && header.Number.Uint64()+core.TriesInMemory <= local { + p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) + p.bumpInvalid() + continue + } + triedb := bc.StateCache().TrieDB() + + account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.bumpInvalid() + continue + } + code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) + if err != nil { + p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) + continue + } + // Accumulate the code and abort if enough data was retrieved + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break + } + } + return p.replyCode(reqID, data) +} + +type GetReceiptsReq []common.Hash + +func (r GetReceiptsReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxReceiptFetch } + +func (r GetReceiptsReq) InMetrics(size int64) { + miscInReceiptPacketsMeter.Mark(1) + miscInReceiptTrafficMeter.Mark(size) +} + +func (r GetReceiptsReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutReceiptPacketsMeter.Mark(1) + miscOutReceiptTrafficMeter.Mark(size) + miscServingTimeReceiptTimer.Update(servingTime) +} + +func (r GetReceiptsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + receipts []rlp.RawValue + ) + bc := backend.BlockChain() + for i, hash := range r { + if i != 0 && !waitOrStop() { + return nil + } + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block's receipts, skipping if unknown to us + results := bc.GetReceiptsByHash(hash) + if results == nil { + if header := bc.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + p.bumpInvalid() + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + log.Error("Failed to encode receipt", "err", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) + } + } + return p.replyReceiptsRLP(reqID, receipts) +} + +type GetProofsReq []ProofReq + +func (r GetProofsReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxProofsFetch } + +func (r GetProofsReq) InMetrics(size int64) { + miscInTrieProofPacketsMeter.Mark(1) + miscInTrieProofTrafficMeter.Mark(size) +} + +func (r GetProofsReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutTrieProofPacketsMeter.Mark(1) + miscOutTrieProofTrafficMeter.Mark(size) + miscServingTimeTrieProofTimer.Update(servingTime) +} + +func (r GetProofsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + var ( + lastBHash common.Hash + root common.Hash + header *types.Header + err error + ) + bc := backend.BlockChain() + nodes := light.NewNodeSet() + + for i, request := range r { + if i != 0 && !waitOrStop() { + return nil + } + // Look up the root hash belonging to the request + if request.BHash != lastBHash { + root, lastBHash = common.Hash{}, request.BHash + + if header = bc.GetHeaderByHash(request.BHash); header == nil { + p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash) + p.bumpInvalid() + continue + } + // Refuse to search stale state data in the database since looking for + // a non-exist key is kind of expensive. + local := bc.CurrentHeader().Number.Uint64() + if !backend.ArchiveMode() && header.Number.Uint64()+core.TriesInMemory <= local { + p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) + p.bumpInvalid() + continue + } + root = header.Root + } + // If a header lookup failed (non existent), ignore subsequent requests for the same header + if root == (common.Hash{}) { + p.bumpInvalid() + continue + } + // Open the account or storage trie for the request + statedb := bc.StateCache() + + var trie state.Trie + switch len(request.AccKey) { + case 0: + // No account key specified, open an account trie + trie, err = statedb.OpenTrie(root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) + continue + } + default: + // Account key specified, open a storage trie + account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.bumpInvalid() + continue + } + trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + continue + } + } + // Prove the user's request from the account or stroage trie + if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { + p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) + continue + } + if nodes.DataSize() >= softResponseLimit { + break + } + } + return p.replyProofsV2(reqID, nodes.NodeList()) +} + +type GetHelperTrieProofsReq []HelperTrieReq + +func (r GetHelperTrieProofsReq) ReqAmount() (uint64, uint64) { + return uint64(len(r)), MaxHelperTrieProofsFetch +} + +func (r GetHelperTrieProofsReq) InMetrics(size int64) { + miscInHelperTriePacketsMeter.Mark(1) + miscInHelperTrieTrafficMeter.Mark(size) +} + +func (r GetHelperTrieProofsReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutHelperTriePacketsMeter.Mark(1) + miscOutHelperTrieTrafficMeter.Mark(size) + miscServingTimeHelperTrieTimer.Update(servingTime) +} + +func (r GetHelperTrieProofsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + var ( + lastIdx uint64 + lastType uint + auxTrie *trie.Trie + auxBytes int + auxData [][]byte + ) + bc := backend.BlockChain() + nodes := light.NewNodeSet() + for i, request := range r { + if i != 0 && !waitOrStop() { + return nil + } + if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { + lastType, lastIdx = request.Type, request.TrieIdx + auxTrie = backend.GetHelperTrie(request.Type, request.TrieIdx) + } + if auxTrie == nil { + return nil + } + // TODO(rjl493456442) short circuit if the proving is failed. + // The original client side code has a dirty hack to retrieve + // the headers with no valid proof. Keep the compatibility for + // legacy les protocol and drop this hack when the les2/3 are + // not supported. + err := auxTrie.Prove(request.Key, request.FromLevel, nodes) + if p.version >= lpv4 && err != nil { + return nil + } + if request.Type == htCanonical && request.AuxReq == htAuxHeader && len(request.Key) == 8 { + header := bc.GetHeaderByNumber(binary.BigEndian.Uint64(request.Key)) + data, err := rlp.EncodeToBytes(header) + if err != nil { + log.Error("Failed to encode header", "err", err) + return nil + } + auxData = append(auxData, data) + auxBytes += len(data) + } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + } + return p.replyHelperTrieProofs(reqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}) +} + +type SendTxReq []*types.Transaction + +func (r SendTxReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxTxSend } + +func (r SendTxReq) InMetrics(size int64) { + miscInTxsPacketsMeter.Mark(1) + miscInTxsTrafficMeter.Mark(size) +} + +func (r SendTxReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutTxsPacketsMeter.Mark(1) + miscOutTxsTrafficMeter.Mark(size) + miscServingTimeTxTimer.Update(servingTime) +} + +func (r SendTxReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + stats := make([]light.TxStatus, len(r)) + for i, tx := range r { + if i != 0 && !waitOrStop() { + return nil + } + hash := tx.Hash() + stats[i] = txStatus(backend, hash) + if stats[i].Status == core.TxStatusUnknown { + addFn := backend.TxPool().AddRemotes + // Add txs synchronously for testing purpose + if backend.AddTxsSync() { + addFn = backend.TxPool().AddRemotesSync + } + if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { + stats[i].Error = errs[0].Error() + continue + } + stats[i] = txStatus(backend, hash) + } + } + return p.replyTxStatus(reqID, stats) +} + +type GetTxStatusReq []common.Hash + +func (r GetTxStatusReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxTxStatus } + +func (r GetTxStatusReq) InMetrics(size int64) { + miscInTxStatusPacketsMeter.Mark(1) + miscInTxStatusTrafficMeter.Mark(size) +} + +func (r GetTxStatusReq) OutMetrics(size int64, servingTime time.Duration) { + miscOutTxStatusPacketsMeter.Mark(1) + miscOutTxStatusTrafficMeter.Mark(size) + miscServingTimeTxStatusTimer.Update(servingTime) +} + +func (r GetTxStatusReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { + stats := make([]light.TxStatus, len(r)) + for i, hash := range r { + if i != 0 && !waitOrStop() { + return nil + } + stats[i] = txStatus(backend, hash) + } + return p.replyTxStatus(reqID, stats) +} + +// txStatus returns the status of a specified transaction. +func txStatus(b serverBackend, hash common.Hash) light.TxStatus { + var stat light.TxStatus + // Looking the transaction in txpool first. + stat.Status = b.TxPool().Status([]common.Hash{hash})[0] + + // If the transaction is unknown to the pool, try looking it up locally. + if stat.Status == core.TxStatusUnknown { + lookup := b.BlockChain().GetTransactionLookup(hash) + if lookup != nil { + stat.Status = core.TxStatusIncluded + stat.Lookup = lookup + } + } + return stat +} From 5a25e6dd7214ded94ebe3fe64ee00f464bb839de Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Thu, 14 Jan 2021 16:31:44 +0100 Subject: [PATCH 2/4] tests/fuzzers/les: add fuzzer for les server handler --- les/server_handler.go | 2 +- les/server_requests.go | 2 +- les/test_helper.go | 4 + tests/fuzzers/les/debug/main.go | 41 +++++ tests/fuzzers/les/les-fuzzer.go | 305 ++++++++++++++++++++++++++++++++ 5 files changed, 352 insertions(+), 2 deletions(-) create mode 100644 tests/fuzzers/les/debug/main.go create mode 100644 tests/fuzzers/les/les-fuzzer.go diff --git a/les/server_handler.go b/les/server_handler.go index 820b688b4b2e..d406d3712bdc 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -229,7 +229,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { ID uint64 Data rlp.RawValue } - hreq handlerRequest + hreq HandlerRequest decodeErr error ) if err := msg.Decode(&req); err != nil { diff --git a/les/server_requests.go b/les/server_requests.go index 3da36c801a6b..d2b18a8491ab 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -39,7 +39,7 @@ type serverBackend interface { GetHelperTrie(typ uint, index uint64) *trie.Trie } -type handlerRequest interface { +type HandlerRequest interface { ReqAmount() (uint64, uint64) InMetrics(size int64) OutMetrics(size int64, servingTime time.Duration) diff --git a/les/test_helper.go b/les/test_helper.go index e3f0616a888a..964e9dbf711f 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -586,3 +586,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer } return s, c, teardown } + +func NewFuzzerPeer(version int) *clientPeer { + return newClientPeer(version, 0, p2p.NewPeer(enode.ID{}, "", nil), nil) +} diff --git a/tests/fuzzers/les/debug/main.go b/tests/fuzzers/les/debug/main.go new file mode 100644 index 000000000000..09e087d4c88a --- /dev/null +++ b/tests/fuzzers/les/debug/main.go @@ -0,0 +1,41 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/ethereum/go-ethereum/tests/fuzzers/les" +) + +func main() { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, "Usage: debug \n") + fmt.Fprintf(os.Stderr, "Example\n") + fmt.Fprintf(os.Stderr, " $ debug ../crashers/4bbef6857c733a87ecf6fd8b9e7238f65eb9862a\n") + os.Exit(1) + } + crasher := os.Args[1] + data, err := ioutil.ReadFile(crasher) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err) + os.Exit(1) + } + les.Fuzz(data) +} diff --git a/tests/fuzzers/les/les-fuzzer.go b/tests/fuzzers/les/les-fuzzer.go new file mode 100644 index 000000000000..676b2f2f4529 --- /dev/null +++ b/tests/fuzzers/les/les-fuzzer.go @@ -0,0 +1,305 @@ +package les + +import ( + "bytes" + "encoding/binary" + "io" + "math/big" + "math/rand" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + l "github.com/ethereum/go-ethereum/les" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" + //fuzz "github.com/google/gofuzz" +) + +var ( + bankKey, _ = crypto.GenerateKey() + bankAddr = crypto.PubkeyToAddress(bankKey.PublicKey) + bankFunds = big.NewInt(1000000000000000000) + + userKey1, _ = crypto.GenerateKey() + userKey2, _ = crypto.GenerateKey() + userAddr1 = crypto.PubkeyToAddress(userKey1.PublicKey) + userAddr2 = crypto.PubkeyToAddress(userKey2.PublicKey) + + testContractAddr common.Address + testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") +) + +func makeChain(n int) (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { + db := rawdb.NewMemoryDatabase() + gspec := core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{bankAddr: {Balance: bankFunds}}, + GasLimit: 100000000, + } + genesis := gspec.MustCommit(db) + signer := types.HomesteadSigner{} + blocks, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, n, + func(i int, gen *core.BlockGen) { + var ( + tx *types.Transaction + addr common.Address + ) + nonce := uint64(i) + if i%16 == 0 { + tx, _ = types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 200000, big.NewInt(0), testContractCode), signer, bankKey) + addr = crypto.CreateAddress(userAddr1, nonce) + } else { + key, _ := crypto.GenerateKey() + addr = crypto.PubkeyToAddress(key.PublicKey) + tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(100000000000), nil), signer, bankKey) + } + gen.AddTx(tx) + addrHashes = append(addrHashes, crypto.Keccak256Hash(addr[:])) + txHashes = append(txHashes, tx.Hash()) + }) + bc, _ = core.NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}, nil, nil) + + if _, err := bc.InsertChain(blocks); err != nil { + panic(err) + } + return +} + +type fuzzer struct { + chain *core.BlockChain + pool *core.TxPool + trie *trie.Trie + addr, txs []common.Hash + + input io.Reader + exhausted bool +} + +func newFuzzer(input []byte) *fuzzer { + f := &fuzzer{ + input: bytes.NewReader(input), + } + f.chain, f.addr, f.txs = makeChain(int(f.randomByte())) + f.pool = core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, f.chain) + f.trie, _ = trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + r := rand.New(rand.NewSource(42)) + for i := 0; i < 100; i++ { + key := make([]byte, r.Intn(32)+1) + r.Read(key) + value := make([]byte, r.Intn(64)) + r.Read(value) + f.trie.Update(key, value) + } + return f +} + +func (f *fuzzer) read(size int) []byte { + out := make([]byte, size) + if _, err := f.input.Read(out); err != nil { + f.exhausted = true + } + return out +} + +func (f *fuzzer) randomByte() byte { + d := f.read(1) + return d[0] +} + +func (f *fuzzer) randomBool() bool { + d := f.read(1) + return d[0]&1 == 1 +} + +func (f *fuzzer) randomInt(max int) int { + if max == 0 { + return 0 + } + if max <= 256 { + return int(f.randomByte()) % max + } + var a uint16 + if err := binary.Read(f.input, binary.LittleEndian, &a); err != nil { + f.exhausted = true + } + return int(a % uint16(max)) +} + +func (f *fuzzer) randomX(max int) uint64 { + var a uint16 + if err := binary.Read(f.input, binary.LittleEndian, &a); err != nil { + f.exhausted = true + } + if a < 0x8000 { + return uint64(a%uint16(max+1)) - 1 + } + return (uint64(1)<<(a%64+1) - 1) & (uint64(a) * 343897772345826595) +} + +func (f *fuzzer) randomBlockHash() common.Hash { + b := f.randomByte() + h := f.chain.GetCanonicalHash(uint64(b)) + if h == (common.Hash{}) { + if b&1 == 1 { + h[0] = b + } + } + return h +} + +func (f *fuzzer) randomAddrKey() []byte { + i := int(f.randomByte()) + if i < len(f.addr) { + return f.addr[i].Bytes() + } else { + h := make([]byte, int(f.randomByte())) + if i < len(h) { + h[i] = byte(i) + 1 + } + return h + } +} + +func (f *fuzzer) randomTxHash() common.Hash { + i := int(f.randomByte()) + if i < len(f.txs) { + return f.txs[i] + } else { + var h common.Hash + if i&1 == 1 { + h[0] = byte(i) + } + return h + } +} + +func (f *fuzzer) randomBytes(maxLen int) []byte { + return f.read(f.randomInt(maxLen + 1)) +} + +func (f *fuzzer) BlockChain() *core.BlockChain { + return f.chain +} + +func (f *fuzzer) TxPool() *core.TxPool { + return f.pool +} + +func (f *fuzzer) ArchiveMode() bool { + return false +} + +func (f *fuzzer) AddTxsSync() bool { + return false +} + +func (f *fuzzer) GetHelperTrie(typ uint, index uint64) *trie.Trie { + if index < 2 { + return f.trie + } + return nil +} + +func (f *fuzzer) doFuzz(req l.HandlerRequest) { + defer f.chain.Stop() + //fuzz.NewFromGoFuzz(input).Fuzz(req) + peer := l.NewFuzzerPeer(3) + req.Serve(f, 42, peer, func() bool { return true }) +} + +func Fuzz(input []byte) int { + f := newFuzzer(input) + if f.exhausted { + return -1 + } + for !f.exhausted { + switch f.randomInt(8) { + case 0: + req := &l.GetBlockHeadersReq{ + Amount: f.randomX(l.MaxHeaderFetch + 1), + Skip: f.randomX(10), + Reverse: f.randomBool(), + } + if f.randomBool() { + req.Origin.Hash = f.randomBlockHash() + } else { + req.Origin.Number = uint64(f.randomByte()) + } + f.doFuzz(req) + + case 1: + req := make(l.GetBlockBodiesReq, f.randomInt(l.MaxBodyFetch+1)) + for i := range req { + req[i] = f.randomBlockHash() + } + f.doFuzz(req) + + case 2: + req := make(l.GetCodeReq, f.randomInt(l.MaxCodeFetch+1)) + for i := range req { + req[i] = l.CodeReq{ + BHash: f.randomBlockHash(), + AccKey: f.randomAddrKey(), + } + } + f.doFuzz(req) + + case 3: + req := make(l.GetReceiptsReq, f.randomInt(l.MaxReceiptFetch+1)) + for i := range req { + req[i] = f.randomBlockHash() + } + f.doFuzz(req) + + case 4: + req := make(l.GetProofsReq, f.randomInt(l.MaxProofsFetch+1)) + for i := range req { + req[i] = l.ProofReq{ + BHash: f.randomBlockHash(), + AccKey: f.randomAddrKey(), + Key: f.randomAddrKey(), + FromLevel: uint(f.randomX(3)), + } + } + f.doFuzz(req) + + case 5: + req := make(l.GetHelperTrieProofsReq, f.randomInt(l.MaxHelperTrieProofsFetch+1)) + for i := range req { + req[i] = l.HelperTrieReq{ + Type: uint(f.randomX(3)), + TrieIdx: f.randomX(3), + Key: f.randomAddrKey(), + FromLevel: uint(f.randomX(3)), + AuxReq: uint(f.randomX(3)), + } + } + f.doFuzz(req) + + case 6: + req := make(l.SendTxReq, f.randomInt(l.MaxTxSend+1)) + signer := types.HomesteadSigner{} + for i := range req { + nonce := uint64(f.randomByte()) + if nonce%1 == 0 { + nonce = uint64(len(f.txs)) + } + req[i], _ = types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(10000), params.TxGas, big.NewInt(1000000000*int64(f.randomByte())), nil), signer, bankKey) + } + f.doFuzz(req) + + case 7: + req := make(l.GetTxStatusReq, f.randomInt(l.MaxTxStatus+1)) + for i := range req { + req[i] = f.randomTxHash() + } + f.doFuzz(req) + } + } + return 0 +} From f011f0b3fe64f533adf8115df2e9f0864c7a4267 Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Wed, 27 Jan 2021 13:05:17 +0800 Subject: [PATCH 3/4] tests, les: update les fuzzer tests: update les fuzzer tests/fuzzer/les: release resources tests/fuzzer/les: pre-initialize all resources --- go.mod | 1 + les/server_handler.go | 4 +- les/server_requests.go | 8 +- oss-fuzz.sh | 1 + tests/fuzzers/les/les-fuzzer.go | 248 +++++++++++++++++++++----------- 5 files changed, 173 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index cddce4e0d8fb..3299d39e66e5 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf github.com/dop251/goja v0.0.0-20200721192441-a695b0cdd498 + github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415 // indirect github.com/edsrzf/mmap-go v1.0.0 github.com/fatih/color v1.7.0 github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5 diff --git a/les/server_handler.go b/les/server_handler.go index d406d3712bdc..ebe137c6c49b 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -240,8 +240,8 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { switch msg.Code { case GetBlockHeadersMsg: p.Log().Trace("Received block header request") - r := &GetBlockHeadersReq{} - decodeErr = rlp.DecodeBytes(req.Data, r) + r := GetBlockHeadersReq{} + decodeErr = rlp.DecodeBytes(req.Data, &r) hreq = r case GetBlockBodiesMsg: diff --git a/les/server_requests.go b/les/server_requests.go index d2b18a8491ab..00f217cbe911 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -53,20 +53,20 @@ type GetBlockHeadersReq struct { Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) } -func (r *GetBlockHeadersReq) ReqAmount() (uint64, uint64) { return r.Amount, MaxHeaderFetch } +func (r GetBlockHeadersReq) ReqAmount() (uint64, uint64) { return r.Amount, MaxHeaderFetch } -func (r *GetBlockHeadersReq) InMetrics(size int64) { +func (r GetBlockHeadersReq) InMetrics(size int64) { miscInHeaderPacketsMeter.Mark(1) miscInHeaderTrafficMeter.Mark(size) } -func (r *GetBlockHeadersReq) OutMetrics(size int64, servingTime time.Duration) { +func (r GetBlockHeadersReq) OutMetrics(size int64, servingTime time.Duration) { miscOutHeaderPacketsMeter.Mark(1) miscOutHeaderTrafficMeter.Mark(size) miscServingTimeHeaderTimer.Update(servingTime) } -func (r *GetBlockHeadersReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { +func (r GetBlockHeadersReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { // Gather headers until the fetch or network limits is reached var ( bc = backend.BlockChain() diff --git a/oss-fuzz.sh b/oss-fuzz.sh index 860ce0952f11..ac93a5a4670e 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -101,6 +101,7 @@ compile_fuzzer tests/fuzzers/trie Fuzz fuzzTrie compile_fuzzer tests/fuzzers/stacktrie Fuzz fuzzStackTrie compile_fuzzer tests/fuzzers/difficulty Fuzz fuzzDifficulty compile_fuzzer tests/fuzzers/abi Fuzz fuzzAbi +compile_fuzzer tests/fuzzers/les Fuzz fuzzLes compile_fuzzer tests/fuzzers/bls12381 FuzzG1Add fuzz_g1_add compile_fuzzer tests/fuzzers/bls12381 FuzzG1Mul fuzz_g1_mul diff --git a/tests/fuzzers/les/les-fuzzer.go b/tests/fuzzers/les/les-fuzzer.go index 676b2f2f4529..0f07de7bc8dc 100644 --- a/tests/fuzzers/les/les-fuzzer.go +++ b/tests/fuzzers/les/les-fuzzer.go @@ -1,3 +1,19 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package les import ( @@ -5,7 +21,6 @@ import ( "encoding/binary" "io" "math/big" - "math/rand" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" @@ -17,24 +32,27 @@ import ( l "github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" - //fuzz "github.com/google/gofuzz" ) var ( - bankKey, _ = crypto.GenerateKey() + bankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") bankAddr = crypto.PubkeyToAddress(bankKey.PublicKey) - bankFunds = big.NewInt(1000000000000000000) - - userKey1, _ = crypto.GenerateKey() - userKey2, _ = crypto.GenerateKey() - userAddr1 = crypto.PubkeyToAddress(userKey1.PublicKey) - userAddr2 = crypto.PubkeyToAddress(userKey2.PublicKey) + bankFunds = new(big.Int).Mul(big.NewInt(100), big.NewInt(params.Ether)) - testContractAddr common.Address + testChainLen = 256 testContractCode = common.Hex2Bytes("606060405260cc8060106000396000f360606040526000357c01000000000000000000000000000000000000000000000000000000009004806360cd2685146041578063c16431b914606b57603f565b005b6055600480803590602001909190505060a9565b6040518082815260200191505060405180910390f35b60886004808035906020019091908035906020019091905050608a565b005b80600060005083606481101560025790900160005b50819055505b5050565b6000600060005082606481101560025790900160005b5054905060c7565b91905056") + + chain *core.BlockChain + addrHashes []common.Hash + txHashes []common.Hash + + chtTrie *trie.Trie + bloomTrie *trie.Trie + chtKeys [][]byte + bloomKeys [][]byte ) -func makeChain(n int) (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { +func makechain() (bc *core.BlockChain, addrHashes, txHashes []common.Hash) { db := rawdb.NewMemoryDatabase() gspec := core.Genesis{ Config: params.TestChainConfig, @@ -43,59 +61,86 @@ func makeChain(n int) (bc *core.BlockChain, addrHashes, txHashes []common.Hash) } genesis := gspec.MustCommit(db) signer := types.HomesteadSigner{} - blocks, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, n, + blocks, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, testChainLen, func(i int, gen *core.BlockGen) { var ( tx *types.Transaction addr common.Address ) nonce := uint64(i) - if i%16 == 0 { + if i%4 == 0 { tx, _ = types.SignTx(types.NewContractCreation(nonce, big.NewInt(0), 200000, big.NewInt(0), testContractCode), signer, bankKey) - addr = crypto.CreateAddress(userAddr1, nonce) + addr = crypto.CreateAddress(bankAddr, nonce) } else { - key, _ := crypto.GenerateKey() - addr = crypto.PubkeyToAddress(key.PublicKey) - tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(100000000000), nil), signer, bankKey) + addr = common.BigToAddress(big.NewInt(int64(i))) + tx, _ = types.SignTx(types.NewTransaction(nonce, addr, big.NewInt(10000), params.TxGas, big.NewInt(params.GWei), nil), signer, bankKey) } gen.AddTx(tx) addrHashes = append(addrHashes, crypto.Keccak256Hash(addr[:])) txHashes = append(txHashes, tx.Hash()) }) bc, _ = core.NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}, nil, nil) - if _, err := bc.InsertChain(blocks); err != nil { panic(err) } return } +func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [][]byte) { + chtTrie, _ = trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + bloomTrie, _ = trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) + for i := 0; i < testChainLen; i++ { + // The element in CHT is -> + key := make([]byte, 8) + binary.BigEndian.PutUint64(key, uint64(i+1)) + chtTrie.Update(key, []byte{0x1, 0xf}) + chtKeys = append(chtKeys, key) + + // The element in Bloom trie is <2 byte bit index> + -> bloom + key2 := make([]byte, 10) + binary.BigEndian.PutUint64(key2[2:], uint64(i+1)) + bloomTrie.Update(key2, []byte{0x2, 0xe}) + bloomKeys = append(bloomKeys, key2) + } + return +} + +func init() { + chain, addrHashes, txHashes = makechain() + chtTrie, bloomTrie, chtKeys, bloomKeys = makeTries() +} + type fuzzer struct { - chain *core.BlockChain - pool *core.TxPool - trie *trie.Trie + chain *core.BlockChain + pool *core.TxPool + + chainLen int addr, txs []common.Hash + nonce uint64 + + chtKeys [][]byte + bloomKeys [][]byte + chtTrie *trie.Trie + bloomTrie *trie.Trie input io.Reader exhausted bool } func newFuzzer(input []byte) *fuzzer { - f := &fuzzer{ - input: bytes.NewReader(input), - } - f.chain, f.addr, f.txs = makeChain(int(f.randomByte())) - f.pool = core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, f.chain) - f.trie, _ = trie.New(common.Hash{}, trie.NewDatabase(rawdb.NewMemoryDatabase())) - r := rand.New(rand.NewSource(42)) - for i := 0; i < 100; i++ { - key := make([]byte, r.Intn(32)+1) - r.Read(key) - value := make([]byte, r.Intn(64)) - r.Read(value) - f.trie.Update(key, value) + return &fuzzer{ + chain: chain, + chainLen: testChainLen, + addr: addrHashes, + txs: txHashes, + chtTrie: chtTrie, + bloomTrie: bloomTrie, + chtKeys: chtKeys, + bloomKeys: bloomKeys, + nonce: uint64(len(txHashes)), + pool: core.NewTxPool(core.DefaultTxPoolConfig, params.TestChainConfig, chain), + input: bytes.NewReader(input), } - return f } func (f *fuzzer) read(size int) []byte { @@ -142,44 +187,43 @@ func (f *fuzzer) randomX(max int) uint64 { } func (f *fuzzer) randomBlockHash() common.Hash { - b := f.randomByte() - h := f.chain.GetCanonicalHash(uint64(b)) - if h == (common.Hash{}) { - if b&1 == 1 { - h[0] = b - } + h := f.chain.GetCanonicalHash(uint64(f.randomInt(3 * f.chainLen))) + if h != (common.Hash{}) { + return h } - return h + return common.BytesToHash(f.read(common.HashLength)) } -func (f *fuzzer) randomAddrKey() []byte { - i := int(f.randomByte()) +func (f *fuzzer) randomAddrHash() []byte { + i := f.randomInt(3 * len(f.addr)) if i < len(f.addr) { return f.addr[i].Bytes() - } else { - h := make([]byte, int(f.randomByte())) - if i < len(h) { - h[i] = byte(i) + 1 - } - return h } + return f.read(common.HashLength) +} + +func (f *fuzzer) randomCHTTrieKey() []byte { + i := f.randomInt(3 * len(f.chtKeys)) + if i < len(f.chtKeys) { + return f.chtKeys[i] + } + return f.read(8) +} + +func (f *fuzzer) randomBloomTrieKey() []byte { + i := f.randomInt(3 * len(f.bloomKeys)) + if i < len(f.bloomKeys) { + return f.bloomKeys[i] + } + return f.read(10) } func (f *fuzzer) randomTxHash() common.Hash { - i := int(f.randomByte()) + i := f.randomInt(3 * len(f.txs)) if i < len(f.txs) { return f.txs[i] - } else { - var h common.Hash - if i&1 == 1 { - h[0] = byte(i) - } - return h } -} - -func (f *fuzzer) randomBytes(maxLen int) []byte { - return f.read(f.randomInt(maxLen + 1)) + return common.BytesToHash(f.read(common.HashLength)) } func (f *fuzzer) BlockChain() *core.BlockChain { @@ -199,20 +243,25 @@ func (f *fuzzer) AddTxsSync() bool { } func (f *fuzzer) GetHelperTrie(typ uint, index uint64) *trie.Trie { - if index < 2 { - return f.trie + if typ == 0 { + return f.chtTrie + } else if typ == 1 { + return f.bloomTrie } return nil } func (f *fuzzer) doFuzz(req l.HandlerRequest) { - defer f.chain.Stop() - //fuzz.NewFromGoFuzz(input).Fuzz(req) - peer := l.NewFuzzerPeer(3) + version := f.randomInt(3) + 2 // [LES2, LES3, LES4] + peer := l.NewFuzzerPeer(version) req.Serve(f, 42, peer, func() bool { return true }) } func Fuzz(input []byte) int { + // We expect some large inputs + if len(input) < 100 { + return -1 + } f := newFuzzer(input) if f.exhausted { return -1 @@ -220,7 +269,7 @@ func Fuzz(input []byte) int { for !f.exhausted { switch f.randomInt(8) { case 0: - req := &l.GetBlockHeadersReq{ + req := l.GetBlockHeadersReq{ Amount: f.randomX(l.MaxHeaderFetch + 1), Skip: f.randomX(10), Reverse: f.randomBool(), @@ -228,7 +277,7 @@ func Fuzz(input []byte) int { if f.randomBool() { req.Origin.Hash = f.randomBlockHash() } else { - req.Origin.Number = uint64(f.randomByte()) + req.Origin.Number = uint64(f.randomInt(f.chainLen * 2)) } f.doFuzz(req) @@ -244,7 +293,7 @@ func Fuzz(input []byte) int { for i := range req { req[i] = l.CodeReq{ BHash: f.randomBlockHash(), - AccKey: f.randomAddrKey(), + AccKey: f.randomAddrHash(), } } f.doFuzz(req) @@ -259,11 +308,19 @@ func Fuzz(input []byte) int { case 4: req := make(l.GetProofsReq, f.randomInt(l.MaxProofsFetch+1)) for i := range req { - req[i] = l.ProofReq{ - BHash: f.randomBlockHash(), - AccKey: f.randomAddrKey(), - Key: f.randomAddrKey(), - FromLevel: uint(f.randomX(3)), + if f.randomBool() { + req[i] = l.ProofReq{ + BHash: f.randomBlockHash(), + AccKey: f.randomAddrHash(), + Key: f.randomAddrHash(), + FromLevel: uint(f.randomX(3)), + } + } else { + req[i] = l.ProofReq{ + BHash: f.randomBlockHash(), + Key: f.randomAddrHash(), + FromLevel: uint(f.randomX(3)), + } } } f.doFuzz(req) @@ -271,12 +328,34 @@ func Fuzz(input []byte) int { case 5: req := make(l.GetHelperTrieProofsReq, f.randomInt(l.MaxHelperTrieProofsFetch+1)) for i := range req { - req[i] = l.HelperTrieReq{ - Type: uint(f.randomX(3)), - TrieIdx: f.randomX(3), - Key: f.randomAddrKey(), - FromLevel: uint(f.randomX(3)), - AuxReq: uint(f.randomX(3)), + switch f.randomInt(3) { + case 0: + // Canonical hash trie + req[i] = l.HelperTrieReq{ + Type: 0, + TrieIdx: f.randomX(3), + Key: f.randomCHTTrieKey(), + FromLevel: uint(f.randomX(3)), + AuxReq: uint(2), + } + case 1: + // Bloom trie + req[i] = l.HelperTrieReq{ + Type: 1, + TrieIdx: f.randomX(3), + Key: f.randomBloomTrieKey(), + FromLevel: uint(f.randomX(3)), + AuxReq: 0, + } + default: + // Random trie + req[i] = l.HelperTrieReq{ + Type: 2, + TrieIdx: f.randomX(3), + Key: f.randomCHTTrieKey(), + FromLevel: uint(f.randomX(3)), + AuxReq: 0, + } } } f.doFuzz(req) @@ -285,9 +364,12 @@ func Fuzz(input []byte) int { req := make(l.SendTxReq, f.randomInt(l.MaxTxSend+1)) signer := types.HomesteadSigner{} for i := range req { - nonce := uint64(f.randomByte()) - if nonce%1 == 0 { - nonce = uint64(len(f.txs)) + var nonce uint64 + if f.randomBool() { + nonce = uint64(f.randomByte()) + } else { + nonce = f.nonce + f.nonce += 1 } req[i], _ = types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(10000), params.TxGas, big.NewInt(1000000000*int64(f.randomByte())), nil), signer, bankKey) } From 347afb4e5d0c4269012248048018f409d8389e39 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Sun, 7 Feb 2021 21:31:49 +0100 Subject: [PATCH 4/4] les: refactored server handler and fuzzer --- go.mod | 1 - les/handler_test.go | 36 +- les/peer.go | 4 +- les/protocol.go | 57 +++ les/server_handler.go | 98 +--- les/server_requests.go | 869 +++++++++++++++++--------------- tests/fuzzers/les/les-fuzzer.go | 100 ++-- 7 files changed, 610 insertions(+), 555 deletions(-) diff --git a/go.mod b/go.mod index 3299d39e66e5..cddce4e0d8fb 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/deckarep/golang-set v0.0.0-20180603214616-504e848d77ea github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf github.com/dop251/goja v0.0.0-20200721192441-a695b0cdd498 - github.com/dvyukov/go-fuzz v0.0.0-20210103155950-6a8e9d1f2415 // indirect github.com/edsrzf/mmap-go v1.0.0 github.com/fatih/color v1.7.0 github.com/fjl/memsize v0.0.0-20190710130421-bcb5799ab5e5 diff --git a/les/handler_test.go b/les/handler_test.go index aca078d9c5df..07d2ff4fcd6d 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -64,27 +64,27 @@ func testGetBlockHeaders(t *testing.T, protocol int) { // Create a batch of tests for various scenarios limit := uint64(MaxHeaderFetch) tests := []struct { - query *GetBlockHeadersReq // The query to execute for header retrieval - expect []common.Hash // The hashes of the block whose headers are expected + query *GetBlockHeadersData // The query to execute for header retrieval + expect []common.Hash // The hashes of the block whose headers are expected }{ // A single random block should be retrievable by hash and number too { - &GetBlockHeadersReq{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Hash: bc.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(limit / 2).Hash()}, }, // Multiple headers should be retrievable in both directions { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 + 1).Hash(), bc.GetBlockByNumber(limit/2 + 2).Hash(), }, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 - 1).Hash(), @@ -93,14 +93,14 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Multiple headers with skip lists should be retrievable { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 + 4).Hash(), bc.GetBlockByNumber(limit/2 + 8).Hash(), }, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(limit / 2).Hash(), bc.GetBlockByNumber(limit/2 - 4).Hash(), @@ -109,26 +109,26 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // The chain endpoints should be retrievable { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: 0}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, []common.Hash{bc.GetBlockByNumber(0).Hash()}, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64()}, Amount: 1}, []common.Hash{bc.CurrentBlock().Hash()}, }, // Ensure protocol limits are honored //{ - // &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + // &GetBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, // []common.Hash{}, //}, // Check that requesting more than available is handled gracefully { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), bc.GetBlockByNumber(bc.CurrentBlock().NumberU64()).Hash(), }, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(4).Hash(), bc.GetBlockByNumber(0).Hash(), @@ -136,13 +136,13 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Check that requesting more than available is handled gracefully, even if mid skip { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, []common.Hash{ bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 4).Hash(), bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - 1).Hash(), }, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, []common.Hash{ bc.GetBlockByNumber(4).Hash(), bc.GetBlockByNumber(1).Hash(), @@ -150,10 +150,10 @@ func testGetBlockHeaders(t *testing.T, protocol int) { }, // Check that non existing headers aren't returned { - &GetBlockHeadersReq{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, []common.Hash{}, }, { - &GetBlockHeadersReq{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1}, + &GetBlockHeadersData{Origin: hashOrNumber{Number: bc.CurrentBlock().NumberU64() + 1}, Amount: 1}, []common.Hash{}, }, } @@ -609,7 +609,7 @@ func TestStopResumeLes3(t *testing.T) { header := server.handler.blockchain.CurrentHeader() req := func() { reqID++ - sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) + sendRequest(server.peer.app, GetBlockHeadersMsg, reqID, &GetBlockHeadersData{Origin: hashOrNumber{Hash: header.Hash()}, Amount: 1}) } for i := 1; i <= 5; i++ { // send requests while we still have enough buffer and expect a response diff --git a/les/peer.go b/les/peer.go index a43c3bc6a927..ead926844ed0 100644 --- a/les/peer.go +++ b/les/peer.go @@ -432,14 +432,14 @@ func (p *serverPeer) sendRequest(msgcode, reqID uint64, data interface{}, amount // specified header query, based on the hash of an origin block. func (p *serverPeer) requestHeadersByHash(reqID uint64, origin common.Hash, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) - return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) + return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) } // requestHeadersByNumber fetches a batch of blocks' headers corresponding to the // specified header query, based on the number of an origin block. func (p *serverPeer) requestHeadersByNumber(reqID, origin uint64, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) - return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersReq{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) + return p.sendRequest(GetBlockHeadersMsg, reqID, &GetBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}, amount) } // requestBodies fetches a batch of blocks' bodies corresponding to the hashes diff --git a/les/protocol.go b/les/protocol.go index a201bb2d9522..6b49d96d04e9 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -24,6 +24,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" lpc "github.com/ethereum/go-ethereum/les/lespay/client" "github.com/ethereum/go-ethereum/p2p/enode" @@ -83,6 +84,62 @@ const ( ResumeMsg = 0x17 ) +// GetBlockHeadersData represents a block header query (the request ID is not included) +type GetBlockHeadersData struct { + Origin hashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +// GetBlockHeadersPacket represents a block header request +type GetBlockHeadersPacket struct { + ReqID uint64 + Query GetBlockHeadersData +} + +// GetBlockBodiesPacket represents a block body request +type GetBlockBodiesPacket struct { + ReqID uint64 + Hashes []common.Hash +} + +// GetCodePacket represents a contract code request +type GetCodePacket struct { + ReqID uint64 + Reqs []CodeReq +} + +// GetReceiptsPacket represents a block receipts request +type GetReceiptsPacket struct { + ReqID uint64 + Hashes []common.Hash +} + +// GetProofsPacket represents a proof request +type GetProofsPacket struct { + ReqID uint64 + Reqs []ProofReq +} + +// GetHelperTrieProofsPacket represents a helper trie proof request +type GetHelperTrieProofsPacket struct { + ReqID uint64 + Reqs []HelperTrieReq +} + +// SendTxPacket represents a transaction propagation request +type SendTxPacket struct { + ReqID uint64 + Txs []*types.Transaction +} + +// GetTxStatusPacket represents a transaction status query +type GetTxStatusPacket struct { + ReqID uint64 + Hashes []common.Hash +} + type requestInfo struct { name string maxCount uint64 diff --git a/les/server_handler.go b/les/server_handler.go index ebe137c6c49b..a88899c4ccef 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -224,92 +224,34 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { p.responseCount++ responseCount := p.responseCount - var ( - req struct { - ID uint64 - Data rlp.RawValue - } - hreq HandlerRequest - decodeErr error - ) - if err := msg.Decode(&req); err != nil { + req, ok := Les3[msg.Code] + if !ok { + p.Log().Trace("Received invalid message", "code", msg.Code) clientErrorMeter.Mark(1) - return errResp(ErrDecode, "%v: %v", msg, err) + return errResp(ErrInvalidMsgCode, "%v", msg.Code) } + p.Log().Trace("Received " + req.Name) - switch msg.Code { - case GetBlockHeadersMsg: - p.Log().Trace("Received block header request") - r := GetBlockHeadersReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetBlockBodiesMsg: - p.Log().Trace("Received block bodies request") - r := GetBlockBodiesReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetCodeMsg: - p.Log().Trace("Received code request") - r := GetCodeReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetReceiptsMsg: - p.Log().Trace("Received receipts request") - r := GetReceiptsReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetProofsV2Msg: - p.Log().Trace("Received les/2 proofs request") - r := GetProofsReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetHelperTrieProofsMsg: - p.Log().Trace("Received helper trie proof request") - r := GetHelperTrieProofsReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case SendTxV2Msg: - p.Log().Trace("Received new transactions") - r := SendTxReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - case GetTxStatusMsg: - p.Log().Trace("Received transaction status query request") - r := GetTxStatusReq{} - decodeErr = rlp.DecodeBytes(req.Data, &r) - hreq = r - - default: - p.Log().Trace("Received invalid message", "code", msg.Code) + serve, reqID, reqCnt, err := req.Handle(msg) + if err != nil { clientErrorMeter.Mark(1) - return errResp(ErrInvalidMsgCode, "%v", msg.Code) + return errResp(ErrDecode, "%v: %v", msg, err) } if metrics.EnabledExpensive { - hreq.InMetrics(int64(msg.Size)) - } - if decodeErr != nil { - clientErrorMeter.Mark(1) - return errResp(ErrDecode, "%v: %v", msg, decodeErr) + req.InPacketsMeter.Mark(1) + req.InTrafficMeter.Mark(int64(msg.Size)) } - reqCnt, maxCnt := hreq.ReqAmount() // Short circuit if the peer is already frozen or the request is invalid. inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) - if p.isFrozen() || reqCnt == 0 || reqCnt > maxCnt { + if p.isFrozen() || reqCnt == 0 || reqCnt > req.MaxCount { p.fcClient.OneTimeCost(inSizeCost) return nil } // Prepaid max cost units before request been serving. maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt) - accepted, bufShort, priority := p.fcClient.AcceptRequest(req.ID, responseCount, maxCost) + accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost) if !accepted { p.freeze() p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) @@ -329,7 +271,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { wg.Add(1) go func() { defer wg.Done() - reply := hreq.Serve(h, req.ID, p, task.waitOrStop) + reply := serve(h, p, task.waitOrStop) if reply != nil { task.done() } @@ -340,7 +282,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { // Short circuit if the client is already frozen. if p.isFrozen() { realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0) - p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, realCost) + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) return } // Positive correction buffer value with real cost. @@ -357,7 +299,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { realCost = maxCost } } - bv := p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, realCost) + bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) if reply != nil { // Feed cost tracker request serving statistic. h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost) @@ -372,12 +314,14 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { } }) if metrics.EnabledExpensive { - hreq.OutMetrics(int64(replySize), time.Duration(task.servingTime)) + req.OutPacketsMeter.Mark(1) + req.OutTrafficMeter.Mark(int64(replySize)) + req.ServingTimeMeter.Update(time.Duration(task.servingTime)) } } }() } else { - p.fcClient.RequestProcessed(req.ID, responseCount, maxCost, inSizeCost) + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost) } // If the client has made too much invalid request(e.g. request a non-existent data), @@ -389,18 +333,22 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { return nil } +// BlockChain implements serverBackend func (h *serverHandler) BlockChain() *core.BlockChain { return h.blockchain } +// TxPool implements serverBackend func (h *serverHandler) TxPool() *core.TxPool { return h.txpool } +// ArchiveMode implements serverBackend func (h *serverHandler) ArchiveMode() bool { return h.server.archiveMode } +// AddTxsSync implements serverBackend func (h *serverHandler) AddTxsSync() bool { return h.addTxsSync } diff --git a/les/server_requests.go b/les/server_requests.go index 00f217cbe911..d4af8006f043 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -19,7 +19,6 @@ package les import ( "encoding/binary" "encoding/json" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" @@ -27,10 +26,12 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) +// serverBackend defines the backend functions needed for serving LES requests type serverBackend interface { ArchiveMode() bool AddTxsSync() bool @@ -39,302 +40,257 @@ type serverBackend interface { GetHelperTrie(typ uint, index uint64) *trie.Trie } -type HandlerRequest interface { - ReqAmount() (uint64, uint64) - InMetrics(size int64) - OutMetrics(size int64, servingTime time.Duration) - Serve(b serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply +// Decoder is implemented by the messages passed to the handler functions +type Decoder interface { + Decode(val interface{}) error } -type GetBlockHeadersReq struct { - Origin hashOrNumber // Block from which to retrieve headers - Amount uint64 // Maximum number of headers to retrieve - Skip uint64 // Blocks to skip between consecutive headers - Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +// RequestType is a static struct that describes an LES request type and references +// its handler function. +type RequestType struct { + Name string + MaxCount uint64 + InPacketsMeter, InTrafficMeter, OutPacketsMeter, OutTrafficMeter metrics.Meter + ServingTimeMeter metrics.Timer + Handle func(msg Decoder) (serve serveRequestFn, reqID, amount uint64, err error) } -func (r GetBlockHeadersReq) ReqAmount() (uint64, uint64) { return r.Amount, MaxHeaderFetch } - -func (r GetBlockHeadersReq) InMetrics(size int64) { - miscInHeaderPacketsMeter.Mark(1) - miscInHeaderTrafficMeter.Mark(size) -} - -func (r GetBlockHeadersReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutHeaderPacketsMeter.Mark(1) - miscOutHeaderTrafficMeter.Mark(size) - miscServingTimeHeaderTimer.Update(servingTime) +// serveRequestFn is returned by the request handler functions after decoding the request. +// This function does the actual request serving using the supplied backend. waitOrStop is +// called between serving individual request items and may block if the serving process +// needs to be throttled. If it returns false then the process is terminated. +// The reply is not sent by this function yet. The flow control feedback value is supplied +// by the protocol handler when calling the send function of the returned reply struct. +type serveRequestFn func(backend serverBackend, peer *clientPeer, waitOrStop func() bool) *reply + +// Les3 contains the request types supported by les/2 and les/3 +var Les3 = map[uint64]RequestType{ + GetBlockHeadersMsg: RequestType{ + Name: "block header request", + MaxCount: MaxHeaderFetch, + InPacketsMeter: miscInHeaderPacketsMeter, + InTrafficMeter: miscInHeaderTrafficMeter, + OutPacketsMeter: miscOutHeaderPacketsMeter, + OutTrafficMeter: miscOutHeaderTrafficMeter, + ServingTimeMeter: miscServingTimeHeaderTimer, + Handle: handleGetBlockHeaders, + }, + GetBlockBodiesMsg: RequestType{ + Name: "block bodies request", + MaxCount: MaxBodyFetch, + InPacketsMeter: miscInBodyPacketsMeter, + InTrafficMeter: miscInBodyTrafficMeter, + OutPacketsMeter: miscOutBodyPacketsMeter, + OutTrafficMeter: miscOutBodyTrafficMeter, + ServingTimeMeter: miscServingTimeBodyTimer, + Handle: handleGetBlockBodies, + }, + GetCodeMsg: RequestType{ + Name: "code request", + MaxCount: MaxCodeFetch, + InPacketsMeter: miscInCodePacketsMeter, + InTrafficMeter: miscInCodeTrafficMeter, + OutPacketsMeter: miscOutCodePacketsMeter, + OutTrafficMeter: miscOutCodeTrafficMeter, + ServingTimeMeter: miscServingTimeCodeTimer, + Handle: handleGetCode, + }, + GetReceiptsMsg: RequestType{ + Name: "receipts request", + MaxCount: MaxReceiptFetch, + InPacketsMeter: miscInReceiptPacketsMeter, + InTrafficMeter: miscInReceiptTrafficMeter, + OutPacketsMeter: miscOutReceiptPacketsMeter, + OutTrafficMeter: miscOutReceiptTrafficMeter, + ServingTimeMeter: miscServingTimeReceiptTimer, + Handle: handleGetReceipts, + }, + GetProofsV2Msg: RequestType{ + Name: "les/2 proofs request", + MaxCount: MaxProofsFetch, + InPacketsMeter: miscInTrieProofPacketsMeter, + InTrafficMeter: miscInTrieProofTrafficMeter, + OutPacketsMeter: miscOutTrieProofPacketsMeter, + OutTrafficMeter: miscOutTrieProofTrafficMeter, + ServingTimeMeter: miscServingTimeTrieProofTimer, + Handle: handleGetProofs, + }, + GetHelperTrieProofsMsg: RequestType{ + Name: "helper trie proof request", + MaxCount: MaxHelperTrieProofsFetch, + InPacketsMeter: miscInHelperTriePacketsMeter, + InTrafficMeter: miscInHelperTrieTrafficMeter, + OutPacketsMeter: miscOutHelperTriePacketsMeter, + OutTrafficMeter: miscOutHelperTrieTrafficMeter, + ServingTimeMeter: miscServingTimeHelperTrieTimer, + Handle: handleGetHelperTrieProofs, + }, + SendTxV2Msg: RequestType{ + Name: "new transactions", + MaxCount: MaxTxSend, + InPacketsMeter: miscInTxsPacketsMeter, + InTrafficMeter: miscInTxsTrafficMeter, + OutPacketsMeter: miscOutTxsPacketsMeter, + OutTrafficMeter: miscOutTxsTrafficMeter, + ServingTimeMeter: miscServingTimeTxTimer, + Handle: handleSendTx, + }, + GetTxStatusMsg: RequestType{ + Name: "transaction status query request", + MaxCount: MaxTxStatus, + InPacketsMeter: miscInTxStatusPacketsMeter, + InTrafficMeter: miscInTxStatusTrafficMeter, + OutPacketsMeter: miscOutTxStatusPacketsMeter, + OutTrafficMeter: miscOutTxStatusTrafficMeter, + ServingTimeMeter: miscServingTimeTxStatusTimer, + Handle: handleGetTxStatus, + }, } -func (r GetBlockHeadersReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - // Gather headers until the fetch or network limits is reached - var ( - bc = backend.BlockChain() - hashMode = r.Origin.Hash != (common.Hash{}) - first = true - maxNonCanonical = uint64(100) - bytes common.StorageSize - headers []*types.Header - unknown bool - ) - for !unknown && len(headers) < int(r.Amount) && bytes < softResponseLimit { - if !first && !waitOrStop() { - return nil - } - // Retrieve the next header satisfying the r - var origin *types.Header - if hashMode { - if first { - origin = bc.GetHeaderByHash(r.Origin.Hash) - if origin != nil { - r.Origin.Number = origin.Number.Uint64() +// handleGetBlockHeaders handles a block header request +func handleGetBlockHeaders(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetBlockHeadersPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + // Gather headers until the fetch or network limits is reached + var ( + bc = backend.BlockChain() + hashMode = r.Query.Origin.Hash != (common.Hash{}) + first = true + maxNonCanonical = uint64(100) + bytes common.StorageSize + headers []*types.Header + unknown bool + ) + for !unknown && len(headers) < int(r.Query.Amount) && bytes < softResponseLimit { + if !first && !waitOrStop() { + return nil + } + // Retrieve the next header satisfying the r + var origin *types.Header + if hashMode { + if first { + origin = bc.GetHeaderByHash(r.Query.Origin.Hash) + if origin != nil { + r.Query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = bc.GetHeader(r.Query.Origin.Hash, r.Query.Origin.Number) } } else { - origin = bc.GetHeader(r.Origin.Hash, r.Origin.Number) + origin = bc.GetHeaderByNumber(r.Query.Origin.Number) } - } else { - origin = bc.GetHeaderByNumber(r.Origin.Number) - } - if origin == nil { - break - } - headers = append(headers, origin) - bytes += estHeaderRlpSize - - // Advance to the next header of the r - switch { - case hashMode && r.Reverse: - // Hash based traversal towards the genesis block - ancestor := r.Skip + 1 - if ancestor == 0 { - unknown = true - } else { - r.Origin.Hash, r.Origin.Number = bc.GetAncestor(r.Origin.Hash, r.Origin.Number, ancestor, &maxNonCanonical) - unknown = r.Origin.Hash == common.Hash{} - } - case hashMode && !r.Reverse: - // Hash based traversal towards the leaf block - var ( - current = origin.Number.Uint64() - next = current + r.Skip + 1 - ) - if next <= current { - infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") - p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", r.Skip, "next", next, "attacker", infos) - unknown = true - } else { - if header := bc.GetHeaderByNumber(next); header != nil { - nextHash := header.Hash() - expOldHash, _ := bc.GetAncestor(nextHash, next, r.Skip+1, &maxNonCanonical) - if expOldHash == r.Origin.Hash { - r.Origin.Hash, r.Origin.Number = nextHash, next + if origin == nil { + break + } + headers = append(headers, origin) + bytes += estHeaderRlpSize + + // Advance to the next header of the r + switch { + case hashMode && r.Query.Reverse: + // Hash based traversal towards the genesis block + ancestor := r.Query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + r.Query.Origin.Hash, r.Query.Origin.Number = bc.GetAncestor(r.Query.Origin.Hash, r.Query.Origin.Number, ancestor, &maxNonCanonical) + unknown = r.Query.Origin.Hash == common.Hash{} + } + case hashMode && !r.Query.Reverse: + // Hash based traversal towards the leaf block + var ( + current = origin.Number.Uint64() + next = current + r.Query.Skip + 1 + ) + if next <= current { + infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") + p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", r.Query.Skip, "next", next, "attacker", infos) + unknown = true + } else { + if header := bc.GetHeaderByNumber(next); header != nil { + nextHash := header.Hash() + expOldHash, _ := bc.GetAncestor(nextHash, next, r.Query.Skip+1, &maxNonCanonical) + if expOldHash == r.Query.Origin.Hash { + r.Query.Origin.Hash, r.Query.Origin.Number = nextHash, next + } else { + unknown = true + } } else { unknown = true } + } + case r.Query.Reverse: + // Number based traversal towards the genesis block + if r.Query.Origin.Number >= r.Query.Skip+1 { + r.Query.Origin.Number -= r.Query.Skip + 1 } else { unknown = true } - } - case r.Reverse: - // Number based traversal towards the genesis block - if r.Origin.Number >= r.Skip+1 { - r.Origin.Number -= r.Skip + 1 - } else { - unknown = true - } - case !r.Reverse: - // Number based traversal towards the leaf block - r.Origin.Number += r.Skip + 1 + case !r.Query.Reverse: + // Number based traversal towards the leaf block + r.Query.Origin.Number += r.Query.Skip + 1 + } + first = false } - first = false - } - return p.replyBlockHeaders(reqID, headers) -} - -type GetBlockBodiesReq []common.Hash - -func (r GetBlockBodiesReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxBodyFetch } - -func (r GetBlockBodiesReq) InMetrics(size int64) { - miscInBodyPacketsMeter.Mark(1) - miscInBodyTrafficMeter.Mark(size) + return p.replyBlockHeaders(r.ReqID, headers) + }, r.ReqID, r.Query.Amount, nil } -func (r GetBlockBodiesReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutBodyPacketsMeter.Mark(1) - miscOutBodyTrafficMeter.Mark(size) - miscServingTimeBodyTimer.Update(servingTime) -} - -func (r GetBlockBodiesReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - var ( - bytes int - bodies []rlp.RawValue - ) - bc := backend.BlockChain() - for i, hash := range r { - if i != 0 && !waitOrStop() { - return nil - } - if bytes >= softResponseLimit { - break - } - body := bc.GetBodyRLP(hash) - if body == nil { - p.bumpInvalid() - continue - } - bodies = append(bodies, body) - bytes += len(body) +// handleGetBlockBodies handles a block body request +func handleGetBlockBodies(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetBlockBodiesPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err } - return p.replyBlockBodiesRLP(reqID, bodies) -} - -type GetCodeReq []CodeReq - -func (r GetCodeReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxCodeFetch } - -func (r GetCodeReq) InMetrics(size int64) { - miscInCodePacketsMeter.Mark(1) - miscInCodeTrafficMeter.Mark(size) -} - -func (r GetCodeReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutCodePacketsMeter.Mark(1) - miscOutCodeTrafficMeter.Mark(size) - miscServingTimeCodeTimer.Update(servingTime) -} - -func (r GetCodeReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - var ( - bytes int - data [][]byte - ) - bc := backend.BlockChain() - for i, request := range r { - if i != 0 && !waitOrStop() { - return nil - } - // Look up the root hash belonging to the request - header := bc.GetHeaderByHash(request.BHash) - if header == nil { - p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash) - p.bumpInvalid() - continue - } - // Refuse to search stale state data in the database since looking for - // a non-exist key is kind of expensive. - local := bc.CurrentHeader().Number.Uint64() - if !backend.ArchiveMode() && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) - p.bumpInvalid() - continue - } - triedb := bc.StateCache().TrieDB() - - account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) - if err != nil { - p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) - p.bumpInvalid() - continue - } - code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) - if err != nil { - p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) - continue - } - // Accumulate the code and abort if enough data was retrieved - data = append(data, code) - if bytes += len(code); bytes >= softResponseLimit { - break - } - } - return p.replyCode(reqID, data) -} - -type GetReceiptsReq []common.Hash - -func (r GetReceiptsReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxReceiptFetch } - -func (r GetReceiptsReq) InMetrics(size int64) { - miscInReceiptPacketsMeter.Mark(1) - miscInReceiptTrafficMeter.Mark(size) -} - -func (r GetReceiptsReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutReceiptPacketsMeter.Mark(1) - miscOutReceiptTrafficMeter.Mark(size) - miscServingTimeReceiptTimer.Update(servingTime) -} - -func (r GetReceiptsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - var ( - bytes int - receipts []rlp.RawValue - ) - bc := backend.BlockChain() - for i, hash := range r { - if i != 0 && !waitOrStop() { - return nil - } - if bytes >= softResponseLimit { - break - } - // Retrieve the requested block's receipts, skipping if unknown to us - results := bc.GetReceiptsByHash(hash) - if results == nil { - if header := bc.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + bodies []rlp.RawValue + ) + bc := backend.BlockChain() + for i, hash := range r.Hashes { + if i != 0 && !waitOrStop() { + return nil + } + if bytes >= softResponseLimit { + break + } + body := bc.GetBodyRLP(hash) + if body == nil { p.bumpInvalid() continue } + bodies = append(bodies, body) + bytes += len(body) } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(results); err != nil { - log.Error("Failed to encode receipt", "err", err) - } else { - receipts = append(receipts, encoded) - bytes += len(encoded) - } - } - return p.replyReceiptsRLP(reqID, receipts) -} - -type GetProofsReq []ProofReq - -func (r GetProofsReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxProofsFetch } - -func (r GetProofsReq) InMetrics(size int64) { - miscInTrieProofPacketsMeter.Mark(1) - miscInTrieProofTrafficMeter.Mark(size) + return p.replyBlockBodiesRLP(r.ReqID, bodies) + }, r.ReqID, uint64(len(r.Hashes)), nil } -func (r GetProofsReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutTrieProofPacketsMeter.Mark(1) - miscOutTrieProofTrafficMeter.Mark(size) - miscServingTimeTrieProofTimer.Update(servingTime) -} - -func (r GetProofsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - var ( - lastBHash common.Hash - root common.Hash - header *types.Header - err error - ) - bc := backend.BlockChain() - nodes := light.NewNodeSet() - - for i, request := range r { - if i != 0 && !waitOrStop() { - return nil - } - // Look up the root hash belonging to the request - if request.BHash != lastBHash { - root, lastBHash = common.Hash{}, request.BHash - - if header = bc.GetHeaderByHash(request.BHash); header == nil { - p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash) +// handleGetCode handles a contract code request +func handleGetCode(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetCodePacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + data [][]byte + ) + bc := backend.BlockChain() + for i, request := range r.Reqs { + if i != 0 && !waitOrStop() { + return nil + } + // Look up the root hash belonging to the request + header := bc.GetHeaderByHash(request.BHash) + if header == nil { + p.Log().Warn("Failed to retrieve associate header for code", "hash", request.BHash) p.bumpInvalid() continue } @@ -342,182 +298,257 @@ func (r GetProofsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, // a non-exist key is kind of expensive. local := bc.CurrentHeader().Number.Uint64() if !backend.ArchiveMode() && header.Number.Uint64()+core.TriesInMemory <= local { - p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) + p.Log().Debug("Reject stale code request", "number", header.Number.Uint64(), "head", local) p.bumpInvalid() continue } - root = header.Root - } - // If a header lookup failed (non existent), ignore subsequent requests for the same header - if root == (common.Hash{}) { - p.bumpInvalid() - continue - } - // Open the account or storage trie for the request - statedb := bc.StateCache() + triedb := bc.StateCache().TrieDB() - var trie state.Trie - switch len(request.AccKey) { - case 0: - // No account key specified, open an account trie - trie, err = statedb.OpenTrie(root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) - continue - } - default: - // Account key specified, open a storage trie - account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + account, err := getAccount(triedb, header.Root, common.BytesToHash(request.AccKey)) if err != nil { - p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.Log().Warn("Failed to retrieve account for code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) p.bumpInvalid() continue } - trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) - if trie == nil || err != nil { - p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + code, err := bc.StateCache().ContractCode(common.BytesToHash(request.AccKey), common.BytesToHash(account.CodeHash)) + if err != nil { + p.Log().Warn("Failed to retrieve account code", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "codehash", common.BytesToHash(account.CodeHash), "err", err) continue } + // Accumulate the code and abort if enough data was retrieved + data = append(data, code) + if bytes += len(code); bytes >= softResponseLimit { + break + } } - // Prove the user's request from the account or stroage trie - if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { - p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) - continue - } - if nodes.DataSize() >= softResponseLimit { - break - } - } - return p.replyProofsV2(reqID, nodes.NodeList()) -} - -type GetHelperTrieProofsReq []HelperTrieReq - -func (r GetHelperTrieProofsReq) ReqAmount() (uint64, uint64) { - return uint64(len(r)), MaxHelperTrieProofsFetch -} - -func (r GetHelperTrieProofsReq) InMetrics(size int64) { - miscInHelperTriePacketsMeter.Mark(1) - miscInHelperTrieTrafficMeter.Mark(size) -} - -func (r GetHelperTrieProofsReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutHelperTriePacketsMeter.Mark(1) - miscOutHelperTrieTrafficMeter.Mark(size) - miscServingTimeHelperTrieTimer.Update(servingTime) + return p.replyCode(r.ReqID, data) + }, r.ReqID, uint64(len(r.Reqs)), nil } -func (r GetHelperTrieProofsReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - var ( - lastIdx uint64 - lastType uint - auxTrie *trie.Trie - auxBytes int - auxData [][]byte - ) - bc := backend.BlockChain() - nodes := light.NewNodeSet() - for i, request := range r { - if i != 0 && !waitOrStop() { - return nil - } - if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { - lastType, lastIdx = request.Type, request.TrieIdx - auxTrie = backend.GetHelperTrie(request.Type, request.TrieIdx) - } - if auxTrie == nil { - return nil - } - // TODO(rjl493456442) short circuit if the proving is failed. - // The original client side code has a dirty hack to retrieve - // the headers with no valid proof. Keep the compatibility for - // legacy les protocol and drop this hack when the les2/3 are - // not supported. - err := auxTrie.Prove(request.Key, request.FromLevel, nodes) - if p.version >= lpv4 && err != nil { - return nil - } - if request.Type == htCanonical && request.AuxReq == htAuxHeader && len(request.Key) == 8 { - header := bc.GetHeaderByNumber(binary.BigEndian.Uint64(request.Key)) - data, err := rlp.EncodeToBytes(header) - if err != nil { - log.Error("Failed to encode header", "err", err) +// handleGetReceipts handles a block receipts request +func handleGetReceipts(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetReceiptsPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + var ( + bytes int + receipts []rlp.RawValue + ) + bc := backend.BlockChain() + for i, hash := range r.Hashes { + if i != 0 && !waitOrStop() { return nil } - auxData = append(auxData, data) - auxBytes += len(data) - } - if nodes.DataSize()+auxBytes >= softResponseLimit { - break + if bytes >= softResponseLimit { + break + } + // Retrieve the requested block's receipts, skipping if unknown to us + results := bc.GetReceiptsByHash(hash) + if results == nil { + if header := bc.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + p.bumpInvalid() + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + log.Error("Failed to encode receipt", "err", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) + } } - } - return p.replyHelperTrieProofs(reqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}) + return p.replyReceiptsRLP(r.ReqID, receipts) + }, r.ReqID, uint64(len(r.Hashes)), nil } -type SendTxReq []*types.Transaction - -func (r SendTxReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxTxSend } - -func (r SendTxReq) InMetrics(size int64) { - miscInTxsPacketsMeter.Mark(1) - miscInTxsTrafficMeter.Mark(size) -} - -func (r SendTxReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutTxsPacketsMeter.Mark(1) - miscOutTxsTrafficMeter.Mark(size) - miscServingTimeTxTimer.Update(servingTime) -} - -func (r SendTxReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - stats := make([]light.TxStatus, len(r)) - for i, tx := range r { - if i != 0 && !waitOrStop() { - return nil - } - hash := tx.Hash() - stats[i] = txStatus(backend, hash) - if stats[i].Status == core.TxStatusUnknown { - addFn := backend.TxPool().AddRemotes - // Add txs synchronously for testing purpose - if backend.AddTxsSync() { - addFn = backend.TxPool().AddRemotesSync - } - if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { - stats[i].Error = errs[0].Error() +// handleGetProofs handles a proof request +func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetProofsPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + var ( + lastBHash common.Hash + root common.Hash + header *types.Header + err error + ) + bc := backend.BlockChain() + nodes := light.NewNodeSet() + + for i, request := range r.Reqs { + if i != 0 && !waitOrStop() { + return nil + } + // Look up the root hash belonging to the request + if request.BHash != lastBHash { + root, lastBHash = common.Hash{}, request.BHash + + if header = bc.GetHeaderByHash(request.BHash); header == nil { + p.Log().Warn("Failed to retrieve header for proof", "hash", request.BHash) + p.bumpInvalid() + continue + } + // Refuse to search stale state data in the database since looking for + // a non-exist key is kind of expensive. + local := bc.CurrentHeader().Number.Uint64() + if !backend.ArchiveMode() && header.Number.Uint64()+core.TriesInMemory <= local { + p.Log().Debug("Reject stale trie request", "number", header.Number.Uint64(), "head", local) + p.bumpInvalid() + continue + } + root = header.Root + } + // If a header lookup failed (non existent), ignore subsequent requests for the same header + if root == (common.Hash{}) { + p.bumpInvalid() continue } - stats[i] = txStatus(backend, hash) + // Open the account or storage trie for the request + statedb := bc.StateCache() + + var trie state.Trie + switch len(request.AccKey) { + case 0: + // No account key specified, open an account trie + trie, err = statedb.OpenTrie(root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "root", root, "err", err) + continue + } + default: + // Account key specified, open a storage trie + account, err := getAccount(statedb.TrieDB(), root, common.BytesToHash(request.AccKey)) + if err != nil { + p.Log().Warn("Failed to retrieve account for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "err", err) + p.bumpInvalid() + continue + } + trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) + if trie == nil || err != nil { + p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) + continue + } + } + // Prove the user's request from the account or stroage trie + if err := trie.Prove(request.Key, request.FromLevel, nodes); err != nil { + p.Log().Warn("Failed to prove state request", "block", header.Number, "hash", header.Hash(), "err", err) + continue + } + if nodes.DataSize() >= softResponseLimit { + break + } } - } - return p.replyTxStatus(reqID, stats) + return p.replyProofsV2(r.ReqID, nodes.NodeList()) + }, r.ReqID, uint64(len(r.Reqs)), nil } -type GetTxStatusReq []common.Hash - -func (r GetTxStatusReq) ReqAmount() (uint64, uint64) { return uint64(len(r)), MaxTxStatus } - -func (r GetTxStatusReq) InMetrics(size int64) { - miscInTxStatusPacketsMeter.Mark(1) - miscInTxStatusTrafficMeter.Mark(size) +// handleGetHelperTrieProofs handles a helper trie proof request +func handleGetHelperTrieProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetHelperTrieProofsPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + var ( + lastIdx uint64 + lastType uint + auxTrie *trie.Trie + auxBytes int + auxData [][]byte + ) + bc := backend.BlockChain() + nodes := light.NewNodeSet() + for i, request := range r.Reqs { + if i != 0 && !waitOrStop() { + return nil + } + if auxTrie == nil || request.Type != lastType || request.TrieIdx != lastIdx { + lastType, lastIdx = request.Type, request.TrieIdx + auxTrie = backend.GetHelperTrie(request.Type, request.TrieIdx) + } + if auxTrie == nil { + return nil + } + // TODO(rjl493456442) short circuit if the proving is failed. + // The original client side code has a dirty hack to retrieve + // the headers with no valid proof. Keep the compatibility for + // legacy les protocol and drop this hack when the les2/3 are + // not supported. + err := auxTrie.Prove(request.Key, request.FromLevel, nodes) + if p.version >= lpv4 && err != nil { + return nil + } + if request.Type == htCanonical && request.AuxReq == htAuxHeader && len(request.Key) == 8 { + header := bc.GetHeaderByNumber(binary.BigEndian.Uint64(request.Key)) + data, err := rlp.EncodeToBytes(header) + if err != nil { + log.Error("Failed to encode header", "err", err) + return nil + } + auxData = append(auxData, data) + auxBytes += len(data) + } + if nodes.DataSize()+auxBytes >= softResponseLimit { + break + } + } + return p.replyHelperTrieProofs(r.ReqID, HelperTrieResps{Proofs: nodes.NodeList(), AuxData: auxData}) + }, r.ReqID, uint64(len(r.Reqs)), nil } -func (r GetTxStatusReq) OutMetrics(size int64, servingTime time.Duration) { - miscOutTxStatusPacketsMeter.Mark(1) - miscOutTxStatusTrafficMeter.Mark(size) - miscServingTimeTxStatusTimer.Update(servingTime) +// handleSendTx handles a transaction propagation request +func handleSendTx(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r SendTxPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err + } + amount := uint64(len(r.Txs)) + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + stats := make([]light.TxStatus, len(r.Txs)) + for i, tx := range r.Txs { + if i != 0 && !waitOrStop() { + return nil + } + hash := tx.Hash() + stats[i] = txStatus(backend, hash) + if stats[i].Status == core.TxStatusUnknown { + addFn := backend.TxPool().AddRemotes + // Add txs synchronously for testing purpose + if backend.AddTxsSync() { + addFn = backend.TxPool().AddRemotesSync + } + if errs := addFn([]*types.Transaction{tx}); errs[0] != nil { + stats[i].Error = errs[0].Error() + continue + } + stats[i] = txStatus(backend, hash) + } + } + return p.replyTxStatus(r.ReqID, stats) + }, r.ReqID, amount, nil } -func (r GetTxStatusReq) Serve(backend serverBackend, reqID uint64, p *clientPeer, waitOrStop func() bool) *reply { - stats := make([]light.TxStatus, len(r)) - for i, hash := range r { - if i != 0 && !waitOrStop() { - return nil - } - stats[i] = txStatus(backend, hash) +// handleGetTxStatus handles a transaction status query +func handleGetTxStatus(msg Decoder) (serveRequestFn, uint64, uint64, error) { + var r GetTxStatusPacket + if err := msg.Decode(&r); err != nil { + return nil, 0, 0, err } - return p.replyTxStatus(reqID, stats) + return func(backend serverBackend, p *clientPeer, waitOrStop func() bool) *reply { + stats := make([]light.TxStatus, len(r.Hashes)) + for i, hash := range r.Hashes { + if i != 0 && !waitOrStop() { + return nil + } + stats[i] = txStatus(backend, hash) + } + return p.replyTxStatus(r.ReqID, stats) + }, r.ReqID, uint64(len(r.Hashes)), nil } // txStatus returns the status of a specified transaction. diff --git a/tests/fuzzers/les/les-fuzzer.go b/tests/fuzzers/les/les-fuzzer.go index 0f07de7bc8dc..9e896c2c1b9e 100644 --- a/tests/fuzzers/les/les-fuzzer.go +++ b/tests/fuzzers/les/les-fuzzer.go @@ -31,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" l "github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -251,10 +252,27 @@ func (f *fuzzer) GetHelperTrie(typ uint, index uint64) *trie.Trie { return nil } -func (f *fuzzer) doFuzz(req l.HandlerRequest) { +type dummyMsg struct { + data []byte +} + +func (d dummyMsg) Decode(val interface{}) error { + return rlp.DecodeBytes(d.data, val) +} + +func (f *fuzzer) doFuzz(msgCode uint64, packet interface{}) { version := f.randomInt(3) + 2 // [LES2, LES3, LES4] peer := l.NewFuzzerPeer(version) - req.Serve(f, 42, peer, func() bool { return true }) + enc, err := rlp.EncodeToBytes(packet) + if err != nil { + panic(err) + } + fn, _, _, err := l.Les3[msgCode].Handle(dummyMsg{enc}) + if err != nil { + panic(err) + } + fn(f, peer, func() bool { return true }) + } func Fuzz(input []byte) int { @@ -269,69 +287,71 @@ func Fuzz(input []byte) int { for !f.exhausted { switch f.randomInt(8) { case 0: - req := l.GetBlockHeadersReq{ - Amount: f.randomX(l.MaxHeaderFetch + 1), - Skip: f.randomX(10), - Reverse: f.randomBool(), + req := &l.GetBlockHeadersPacket{ + Query: l.GetBlockHeadersData{ + Amount: f.randomX(l.MaxHeaderFetch + 1), + Skip: f.randomX(10), + Reverse: f.randomBool(), + }, } if f.randomBool() { - req.Origin.Hash = f.randomBlockHash() + req.Query.Origin.Hash = f.randomBlockHash() } else { - req.Origin.Number = uint64(f.randomInt(f.chainLen * 2)) + req.Query.Origin.Number = uint64(f.randomInt(f.chainLen * 2)) } - f.doFuzz(req) + f.doFuzz(l.GetBlockHeadersMsg, req) case 1: - req := make(l.GetBlockBodiesReq, f.randomInt(l.MaxBodyFetch+1)) - for i := range req { - req[i] = f.randomBlockHash() + req := &l.GetBlockBodiesPacket{Hashes: make([]common.Hash, f.randomInt(l.MaxBodyFetch+1))} + for i := range req.Hashes { + req.Hashes[i] = f.randomBlockHash() } - f.doFuzz(req) + f.doFuzz(l.GetBlockBodiesMsg, req) case 2: - req := make(l.GetCodeReq, f.randomInt(l.MaxCodeFetch+1)) - for i := range req { - req[i] = l.CodeReq{ + req := &l.GetCodePacket{Reqs: make([]l.CodeReq, f.randomInt(l.MaxCodeFetch+1))} + for i := range req.Reqs { + req.Reqs[i] = l.CodeReq{ BHash: f.randomBlockHash(), AccKey: f.randomAddrHash(), } } - f.doFuzz(req) + f.doFuzz(l.GetCodeMsg, req) case 3: - req := make(l.GetReceiptsReq, f.randomInt(l.MaxReceiptFetch+1)) - for i := range req { - req[i] = f.randomBlockHash() + req := &l.GetReceiptsPacket{Hashes: make([]common.Hash, f.randomInt(l.MaxReceiptFetch+1))} + for i := range req.Hashes { + req.Hashes[i] = f.randomBlockHash() } - f.doFuzz(req) + f.doFuzz(l.GetReceiptsMsg, req) case 4: - req := make(l.GetProofsReq, f.randomInt(l.MaxProofsFetch+1)) - for i := range req { + req := &l.GetProofsPacket{Reqs: make([]l.ProofReq, f.randomInt(l.MaxProofsFetch+1))} + for i := range req.Reqs { if f.randomBool() { - req[i] = l.ProofReq{ + req.Reqs[i] = l.ProofReq{ BHash: f.randomBlockHash(), AccKey: f.randomAddrHash(), Key: f.randomAddrHash(), FromLevel: uint(f.randomX(3)), } } else { - req[i] = l.ProofReq{ + req.Reqs[i] = l.ProofReq{ BHash: f.randomBlockHash(), Key: f.randomAddrHash(), FromLevel: uint(f.randomX(3)), } } } - f.doFuzz(req) + f.doFuzz(l.GetProofsV2Msg, req) case 5: - req := make(l.GetHelperTrieProofsReq, f.randomInt(l.MaxHelperTrieProofsFetch+1)) - for i := range req { + req := &l.GetHelperTrieProofsPacket{Reqs: make([]l.HelperTrieReq, f.randomInt(l.MaxHelperTrieProofsFetch+1))} + for i := range req.Reqs { switch f.randomInt(3) { case 0: // Canonical hash trie - req[i] = l.HelperTrieReq{ + req.Reqs[i] = l.HelperTrieReq{ Type: 0, TrieIdx: f.randomX(3), Key: f.randomCHTTrieKey(), @@ -340,7 +360,7 @@ func Fuzz(input []byte) int { } case 1: // Bloom trie - req[i] = l.HelperTrieReq{ + req.Reqs[i] = l.HelperTrieReq{ Type: 1, TrieIdx: f.randomX(3), Key: f.randomBloomTrieKey(), @@ -349,7 +369,7 @@ func Fuzz(input []byte) int { } default: // Random trie - req[i] = l.HelperTrieReq{ + req.Reqs[i] = l.HelperTrieReq{ Type: 2, TrieIdx: f.randomX(3), Key: f.randomCHTTrieKey(), @@ -358,12 +378,12 @@ func Fuzz(input []byte) int { } } } - f.doFuzz(req) + f.doFuzz(l.GetHelperTrieProofsMsg, req) case 6: - req := make(l.SendTxReq, f.randomInt(l.MaxTxSend+1)) + req := &l.SendTxPacket{Txs: make([]*types.Transaction, f.randomInt(l.MaxTxSend+1))} signer := types.HomesteadSigner{} - for i := range req { + for i := range req.Txs { var nonce uint64 if f.randomBool() { nonce = uint64(f.randomByte()) @@ -371,16 +391,16 @@ func Fuzz(input []byte) int { nonce = f.nonce f.nonce += 1 } - req[i], _ = types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(10000), params.TxGas, big.NewInt(1000000000*int64(f.randomByte())), nil), signer, bankKey) + req.Txs[i], _ = types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(10000), params.TxGas, big.NewInt(1000000000*int64(f.randomByte())), nil), signer, bankKey) } - f.doFuzz(req) + f.doFuzz(l.SendTxV2Msg, req) case 7: - req := make(l.GetTxStatusReq, f.randomInt(l.MaxTxStatus+1)) - for i := range req { - req[i] = f.randomTxHash() + req := &l.GetTxStatusPacket{Hashes: make([]common.Hash, f.randomInt(l.MaxTxStatus+1))} + for i := range req.Hashes { + req.Hashes[i] = f.randomTxHash() } - f.doFuzz(req) + f.doFuzz(l.GetTxStatusMsg, req) } } return 0