diff --git a/p2p/discover/portal_protocol.go b/p2p/discover/portal_protocol.go index 46d9e33377b1..27286df293bd 100644 --- a/p2p/discover/portal_protocol.go +++ b/p2p/discover/portal_protocol.go @@ -160,7 +160,9 @@ func DefaultPortalProtocolConfig() *PortalProtocolConfig { } type PortalProtocol struct { - table *Table + table *Table + cachedIdsLock sync.Mutex + cachedIds map[string]enode.ID protocolId string protocolName string @@ -200,6 +202,7 @@ func NewPortalProtocol(config *PortalProtocolConfig, protocolId string, privateK protocolName := portalwire.NetworkNameMap[protocolId] protocol := &PortalProtocol{ + cachedIds: make(map[string]enode.ID), protocolId: protocolId, protocolName: protocolName, ListenAddr: config.ListenAddr, @@ -294,39 +297,25 @@ func (p *PortalProtocol) setupUDPListening() error { var err error p.packetRouter = utp.NewPacketRouter( func(buf []byte, addr *net.UDPAddr) (int, error) { - nodes := p.table.Nodes() - var target *enode.Node - for _, n := range nodes { - if addr.Port != n.UDP() { - continue - } - if addr.IP != nil && addr.IP.To4().String() == n.IP().To4().String() { - target = n + p.Log.Info("will send to target data", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) - break - } - if addr.IP == nil { - nodeIp := n.IP().To4().String() - if nodeIp == "127.0.0.1" || nodeIp == "0.0.0.0" { - target = n - break - } - } - } + p.cachedIdsLock.Lock() + defer p.cachedIdsLock.Unlock() + if id, ok := p.cachedIds[addr.String()]; ok { + //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) + req := &v5wire.TalkRequest{Protocol: string(portalwire.UTPNetwork), Message: buf} + p.DiscV5.sendFromAnotherThread(id, addr, req) - if target == nil { - p.Log.Warn("not fount target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) - return 0, fmt.Errorf("not found target node info") + return len(buf), err + } else { + p.Log.Warn("not found target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) + return 0, fmt.Errorf("not found target node id") } - p.Log.Trace("send to target data", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) - req := &v5wire.TalkRequest{Protocol: string(portalwire.UTPNetwork), Message: buf} - p.DiscV5.sendFromAnotherThread(target.ID(), addr, req) - - return len(buf), err }) ctx := context.Background() var logger *zap.Logger + if p.Log.Enabled(ctx, log.LevelDebug) || p.Log.Enabled(ctx, log.LevelTrace) { logger, err = zap.NewDevelopmentConfig().Build() } else { @@ -370,6 +359,23 @@ func (p *PortalProtocol) setupDiscV5AndTable() error { return nil } +func (p *PortalProtocol) putCacheNodeId(node *enode.Node) { + p.cachedIdsLock.Lock() + defer p.cachedIdsLock.Unlock() + addr := &net.UDPAddr{IP: node.IP(), Port: node.UDP()} + if _, ok := p.cachedIds[addr.String()]; !ok { + p.cachedIds[addr.String()] = node.ID() + } +} + +func (p *PortalProtocol) putCacheId(id enode.ID, addr *net.UDPAddr) { + p.cachedIdsLock.Lock() + defer p.cachedIdsLock.Unlock() + if _, ok := p.cachedIds[addr.String()]; !ok { + p.cachedIds[addr.String()] = id + } +} + func (p *PortalProtocol) ping(node *enode.Node) (uint64, error) { pong, err := p.pingInner(node) if err != nil { @@ -513,6 +519,9 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request * return nil, fmt.Errorf("invalid accept response") } + p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + p.putCacheNodeId(target) + accept := &portalwire.Accept{} err = accept.UnmarshalSSZ(resp[1:]) if err != nil { @@ -581,8 +590,8 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request * connctx, conncancel := context.WithTimeout(ctx, defaultUTPConnectTimeout) laddr := p.utp.Addr().(*utp.Addr) raddr := &utp.Addr{IP: target.IP(), Port: target.UDP()} - conn, err = utp.DialUTPOptions("utp", laddr, raddr, utp.WithContext(connctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(uint32(connId))) p.Log.Info("will connect to: ", "addr", raddr.String(), "connId", connId) + conn, err = utp.DialUTPOptions("utp", laddr, raddr, utp.WithContext(connctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(uint32(connId))) if err != nil { conncancel() p.Log.Error("failed to dial utp connection", "err", err) @@ -636,6 +645,9 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, return 0xff, nil, fmt.Errorf("invalid content response") } + p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + p.putCacheNodeId(target) + switch resp[1] { case portalwire.ContentRawSelector: content := &portalwire.Content{} @@ -660,8 +672,8 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, laddr := p.utp.Addr().(*utp.Addr) raddr := &utp.Addr{IP: target.IP(), Port: target.UDP()} connId := binary.BigEndian.Uint16(connIdMsg.Id[:]) - conn, err := utp.DialUTPOptions("utp", laddr, raddr, utp.WithContext(connctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(uint32(connId))) p.Log.Info("will connect to: ", "addr", raddr.String(), "connId", connId) + conn, err := utp.DialUTPOptions("utp", laddr, raddr, utp.WithContext(connctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(uint32(connId))) if err != nil { conncancel() return 0xff, nil, err @@ -787,16 +799,18 @@ func (p *PortalProtocol) handleUtpTalkRequest(id enode.ID, addr *net.UDPAddr, ms if n := p.DiscV5.getNode(id); n != nil { p.table.addSeenNode(wrapNode(n)) } + + p.putCacheId(id, addr) p.Log.Trace("receive utp data", "addr", addr, "msg-length", len(msg)) p.packetRouter.ReceiveMessage(msg, addr) return []byte("") } func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { - p.Log.Trace("handleTalkRequest", "id", id, "addr", addr) if n := p.DiscV5.getNode(id); n != nil { p.table.addSeenNode(wrapNode(n)) } + p.putCacheId(id, addr) msgCode := msg[0] @@ -961,6 +975,8 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque return nil, err } + p.putCacheId(id, addr) + if errors.Is(err, ContentNotFound) { closestNodes := p.findNodesCloseToContent(contentId, portalFindnodesResultLimit) for i, n := range closestNodes { @@ -1030,14 +1046,13 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque default: ctx, cancel := context.WithTimeout(bctx, defaultUTPConnectTimeout) var conn *utp.Conn + p.Log.Debug("will accept find content conn from: ", "source", addr, "connId", connId) conn, err = p.utp.AcceptUTPContext(ctx, connIdSend) - p.Log.Info("will accept from: ", "source", addr, "connId", connId) if err != nil { - p.Log.Error("failed to accept utp connection", "connId", connIdSend, "err", err) + p.Log.Error("failed to accept utp connection for handle find content", "connId", connIdSend, "err", err) cancel() return } - p.Log.Info("") cancel() err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) @@ -1138,6 +1153,8 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po } } + p.putCacheId(id, addr) + idBuffer := make([]byte, 2) if contentKeyBitlist.Count() != 0 { connId := p.connIdGen.GenCid(id, false) @@ -1151,10 +1168,10 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po default: ctx, cancel := context.WithTimeout(bctx, defaultUTPConnectTimeout) var conn *utp.Conn + p.Log.Debug("will accept offer conn from: ", "source", addr, "connId", connId) conn, err = p.utp.AcceptUTPContext(ctx, connIdSend) - p.Log.Info("will accept from: ", "source", addr, "connId", connId) if err != nil { - p.Log.Error("failed to accept utp connection", "connId", connIdSend, "err", err) + p.Log.Error("failed to accept utp connection for handle offer", "connId", connIdSend, "err", err) cancel() return }