Skip to content

Commit

Permalink
swarm: implement smart dialing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Apr 25, 2023
1 parent f7a45b6 commit ea7f3b4
Show file tree
Hide file tree
Showing 6 changed files with 769 additions and 135 deletions.
131 changes: 131 additions & 0 deletions p2p/net/swarm/dial_ranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package swarm

import (
"time"

"github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)

const (
publicTCPDelay = 300 * time.Millisecond
privateTCPDelay = 30 * time.Millisecond
relayDelay = 500 * time.Millisecond
)

func noDelayRanker(addrs []ma.Multiaddr) []*network.AddrDelay {
res := make([]*network.AddrDelay, len(addrs))
for i, a := range addrs {
res[i] = &network.AddrDelay{Addr: a, Delay: 0}
}
return res
}

// defaultDialRanker is the default ranking logic.
//
// we consider private, public ip4, public ip6, relay addresses separately.
//
// In each group, if a quic address is present, we delay tcp addresses.
//
// private: 30 ms delay.
// public ip4: 300 ms delay.
// public ip6: 300 ms delay.
//
// If a quic-v1 address is present we don't dial quic or webtransport address on the same (ip,port) combination.
// If a tcp address is present we don't dial ws or wss address on the same (ip, port) combination.
// If direct addresses are present we delay all relay addresses by 500 millisecond
func defaultDialRanker(addrs []ma.Multiaddr) []*network.AddrDelay {
ip4 := make([]ma.Multiaddr, 0, len(addrs))
ip6 := make([]ma.Multiaddr, 0, len(addrs))
pvt := make([]ma.Multiaddr, 0, len(addrs))
relay := make([]ma.Multiaddr, 0, len(addrs))

res := make([]*network.AddrDelay, 0, len(addrs))
for _, a := range addrs {
switch {
case !manet.IsPublicAddr(a):
pvt = append(pvt, a)
case isProtocolAddr(a, ma.P_IP4):
ip4 = append(ip4, a)
case isProtocolAddr(a, ma.P_IP6):
ip6 = append(ip6, a)
case isRelayAddr(a):
relay = append(relay, a)
default:
res = append(res, &network.AddrDelay{Addr: a, Delay: 0})
}
}
var roffset time.Duration = 0
if len(ip4) > 0 || len(ip6) > 0 {
roffset = relayDelay
}

res = append(res, getAddrDelay(pvt, privateTCPDelay, 0)...)
res = append(res, getAddrDelay(ip4, publicTCPDelay, 0)...)
res = append(res, getAddrDelay(ip6, publicTCPDelay, 0)...)
res = append(res, getAddrDelay(relay, publicTCPDelay, roffset)...)
return res
}

func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, offset time.Duration) []*network.AddrDelay {
var hasQuic, hasQuicV1 bool
quicV1Addr := make(map[string]struct{})
tcpAddr := make(map[string]struct{})
for _, a := range addrs {
switch {
case isProtocolAddr(a, ma.P_WEBTRANSPORT):
case isProtocolAddr(a, ma.P_QUIC):
hasQuic = true
case isProtocolAddr(a, ma.P_QUIC_V1):
hasQuicV1 = true
quicV1Addr[addrPort(a, ma.P_UDP)] = struct{}{}
case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS):
case isProtocolAddr(a, ma.P_TCP):
tcpAddr[addrPort(a, ma.P_TCP)] = struct{}{}
}
}

res := make([]*network.AddrDelay, 0, len(addrs))
for _, a := range addrs {
delay := offset
switch {
case isProtocolAddr(a, ma.P_WEBTRANSPORT):
if hasQuicV1 {
if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok {
continue
}
}
case isProtocolAddr(a, ma.P_QUIC):
if hasQuicV1 {
if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok {
continue
}
}
case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS):
if _, ok := tcpAddr[addrPort(a, ma.P_TCP)]; ok {
continue
}
if hasQuic || hasQuicV1 {
delay = tcpDelay
}
case isProtocolAddr(a, ma.P_TCP):
if hasQuic || hasQuicV1 {
delay = tcpDelay
}
}
res = append(res, &network.AddrDelay{Addr: a, Delay: delay})
}
return res
}

