diff --git a/p2p/net/filter/filter.go b/p2p/net/filter/filter.go index a6d87c07f1..63a34b7d59 100644 --- a/p2p/net/filter/filter.go +++ b/p2p/net/filter/filter.go @@ -2,7 +2,6 @@ package filter import ( "net" - "strings" "sync" manet "gx/ipfs/QmYVqhVfbK4BKvbW88Lhm26b3ud14sTBvcm1H7uWUx1Fkp/go-multiaddr-net" @@ -27,18 +26,24 @@ func (fs *Filters) AddDialFilter(f *net.IPNet) { } func (f *Filters) AddrBlocked(a ma.Multiaddr) bool { - _, addr, err := manet.DialArgs(a) + maddr := ma.Split(a) + if len(maddr) == 0 { + return false + } + netaddr, err := manet.ToNetAddr(maddr[0]) if err != nil { // if we cant parse it, its probably not blocked return false } + netip := net.ParseIP(netaddr.String()) + if netip == nil { + return false + } - ipstr := strings.Split(addr, ":")[0] - ip := net.ParseIP(ipstr) f.mu.RLock() defer f.mu.RUnlock() for _, ft := range f.filters { - if ft.Contains(ip) { + if ft.Contains(netip) { return true } } diff --git a/p2p/net/filter/filter_test.go b/p2p/net/filter/filter_test.go new file mode 100644 index 0000000000..8a0c46a9f8 --- /dev/null +++ b/p2p/net/filter/filter_test.go @@ -0,0 +1,51 @@ +package filter + +import ( + "net" + "testing" + + ma "gx/ipfs/QmcobAGsCjYt5DXoq9et9L8yR8er7o7Cu3DTvpaq12jYSz/go-multiaddr" +) + +func TestFilter(t *testing.T) { + f := NewFilters() + for _, cidr := range []string{ + "1.2.3.0/24", + "4.3.2.1/32", + "fd00::/8", + "fc00::1/128", + } { + _, ipnet, _ := net.ParseCIDR(cidr) + f.AddDialFilter(ipnet) + } + + for _, blocked := range []string{ + "/ip4/1.2.3.4/tcp/123", + "/ip4/4.3.2.1/udp/123", + "/ip6/fd00::2/tcp/321", + "/ip6/fc00::1/udp/321", + } { + maddr, err := ma.NewMultiaddr(blocked) + if err != nil { + t.Error(err) + } + if !f.AddrBlocked(maddr) { + t.Fatalf("expected %s to be blocked", blocked) + } + } + + for _, notBlocked := range []string{ + "/ip4/1.2.4.1/tcp/123", + "/ip4/4.3.2.2/udp/123", + "/ip6/fe00::1/tcp/321", + "/ip6/fc00::2/udp/321", + } { + maddr, err := ma.NewMultiaddr(notBlocked) + if err != nil { + t.Error(err) + } + if f.AddrBlocked(maddr) { + t.Fatalf("expected %s to not be blocked", notBlocked) + } + } +} diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index c69c7db38f..6c81a9c14f 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -303,7 +303,7 @@ func TestAddrBlocking(t *testing.T) { swarms := makeSwarms(ctx, t, 2) swarms[0].SetConnHandler(func(conn *Conn) { - t.Fatal("no connections should happen!") + t.Fatalf("no connections should happen! -- %s", conn) }) _, block, err := net.ParseCIDR("127.0.0.1/8")