diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go new file mode 100644 index 0000000000..a901ee22fb --- /dev/null +++ b/p2p/net/swarm/dial_ranker.go @@ -0,0 +1,131 @@ +package swarm + +import ( + "time" + + "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const ( + publicTCPDelay = 300 * time.Millisecond + privateTCPDelay = 30 * time.Millisecond + relayDelay = 500 * time.Millisecond +) + +func noDelayRanker(addrs []ma.Multiaddr) []*network.AddrDelay { + res := make([]*network.AddrDelay, len(addrs)) + for i, a := range addrs { + res[i] = &network.AddrDelay{Addr: a, Delay: 0} + } + return res +} + +// defaultDialRanker is the default ranking logic. +// +// we consider private, public ip4, public ip6, relay addresses separately. +// +// In each group, if a quic address is present, we delay tcp addresses. +// +// private: 30 ms delay. +// public ip4: 300 ms delay. +// public ip6: 300 ms delay. +// +// If a quic-v1 address is present we don't dial quic or webtransport address on the same (ip,port) combination. +// If a tcp address is present we don't dial ws or wss address on the same (ip, port) combination. +// If direct addresses are present we delay all relay addresses by 500 millisecond +func defaultDialRanker(addrs []ma.Multiaddr) []*network.AddrDelay { + ip4 := make([]ma.Multiaddr, 0, len(addrs)) + ip6 := make([]ma.Multiaddr, 0, len(addrs)) + pvt := make([]ma.Multiaddr, 0, len(addrs)) + relay := make([]ma.Multiaddr, 0, len(addrs)) + + res := make([]*network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + switch { + case !manet.IsPublicAddr(a): + pvt = append(pvt, a) + case isProtocolAddr(a, ma.P_IP4): + ip4 = append(ip4, a) + case isProtocolAddr(a, ma.P_IP6): + ip6 = append(ip6, a) + case isRelayAddr(a): + relay = append(relay, a) + default: + res = append(res, &network.AddrDelay{Addr: a, Delay: 0}) + } + } + var roffset time.Duration = 0 + if len(ip4) > 0 || len(ip6) > 0 { + roffset = relayDelay + } + + res = append(res, getAddrDelay(pvt, privateTCPDelay, 0)...) + res = append(res, getAddrDelay(ip4, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(ip6, publicTCPDelay, 0)...) + res = append(res, getAddrDelay(relay, publicTCPDelay, roffset)...) + return res +} + +func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, offset time.Duration) []*network.AddrDelay { + var hasQuic, hasQuicV1 bool + quicV1Addr := make(map[string]struct{}) + tcpAddr := make(map[string]struct{}) + for _, a := range addrs { + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + case isProtocolAddr(a, ma.P_QUIC): + hasQuic = true + case isProtocolAddr(a, ma.P_QUIC_V1): + hasQuicV1 = true + quicV1Addr[addrPort(a, ma.P_UDP)] = struct{}{} + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + case isProtocolAddr(a, ma.P_TCP): + tcpAddr[addrPort(a, ma.P_TCP)] = struct{}{} + } + } + + res := make([]*network.AddrDelay, 0, len(addrs)) + for _, a := range addrs { + delay := offset + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_QUIC): + if hasQuicV1 { + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + } + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + if _, ok := tcpAddr[addrPort(a, ma.P_TCP)]; ok { + continue + } + if hasQuic || hasQuicV1 { + delay = tcpDelay + } + case isProtocolAddr(a, ma.P_TCP): + if hasQuic || hasQuicV1 { + delay = tcpDelay + } + } + res = append(res, &network.AddrDelay{Addr: a, Delay: delay}) + } + return res +} + +func addrPort(a ma.Multiaddr, p int) string { + c, _ := ma.SplitFirst(a) + port, _ := a.ValueForProtocol(p) + return c.Value() + ":" + port +} + +func isProtocolAddr(a ma.Multiaddr, p int) bool { + _, err := a.ValueForProtocol(p) + return err == nil +} diff --git a/p2p/net/swarm/dial_ranker_test.go b/p2p/net/swarm/dial_ranker_test.go new file mode 100644 index 0000000000..4adf2eada6 --- /dev/null +++ b/p2p/net/swarm/dial_ranker_test.go @@ -0,0 +1,253 @@ +package swarm + +import ( + "fmt" + "sort" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + ma "github.com/multiformats/go-multiaddr" +) + +func TestNoDelayRanker(t *testing.T) { + addrs := []ma.Multiaddr{ + ma.StringCast("/ip4/1.2.3.4/tcp/1"), + ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"), + } + addrDelays := noDelayRanker(addrs) + if len(addrs) != len(addrDelays) { + t.Errorf("addrDelay should have the same number of elements as addr") + } + + for _, a := range addrs { + for _, ad := range addrDelays { + if a.Equal(ad.Addr) { + if ad.Delay != 0 { + t.Errorf("expected 0 delay, got %s", ad.Delay) + } + } + } + } +} + +func TestDelayRankerTCPDelay(t *testing.T) { + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + ptcp := ma.StringCast("/ip4/192.168.0.100/tcp/1/") + + quic := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicv1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + + tcp6 := ma.StringCast("/ip6/1::1/tcp/1") + quicv16 := ma.StringCast("/ip6/1::2/udp/1/quic-v1") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []*network.AddrDelay + }{ + { + name: "quic prioritised over tcp", + addrs: []ma.Multiaddr{quic, tcp}, + output: []*network.AddrDelay{ + {Addr: quic, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "quic-v1 prioritised over tcp", + addrs: []ma.Multiaddr{quicv1, tcp}, + output: []*network.AddrDelay{ + {Addr: quicv1, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + }, + }, + { + name: "ip6 treated separately", + addrs: []ma.Multiaddr{quicv16, tcp6, quic}, + output: []*network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + {Addr: quic, Delay: 0}, + {Addr: tcp6, Delay: publicTCPDelay}, + }, + }, + { + name: "private addrs treated separately", + addrs: []ma.Multiaddr{pquicv1, ptcp}, + output: []*network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + {Addr: ptcp, Delay: privateTCPDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + }) + } +} + +func TestDelayRankerAddrDropped(t *testing.T) { + pquic := ma.StringCast("/ip4/192.168.0.100/udp/1/quic") + pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1") + + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + quicv1Addr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + wt := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/") + wt2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1/webtransport/") + + quic6 := ma.StringCast("/ip6/1::1/udp/1/quic") + quicv16 := ma.StringCast("/ip6/1::1/udp/1/quic-v1") + + tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + ws := ma.StringCast("/ip4/1.2.3.5/tcp/1/ws") + ws2 := ma.StringCast("/ip4/1.2.3.4/tcp/1/ws") + wss := ma.StringCast("/ip4/1.2.3.5/tcp/1/wss") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []*network.AddrDelay + }{ + { + name: "quic dropped when quic-v1 present", + addrs: []ma.Multiaddr{quicAddr, quicv1Addr, quicAddr2}, + output: []*network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + }, + }, + { + name: "webtransport dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv1Addr, wt, wt2, quicAddr}, + output: []*network.AddrDelay{ + {Addr: quicv1Addr, Delay: 0}, + {Addr: wt2, Delay: 0}, + }, + }, + { + name: "ip6 quic dropped when quicv1 present", + addrs: []ma.Multiaddr{quicv16, quic6}, + output: []*network.AddrDelay{ + {Addr: quicv16, Delay: 0}, + }, + }, + { + name: "web socket removed when tcp present", + addrs: []ma.Multiaddr{quicAddr, tcp, ws, wss, ws2}, + output: []*network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: tcp, Delay: publicTCPDelay}, + {Addr: ws2, Delay: publicTCPDelay}, + }, + }, + { + name: "private quic dropped when quiv1 present", + addrs: []ma.Multiaddr{pquic, pquicv1}, + output: []*network.AddrDelay{ + {Addr: pquicv1, Delay: 0}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + }) + } +} + +func TestDelayRankerRelay(t *testing.T) { + quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + + pid := test.RandPeerIDFatal(t) + r1 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p-circuit/p2p/%s", pid)) + r2 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/udp/1/quic/p2p-circuit/p2p/%s", pid)) + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []*network.AddrDelay + }{ + { + name: "relay address delayed", + addrs: []ma.Multiaddr{quicAddr, quicAddr2, r1, r2}, + output: []*network.AddrDelay{ + {Addr: quicAddr, Delay: 0}, + {Addr: quicAddr2, Delay: 0}, + {Addr: r2, Delay: relayDelay}, + {Addr: r1, Delay: publicTCPDelay + relayDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := defaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + for _, a := range res { + log.Errorf("%v", a) + } + for _, a := range tc.output { + log.Errorf("%v", a) + } + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sort.Slice(res, func(i, j int) bool { + if res[i].Delay == res[j].Delay { + return res[i].Addr.String() < res[j].Addr.String() + } + return res[i].Delay < res[j].Delay + }) + sort.Slice(tc.output, func(i, j int) bool { + if tc.output[i].Delay == tc.output[j].Delay { + return tc.output[i].Addr.String() < tc.output[j].Addr.String() + } + return tc.output[i].Delay < tc.output[j].Delay + }) + }) + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index f805371cc6..6f9445edee 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -2,13 +2,15 @@ package swarm import ( "context" + "math" + "sort" "sync" + "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // ///////////////////////////////////////////////////////////////////////////////// @@ -26,49 +28,66 @@ type dialResponse struct { err error } +// pendRequest tracks a pending dial request type pendRequest struct { - req dialRequest // the original request - err *DialError // dial error accumulator - addrs map[ma.Multiaddr]struct{} // pending addr dials + req dialRequest // the original request + err *DialError + // addrs is the set of addresses relevant to this request for which + // there are pending dials. Request is completed if any dial succeeds or + // this map is empty. + addrs map[ma.Multiaddr]struct{} } +// addrDial tracks dial on a single address. +// we track pendRequests per address and not on dial worker to support direct dial +// requests. type addrDial struct { addr ma.Multiaddr ctx context.Context conn *Conn err error - requests []int + requests []*pendRequest + delay time.Duration } type dialWorker struct { - s *Swarm - peer peer.ID - reqch <-chan dialRequest - reqno int - requests map[int]*pendRequest - pending map[ma.Multiaddr]*addrDial - resch chan dialResult - - connected bool // true when a connection has been successfully established - - nextDial []ma.Multiaddr - - // ready when we have more addresses to dial (nextDial is not empty) - triggerDial <-chan struct{} + s *Swarm + peer peer.ID + reqch <-chan dialRequest + requests map[*pendRequest]struct{} + resch chan dialResult + connected bool + + // trackedDials tracks all dials made by the worker loop. + // an item is only removed from the map in case of a backoff error which is + // required to support simultaneous connect requests. + trackedDials map[ma.Multiaddr]*addrDial + // dialQueue is the list of addresses that will be dialed. + dialQueue []*network.AddrDelay + // dialTimer is used to trigger dials to addresses from dialQueue. + dialTimer *time.Timer + // currDials is the number of dials inflight. + currDials int + // loopStTime is the starting time of dial loop. Delays on addresses in + // dialQueue are wrt this time. + loopStTime time.Time + // timerRunning indicates whether dialTimer is running. + timerRunning bool // for testing wg sync.WaitGroup } func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { - return &dialWorker{ - s: s, - peer: p, - reqch: reqch, - requests: make(map[int]*pendRequest), - pending: make(map[ma.Multiaddr]*addrDial), - resch: make(chan dialResult), + w := &dialWorker{ + s: s, + peer: p, + reqch: reqch, + requests: make(map[*pendRequest]struct{}), + trackedDials: make(map[ma.Multiaddr]*addrDial), + resch: make(chan dialResult), } + return w } func (w *dialWorker) loop() { @@ -76,10 +95,9 @@ func (w *dialWorker) loop() { defer w.wg.Done() defer w.s.limiter.clearAllPeerDials(w.peer) - // used to signal readiness to dial and completion of the dial - ready := make(chan struct{}) - close(ready) - + w.loopStTime = time.Now() + w.dialTimer = time.NewTimer(math.MaxInt64) + w.timerRunning = true loop: for { select { @@ -95,14 +113,17 @@ loop: } addrs, err := w.s.addrsForDial(req.ctx, w.peer) + if err != nil { req.resch <- dialResponse{err: err} continue loop } // at this point, len(addrs) > 0 or else it would be error from addrsForDial - // ranke them to process in order - addrs = w.rankAddrs(addrs) + + // rank them to process in order + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + addrDelays := w.rankAddrs(addrs, simConnect) // create the pending request object pr := &pendRequest{ @@ -110,89 +131,99 @@ loop: err: &DialError{Peer: w.peer}, addrs: make(map[ma.Multiaddr]struct{}), } - for _, a := range addrs { - pr.addrs[a] = struct{}{} + for _, a := range addrDelays { + pr.addrs[a.Addr] = struct{}{} } // check if any of the addrs has been successfully dialed and accumulate // errors from complete dials while collecting new addrs to dial/join - var todial []ma.Multiaddr - var tojoin []*addrDial - - for _, a := range addrs { - ad, ok := w.pending[a] + tojoin := 0 + newdials := 0 + for _, nad := range addrDelays { + ad, ok := w.trackedDials[nad.Addr] if !ok { - todial = append(todial, a) + w.dialQueue = append(w.dialQueue, nad) + w.trackedDials[nad.Addr] = &addrDial{ + addr: nad.Addr, + ctx: req.ctx, + requests: []*pendRequest{pr}, + delay: nad.Delay, + } + newdials++ continue } - if ad.conn != nil { - // dial to this addr was successful, complete the request - req.resch <- dialResponse{conn: ad.conn} - continue loop - } - + // check if this dial has already errored. + // this dial couldn't have succeeded because bestAcceptableConnToPeer + // didn't return a connection if ad.err != nil { // dial to this addr errored, accumulate the error - pr.err.recordErr(a, ad.err) - delete(pr.addrs, a) + pr.err.recordErr(nad.Addr, ad.err) + delete(pr.addrs, nad.Addr) continue } - // dial is still pending, add to the join list - tojoin = append(tojoin, ad) - } - - if len(todial) == 0 && len(tojoin) == 0 { - // all request applicable addrs have been dialed, we must have errored - req.resch <- dialResponse{err: pr.err} - continue loop - } - - // the request has some pending or new dials, track it and schedule new dials - w.reqno++ - w.requests[w.reqno] = pr + // dial is still pending + tojoin++ + + // update delay for pending dials + // we only decrease the delay to not override a simulteneous connect + // 0 delay with a higher delay for a normal dial request + if ad.delay > nad.Delay { + for _, aa := range w.dialQueue { + if aa.Addr.Equal(nad.Addr) { + aa.Delay = nad.Delay + ad.delay = nad.Delay + break + } + } + } - for _, ad := range tojoin { + // update dial context for simulataneous connect request if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) } } - ad.requests = append(ad.requests, w.reqno) - } - - if len(todial) > 0 { - for _, a := range todial { - w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} - } - w.nextDial = append(w.nextDial, todial...) - w.nextDial = w.rankAddrs(w.nextDial) + ad.requests = append(ad.requests, pr) + } - // trigger a new dial now to account for the new addrs we added - w.triggerDial = ready + if newdials+tojoin == 0 { + // all request applicable addrs have been dialed, we must have errored + req.resch <- dialResponse{err: pr.err} + continue loop } - case <-w.triggerDial: - for _, addr := range w.nextDial { - // spawn the dial - ad := w.pending[addr] - err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + w.requests[pr] = struct{}{} + sort.Slice(w.dialQueue, func(i, j int) bool { return w.dialQueue[i].Delay < w.dialQueue[j].Delay }) + w.scheduleNextDial() + + case <-w.dialTimer.C: + // Dial the highest priority addresses without checking + // delay timer. An early trigger means there are no inflight + // dials. + var i int + for i = 0; i < len(w.dialQueue); i++ { + a := w.dialQueue[i] + if a.Delay != w.dialQueue[0].Delay { + break + } + ad := w.trackedDials[a.Addr] + err := w.s.dialNextAddr(ad.ctx, w.peer, a.Addr, w.resch) if err != nil { w.dispatchError(ad, err) + } else { + w.currDials++ } } - - w.nextDial = nil - w.triggerDial = nil + w.dialQueue = w.dialQueue[i:] + w.timerRunning = false + w.scheduleNextDial() case res := <-w.resch: - if res.Conn != nil { - w.connected = true - } - - ad := w.pending[res.Addr] + w.currDials-- + ad := w.trackedDials[res.Addr] if res.Conn != nil { // we got a connection, add it to the swarm @@ -201,19 +232,20 @@ loop: // oops no, we failed to add it to the swarm res.Conn.Close() w.dispatchError(ad, err) + w.scheduleNextDial() continue loop } + w.connected = true // dispatch to still pending requests - for _, reqno := range ad.requests { - pr, ok := w.requests[reqno] + for _, pr := range ad.requests { + _, ok := w.requests[pr] if !ok { // it has already dispatched a connection continue } - pr.req.resch <- dialResponse{conn: conn} - delete(w.requests, reqno) + delete(w.requests, pr) } ad.conn = conn @@ -230,15 +262,31 @@ loop: } w.dispatchError(ad, res.Err) + w.scheduleNextDial() } } } +func (w *dialWorker) scheduleNextDial() { + if len(w.dialQueue) > 0 { + d := w.dialQueue[0].Delay + if w.currDials == 0 { + // no active dials, dial next address immediately + d = 0 + } + if w.timerRunning && !w.dialTimer.Stop() { + <-w.dialTimer.C + } + w.dialTimer.Reset(time.Until(w.loopStTime.Add(d))) + w.timerRunning = true + } +} + // dispatches an error to a specific addr dial func (w *dialWorker) dispatchError(ad *addrDial, err error) { ad.err = err - for _, reqno := range ad.requests { - pr, ok := w.requests[reqno] + for _, pr := range ad.requests { + _, ok := w.requests[pr] if !ok { // has already been dispatched continue @@ -258,7 +306,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { } else { pr.req.resch <- dialResponse{err: pr.err} } - delete(w.requests, reqno) + delete(w.requests, pr) } } @@ -271,43 +319,13 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff // regresses without this. if err == ErrDialBackoff { - delete(w.pending, ad.addr) + delete(w.trackedDials, ad.addr) } } -// ranks addresses in descending order of preference for dialing, with the following rules: -// NonRelay > Relay -// NonWS > WS -// Private > Public -// UDP > TCP -func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - addrTier := func(a ma.Multiaddr) (tier int) { - if isRelayAddr(a) { - tier |= 0b1000 - } - if isExpensiveAddr(a) { - tier |= 0b0100 - } - if !manet.IsPrivateAddr(a) { - tier |= 0b0010 - } - if isFdConsumingAddr(a) { - tier |= 0b0001 - } - - return tier - } - - tiers := make([][]ma.Multiaddr, 16) - for _, a := range addrs { - tier := addrTier(a) - tiers[tier] = append(tiers[tier], a) +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, simConnect bool) []*network.AddrDelay { + if simConnect { + return noDelayRanker(addrs) } - - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, tier := range tiers { - result = append(result, tier...) - } - - return result + return w.s.dialRanker(addrs) } diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 2c441106b1..a416eacd08 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -5,11 +5,13 @@ import ( "crypto/rand" "errors" "fmt" + "net" "sync" "testing" "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/sec" @@ -24,6 +26,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/tcp" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) @@ -88,6 +91,19 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { return u } +func makeTcpListener(t *testing.T) (net.Listener, ma.Multiaddr) { + t.Helper() + lst, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Error(err) + } + addr, err := manet.FromNetAddr(lst.Addr()) + if err != nil { + t.Error(err) + } + return lst, addr +} + func TestDialWorkerLoopBasic(t *testing.T) { s1 := makeSwarm(t) s2 := makeSwarm(t) @@ -342,3 +358,206 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { close(reqch) worker.wg.Wait() } + +func TestDialWorkerLoopRanking(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + defer s1.Close() + defer s2.Close() + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range s2.ListenAddresses() { + if _, err := a.ValueForProtocol(ma.P_QUIC); err == nil { + quicAddr = a + } + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + tcpL1, silAddr1 := makeTcpListener(t) + ch1 := make(chan struct{}) + defer tcpL1.Close() + tcpL2, silAddr2 := makeTcpListener(t) + ch2 := make(chan struct{}) + defer tcpL2.Close() + tcpL3, silAddr3 := makeTcpListener(t) + ch3 := make(chan struct{}) + defer tcpL3.Close() + + acceptAndIgnore := func(ch chan struct{}, l net.Listener) func() { + return func() { + for { + _, err := l.Accept() + if err != nil { + break + } + ch <- struct{}{} + } + } + } + go acceptAndIgnore(ch1, tcpL1)() + go acceptAndIgnore(ch2, tcpL2)() + go acceptAndIgnore(ch3, tcpL3)() + + ranker := func(addrs []ma.Multiaddr) []*network.AddrDelay { + res := make([]*network.AddrDelay, 0) + for _, a := range addrs { + switch { + case a.Equal(silAddr1): + res = append(res, &network.AddrDelay{Addr: a, Delay: 0}) + case a.Equal(silAddr2): + res = append(res, &network.AddrDelay{Addr: a, Delay: 1 * time.Second}) + case a.Equal(tcpAddr): + res = append(res, &network.AddrDelay{Addr: a, Delay: 2 * time.Second}) + case a.Equal(silAddr3): + res = append(res, &network.AddrDelay{Addr: a, Delay: 3 * time.Second}) + default: + t.Errorf("unexpected address %s", a) + } + } + return res + } + + // should connect to quic with both tcp and quic address + s1.dialRanker = ranker + s2addrs := []ma.Multiaddr{tcpAddr, silAddr1, silAddr2, silAddr3} + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2addrs, peerstore.PermanentAddrTTL) + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + worker1 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker1.loop() + defer worker1.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case <-ch1: + case <-time.After(1 * time.Second): + t.Fatal("expected dial to tcp1") + case <-resch: + t.Fatalf("didn't expect connection to succeed") + } + select { + case <-ch2: + case <-time.After(2 * time.Second): + t.Fatalf("expected dial to tcp2") + case <-resch: + t.Fatalf("didn't expect connection to succeed") + } + select { + case res := <-resch: + if !res.conn.RemoteMultiaddr().Equal(tcpAddr) { + log.Errorf("invalid connection address. expected %s got %s", tcpAddr, res.conn.RemoteMultiaddr()) + } + case <-time.After(2 * time.Second): + t.Fatalf("expected dial to succeed") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + select { + case <-ch3: + t.Errorf("didn't expect tcp call") + case <-time.After(2 * time.Second): + } + + quicFirstRanker := func(addrs []ma.Multiaddr) []*network.AddrDelay { + m := make([]*network.AddrDelay, 0) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + m = append(m, &network.AddrDelay{Addr: a, Delay: 500 * time.Millisecond}) + } else { + m = append(m, &network.AddrDelay{Addr: a, Delay: 0}) + } + } + return m + } + + // tcp should connect after delay + s1.dialRanker = quicFirstRanker + s2.ListenClose(quicAddr) + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{quicAddr, tcpAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker2 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker2.loop() + defer worker2.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + t.Fatalf("expected a delay before connecting %s", res.conn.LocalMultiaddr()) + case <-time.After(400 * time.Millisecond): + } + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatalf("expected tcp connection %s", res.conn.LocalMultiaddr()) + } + case <-time.After(1 * time.Second): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + s2.Listen(quicAddr) + + // should dial tcp immediately if there's no quic address available + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{tcpAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker3 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker3.loop() + defer worker3.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatalf("expected tcp connection, got: %s", res.conn.LocalMultiaddr()) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) + + // should dial next immediately when one connection errors after timeout + quicFirstLargeDelayRanker := func(addrs []ma.Multiaddr) []*network.AddrDelay { + m := make([]*network.AddrDelay, 0) + for _, a := range addrs { + if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + m = append(m, &network.AddrDelay{Addr: a, Delay: 10 * time.Second}) + } else { + m = append(m, &network.AddrDelay{Addr: a, Delay: 0}) + } + } + return m + } + + s1.dialRanker = quicFirstLargeDelayRanker + s2.ListenClose(quicAddr) + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{tcpAddr, quicAddr}, peerstore.PermanentAddrTTL) + reqch = make(chan dialRequest) + resch = make(chan dialResponse) + worker4 := newDialWorker(s1, s2.LocalPeer(), reqch) + go worker4.loop() + defer worker4.wg.Wait() + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + select { + case res := <-resch: + require.NoError(t, res.err) + if _, err := res.conn.LocalMultiaddr().ValueForProtocol(ma.P_TCP); err != nil { + t.Fatal("expected tcp connection") + } + case <-time.After(2 * time.Second): + t.Fatal("dial didn't complete") + } + close(reqch) + s1.ClosePeer(s2.LocalPeer()) + s1.peers.ClearAddrs(s2.LocalPeer()) +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index cd19e726ed..3f4444f356 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -100,6 +100,23 @@ func WithResourceManager(m network.ResourceManager) Option { } } +// WithNoDialDelay configures swarm to dial all addresses for a peer without +// any delay +func WithNoDialDelay() Option { + return func(s *Swarm) error { + s.dialRanker = noDelayRanker + return nil + } +} + +// WithDialRanker configures swarm to use d as the DialRanker +func WithDialRanker(d network.DialRanker) Option { + return func(s *Swarm) error { + s.dialRanker = d + return nil + } +} + // Swarm is a connection muxer, allowing connections to other peers to // be opened and closed, while still using the same Chan for all // communication. The Chan sends/receives Messages, which note the @@ -163,6 +180,8 @@ type Swarm struct { bwc metrics.Reporter metricsTracer MetricsTracer + + dialRanker network.DialRanker } // NewSwarm constructs a Swarm. @@ -181,6 +200,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts dialTimeout: defaultDialTimeout, dialTimeoutLocal: defaultDialTimeoutLocal, maResolver: madns.DefaultResolver, + dialRanker: defaultDialRanker, } s.conns.m = make(map[peer.ID][]*Conn) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 5423a199b7..256cff9dea 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -542,13 +542,6 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isExpensiveAddr(addr ma.Multiaddr) bool { - _, wsErr := addr.ValueForProtocol(ma.P_WS) - _, wssErr := addr.ValueForProtocol(ma.P_WSS) - _, wtErr := addr.ValueForProtocol(ma.P_WEBTRANSPORT) - return wsErr == nil || wssErr == nil || wtErr == nil -} - func isRelayAddr(addr ma.Multiaddr) bool { _, err := addr.ValueForProtocol(ma.P_CIRCUIT) return err == nil