From 38ad6602c98cb5a417ccc5a28c030524d89e323a Mon Sep 17 00:00:00 2001 From: vyzo Date: Mon, 5 Jul 2021 03:00:27 +0300 Subject: [PATCH] add hole punching support (#194) --- conn_test.go | 53 +++++++++++++++++++++- go.mod | 2 +- go.sum | 4 +- listener.go | 18 +++++++- transport.go | 110 ++++++++++++++++++++++++++++++++++++++++++++-- transport_test.go | 2 +- 6 files changed, 179 insertions(+), 10 deletions(-) diff --git a/conn_test.go b/conn_test.go index 96559cc..f962008 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,14 +11,14 @@ import ( "sync/atomic" "time" - gomock "github.com/golang/mock/gomock" ic "github.com/libp2p/go-libp2p-core/crypto" + n "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" - quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" ma "github.com/multiformats/go-multiaddr" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -345,4 +345,53 @@ var _ = Describe("Connection", func() { Expect(rerr).To(HaveOccurred()) Expect(rerr.Error()).To(ContainSubstring("received a stateless reset")) }) + + It("hole punches", func() { + t1, err := NewTransport(serverKey, nil, nil) + Expect(err).ToNot(HaveOccurred()) + laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + ln1, err := t1.Listen(laddr) + Expect(err).ToNot(HaveOccurred()) + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + if _, err := ln1.Accept(); err == nil { + Fail("didn't expect to accept any connections") + } + }() + + t2, err := NewTransport(clientKey, nil, nil) + Expect(err).ToNot(HaveOccurred()) + ln2, err := t2.Listen(laddr) + Expect(err).ToNot(HaveOccurred()) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + if _, err := ln2.Accept(); err == nil { + Fail("didn't expect to accept any connections") + } + }() + connChan := make(chan tpt.CapableConn) + go func() { + defer GinkgoRecover() + conn, err := t2.Dial(n.WithSimultaneousConnect(context.Background(), ""), ln1.Multiaddr(), serverID) + Expect(err).ToNot(HaveOccurred()) + connChan <- conn + }() + conn1, err := t1.Dial(n.WithSimultaneousConnect(context.Background(), ""), ln2.Multiaddr(), clientID) + Expect(err).ToNot(HaveOccurred()) + defer conn1.Close() + Expect(conn1.RemotePeer()).To(Equal(clientID)) + var conn2 tpt.CapableConn + Eventually(connChan).Should(Receive(&conn2)) + defer conn2.Close() + Expect(conn2.RemotePeer()).To(Equal(serverID)) + ln1.Close() + ln2.Close() + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) }) diff --git a/go.mod b/go.mod index 966cee0..423b1c4 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang/mock v1.6.0 github.com/ipfs/go-log v1.0.4 github.com/klauspost/compress v1.11.7 - github.com/libp2p/go-libp2p-core v0.8.0 + github.com/libp2p/go-libp2p-core v0.8.5 github.com/libp2p/go-libp2p-tls v0.1.3 github.com/libp2p/go-netroute v0.1.3 github.com/lucas-clemente/quic-go v0.21.1 diff --git a/go.sum b/go.sum index b924e0d..5c39641 100644 --- a/go.sum +++ b/go.sum @@ -226,8 +226,8 @@ github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoR github.com/libp2p/go-flow-metrics v0.0.1/go.mod h1:Iv1GH0sG8DtYN3SVJ2eG221wMiNpZxBdp967ls1g+k8= github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= github.com/libp2p/go-libp2p-core v0.0.1/go.mod h1:g/VxnTZ/1ygHxH3dKok7Vno1VfpvGcGip57wjTU4fco= -github.com/libp2p/go-libp2p-core v0.8.0 h1:5K3mT+64qDTKbV3yTdbMCzJ7O6wbNsavAEb8iqBvBcI= -github.com/libp2p/go-libp2p-core v0.8.0/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= +github.com/libp2p/go-libp2p-core v0.8.5 h1:aEgbIcPGsKy6zYcC+5AJivYFedhYa4sW7mIpWpUaLKw= +github.com/libp2p/go-libp2p-core v0.8.5/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-tls v0.1.3 h1:twKMhMu44jQO+HgQK9X8NHO5HkeJu2QbhLzLJpa8oNM= github.com/libp2p/go-libp2p-tls v0.1.3/go.mod h1:wZfuewxOndz5RTnCAxFliGjvYSDA40sKitV4c50uI1M= github.com/libp2p/go-maddr-filter v0.1.0/go.mod h1:VzZhTXkMucEGGEOSKddrwGiOv0tUhgnKqNEmIAz/bPU= diff --git a/listener.go b/listener.go index a61f312..d91a249 100644 --- a/listener.go +++ b/listener.go @@ -12,7 +12,7 @@ import ( p2ptls "github.com/libp2p/go-libp2p-tls" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" ) @@ -74,6 +74,21 @@ func (l *listener) Accept() (tpt.CapableConn, error) { sess.CloseWithError(errorCodeConnectionGating, "connection gated") continue } + + // return through active hole punching if any + key := holePunchKey{addr: sess.RemoteAddr().String(), peer: conn.remotePeerID} + var wasHolePunch bool + l.transport.holePunchingMx.Lock() + holePunch, ok := l.transport.holePunching[key] + if ok && !holePunch.fulfilled { + holePunch.connCh <- conn + wasHolePunch = true + holePunch.fulfilled = true + } + l.transport.holePunchingMx.Unlock() + if wasHolePunch { + continue + } return conn, nil } } @@ -92,6 +107,7 @@ func (l *listener) setupConn(sess quic.Session) (*conn, error) { if err != nil { return nil, err } + remoteMultiaddr, err := toQuicMultiaddr(sess.RemoteAddr()) if err != nil { return nil, err diff --git a/transport.go b/transport.go index 7e92916..b7d962e 100644 --- a/transport.go +++ b/transport.go @@ -1,20 +1,23 @@ package libp2pquic import ( + "bytes" "context" "errors" "fmt" "io" + "math/rand" "net" - - "github.com/libp2p/go-libp2p-core/connmgr" - n "github.com/libp2p/go-libp2p-core/network" + "sync" + "time" "github.com/minio/sha256-simd" "golang.org/x/crypto/hkdf" logging "github.com/ipfs/go-log" + "github.com/libp2p/go-libp2p-core/connmgr" ic "github.com/libp2p/go-libp2p-core/crypto" + n "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/pnet" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -27,8 +30,12 @@ import ( var log = logging.Logger("quic-transport") +var ErrHolePunching = errors.New("hole punching attempted; no active dial") + var quicDialContext = quic.DialContext // so we can mock it in tests +var HolePunchTimeout = 5 * time.Second + var quicConfig = &quic.Config{ MaxIncomingStreams: 1000, MaxIncomingUniStreams: -1, // disable unidirectional streams @@ -96,10 +103,23 @@ type transport struct { serverConfig *quic.Config clientConfig *quic.Config gater connmgr.ConnectionGater + + holePunchingMx sync.Mutex + holePunching map[holePunchKey]*activeHolePunch } var _ tpt.Transport = &transport{} +type holePunchKey struct { + addr string + peer peer.ID +} + +type activeHolePunch struct { + connCh chan tpt.CapableConn + fulfilled bool +} + // NewTransport creates a new QUIC transport func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (tpt.Transport, error) { if len(psk) > 0 { @@ -138,6 +158,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) ( serverConfig: config, clientConfig: config.Clone(), gater: gater, + holePunching: make(map[holePunchKey]*activeHolePunch), }, nil } @@ -156,6 +177,13 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } tlsConf, keyCh := t.identity.ConfigForPeer(p) + + if simConnect, _ := n.GetSimultaneousConnect(ctx); simConnect { + if bytes.Compare([]byte(t.localPeer), []byte(p)) < 0 { + return t.holePunch(ctx, network, addr, p) + } + } + pconn, err := t.connManager.Dial(network, addr) if err != nil { return nil, err @@ -202,6 +230,82 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return conn, nil } +func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) { + pconn, err := t.connManager.Dial(network, addr) + if err != nil { + return nil, err + } + defer pconn.DecreaseCount() + + ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout) + defer cancel() + + key := holePunchKey{addr: addr.String(), peer: p} + t.holePunchingMx.Lock() + if _, ok := t.holePunching[key]; ok { + t.holePunchingMx.Unlock() + return nil, fmt.Errorf("already punching hole for %s", addr) + } + connCh := make(chan tpt.CapableConn, 1) + t.holePunching[key] = &activeHolePunch{connCh: connCh} + t.holePunchingMx.Unlock() + + var timer *time.Timer + defer func() { + if timer != nil { + timer.Stop() + } + }() + + payload := make([]byte, 64) + var punchErr error +loop: + for i := 0; ; i++ { + if _, err := rand.Read(payload); err != nil { + punchErr = err + break + } + if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil { + punchErr = err + break + } + + maxSleep := 10 * (i + 1) * (i + 1) // in ms + if maxSleep > 200 { + maxSleep = 200 + } + d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond + if timer == nil { + timer = time.NewTimer(d) + } else { + timer.Reset(d) + } + select { + case c := <-connCh: + t.holePunchingMx.Lock() + delete(t.holePunching, key) + t.holePunchingMx.Unlock() + return c, nil + case <-timer.C: + case <-ctx.Done(): + punchErr = ErrHolePunching + break loop + } + } + // we only arrive here if punchErr != nil + t.holePunchingMx.Lock() + defer func() { + delete(t.holePunching, key) + t.holePunchingMx.Unlock() + }() + select { + case c := <-t.holePunching[key].connCh: + return c, nil + default: + return nil, punchErr + } +} + // Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC)) diff --git a/transport_test.go b/transport_test.go index 226ec2f..00685fc 100644 --- a/transport_test.go +++ b/transport_test.go @@ -11,7 +11,7 @@ import ( ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo"