Skip to content

Commit a8891d9

Browse files
committed
remove peer for diff cache when peer closed
1 parent 85c037d commit a8891d9

File tree

9 files changed

+223
-35
lines changed

9 files changed

+223
-35
lines changed

consensus/parlia/parlia.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ func (p *Parlia) AllowLightProcess(chain consensus.ChainReader, currentHeader *t
895895
validators := snap.validators()
896896

897897
validatorNum := int64(len(validators))
898-
// It is not allowed if the only two validators
898+
// It is not allowed if only two validators
899899
if validatorNum <= 2 {
900900
return false
901901
}

core/blockchain.go

+42-4
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ const (
9494
diffLayerFreezerBlockLimit = 864000 // The number of diff layers that should be kept in disk.
9595
diffLayerPruneRecheckInterval = 1 * time.Second // The interval to prune unverified diff layers
9696
maxDiffQueueDist = 2048 // Maximum allowed distance from the chain head to queue diffLayers
97-
maxDiffLimit = 2048 // Maximum number of unique diff layers a peer may have delivered
97+
maxDiffLimit = 2048 // Maximum number of unique diff layers a peer may have responded
9898
maxDiffForkDist = 11 // Maximum allowed backward distance from the chain head
99+
maxDiffLimitForBroadcast = 128 // Maximum number of unique diff layers a peer may have broadcasted
99100

100101
// BlockChainVersion ensures that an incompatible database forces a resync from scratch.
101102
//
@@ -2534,6 +2535,34 @@ func (bc *BlockChain) removeDiffLayers(diffHash common.Hash) {
25342535
}
25352536
}
25362537

2538+
func (bc *BlockChain) RemoveDiffPeer(pid string) {
2539+
bc.diffMux.Lock()
2540+
defer bc.diffMux.Unlock()
2541+
if invaliDiffHashes := bc.diffPeersToDiffHashes[pid]; invaliDiffHashes != nil {
2542+
for invalidDiffHash := range invaliDiffHashes {
2543+
lastDiffHash := false
2544+
if peers, ok := bc.diffHashToPeers[invalidDiffHash]; ok {
2545+
delete(peers, pid)
2546+
if len(peers) == 0 {
2547+
lastDiffHash = true
2548+
delete(bc.diffHashToPeers, invalidDiffHash)
2549+
}
2550+
}
2551+
if lastDiffHash {
2552+
affectedBlockHash := bc.diffHashToBlockHash[invalidDiffHash]
2553+
if diffs, exist := bc.blockHashToDiffLayers[affectedBlockHash]; exist {
2554+
delete(diffs, invalidDiffHash)
2555+
if len(diffs) == 0 {
2556+
delete(bc.blockHashToDiffLayers, affectedBlockHash)
2557+
}
2558+
}
2559+
delete(bc.diffHashToBlockHash, invalidDiffHash)
2560+
}
2561+
}
2562+
delete(bc.diffPeersToDiffHashes, pid)
2563+
}
2564+
}
2565+
25372566
func (bc *BlockChain) untrustedDiffLayerPruneLoop() {
25382567
recheck := time.Tick(diffLayerPruneRecheckInterval)
25392568
bc.wg.Add(1)
@@ -2595,7 +2624,7 @@ func (bc *BlockChain) pruneDiffLayer() {
25952624
}
25962625

25972626
// Process received diff layers
2598-
func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) error {
2627+
func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string, fulfilled bool) error {
25992628
// Basic check
26002629
currentHeight := bc.CurrentBlock().NumberU64()
26012630
if diffLayer.Number > currentHeight && diffLayer.Number-currentHeight > maxDiffQueueDist {
@@ -2610,6 +2639,13 @@ func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) er
26102639
bc.diffMux.Lock()
26112640
defer bc.diffMux.Unlock()
26122641

2642+
if !fulfilled {
2643+
if len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimitForBroadcast {
2644+
log.Error("too many accumulated diffLayers", "pid", pid)
2645+
return nil
2646+
}
2647+
}
2648+
26132649
if len(bc.diffPeersToDiffHashes[pid]) > maxDiffLimit {
26142650
log.Error("too many accumulated diffLayers", "pid", pid)
26152651
return nil
@@ -2618,12 +2654,14 @@ func (bc *BlockChain) HandleDiffLayer(diffLayer *types.DiffLayer, pid string) er
26182654
if _, alreadyHas := bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash]; alreadyHas {
26192655
return nil
26202656
}
2621-
} else {
2622-
bc.diffPeersToDiffHashes[pid] = make(map[common.Hash]struct{})
26232657
}
2658+
bc.diffPeersToDiffHashes[pid] = make(map[common.Hash]struct{})
26242659
bc.diffPeersToDiffHashes[pid][diffLayer.DiffHash] = struct{}{}
26252660
if _, exist := bc.diffNumToBlockHashes[diffLayer.Number]; !exist {
26262661
bc.diffNumToBlockHashes[diffLayer.Number] = make(map[common.Hash]struct{})
2662+
}
2663+
if len(bc.diffNumToBlockHashes[diffLayer.Number]) > 4 {
2664+
26272665
}
26282666
bc.diffNumToBlockHashes[diffLayer.Number][diffLayer.BlockHash] = struct{}{}
26292667

core/blockchain_diff_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func TestProcessDiffLayer(t *testing.T) {
143143
if err != nil {
144144
t.Errorf("failed to decode rawdata %v", err)
145145
}
146-
lightBackend.Chain().HandleDiffLayer(diff, "testpid")
146+
lightBackend.Chain().HandleDiffLayer(diff, "testpid", true)
147147
_, err = lightBackend.chain.insertChain([]*types.Block{block}, true)
148148
if err != nil {
149149
t.Errorf("failed to insert block %v", err)
@@ -158,7 +158,7 @@ func TestProcessDiffLayer(t *testing.T) {
158158
bz, _ := rlp.EncodeToBytes(&latestAccount)
159159
diff.Accounts[0].Blob = bz
160160

161-
lightBackend.Chain().HandleDiffLayer(diff, "testpid")
161+
lightBackend.Chain().HandleDiffLayer(diff, "testpid", true)
162162

163163
_, err := lightBackend.chain.insertChain([]*types.Block{nextBlock}, true)
164164
if err != nil {
@@ -216,8 +216,8 @@ func TestPruneDiffLayer(t *testing.T) {
216216
header := fullBackend.chain.GetHeaderByNumber(num)
217217
rawDiff := fullBackend.chain.GetDiffLayerRLP(header.Hash())
218218
diff, _ := rawDataToDiffLayer(rawDiff)
219-
fullBackend.Chain().HandleDiffLayer(diff, "testpid1")
220-
fullBackend.Chain().HandleDiffLayer(diff, "testpid2")
219+
fullBackend.Chain().HandleDiffLayer(diff, "testpid1", true)
220+
fullBackend.Chain().HandleDiffLayer(diff, "testpid2", true)
221221

222222
}
223223
fullBackend.chain.pruneDiffLayer()

eth/downloader/downloader.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@ type Downloader struct {
161161
quitLock sync.Mutex // Lock to prevent double closes
162162

163163
// Testing hooks
164-
syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run
164+
syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run
165165
bodyFetchHook func([]*types.Header, ...interface{}) // Method to call upon starting a block body fetch
166166
receiptFetchHook func([]*types.Header, ...interface{}) // Method to call upon starting a receipt fetch
167-
chainInsertHook func([]*fetchResult) // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
167+
chainInsertHook func([]*fetchResult) // Method to call upon inserting a chain of blocks (possibly in multiple invocations)
168168
}
169169

170170
// LightChain encapsulates functions required to synchronise a light chain.
@@ -230,7 +230,7 @@ type IPeerSet interface {
230230
GetDiffPeer(string) IDiffPeer
231231
}
232232

233-
func DiffBodiesFetchOption(peers IPeerSet) DownloadOption {
233+
func EnableDiffFetchOp(peers IPeerSet) DownloadOption {
234234
return func(dl *Downloader) *Downloader {
235235
var hook = func(headers []*types.Header, args ...interface{}) {
236236
if len(args) < 2 {

eth/handler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ func newHandler(config *handlerConfig) (*handler, error) {
194194
}
195195
var downloadOptions []downloader.DownloadOption
196196
if h.diffSync {
197-
downloadOptions = append(downloadOptions, downloader.DiffBodiesFetchOption(h.peers))
197+
downloadOptions = append(downloadOptions, downloader.EnableDiffFetchOp(h.peers))
198198
}
199199
h.downloader = downloader.New(h.checkpointNumber, config.Database, h.stateBloom, h.eventMux, h.chain, nil, h.removePeer, downloadOptions...)
200200

eth/handler_diff.go

+26-17
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ func (h *diffHandler) RunPeer(peer *diff.Peer, hand diff.Handler) error {
3535
if err := peer.Handshake(h.diffSync); err != nil {
3636
return err
3737
}
38+
defer h.chain.RemoveDiffPeer(peer.ID())
3839
return (*handler)(h).runDiffExtension(peer, hand)
3940
}
4041

@@ -55,26 +56,34 @@ func (h *diffHandler) Handle(peer *diff.Peer, packet diff.Packet) error {
5556
// data packet for the local node to consume.
5657
switch packet := packet.(type) {
5758
case *diff.DiffLayersPacket:
58-
diffs, err := packet.Unpack()
59-
if err != nil {
60-
return err
61-
}
62-
for _, d := range diffs {
63-
if d != nil {
64-
if err := d.Validate(); err != nil {
65-
return err
66-
}
67-
}
68-
}
69-
for _, diff := range diffs {
70-
err := h.chain.HandleDiffLayer(diff, peer.ID())
71-
if err != nil {
72-
return err
73-
}
74-
}
59+
return h.handleDiffLayerPackage(packet, peer.ID(), false)
60+
61+
case *diff.FullDiffLayersPacket:
62+
return h.handleDiffLayerPackage(&packet.DiffLayersPacket, peer.ID(), true)
7563

7664
default:
7765
return fmt.Errorf("unexpected diff packet type: %T", packet)
7866
}
7967
return nil
8068
}
69+
70+
func (h *diffHandler) handleDiffLayerPackage(packet *diff.DiffLayersPacket, pid string, fulfilled bool) error {
71+
diffs, err := packet.Unpack()
72+
if err != nil {
73+
return err
74+
}
75+
for _, d := range diffs {
76+
if d != nil {
77+
if err := d.Validate(); err != nil {
78+
return err
79+
}
80+
}
81+
}
82+
for _, diff := range diffs {
83+
err := h.chain.HandleDiffLayer(diff, pid, fulfilled)
84+
if err != nil {
85+
return err
86+
}
87+
}
88+
return nil
89+
}

eth/protocols/diff/handler.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ const (
2020
maxDiffLayerServe = 1024
2121
)
2222

23+
var requestTracker = NewTracker(time.Minute)
24+
2325
// Handler is a callback to invoke from an outside runner after the boilerplate
2426
// exchanges have passed.
2527
type Handler func(peer *Peer) error
@@ -139,8 +141,11 @@ func handleMessage(backend Backend, peer *Peer) error {
139141
if err := msg.Decode(res); err != nil {
140142
return fmt.Errorf("%w: message %v: %v", errDecode, msg, err)
141143
}
142-
requestTracker.Fulfil(peer.id, peer.version, FullDiffLayerMsg, res.RequestId)
143-
return backend.Handle(peer, &res.DiffLayersPacket)
144+
if fulfilled := requestTracker.Fulfil(peer.id, peer.version, FullDiffLayerMsg, res.RequestId); fulfilled {
145+
return backend.Handle(peer, res)
146+
} else {
147+
return fmt.Errorf("%w: %v", errUnexpectedMsg, msg.Code)
148+
}
144149
default:
145150
return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code)
146151
}

eth/protocols/diff/protocol.go

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ var (
5858
errMsgTooLarge = errors.New("message too long")
5959
errDecode = errors.New("invalid message")
6060
errInvalidMsgCode = errors.New("invalid message code")
61+
errUnexpectedMsg = errors.New("unexpected message code")
6162
errBadRequest = errors.New("bad request")
6263
errNoCapMsg = errors.New("miss cap message during handshake")
6364
)

eth/protocols/diff/tracker.go

+138-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,145 @@
1717
package diff
1818

1919
import (
20+
"container/list"
21+
"fmt"
22+
"sync"
2023
"time"
2124

22-
"github.com/ethereum/go-ethereum/p2p/tracker"
25+
"github.com/ethereum/go-ethereum/log"
2326
)
2427

25-
// requestTracker is a singleton tracker for request times.
26-
var requestTracker = tracker.New(ProtocolName, time.Minute)
28+
const (
29+
// maxTrackedPackets is a huge number to act as a failsafe on the number of
30+
// pending requests the node will track. It should never be hit unless an
31+
// attacker figures out a way to spin requests.
32+
maxTrackedPackets = 10000
33+
)
34+
35+
// request tracks sent network requests which have not yet received a response.
36+
type request struct {
37+
peer string
38+
version uint // Protocol version
39+
40+
reqCode uint64 // Protocol message code of the request
41+
resCode uint64 // Protocol message code of the expected response
42+
43+
time time.Time // Timestamp when the request was made
44+
expire *list.Element // Expiration marker to untrack it
45+
}
46+
47+
type Tracker struct {
48+
timeout time.Duration // Global timeout after which to drop a tracked packet
49+
50+
pending map[uint64]*request // Currently pending requests
51+
expire *list.List // Linked list tracking the expiration order
52+
wake *time.Timer // Timer tracking the expiration of the next item
53+
54+
lock sync.Mutex // Lock protecting from concurrent updates
55+
}
56+
57+
func NewTracker(timeout time.Duration) *Tracker {
58+
return &Tracker{
59+
timeout: timeout,
60+
pending: make(map[uint64]*request),
61+
expire: list.New(),
62+
}
63+
}
64+
65+
// Track adds a network request to the tracker to wait for a response to arrive
66+
// or until the request it cancelled or times out.
67+
func (t *Tracker) Track(peer string, version uint, reqCode uint64, resCode uint64, id uint64) {
68+
t.lock.Lock()
69+
defer t.lock.Unlock()
70+
71+
// If there's a duplicate request, we've just random-collided (or more probably,
72+
// we have a bug), report it. We could also add a metric, but we're not really
73+
// expecting ourselves to be buggy, so a noisy warning should be enough.
74+
if _, ok := t.pending[id]; ok {
75+
log.Error("Network request id collision", "version", version, "code", reqCode, "id", id)
76+
return
77+
}
78+
// If we have too many pending requests, bail out instead of leaking memory
79+
if pending := len(t.pending); pending >= maxTrackedPackets {
80+
log.Error("Request tracker exceeded allowance", "pending", pending, "peer", peer, "version", version, "code", reqCode)
81+
return
82+
}
83+
// Id doesn't exist yet, start tracking it
84+
t.pending[id] = &request{
85+
peer: peer,
86+
version: version,
87+
reqCode: reqCode,
88+
resCode: resCode,
89+
time: time.Now(),
90+
expire: t.expire.PushBack(id),
91+
}
92+
93+
// If we've just inserted the first item, start the expiration timer
94+
if t.wake == nil {
95+
t.wake = time.AfterFunc(t.timeout, t.clean)
96+
}
97+
}
98+
99+
// clean is called automatically when a preset time passes without a response
100+
// being dleivered for the first network request.
101+
func (t *Tracker) clean() {
102+
t.lock.Lock()
103+
defer t.lock.Unlock()
104+
105+
// Expire anything within a certain threshold (might be no items at all if
106+
// we raced with the delivery)
107+
for t.expire.Len() > 0 {
108+
// Stop iterating if the next pending request is still alive
109+
var (
110+
head = t.expire.Front()
111+
id = head.Value.(uint64)
112+
req = t.pending[id]
113+
)
114+
if time.Since(req.time) < t.timeout+5*time.Millisecond {
115+
break
116+
}
117+
// Nope, dead, drop it
118+
t.expire.Remove(head)
119+
delete(t.pending, id)
120+
}
121+
t.schedule()
122+
}
123+
124+
// schedule starts a timer to trigger on the expiration of the first network
125+
// packet.
126+
func (t *Tracker) schedule() {
127+
if t.expire.Len() == 0 {
128+
t.wake = nil
129+
return
130+
}
131+
t.wake = time.AfterFunc(time.Until(t.pending[t.expire.Front().Value.(uint64)].time.Add(t.timeout)), t.clean)
132+
}
133+
134+
// Fulfil fills a pending request, if any is available.
135+
func (t *Tracker) Fulfil(peer string, version uint, code uint64, id uint64) bool {
136+
t.lock.Lock()
137+
defer t.lock.Unlock()
138+
139+
// If it's a non existing request, track as stale response
140+
req, ok := t.pending[id]
141+
if !ok {
142+
return false
143+
}
144+
// If the response is funky, it might be some active attack
145+
if req.peer != peer || req.version != version || req.resCode != code {
146+
log.Warn("Network response id collision",
147+
"have", fmt.Sprintf("%s:/%d:%d", peer, version, code),
148+
"want", fmt.Sprintf("%s:/%d:%d", peer, req.version, req.resCode),
149+
)
150+
return false
151+
}
152+
// Everything matches, mark the request serviced
153+
t.expire.Remove(req.expire)
154+
delete(t.pending, id)
155+
if req.expire.Prev() == nil {
156+
if t.wake.Stop() {
157+
t.schedule()
158+
}
159+
}
160+
return true
161+
}

0 commit comments

Comments
 (0)