Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

Commit

Permalink
fix race condition when accepting hole-punched connections
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 3, 2021
1 parent 28491d9 commit 62696b5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
20 changes: 8 additions & 12 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,18 @@ func (l *listener) Accept() (tpt.CapableConn, error) {

// return through active hole punching if any
key := sess.RemoteAddr().String()
var wasHolePunch bool
l.transport.holePunchingMx.Lock()
holePunch, ok := l.transport.holePunching[key]
if ok && !holePunch.fulfilled {
holePunch.connCh <- conn
wasHolePunch = true
l.transport.holePunching[key].fulfilled = true
}
l.transport.holePunchingMx.Unlock()
if ok {
select {
case holePunch.connCh <- conn:
// We need to delete the entry from the map here,
// in case we accept two connections from the same address.
l.transport.holePunchingMx.Lock()
delete(l.transport.holePunching, key)
l.transport.holePunchingMx.Unlock()
continue
default:
}
if wasHolePunch {
continue
}

return conn, nil
}
}
Expand Down
53 changes: 37 additions & 16 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ import (
"sync"
"time"

"github.com/libp2p/go-libp2p-core/connmgr"
n "github.com/libp2p/go-libp2p-core/network"

"github.com/minio/sha256-simd"
"golang.org/x/crypto/hkdf"

logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/connmgr"
ic "github.com/libp2p/go-libp2p-core/crypto"
n "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/pnet"
tpt "github.com/libp2p/go-libp2p-core/transport"
Expand Down Expand Up @@ -106,13 +105,14 @@ type transport struct {
gater connmgr.ConnectionGater

holePunchingMx sync.Mutex
holePunching map[string]activeHolePunch
holePunching map[string]*activeHolePunch
}

var _ tpt.Transport = &transport{}

type activeHolePunch struct {
connCh chan tpt.CapableConn
connCh chan tpt.CapableConn
fulfilled bool
}

// NewTransport creates a new QUIC transport
Expand Down Expand Up @@ -153,7 +153,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (
serverConfig: config,
clientConfig: config.Clone(),
gater: gater,
holePunching: make(map[string]activeHolePunch),
holePunching: make(map[string]*activeHolePunch),
}, nil
}

Expand Down Expand Up @@ -235,26 +235,34 @@ func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDP
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout)
defer cancel()

connCh := make(chan tpt.CapableConn)

key := addr.String()
t.holePunchingMx.Lock()
t.holePunching[key] = activeHolePunch{connCh: connCh}
if _, ok := t.holePunching[key]; ok {
t.holePunchingMx.Unlock()
return nil, fmt.Errorf("already punching hole for %s", addr)
}
connCh := make(chan tpt.CapableConn, 1)
t.holePunching[key] = &activeHolePunch{connCh: connCh}
t.holePunchingMx.Unlock()

payload := make([]byte, 64)
var timer *time.Timer
defer func() {
if timer != nil {
timer.Stop()
}
}()

payload := make([]byte, 64)
var punchErr error
loop:
for i := 0; ; i++ {
if _, err := rand.Read(payload); err != nil {
return nil, err
punchErr = err
break
}
if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil {
return nil, err
punchErr = err
break
}

maxSleep := 10 * (i + 1) * (i + 1) // in ms
Expand All @@ -269,15 +277,28 @@ func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDP
}
select {
case c := <-connCh:
return c, nil
case <-timer.C:
case <-ctx.Done():
t.holePunchingMx.Lock()
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
return nil, ErrHolePunching
return c, nil
case <-timer.C:
case <-ctx.Done():
punchErr = ErrHolePunching
break loop
}
}
// we only arrive here if punchErr != nil
t.holePunchingMx.Lock()
defer func() {
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
}()
select {
case c := <-t.holePunching[key].connCh:
return c, nil
default:
return nil, punchErr
}
}

// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic
Expand Down

0 comments on commit 62696b5

Please sign in to comment.