func addrPort(a ma.Multiaddr, p int) string {
c, _ := ma.SplitFirst(a)
port, _ := a.ValueForProtocol(p)
return c.Value() + ":" + port
}

func isProtocolAddr(a ma.Multiaddr, p int) bool {
_, err := a.ValueForProtocol(p)
return err == nil
}
253 changes: 253 additions & 0 deletions p2p/net/swarm/dial_ranker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package swarm

import (
"fmt"
"sort"
"testing"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/test"
ma "github.com/multiformats/go-multiaddr"
)

func TestNoDelayRanker(t *testing.T) {
addrs := []ma.Multiaddr{
ma.StringCast("/ip4/1.2.3.4/tcp/1"),
ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1"),
}
addrDelays := noDelayRanker(addrs)
if len(addrs) != len(addrDelays) {
t.Errorf("addrDelay should have the same number of elements as addr")
}

for _, a := range addrs {
for _, ad := range addrDelays {
if a.Equal(ad.Addr) {
if ad.Delay != 0 {
t.Errorf("expected 0 delay, got %s", ad.Delay)
}
}
}
}
}

func TestDelayRankerTCPDelay(t *testing.T) {
pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1")
ptcp := ma.StringCast("/ip4/192.168.0.100/tcp/1/")

quic := ma.StringCast("/ip4/1.2.3.4/udp/1/quic")
quicv1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/")

tcp6 := ma.StringCast("/ip6/1::1/tcp/1")
quicv16 := ma.StringCast("/ip6/1::2/udp/1/quic-v1")

testCase := []struct {
name string
addrs []ma.Multiaddr
output []*network.AddrDelay
}{
{
name: "quic prioritised over tcp",
addrs: []ma.Multiaddr{quic, tcp},
output: []*network.AddrDelay{
{Addr: quic, Delay: 0},
{Addr: tcp, Delay: publicTCPDelay},
},
},
{
name: "quic-v1 prioritised over tcp",
addrs: []ma.Multiaddr{quicv1, tcp},
output: []*network.AddrDelay{
{Addr: quicv1, Delay: 0},
{Addr: tcp, Delay: publicTCPDelay},
},
},
{
name: "ip6 treated separately",
addrs: []ma.Multiaddr{quicv16, tcp6, quic},
output: []*network.AddrDelay{
{Addr: quicv16, Delay: 0},
{Addr: quic, Delay: 0},
{Addr: tcp6, Delay: publicTCPDelay},
},
},
{
name: "private addrs treated separately",
addrs: []ma.Multiaddr{pquicv1, ptcp},
output: []*network.AddrDelay{
{Addr: pquicv1, Delay: 0},
{Addr: ptcp, Delay: privateTCPDelay},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
res := defaultDialRanker(tc.addrs)
if len(res) != len(tc.output) {
for _, a := range res {
log.Errorf("%v", a)
}
for _, a := range tc.output {
log.Errorf("%v", a)
}
t.Errorf("expected elems: %d got: %d", len(tc.output), len(res))
}
sort.Slice(res, func(i, j int) bool {
if res[i].Delay == res[j].Delay {
return res[i].Addr.String() < res[j].Addr.String()
}
return res[i].Delay < res[j].Delay
})
sort.Slice(tc.output, func(i, j int) bool {
if tc.output[i].Delay == tc.output[j].Delay {
return tc.output[i].Addr.String() < tc.output[j].Addr.String()
}
return tc.output[i].Delay < tc.output[j].Delay
})
})
}
}

