diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go index aad2c72b1b..e85c74e6f2 100644 --- a/eth/handler_eth_test.go +++ b/eth/handler_eth_test.go @@ -239,6 +239,76 @@ func testForkIDSplit(t *testing.T, protocol uint) { func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, eth.ETH65) } func TestRecvTransactions66(t *testing.T) { testRecvTransactions(t, eth.ETH66) } +func TestWaitDiffExtensionTimout(t *testing.T) { + t.Parallel() + + // Create a message handler, configure it to accept transactions and watch them + handler := newTestHandler() + defer handler.close() + + // Create a source peer to send messages through and a sink handler to receive them + _, p2pSink := p2p.MsgPipe() + defer p2pSink.Close() + + protos := []p2p.Protocol{ + { + Name: "diff", + Version: 1, + }, + } + + sink := eth.NewPeer(eth.ETH67, p2p.NewPeerWithProtocols(enode.ID{2}, protos, "", []p2p.Cap{ + { + Name: "diff", + Version: 1, + }, + }), p2pSink, nil) + defer sink.Close() + + err := handler.handler.runEthPeer(sink, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + + if err == nil || err.Error() != "peer wait timeout" { + t.Fatalf("error should be `peer wait timeout`") + } +} + +func TestWaitSnapExtensionTimout(t *testing.T) { + t.Parallel() + + // Create a message handler, configure it to accept transactions and watch them + handler := newTestHandler() + defer handler.close() + + // Create a source peer to send messages through and a sink handler to receive them + _, p2pSink := p2p.MsgPipe() + defer p2pSink.Close() + + protos := []p2p.Protocol{ + { + Name: "snap", + Version: 1, + }, + } + + sink := eth.NewPeer(eth.ETH67, p2p.NewPeerWithProtocols(enode.ID{2}, protos, "", []p2p.Cap{ + { + Name: "snap", + Version: 1, + }, + }), p2pSink, nil) + defer sink.Close() + + err := handler.handler.runEthPeer(sink, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + + if err == nil || err.Error() != "peer wait timeout" { + t.Fatalf("error should be `peer wait timeout`") + } +} + func testRecvTransactions(t *testing.T, protocol uint) { t.Parallel() diff --git a/eth/peerset.go b/eth/peerset.go index 220b01d832..0f5245a05e 100644 --- a/eth/peerset.go +++ b/eth/peerset.go @@ -20,6 +20,7 @@ import ( "errors" "math/big" "sync" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/eth/downloader" @@ -38,19 +39,28 @@ var ( // to the peer set, but one with the same id already exists. errPeerAlreadyRegistered = errors.New("peer already registered") + // errPeerWaitTimeout is returned if a peer waits extension for too long + errPeerWaitTimeout = errors.New("peer wait timeout") + // errPeerNotRegistered is returned if a peer is attempted to be removed from // a peer set, but no peer with the given id exists. errPeerNotRegistered = errors.New("peer not registered") // errSnapWithoutEth is returned if a peer attempts to connect only on the - // snap protocol without advertizing the eth main protocol. + // snap protocol without advertising the eth main protocol. errSnapWithoutEth = errors.New("peer connected on snap without compatible eth support") // errDiffWithoutEth is returned if a peer attempts to connect only on the - // diff protocol without advertizing the eth main protocol. + // diff protocol without advertising the eth main protocol. errDiffWithoutEth = errors.New("peer connected on diff without compatible eth support") ) +const ( + // extensionWaitTimeout is the maximum allowed time for the extension wait to + // complete before dropping the connection as malicious. + extensionWaitTimeout = 10 * time.Second +) + // peerSet represents the collection of active peers currently participating in // the `eth` protocol, with or without the `snap` extension. type peerSet struct { @@ -169,7 +179,16 @@ func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) { ps.snapWait[id] = wait ps.lock.Unlock() - return <-wait, nil + select { + case peer := <-wait: + return peer, nil + + case <-time.After(extensionWaitTimeout): + ps.lock.Lock() + delete(ps.snapWait, id) + ps.lock.Unlock() + return nil, errPeerWaitTimeout + } } // waitDiffExtension blocks until all satellite protocols are connected and tracked @@ -203,7 +222,16 @@ func (ps *peerSet) waitDiffExtension(peer *eth.Peer) (*diff.Peer, error) { ps.diffWait[id] = wait ps.lock.Unlock() - return <-wait, nil + select { + case peer := <-wait: + return peer, nil + + case <-time.After(extensionWaitTimeout): + ps.lock.Lock() + delete(ps.diffWait, id) + ps.lock.Unlock() + return nil, errPeerWaitTimeout + } } func (ps *peerSet) GetDiffPeer(pid string) downloader.IDiffPeer { diff --git a/eth/protocols/diff/handshake.go b/eth/protocols/diff/handshake.go index 4198ea88a1..0f17fb9e8b 100644 --- a/eth/protocols/diff/handshake.go +++ b/eth/protocols/diff/handshake.go @@ -26,7 +26,7 @@ import ( const ( // handshakeTimeout is the maximum allowed time for the `diff` handshake to - // complete before dropping the connection.= as malicious. + // complete before dropping the connection as malicious. handshakeTimeout = 5 * time.Second ) diff --git a/p2p/peer.go b/p2p/peer.go index e057e689f6..3b633108db 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -129,6 +129,16 @@ func NewPeer(id enode.ID, name string, caps []Cap) *Peer { return peer } +// NewPeerWithProtocols returns a peer for testing purposes. +func NewPeerWithProtocols(id enode.ID, protocols []Protocol, name string, caps []Cap) *Peer { + pipe, _ := net.Pipe() + node := enode.SignNull(new(enr.Record), id) + conn := &conn{fd: pipe, transport: nil, node: node, caps: caps, name: name} + peer := newPeer(log.Root(), conn, protocols) + close(peer.closed) // ensures Disconnect doesn't block + return peer +} + // ID returns the node's public key. func (p *Peer) ID() enode.ID { return p.rw.node.ID() diff --git a/p2p/server.go b/p2p/server.go index dbaee12ea1..2a38550abf 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -63,6 +63,9 @@ const ( // Maximum amount of time allowed for writing a complete message. frameWriteTimeout = 20 * time.Second + + // Maximum time to wait before stop the p2p server + stopTimeout = 5 * time.Second ) var errServerStopped = errors.New("server stopped") @@ -403,7 +406,18 @@ func (srv *Server) Stop() { } close(srv.quit) srv.lock.Unlock() - srv.loopWG.Wait() + + stopChan := make(chan struct{}) + go func() { + srv.loopWG.Wait() + close(stopChan) + }() + + select { + case <-stopChan: + case <-time.After(stopTimeout): + srv.log.Warn("stop p2p server timeout, forcing stop") + } } // sharedUDPConn implements a shared connection. Write sends messages to the underlying connection while read returns diff --git a/p2p/server_test.go b/p2p/server_test.go index a5b3190aed..4ad8eb18e6 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -203,6 +203,29 @@ func TestServerDial(t *testing.T) { } } +func TestServerStopTimeout(t *testing.T) { + srv := &Server{Config: Config{ + PrivateKey: newkey(), + MaxPeers: 1, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace).New("server", "1"), + }} + srv.Start() + srv.loopWG.Add(1) + + stopChan := make(chan struct{}) + go func() { + srv.Stop() + close(stopChan) + }() + + select { + case <-stopChan: + case <-time.After(10 * time.Second): + t.Error("server should be shutdown in 10 seconds") + } +} + // This test checks that RemovePeer disconnects the peer if it is connected. func TestServerRemovePeerDisconnect(t *testing.T) { srv1 := &Server{Config: Config{