func TestDelayRankerAddrDropped(t *testing.T) {
pquic := ma.StringCast("/ip4/192.168.0.100/udp/1/quic")
pquicv1 := ma.StringCast("/ip4/192.168.0.100/udp/1/quic-v1")

quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic")
quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic")
quicv1Addr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")
wt := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/")
wt2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1/webtransport/")

quic6 := ma.StringCast("/ip6/1::1/udp/1/quic")
quicv16 := ma.StringCast("/ip6/1::1/udp/1/quic-v1")

tcp := ma.StringCast("/ip4/1.2.3.5/tcp/1/")
ws := ma.StringCast("/ip4/1.2.3.5/tcp/1/ws")
ws2 := ma.StringCast("/ip4/1.2.3.4/tcp/1/ws")
wss := ma.StringCast("/ip4/1.2.3.5/tcp/1/wss")

testCase := []struct {
name string
addrs []ma.Multiaddr
output []*network.AddrDelay
}{
{
name: "quic dropped when quic-v1 present",
addrs: []ma.Multiaddr{quicAddr, quicv1Addr, quicAddr2},
output: []*network.AddrDelay{
{Addr: quicv1Addr, Delay: 0},
{Addr: quicAddr2, Delay: 0},
},
},
{
name: "webtransport dropped when quicv1 present",
addrs: []ma.Multiaddr{quicv1Addr, wt, wt2, quicAddr},
output: []*network.AddrDelay{
{Addr: quicv1Addr, Delay: 0},
{Addr: wt2, Delay: 0},
},
},
{
name: "ip6 quic dropped when quicv1 present",
addrs: []ma.Multiaddr{quicv16, quic6},
output: []*network.AddrDelay{
{Addr: quicv16, Delay: 0},
},
},
{
name: "web socket removed when tcp present",
addrs: []ma.Multiaddr{quicAddr, tcp, ws, wss, ws2},
output: []*network.AddrDelay{
{Addr: quicAddr, Delay: 0},
{Addr: tcp, Delay: publicTCPDelay},
{Addr: ws2, Delay: publicTCPDelay},
},
},
{
name: "private quic dropped when quiv1 present",
addrs: []ma.Multiaddr{pquic, pquicv1},
output: []*network.AddrDelay{
{Addr: pquicv1, Delay: 0},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
res := defaultDialRanker(tc.addrs)
if len(res) != len(tc.output) {
for _, a := range res {
log.Errorf("%v", a)
}
for _, a := range tc.output {
log.Errorf("%v", a)
}
t.Errorf("expected elems: %d got: %d", len(tc.output), len(res))
}
sort.Slice(res, func(i, j int) bool {
if res[i].Delay == res[j].Delay {
return res[i].Addr.String() < res[j].Addr.String()
}
return res[i].Delay < res[j].Delay
})
sort.Slice(tc.output, func(i, j int) bool {
if tc.output[i].Delay == tc.output[j].Delay {
return tc.output[i].Addr.String() < tc.output[j].Addr.String()
}
return tc.output[i].Delay < tc.output[j].Delay
})
})
}
}

func TestDelayRankerRelay(t *testing.T) {
quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic")
quicAddr2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic")

pid := test.RandPeerIDFatal(t)
r1 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p-circuit/p2p/%s", pid))
r2 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/udp/1/quic/p2p-circuit/p2p/%s", pid))

testCase := []struct {
name string
addrs []ma.Multiaddr
output []*network.AddrDelay
}{
{
name: "relay address delayed",
addrs: []ma.Multiaddr{quicAddr, quicAddr2, r1, r2},
output: []*network.AddrDelay{
{Addr: quicAddr, Delay: 0},
{Addr: quicAddr2, Delay: 0},
{Addr: r2, Delay: relayDelay},
{Addr: r1, Delay: publicTCPDelay + relayDelay},
},
},
}
for _, tc := range testCase {
t.Run(tc.name, func(t *testing.T) {
res := defaultDialRanker(tc.addrs)
if len(res) != len(tc.output) {
for _, a := range res {
log.Errorf("%v", a)
}
for _, a := range tc.output {
log.Errorf("%v", a)
}
t.Errorf("expected elems: %d got: %d", len(tc.output), len(res))
}
sort.Slice(res, func(i, j int) bool {
if res[i].Delay == res[j].Delay {
return res[i].Addr.String() < res[j].Addr.String()
}
return res[i].Delay < res[j].Delay
})
sort.Slice(tc.output, func(i, j int) bool {
if tc.output[i].Delay == tc.output[j].Delay {
return tc.output[i].Addr.String() < tc.output[j].Addr.String()
}
return tc.output[i].Delay < tc.output[j].Delay
})
})
}
}
Loading

0 comments on commit ea7f3b4

Please sign in to comment.