diff --git a/config/config.go b/config/config.go index bee597545b..94e18be27f 100644 --- a/config/config.go +++ b/config/config.go @@ -30,6 +30,7 @@ import ( circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" @@ -78,6 +79,7 @@ type Config struct { PeerKey crypto.PrivKey + QUICReuse []fx.Option Transports []fx.Option Muxers []tptu.StreamMuxer SecurityTransports []Security @@ -239,6 +241,13 @@ func (cfg *Config) addTransports(h host.Host) error { ))) } + fxopts = append(fxopts, fx.Provide(PrivKeyToStatelessResetKey)) + if cfg.QUICReuse != nil { + fxopts = append(fxopts, cfg.QUICReuse...) + } else { + fxopts = append(fxopts, fx.Provide(quicreuse.NewConnManager)) // TODO: close the ConnManager when shutting down the node + } + fxopts = append(fxopts, fx.Invoke( fx.Annotate( func(tpts []transport.Transport) error { diff --git a/config/quic_stateless_reset.go b/config/quic_stateless_reset.go new file mode 100644 index 0000000000..3cbb6970ac --- /dev/null +++ b/config/quic_stateless_reset.go @@ -0,0 +1,27 @@ +package config + +import ( + "crypto/sha256" + "io" + + "golang.org/x/crypto/hkdf" + + "github.com/libp2p/go-libp2p/core/crypto" + + "github.com/lucas-clemente/quic-go" +) + +const statelessResetKeyInfo = "libp2p quic stateless reset key" + +func PrivKeyToStatelessResetKey(key crypto.PrivKey) (quic.StatelessResetKey, error) { + var statelessResetKey quic.StatelessResetKey + keyBytes, err := key.Raw() + if err != nil { + return statelessResetKey, err + } + keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(statelessResetKeyInfo)) + if _, err := io.ReadFull(keyReader, statelessResetKey[:]); err != nil { + return statelessResetKey, err + } + return statelessResetKey, nil +} diff --git a/go.mod b/go.mod index fc9c2b8640..f26865a3c4 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/libp2p/zeroconf/v2 v2.2.0 github.com/lucas-clemente/quic-go v0.31.0 github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd - github.com/marten-seemann/webtransport-go v0.3.0 + github.com/marten-seemann/webtransport-go v0.4.0 github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/minio/sha256-simd v1.0.0 github.com/mr-tron/base58 v1.2.0 diff --git a/go.sum b/go.sum index 37a7db8fef..51e24299fa 100644 --- a/go.sum +++ b/go.sum @@ -333,8 +333,8 @@ github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sN github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd/go.mod h1:QuCEs1Nt24+FYQEqAAncTDPJIuGs+LxK1MCiFL25pMU= -github.com/marten-seemann/webtransport-go v0.3.0 h1:TqUSf7/qZN8bJyuGrDMz9nDrfMbgH8p7KqV3TYrkBgo= -github.com/marten-seemann/webtransport-go v0.3.0/go.mod h1:4xcfySgZMLP4aG5GBGj1egP7NlpfwgYJ1WJMvPPiVMU= +github.com/marten-seemann/webtransport-go v0.4.0 h1:seNdLfPIEQCZFrWlSF/o8jfx2DBib08lSyt95iC0jhs= +github.com/marten-seemann/webtransport-go v0.4.0/go.mod h1:4xcfySgZMLP4aG5GBGj1egP7NlpfwgYJ1WJMvPPiVMU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/libp2p_test.go b/libp2p_test.go index 9a469bd1a5..c5f667610c 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -173,7 +173,7 @@ func TestChainOptions(t *testing.T) { func TestTransportConstructorTCP(t *testing.T) { h, err := New( - Transport(tcp.NewTCPTransport), + Transport(tcp.NewTCPTransport, tcp.DisableReuseport()), DisableRelay(), ) require.NoError(t, err) @@ -186,7 +186,7 @@ func TestTransportConstructorTCP(t *testing.T) { func TestTransportConstructorQUIC(t *testing.T) { h, err := New( - Transport(quic.NewTransport, quic.DisableReuseport()), + Transport(quic.NewTransport), DisableRelay(), ) require.NoError(t, err) @@ -248,7 +248,7 @@ func TestTransportConstructorWithWrongOpts(t *testing.T) { Transport(quic.NewTransport, tcp.DisableReuseport()), DisableRelay(), ) - require.EqualError(t, err, "transport option of type tcp.Option not assignable to libp2pquic.Option") + require.EqualError(t, err, "transport constructor doesn't take any options") } func TestSecurityConstructor(t *testing.T) { diff --git a/options.go b/options.go index 1da2a70d14..381706f7d2 100644 --- a/options.go +++ b/options.go @@ -25,6 +25,7 @@ import ( tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" @@ -96,6 +97,33 @@ func Muxer(name string, muxer network.Multiplexer) Option { } } +func QUICReuse(constructor interface{}, opts ...quicreuse.Option) Option { + return func(cfg *Config) error { + tag := `group:"quicreuseopts"` + typ := reflect.ValueOf(constructor).Type() + numParams := typ.NumIn() + isVariadic := typ.IsVariadic() + + if !isVariadic && len(opts) > 0 { + return errors.New("QUICReuse constructor doesn't take any options") + } + + var params []string + if isVariadic && len(opts) > 0 { + // If there are options, apply the tag. + // Since options are variadic, they have to be the last argument of the constructor. + params = make([]string, numParams) + params[len(params)-1] = tag + } + + cfg.QUICReuse = append(cfg.QUICReuse, fx.Provide(fx.Annotate(constructor, fx.ParamTags(params...)))) + for _, opt := range opts { + cfg.QUICReuse = append(cfg.QUICReuse, fx.Supply(fx.Annotate(opt, fx.ResultTags(tag)))) + } + return nil + } +} + // Transport configures libp2p to use the given transport (or transport // constructor). // diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 346550c1ea..3ec0d9cee5 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -19,6 +19,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/muxer/yamux" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" ma "github.com/multiformats/go-multiaddr" @@ -58,7 +59,11 @@ func makeSwarm(t *testing.T) *Swarm { t.Fatal(err) } - quicTransport, err := quic.NewTransport(priv, nil, nil, nil) + reuse, err := quicreuse.NewConnManager([32]byte{}) + if err != nil { + t.Fatal(err) + } + quicTransport, err := quic.NewTransport(priv, reuse, nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index 4809c4c01e..42efcfef82 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -5,15 +5,15 @@ import ( "fmt" "testing" - "github.com/libp2p/go-libp2p/core/peer" - ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/p2p/net/swarm" swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" @@ -80,10 +80,13 @@ func TestDialAddressSelection(t *testing.T) { tcpTr, err := tcp.NewTCPTransport(nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(tcpTr)) - quicTr, err := quic.NewTransport(priv, nil, nil, nil) + reuse, err := quicreuse.NewConnManager([32]byte{}) + require.NoError(t, err) + defer reuse.Close() + quicTr, err := quic.NewTransport(priv, reuse, nil, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(quicTr)) - webtransportTr, err := webtransport.New(priv, nil, nil) + webtransportTr, err := webtransport.New(priv, reuse, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(webtransportTr)) h := sha256.Sum256([]byte("foo")) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 08fe7ad0f3..ea062cfd68 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/control" "github.com/libp2p/go-libp2p/core/crypto" @@ -160,7 +162,11 @@ func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { } } if !cfg.disableQUIC { - quicTransport, err := quic.NewTransport(priv, nil, cfg.connectionGater, nil) + reuse, err := quicreuse.NewConnManager([32]byte{}) + if err != nil { + t.Fatal(err) + } + quicTransport, err := quic.NewTransport(priv, reuse, nil, cfg.connectionGater, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/test/quic/quic_test.go b/p2p/test/quic/quic_test.go new file mode 100644 index 0000000000..a0263833f8 --- /dev/null +++ b/p2p/test/quic/quic_test.go @@ -0,0 +1,169 @@ +package quic_test + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getQUICMultiaddrCode(addr ma.Multiaddr) int { + if _, err := addr.ValueForProtocol(ma.P_QUIC); err == nil { + return ma.P_QUIC + } + if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return ma.P_QUIC_V1 + } + return 0 +} + +func TestQUICVersions(t *testing.T) { + h1, err := libp2p.New( + libp2p.Transport(libp2pquic.NewTransport), + libp2p.Transport(webtransport.New), + libp2p.ListenAddrStrings( + "/ip4/127.0.0.1/udp/12345/quic", // QUIC draft-29 + "/ip4/127.0.0.1/udp/12345/quic-v1", // QUIC v1 + ), + ) + require.NoError(t, err) + defer h1.Close() + + addrs := h1.Addrs() + require.Len(t, addrs, 2) + var quicDraft29Addr, quicV1Addr ma.Multiaddr + for _, addr := range addrs { + switch getQUICMultiaddrCode(addr) { + case ma.P_QUIC: + quicDraft29Addr = addr + case ma.P_QUIC_V1: + quicV1Addr = addr + } + } + require.NotNil(t, quicDraft29Addr, "expected to be listening on a QUIC draft-29 address") + require.NotNil(t, quicV1Addr, "expected to be listening on a QUIC v1 address") + + // connect using QUIC draft-29 + h2, err := libp2p.New( + libp2p.Transport(libp2pquic.NewTransport), + ) + require.NoError(t, err) + require.NoError(t, h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{quicDraft29Addr}})) + conns := h2.Network().ConnsToPeer(h1.ID()) + require.Len(t, conns, 1) + require.Equal(t, ma.P_QUIC, getQUICMultiaddrCode(conns[0].LocalMultiaddr())) + require.Equal(t, ma.P_QUIC, getQUICMultiaddrCode(conns[0].RemoteMultiaddr())) + h2.Close() + + // connect using QUIC v1 + h3, err := libp2p.New( + libp2p.Transport(libp2pquic.NewTransport), + ) + require.NoError(t, err) + require.NoError(t, h3.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{quicV1Addr}})) + conns = h3.Network().ConnsToPeer(h1.ID()) + require.Len(t, conns, 1) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].LocalMultiaddr())) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].RemoteMultiaddr())) + h3.Close() +} + +func TestDisableQUICDraft29(t *testing.T) { + h1, err := libp2p.New( + libp2p.QUICReuse(quicreuse.NewConnManager, quicreuse.DisableDraft29()), + libp2p.Transport(libp2pquic.NewTransport), + libp2p.Transport(webtransport.New), + libp2p.ListenAddrStrings( + "/ip4/127.0.0.1/udp/12346/quic", // QUIC draft-29 + "/ip4/127.0.0.1/udp/12346/quic-v1", // QUIC v1 + ), + ) + require.NoError(t, err) + defer h1.Close() + + addrs := h1.Addrs() + require.Len(t, addrs, 1) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(addrs[0])) + + // connect using QUIC draft-29 + h2, err := libp2p.New( + libp2p.Transport(libp2pquic.NewTransport), + ) + require.NoError(t, err) + defer h2.Close() + require.ErrorContains(t, + h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/udp/12346/quic")}}), + "no compatible QUIC version found", + ) + // make sure that dialing QUIC v1 works + require.NoError(t, h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{addrs[0]}})) +} + +func TestQUICAndWebTransport(t *testing.T) { + h1, err := libp2p.New( + libp2p.QUICReuse(quicreuse.NewConnManager, quicreuse.DisableDraft29()), + libp2p.Transport(libp2pquic.NewTransport), + libp2p.Transport(webtransport.New), + libp2p.ListenAddrStrings( + "/ip4/127.0.0.1/udp/12347/quic-v1", + "/ip4/127.0.0.1/udp/12347/quic-v1/webtransport", + ), + ) + require.NoError(t, err) + defer h1.Close() + + addrs := h1.Addrs() + require.Len(t, addrs, 2) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(addrs[0])) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(addrs[1])) + var quicAddr, webtransportAddr ma.Multiaddr + for _, addr := range addrs { + if _, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT); err == nil { + webtransportAddr = addr + } else { + quicAddr = addr + } + } + require.NotNil(t, webtransportAddr, "expected to have a WebTransport address") + require.NotNil(t, quicAddr, "expected to have a QUIC v1 address") + + h2, err := libp2p.New( + libp2p.Transport(libp2pquic.NewTransport), + libp2p.NoListenAddrs, + ) + require.NoError(t, err) + require.NoError(t, h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})) + for _, conns := range [][]network.Conn{h2.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h2.ID())} { + require.Len(t, conns, 1) + if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err == nil { + t.Fatalf("expected a QUIC connection, got a WebTransport connection (%s <-> %s)", conns[0].LocalMultiaddr(), conns[0].RemoteMultiaddr()) + } + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].LocalMultiaddr())) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].RemoteMultiaddr())) + } + h2.Close() + + h3, err := libp2p.New( + libp2p.Transport(webtransport.New), + libp2p.NoListenAddrs, + ) + require.NoError(t, err) + require.NoError(t, h3.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})) + for _, conns := range [][]network.Conn{h3.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h3.ID())} { + require.Len(t, conns, 1) + if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err != nil { + t.Fatalf("expected a WebTransport connection, got a QUIC connection (%s <-> %s)", conns[0].LocalMultiaddr(), conns[0].RemoteMultiaddr()) + } + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].LocalMultiaddr())) + require.Equal(t, ma.P_QUIC_V1, getQUICMultiaddrCode(conns[0].RemoteMultiaddr())) + } + h3.Close() +} diff --git a/p2p/transport/quic/cmd/client/main.go b/p2p/transport/quic/cmd/client/main.go index f8071b6e6d..f33d65ecd4 100644 --- a/p2p/transport/quic/cmd/client/main.go +++ b/p2p/transport/quic/cmd/client/main.go @@ -8,6 +8,8 @@ import ( "log" "os" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic" @@ -39,7 +41,11 @@ func run(raddr string, p string) error { return err } - t, err := libp2pquic.NewTransport(priv, nil, nil, nil) + reuse, err := quicreuse.NewConnManager([32]byte{}) + if err != nil { + return err + } + t, err := libp2pquic.NewTransport(priv, reuse, nil, nil, nil) if err != nil { return err } diff --git a/p2p/transport/quic/cmd/server/main.go b/p2p/transport/quic/cmd/server/main.go index 7122de6058..a939c4725f 100644 --- a/p2p/transport/quic/cmd/server/main.go +++ b/p2p/transport/quic/cmd/server/main.go @@ -7,6 +7,8 @@ import ( "log" "os" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" @@ -39,7 +41,11 @@ func run(port string) error { return err } - t, err := libp2pquic.NewTransport(priv, nil, nil, nil) + reuse, err := quicreuse.NewConnManager([32]byte{}) + if err != nil { + return err + } + t, err := libp2pquic.NewTransport(priv, reuse, nil, nil, nil) if err != nil { return err } diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index cc74b68cb2..999615ceb8 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -2,7 +2,6 @@ package libp2pquic import ( "context" - "net" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -13,17 +12,8 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -type pConn interface { - net.PacketConn - - // count conn reference - DecreaseCount() - IncreaseCount() -} - type conn struct { quicConn quic.Connection - pconn pConn transport *transport scope network.ConnManagementScope @@ -44,7 +34,6 @@ var _ tpt.CapableConn = &conn{} func (c *conn) Close() error { c.transport.removeConn(c.quicConn) err := c.quicConn.CloseWithError(0, "") - c.pconn.DecreaseCount() c.scope.Done() return err } diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index 42d56293e3..770377869a 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" @@ -30,12 +32,12 @@ import ( type connTestCase struct { Name string - Options []Option + Options []quicreuse.Option } var connTestCases = []*connTestCase{ - {"reuseport_on", []Option{DisableDraft29()}}, - {"reuseport_off", []Option{DisableReuseport(), DisableDraft29()}}, + {"reuseport_on", []quicreuse.Option{quicreuse.DisableDraft29()}}, + {"reuseport_off", []quicreuse.Option{quicreuse.DisableReuseport(), quicreuse.DisableDraft29()}}, } func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { @@ -66,6 +68,14 @@ func runServer(t *testing.T, tr tpt.Transport, addr string) tpt.Listener { return ln } +func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManager { + t.Helper() + cm, err := quicreuse.NewConnManager([32]byte{}, opts...) + require.NoError(t, err) + t.Cleanup(func() { cm.Close() }) + return cm +} + func TestHandshake(t *testing.T) { for _, tc := range connTestCases { t.Run(tc.Name, func(t *testing.T) { @@ -77,12 +87,13 @@ func TestHandshake(t *testing.T) { func testHandshake(t *testing.T, tc *connTestCase) { serverID, serverKey := createPeer(t) clientID, clientKey := createPeer(t) - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() handshake := func(t *testing.T, ln tpt.Listener) { - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) @@ -132,7 +143,7 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { defer ctrl.Finish() serverRcmgr := mocknetwork.NewMockResourceManager(ctrl) - serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")) @@ -140,7 +151,7 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { defer ln.Close() clientRcmgr := mocknetwork.NewMockResourceManager(ctrl) - clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr) require.NoError(t, err) defer clientTransport.(io.Closer).Close() @@ -181,7 +192,7 @@ func testResourceManagerDialDenied(t *testing.T, tc *connTestCase) { defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - clientTransport, err := NewTransport(clientKey, nil, nil, rcmgr, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, rcmgr) require.NoError(t, err) defer clientTransport.(io.Closer).Close() @@ -214,7 +225,7 @@ func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) { defer ctrl.Finish() clientRcmgr := mocknetwork.NewMockResourceManager(ctrl) - clientTransport, err := NewTransport(clientKey, nil, nil, clientRcmgr, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, clientRcmgr) require.NoError(t, err) defer clientTransport.(io.Closer).Close() @@ -226,7 +237,7 @@ func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) { serverConnScope.EXPECT().SetPeer(clientID).Return(rerr), serverConnScope.EXPECT().Done(), ) - serverTransport, err := NewTransport(serverKey, nil, nil, serverRcmgr, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, serverRcmgr) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln, err := serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")) @@ -269,13 +280,13 @@ func testStreams(t *testing.T, tc *connTestCase) { serverID, serverKey := createPeer(t) _, clientKey := createPeer(t) - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) @@ -310,12 +321,12 @@ func testHandshakeFailPeerIDMismatch(t *testing.T, tc *connTestCase) { _, clientKey := createPeer(t) thirdPartyID, _ := createPeer(t) - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) // dial, but expect the wrong peer ID _, err = clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], thirdPartyID) @@ -356,7 +367,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { cg := NewMockConnectionGater(mockCtrl) t.Run("accepted connections", func(t *testing.T) { - serverTransport, err := NewTransport(serverKey, nil, cg, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, cg, nil) defer serverTransport.(io.Closer).Close() require.NoError(t, err) ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") @@ -371,7 +382,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { require.NoError(t, err) }() - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails @@ -386,7 +397,6 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { // now allow the address and make sure the connection goes through cg.EXPECT().InterceptAccept(gomock.Any()).Return(true) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second conn, err = clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) require.NoError(t, err) defer conn.Close() @@ -401,7 +411,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { }) t.Run("secured connections", func(t *testing.T) { - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") @@ -410,7 +420,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { cg := NewMockConnectionGater(mockCtrl) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()) - clientTransport, err := NewTransport(clientKey, nil, cg, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, cg, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() @@ -421,7 +431,6 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { // now allow the peerId and make sure the connection goes through cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) require.NoError(t, err) conn.Close() @@ -441,12 +450,12 @@ func testDialTwo(t *testing.T, tc *connTestCase) { _, clientKey := createPeer(t) serverID2, serverKey2 := createPeer(t) - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") defer ln1.Close() - serverTransport2, err := NewTransport(serverKey2, nil, nil, nil, tc.Options...) + serverTransport2, err := NewTransport(serverKey2, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport2.(io.Closer).Close() ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic-v1") @@ -472,7 +481,7 @@ func testDialTwo(t *testing.T, tc *connTestCase) { } }() - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddrs()[0], serverID) @@ -517,41 +526,28 @@ func TestStatelessReset(t *testing.T) { } func testStatelessReset(t *testing.T, tc *connTestCase) { - origGarbageCollectInterval := garbageCollectInterval - origMaxUnusedDuration := maxUnusedDuration - - garbageCollectInterval = 50 * time.Millisecond - maxUnusedDuration = 0 - - t.Cleanup(func() { - garbageCollectInterval = origGarbageCollectInterval - maxUnusedDuration = origMaxUnusedDuration - }) - serverID, serverKey := createPeer(t) _, clientKey := createPeer(t) - serverTransport, err := NewTransport(serverKey, nil, nil, nil, tc.Options...) + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer serverTransport.(io.Closer).Close() ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") var drop uint32 - serverPort := ln.Addr().(*net.UDPAddr).Port + dropCallback := func(quicproxy.Direction, []byte) bool { return atomic.LoadUint32(&drop) > 0 } proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DropPacket: func(quicproxy.Direction, []byte) bool { - return atomic.LoadUint32(&drop) > 0 - }, + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DropPacket: dropCallback, }) require.NoError(t, err) - defer proxy.Close() + proxyLocalAddr := proxy.LocalAddr() // establish a connection - clientTransport, err := NewTransport(clientKey, nil, nil, nil, tc.Options...) + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() - proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr(), quic.Version1) + proxyAddr, err := quicreuse.ToQuicMultiaddr(proxy.LocalAddr(), quic.Version1) require.NoError(t, err) conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID) require.NoError(t, err) @@ -577,18 +573,23 @@ func testStatelessReset(t *testing.T, tc *connTestCase) { atomic.StoreUint32(&drop, 1) ln.Close() (<-connChan).Close() + proxy.Close() - // The kernel might take a while to free up the UPD port. - // Retry starting the listener until we're successful (with a 3s timeout). - require.Eventually(t, func() bool { - var err error - ln, err = serverTransport.Listen(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", serverPort))) - return err == nil - }, 3*time.Second, 50*time.Millisecond) + // Start another listener (on a different port). + ln, err = serverTransport.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1")) + require.NoError(t, err) defer ln.Close() // Now that the new server is up, re-enable packet forwarding. atomic.StoreUint32(&drop, 0) + // Recreate the proxy, such that its client-facing port stays constant. + proxy, err = quicproxy.NewQuicProxy(proxyLocalAddr.String(), &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DropPacket: dropCallback, + }) + require.NoError(t, err) + defer proxy.Close() + // Trigger something (not too small) to be sent, so that we receive the stateless reset. // The new server doesn't have any state for the previously established connection. // We expect it to send a stateless reset. @@ -597,7 +598,8 @@ func testStatelessReset(t *testing.T, tc *connTestCase) { _, rerr = str.Read([]byte{0, 0}) } require.Error(t, rerr) - require.Contains(t, rerr.Error(), "received a stateless reset") + var statelessResetErr *quic.StatelessResetError + require.ErrorAs(t, rerr, &statelessResetErr) } // Hole punching is only expected to work with reuseport enabled. @@ -606,7 +608,7 @@ func TestHolePunching(t *testing.T) { serverID, serverKey := createPeer(t) clientID, clientKey := createPeer(t) - t1, err := NewTransport(serverKey, nil, nil, nil) + t1, err := NewTransport(serverKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) defer t1.(io.Closer).Close() laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1") @@ -620,7 +622,7 @@ func TestHolePunching(t *testing.T) { require.Error(t, err, "didn't expect to accept any connections") }() - t2, err := NewTransport(clientKey, nil, nil, nil) + t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) defer t2.(io.Closer).Close() ln2, err := t2.Listen(laddr) @@ -679,7 +681,7 @@ func TestHolePunching(t *testing.T) { func TestGetErrorWhenListeningWithDraft29WhenDisabled(t *testing.T) { _, serverKey := createPeer(t) - t1, err := NewTransport(serverKey, nil, nil, nil, DisableDraft29()) + t1, err := NewTransport(serverKey, newConnManager(t, quicreuse.DisableDraft29()), nil, nil, nil) require.NoError(t, err) defer t1.(io.Closer).Close() laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") @@ -710,12 +712,12 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { serverID, serverKey := createPeer(t) _, clientKey := createPeer(t) - var serverOpts []Option + var serverOpts []quicreuse.Option if tc.serverDisablesDraft29 { - serverOpts = append(serverOpts, DisableDraft29()) + serverOpts = append(serverOpts, quicreuse.DisableDraft29()) } - t1, err := NewTransport(serverKey, nil, nil, nil, serverOpts...) + t1, err := NewTransport(serverKey, newConnManager(t, serverOpts...), nil, nil, nil) require.NoError(t, err) defer t1.(io.Closer).Close() laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1") @@ -723,14 +725,14 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { ln1, err := t1.Listen(laddr) require.NoError(t, err) - t2, err := NewTransport(clientKey, nil, nil, nil) + t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) defer t2.(io.Closer).Close() ctx := context.Background() for _, a := range ln1.Multiaddrs() { - _, v, err := fromQuicMultiaddr(a) + _, v, err := quicreuse.FromQuicMultiaddr(a) require.NoError(t, err) done := make(chan struct{}) @@ -740,9 +742,9 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, versionConnLocal, err := fromQuicMultiaddr(conn.LocalMultiaddr()) + _, versionConnLocal, err := quicreuse.FromQuicMultiaddr(conn.LocalMultiaddr()) require.NoError(t, err) - _, versionConnRemote, err := fromQuicMultiaddr(conn.RemoteMultiaddr()) + _, versionConnRemote, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr()) require.NoError(t, err) require.Equal(t, v, versionConnLocal) @@ -751,9 +753,9 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { conn, err := t2.Dial(ctx, a, serverID) require.NoError(t, err) - _, versionConnLocal, err := fromQuicMultiaddr(conn.LocalMultiaddr()) + _, versionConnLocal, err := quicreuse.FromQuicMultiaddr(conn.LocalMultiaddr()) require.NoError(t, err) - _, versionConnRemote, err := fromQuicMultiaddr(conn.RemoteMultiaddr()) + _, versionConnRemote, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr()) require.NoError(t, err) require.Equal(t, v, versionConnLocal) diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index 3bdba01762..ec7929fc94 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -2,7 +2,6 @@ package libp2pquic import ( "context" - "crypto/tls" "errors" "net" @@ -11,17 +10,15 @@ import ( "github.com/libp2p/go-libp2p/core/peer" tpt "github.com/libp2p/go-libp2p/core/transport" p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" ) -var quicListen = quic.Listen // so we can mock it in tests - // A listener listens for QUIC connections. type listener struct { - quicListener quic.Listener - conn pConn + reuseListener quicreuse.Listener transport *transport rcmgr network.ResourceManager privKey ic.PrivKey @@ -31,38 +28,19 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(pconn pConn, t *transport, localPeer peer.ID, key ic.PrivKey, identity *p2ptls.Identity, rcmgr network.ResourceManager, enableDraft29 bool) (tpt.Listener, error) { - var tlsConf tls.Config - tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { - // return a tls.Config that verifies the peer's certificate chain. - // Note that since we have no way of associating an incoming QUIC connection with - // the peer ID calculated here, we don't actually receive the peer's public key - // from the key chan. - conf, _ := identity.ConfigForPeer("") - return conf, nil - } - ln, err := quicListen(pconn, &tlsConf, t.serverConfig) - if err != nil { - return nil, err - } - localMultiaddr, err := toQuicMultiaddr(ln.Addr(), quic.Version1) - if err != nil { - return nil, err - } - - localMultiaddrs := map[quic.VersionNumber]ma.Multiaddr{quic.Version1: localMultiaddr} - - if enableDraft29 { - localMultiaddr, err := toQuicMultiaddr(ln.Addr(), quic.VersionDraft29) - if err != nil { - return nil, err +func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Listener, error) { + localMultiaddrs := make(map[quic.VersionNumber]ma.Multiaddr) + for _, addr := range ln.Multiaddrs() { + if _, err := addr.ValueForProtocol(ma.P_QUIC); err == nil { + localMultiaddrs[quic.VersionDraft29] = addr + } + if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + localMultiaddrs[quic.Version1] = addr } - localMultiaddrs[quic.VersionDraft29] = localMultiaddr } return &listener{ - conn: pconn, - quicListener: ln, + reuseListener: ln, transport: t, rcmgr: rcmgr, privKey: key, @@ -74,13 +52,13 @@ func newListener(pconn pConn, t *transport, localPeer peer.ID, key ic.PrivKey, i // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { for { - qconn, err := l.quicListener.Accept(context.Background()) + qconn, err := l.reuseListener.Accept(context.Background()) if err != nil { return nil, err } c, err := l.setupConn(qconn) if err != nil { - qconn.CloseWithError(0, err.Error()) + qconn.CloseWithError(1, err.Error()) continue } if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) { @@ -109,7 +87,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { - remoteMultiaddr, err := toQuicMultiaddr(qconn.RemoteAddr(), qconn.ConnectionState().Version) + remoteMultiaddr, err := quicreuse.ToQuicMultiaddr(qconn.RemoteAddr(), qconn.ConnectionState().Version) if err != nil { return nil, err } @@ -144,10 +122,8 @@ func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { return nil, errors.New("unknown QUIC version:" + qconn.ConnectionState().Version.String()) } - l.conn.IncreaseCount() return &conn{ quicConn: qconn, - pconn: l.conn, transport: l.transport, scope: connScope, localPeer: l.localPeer, @@ -161,23 +137,12 @@ func (l *listener) setupConn(qconn quic.Connection) (*conn, error) { // Close closes the listener. func (l *listener) Close() error { - defer l.conn.DecreaseCount() - - if err := l.quicListener.Close(); err != nil { - return err - } - - if _, ok := l.conn.(*noreuseConn); ok { - // if we use a `noreuseConn`, close the underlying connection - return l.conn.Close() - } - - return nil + return l.reuseListener.Close() } // Addr returns the address of this listener. func (l *listener) Addr() net.Addr { - return l.quicListener.Addr() + return l.reuseListener.Addr() } // Multiaddr returns the multiaddress of this listener. diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index a43ad2fed2..d45f065e18 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -3,9 +3,7 @@ package libp2pquic import ( "crypto/rand" "crypto/rsa" - "crypto/tls" "crypto/x509" - "errors" "fmt" "io" "net" @@ -16,7 +14,6 @@ import ( "github.com/libp2p/go-libp2p/core/network" tpt "github.com/libp2p/go-libp2p/core/transport" - "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -26,35 +23,11 @@ func newTransport(t *testing.T, rcmgr network.ResourceManager) tpt.Transport { require.NoError(t, err) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) require.NoError(t, err) - tr, err := NewTransport(key, nil, nil, rcmgr) + tr, err := NewTransport(key, newConnManager(t), nil, nil, rcmgr) require.NoError(t, err) return tr } -// The conn passed to quic-go should be a conn that quic-go can be -// type-asserted to a UDPConn. That way, it can use all kinds of optimizations. -func TestConnUsedForListening(t *testing.T) { - origQuicListen := quicListen - t.Cleanup(func() { quicListen = origQuicListen }) - - var conn net.PacketConn - quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { - conn = c - return nil, errors.New("listen error") - } - localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - require.NoError(t, err) - - tr := newTransport(t, nil) - defer tr.(io.Closer).Close() - _, err = tr.Listen(localAddr) - require.EqualError(t, err, "listen error") - require.NotNil(t, conn) - defer conn.Close() - _, ok := conn.(quic.OOBCapablePacketConn) - require.True(t, ok) -} - func TestListenAddr(t *testing.T) { tr := newTransport(t, nil) defer tr.(io.Closer).Close() diff --git a/p2p/transport/quic/options.go b/p2p/transport/quic/options.go deleted file mode 100644 index 8a7c8ce399..0000000000 --- a/p2p/transport/quic/options.go +++ /dev/null @@ -1,44 +0,0 @@ -package libp2pquic - -type Option func(opts *config) error - -type config struct { - disableReuseport bool - disableDraft29 bool - metrics bool -} - -func (cfg *config) apply(opts ...Option) error { - for _, opt := range opts { - if err := opt(cfg); err != nil { - return err - } - } - - return nil -} - -func DisableReuseport() Option { - return func(cfg *config) error { - cfg.disableReuseport = true - return nil - } -} - -// DisableDraft29 disables support for QUIC draft-29. -// This option should be set, unless support for this legacy QUIC version is needed for backwards compatibility. -// Support for QUIC draft-29 is already deprecated and will be removed in the future, see https://github.com/libp2p/go-libp2p/issues/1841. -func DisableDraft29() Option { - return func(cfg *config) error { - cfg.disableDraft29 = true - return nil - } -} - -// WithMetrics enables Prometheus metrics collection. -func WithMetrics() Option { - return func(cfg *config) error { - cfg.metrics = true - return nil - } -} diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 7f598bf3c9..0c34d43640 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -2,15 +2,15 @@ package libp2pquic import ( "context" + "crypto/tls" "errors" "fmt" - "io" "math/rand" "net" "sync" "time" - "golang.org/x/crypto/hkdf" + manet "github.com/multiformats/go-multiaddr/net" "github.com/libp2p/go-libp2p/core/connmgr" ic "github.com/libp2p/go-libp2p/core/crypto" @@ -19,138 +19,35 @@ import ( "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" - "github.com/libp2p/go-libp2p/p2p/transport/internal/quicutils" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" - manet "github.com/multiformats/go-multiaddr/net" logging "github.com/ipfs/go-log/v2" "github.com/lucas-clemente/quic-go" - quiclogging "github.com/lucas-clemente/quic-go/logging" - "github.com/minio/sha256-simd" ) var log = logging.Logger("quic-transport") var ErrHolePunching = errors.New("hole punching attempted; no active dial") -var quicDialContext = quic.DialContext // so we can mock it in tests - var HolePunchTimeout = 5 * time.Second -var quicConfig = &quic.Config{ - MaxIncomingStreams: 256, - MaxIncomingUniStreams: -1, // disable unidirectional streams - MaxStreamReceiveWindow: 10 * (1 << 20), // 10 MB - MaxConnectionReceiveWindow: 15 * (1 << 20), // 15 MB - RequireAddressValidation: func(net.Addr) bool { - // TODO(#1535): require source address validation when under load - return false - }, - KeepAlivePeriod: 15 * time.Second, - Versions: []quic.VersionNumber{quic.VersionDraft29, quic.Version1}, -} - -const statelessResetKeyInfo = "libp2p quic stateless reset key" const errorCodeConnectionGating = 0x47415445 // GATE in ASCII -type noreuseConn struct { - *net.UDPConn -} - -func (c *noreuseConn) IncreaseCount() {} -func (c *noreuseConn) DecreaseCount() {} - -type connManager struct { - reuseUDP4 *reuse - reuseUDP6 *reuse - reuseportEnable bool -} - -func newConnManager(reuseport bool) (*connManager, error) { - reuseUDP4 := newReuse() - reuseUDP6 := newReuse() - return &connManager{ - reuseUDP4: reuseUDP4, - reuseUDP6: reuseUDP6, - reuseportEnable: reuseport, - }, nil -} - -func (c *connManager) getReuse(network string) (*reuse, error) { - switch network { - case "udp4": - return c.reuseUDP4, nil - case "udp6": - return c.reuseUDP6, nil - default: - return nil, errors.New("invalid network: must be either udp4 or udp6") - } -} - -func (c *connManager) Listen(network string, laddr *net.UDPAddr) (pConn, error) { - if c.reuseportEnable { - reuse, err := c.getReuse(network) - if err != nil { - return nil, err - } - return reuse.Listen(network, laddr) - } - - conn, err := net.ListenUDP(network, laddr) - if err != nil { - return nil, err - } - return &noreuseConn{conn}, nil -} - -func (c *connManager) Dial(network string, raddr *net.UDPAddr) (pConn, error) { - if c.reuseportEnable { - reuse, err := c.getReuse(network) - if err != nil { - return nil, err - } - return reuse.Dial(network, raddr) - } - - var laddr *net.UDPAddr - switch network { - case "udp4": - laddr = &net.UDPAddr{IP: net.IPv4zero, Port: 0} - case "udp6": - laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0} - } - conn, err := net.ListenUDP(network, laddr) - if err != nil { - return nil, err - } - return &noreuseConn{conn}, nil -} - -func (c *connManager) Close() error { - if err := c.reuseUDP6.Close(); err != nil { - return err - } - return c.reuseUDP4.Close() -} - // The Transport implements the tpt.Transport interface for QUIC connections. type transport struct { - privKey ic.PrivKey - localPeer peer.ID - identity *p2ptls.Identity - connManager *connManager - serverConfig *quic.Config - clientConfig *quic.Config - gater connmgr.ConnectionGater - rcmgr network.ResourceManager + privKey ic.PrivKey + localPeer peer.ID + identity *p2ptls.Identity + connManager *quicreuse.ConnManager + gater connmgr.ConnectionGater + rcmgr network.ResourceManager holePunchingMx sync.Mutex holePunching map[holePunchKey]*activeHolePunch - enableDraft29 bool - connMx sync.Mutex conns map[quic.Connection]*conn } @@ -168,12 +65,7 @@ type activeHolePunch struct { } // NewTransport creates a new QUIC transport -func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { - var cfg config - if err := cfg.apply(opts...); err != nil { - return nil, fmt.Errorf("unable to apply quic-tpt option(s): %w", err) - } - +func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) { if len(psk) > 0 { log.Error("QUIC doesn't support private networks yet.") return nil, errors.New("QUIC doesn't support private networks yet") @@ -186,77 +78,28 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, r if err != nil { return nil, err } - connManager, err := newConnManager(!cfg.disableReuseport) - if err != nil { - return nil, err - } + if rcmgr == nil { rcmgr = &network.NullResourceManager{} } - qconfig := quicConfig.Clone() - if cfg.disableDraft29 { - qconfig.Versions = []quic.VersionNumber{quic.Version1} - } - - keyBytes, err := key.Raw() - if err != nil { - return nil, err - } - keyReader := hkdf.New(sha256.New, keyBytes, nil, []byte(statelessResetKeyInfo)) - var statelessResetKey quic.StatelessResetKey - if _, err := io.ReadFull(keyReader, statelessResetKey[:]); err != nil { - return nil, err - } - qconfig.StatelessResetKey = &statelessResetKey - var tracers []quiclogging.Tracer - if qlogTracer := quicutils.QLOGTracer; qlogTracer != nil { - tracers = append(tracers, qlogTracer) - } - if cfg.metrics { - tracers = append(tracers, &metricsTracer{}) - } - if len(tracers) > 0 { - qconfig.Tracer = quiclogging.NewMultiplexedTracer(tracers...) - } - tr := &transport{ - privKey: key, - localPeer: localPeer, - identity: identity, - connManager: connManager, - gater: gater, - rcmgr: rcmgr, - conns: make(map[quic.Connection]*conn), - holePunching: make(map[holePunchKey]*activeHolePunch), - enableDraft29: !cfg.disableDraft29, - } - qconfig.AllowConnectionWindowIncrease = tr.allowWindowIncrease - tr.serverConfig = qconfig - tr.clientConfig = qconfig.Clone() - return tr, nil + return &transport{ + privKey: key, + localPeer: localPeer, + identity: identity, + connManager: connManager, + gater: gater, + rcmgr: rcmgr, + conns: make(map[quic.Connection]*conn), + holePunching: make(map[holePunchKey]*activeHolePunch), + }, nil } // Dial dials a new QUIC connection func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - _, v, err := fromQuicMultiaddr(raddr) - if err != nil { - return nil, err - } - netw, host, err := manet.DialArgs(raddr) - if err != nil { - return nil, err - } - addr, err := net.ResolveUDPAddr(netw, host) - if err != nil { - return nil, err - } - remoteMultiaddr, err := toQuicMultiaddr(addr, v) - if err != nil { - return nil, err - } tlsConf, keyCh := t.identity.ConfigForPeer(p) if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { - return t.holePunch(ctx, netw, addr, p) + return t.holePunch(ctx, raddr, p) } scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) @@ -269,27 +112,11 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp scope.Done() return nil, err } - pconn, err := t.connManager.Dial(netw, addr) + pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease) if err != nil { return nil, err } - clientConfig := t.clientConfig.Clone() - if v == quic.Version1 { - // The endpoint has explicit support for version 1, so we'll only use that version. - clientConfig.Versions = []quic.VersionNumber{quic.Version1} - } else if v == quic.VersionDraft29 { - clientConfig.Versions = []quic.VersionNumber{quic.VersionDraft29} - } else { - return nil, errors.New("unknown QUIC version") - } - - qconn, err := quicDialContext(ctx, pconn, addr, host, tlsConf, clientConfig) - if err != nil { - scope.Done() - pconn.DecreaseCount() - return nil, err - } // Should be ready by this point, don't block. var remotePubKey ic.PubKey select { @@ -297,19 +124,18 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp default: } if remotePubKey == nil { - pconn.DecreaseCount() + pconn.CloseWithError(1, "") scope.Done() return nil, errors.New("p2p/transport/quic BUG: expected remote pub key to be set") } - localMultiaddr, err := toQuicMultiaddr(pconn.LocalAddr(), v) + localMultiaddr, err := quicreuse.ToQuicMultiaddr(pconn.LocalAddr(), pconn.ConnectionState().Version) if err != nil { - qconn.CloseWithError(0, "") + pconn.CloseWithError(1, "") return nil, err } c := &conn{ - quicConn: qconn, - pconn: pconn, + quicConn: pconn, transport: t, scope: scope, privKey: t.privKey, @@ -317,13 +143,13 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp localMultiaddr: localMultiaddr, remotePubKey: remotePubKey, remotePeerID: p, - remoteMultiaddr: remoteMultiaddr, + remoteMultiaddr: raddr, } if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) { - qconn.CloseWithError(errorCodeConnectionGating, "connection gated") + pconn.CloseWithError(errorCodeConnectionGating, "connection gated") return nil, fmt.Errorf("secured connection gated") } - t.addConn(qconn, c) + t.addConn(pconn, c) return c, nil } @@ -339,7 +165,15 @@ func (t *transport) removeConn(conn quic.Connection) { t.connMx.Unlock() } -func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) { +func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + network, saddr, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + addr, err := net.ResolveUDPAddr(network, saddr) + if err != nil { + return nil, err + } pconn, err := t.connManager.Dial(network, addr) if err != nil { return nil, err @@ -425,35 +259,27 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { // Listen listens for new QUIC connections on the passed multiaddr. func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { - _, v, err := fromQuicMultiaddr(addr) + var tlsConf tls.Config + tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + // return a tls.Config that verifies the peer's certificate chain. + // Note that since we have no way of associating an incoming QUIC connection with + // the peer ID calculated here, we don't actually receive the peer's public key + // from the key chan. + conf, _ := t.identity.ConfigForPeer("") + return conf, nil + } + tlsConf.NextProtos = []string{"libp2p"} + + ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease) if err != nil { return nil, err } - if v == quic.VersionDraft29 && !t.enableDraft29 { - return nil, errors.New("can't listen on `/quic` multiaddr (QUIC draft 29 version) when draft 29 support is disabled") - } - - lnet, host, err := manet.DialArgs(addr) + l, err := newListener(ln, t, t.localPeer, t.privKey, t.rcmgr) if err != nil { + _ = ln.Close() return nil, err } - laddr, err := net.ResolveUDPAddr(lnet, host) - if err != nil { - return nil, err - } - conn, err := t.connManager.Listen(lnet, laddr) - if err != nil { - return nil, err - } - ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity, t.rcmgr, t.enableDraft29) - if err != nil { - if !t.connManager.reuseportEnable { - conn.Close() - } - conn.DecreaseCount() - return nil, err - } - return ln, nil + return l, nil } func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool { @@ -477,10 +303,7 @@ func (t *transport) Proxy() bool { // Protocols returns the set of protocols handled by this transport. func (t *transport) Protocols() []int { - if t.enableDraft29 { - return []int{ma.P_QUIC, ma.P_QUIC_V1} - } - return []int{ma.P_QUIC_V1} + return t.connManager.Protocols() } func (t *transport) String() string { @@ -488,5 +311,5 @@ func (t *transport) String() string { } func (t *transport) Close() error { - return t.connManager.Close() + return nil } diff --git a/p2p/transport/quic/transport_test.go b/p2p/transport/quic/transport_test.go index 0508614a94..3eab8281ec 100644 --- a/p2p/transport/quic/transport_test.go +++ b/p2p/transport/quic/transport_test.go @@ -1,20 +1,15 @@ package libp2pquic import ( - "context" "crypto/rand" "crypto/rsa" - "crypto/tls" "crypto/x509" - "errors" "io" - "net" "testing" ic "github.com/libp2p/go-libp2p/core/crypto" tpt "github.com/libp2p/go-libp2p/core/transport" - "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -25,7 +20,7 @@ func getTransport(t *testing.T) tpt.Transport { require.NoError(t, err) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) require.NoError(t, err) - tr, err := NewTransport(key, nil, nil, nil) + tr, err := NewTransport(key, newConnManager(t), nil, nil, nil) require.NoError(t, err) return tr } @@ -74,28 +69,3 @@ func TestCanDial(t *testing.T) { } } } - -// The connection passed to quic-go needs to be type-assertable to a net.UDPConn, -// in order to enable features like batch processing and ECN. -func TestConnectionPassedToQUIC(t *testing.T) { - tr := getTransport(t) - defer tr.(io.Closer).Close() - - origQuicDialContext := quicDialContext - defer func() { quicDialContext = origQuicDialContext }() - - var conn net.PacketConn - quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Connection, error) { - conn = c - return nil, errors.New("listen error") - } - remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - require.NoError(t, err) - _, err = tr.Dial(context.Background(), remoteAddr, "remote peer id") - require.EqualError(t, err, "listen error") - require.NotNil(t, conn) - defer conn.Close() - if _, ok := conn.(quic.OOBCapablePacketConn); !ok { - t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") - } -} diff --git a/p2p/transport/quicreuse/config.go b/p2p/transport/quicreuse/config.go new file mode 100644 index 0000000000..1bd90821eb --- /dev/null +++ b/p2p/transport/quicreuse/config.go @@ -0,0 +1,23 @@ +package quicreuse + +import ( + "net" + "time" + + "github.com/lucas-clemente/quic-go" +) + +var quicConfig = &quic.Config{ + MaxIncomingStreams: 256, + MaxIncomingUniStreams: 5, // allow some unidirectional streams, in case we speak WebTransport + MaxStreamReceiveWindow: 10 * (1 << 20), // 10 MB + MaxConnectionReceiveWindow: 15 * (1 << 20), // 15 MB + RequireAddressValidation: func(net.Addr) bool { + // TODO(#1535): require source address validation when under load + return false + }, + KeepAlivePeriod: 15 * time.Second, + Versions: []quic.VersionNumber{quic.VersionDraft29, quic.Version1}, + // We don't use datagrams (yet), but this is necessary for WebTransport + EnableDatagrams: true, +} diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go new file mode 100644 index 0000000000..36fdb582c0 --- /dev/null +++ b/p2p/transport/quicreuse/connmgr.go @@ -0,0 +1,234 @@ +package quicreuse + +import ( + "context" + "crypto/tls" + "errors" + "net" + "sync" + + "github.com/lucas-clemente/quic-go" + quiclogging "github.com/lucas-clemente/quic-go/logging" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +var quicDialContext = quic.DialContext // so we can mock it in tests + +type ConnManager struct { + reuseUDP4 *reuse + reuseUDP6 *reuse + enableDraft29 bool + enableReuseport bool + enableMetrics bool + + serverConfig *quic.Config + clientConfig *quic.Config + + connsMu sync.Mutex + conns map[string]connListenerEntry +} + +type connListenerEntry struct { + refCount int + ln *connListener +} + +func NewConnManager(statelessResetKey quic.StatelessResetKey, opts ...Option) (*ConnManager, error) { + cm := &ConnManager{ + enableReuseport: true, + enableDraft29: true, + conns: make(map[string]connListenerEntry), + } + for _, o := range opts { + if err := o(cm); err != nil { + return nil, err + } + } + + quicConf := quicConfig.Clone() + quicConf.StatelessResetKey = &statelessResetKey + + var tracers []quiclogging.Tracer + if qlogTracer != nil { + tracers = append(tracers, qlogTracer) + } + if cm.enableMetrics { + tracers = append(tracers, &metricsTracer{}) + } + if len(tracers) > 0 { + quicConf.Tracer = quiclogging.NewMultiplexedTracer(tracers...) + } + serverConfig := quicConf.Clone() + if !cm.enableDraft29 { + serverConfig.Versions = []quic.VersionNumber{quic.Version1} + } + + cm.clientConfig = quicConf + cm.serverConfig = serverConfig + if cm.enableReuseport { + cm.reuseUDP4 = newReuse() + cm.reuseUDP6 = newReuse() + } + return cm, nil +} + +func (c *ConnManager) getReuse(network string) (*reuse, error) { + switch network { + case "udp4": + return c.reuseUDP4, nil + case "udp6": + return c.reuseUDP6, nil + default: + return nil, errors.New("invalid network: must be either udp4 or udp6") + } +} + +func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) { + if !c.enableDraft29 { + if _, err := addr.ValueForProtocol(ma.P_QUIC); err == nil { + return nil, errors.New("can't listen on `/quic` multiaddr (QUIC draft 29 version) when draft 29 support is disabled") + } + } + + netw, host, err := manet.DialArgs(addr) + if err != nil { + return nil, err + } + laddr, err := net.ResolveUDPAddr(netw, host) + if err != nil { + return nil, err + } + + c.connsMu.Lock() + defer c.connsMu.Unlock() + + key := laddr.String() + entry, ok := c.conns[key] + if !ok { + conn, err := c.listen(netw, laddr) + if err != nil { + return nil, err + } + ln, err := newConnListener(conn, c.serverConfig, c.enableDraft29) + if err != nil { + return nil, err + } + key = conn.LocalAddr().String() + entry = connListenerEntry{ln: ln} + } + l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) }) + if err != nil { + if entry.refCount <= 0 { + entry.ln.Close() + } + return nil, err + } + entry.refCount++ + c.conns[key] = entry + return l, nil +} + +func (c *ConnManager) onListenerClosed(key string) { + c.connsMu.Lock() + defer c.connsMu.Unlock() + + entry := c.conns[key] + entry.refCount = entry.refCount - 1 + if entry.refCount <= 0 { + delete(c.conns, key) + entry.ln.Close() + } else { + c.conns[key] = entry + } +} + +func (c *ConnManager) listen(network string, laddr *net.UDPAddr) (pConn, error) { + if c.enableReuseport { + reuse, err := c.getReuse(network) + if err != nil { + return nil, err + } + return reuse.Listen(network, laddr) + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + return &noreuseConn{conn}, nil +} + +func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) { + naddr, v, err := FromQuicMultiaddr(raddr) + if err != nil { + return nil, err + } + netw, host, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + + quicConf := c.clientConfig.Clone() + quicConf.AllowConnectionWindowIncrease = allowWindowIncrease + + if v == quic.Version1 { + // The endpoint has explicit support for QUIC v1, so we'll only use that version. + quicConf.Versions = []quic.VersionNumber{quic.Version1} + } else if v == quic.VersionDraft29 { + quicConf.Versions = []quic.VersionNumber{quic.VersionDraft29} + } else { + return nil, errors.New("unknown QUIC version") + } + + pconn, err := c.Dial(netw, naddr) + if err != nil { + return nil, err + } + conn, err := quicDialContext(ctx, pconn, naddr, host, tlsConf, quicConf) + if err != nil { + pconn.DecreaseCount() + return nil, err + } + return conn, nil +} + +func (c *ConnManager) Dial(network string, raddr *net.UDPAddr) (pConn, error) { + if c.enableReuseport { + reuse, err := c.getReuse(network) + if err != nil { + return nil, err + } + return reuse.Dial(network, raddr) + } + + var laddr *net.UDPAddr + switch network { + case "udp4": + laddr = &net.UDPAddr{IP: net.IPv4zero, Port: 0} + case "udp6": + laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0} + } + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + return &noreuseConn{conn}, nil +} + +func (c *ConnManager) Protocols() []int { + if c.enableDraft29 { + return []int{ma.P_QUIC, ma.P_QUIC_V1} + } + return []int{ma.P_QUIC_V1} +} + +func (c *ConnManager) Close() error { + if !c.enableReuseport { + return nil + } + if err := c.reuseUDP6.Close(); err != nil { + return err + } + return c.reuseUDP4.Close() +} diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go new file mode 100644 index 0000000000..a130bb2d1c --- /dev/null +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -0,0 +1,281 @@ +package quicreuse + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" + + "github.com/lucas-clemente/quic-go" + + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" +) + +func checkClosed(t *testing.T, cm *ConnManager) { + for _, r := range []*reuse{cm.reuseUDP4, cm.reuseUDP6} { + if r == nil { + continue + } + r.mutex.Lock() + for _, conn := range r.global { + require.Zero(t, conn.GetCount()) + } + for _, conns := range r.unicast { + for _, conn := range conns { + require.Zero(t, conn.GetCount()) + } + } + r.mutex.Unlock() + } + require.Eventually(t, func() bool { return !isGarbageCollectorRunning() }, 200*time.Millisecond, 10*time.Millisecond) +} + +func TestListenQUICDraft29Disabled(t *testing.T) { + cm, err := NewConnManager([32]byte{}, DisableDraft29(), DisableReuseport()) + require.NoError(t, err) + defer cm.Close() + _, err = cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"), &tls.Config{}, nil) + require.EqualError(t, err, "can't listen on `/quic` multiaddr (QUIC draft 29 version) when draft 29 support is disabled") + ln, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{"proto"}}, nil) + require.NoError(t, err) + require.NoError(t, ln.Close()) + require.False(t, isGarbageCollectorRunning()) +} + +func TestListenOnSameProto(t *testing.T) { + t.Run("with reuseport", func(t *testing.T) { + testListenOnSameProto(t, true) + }) + + t.Run("without reuseport", func(t *testing.T) { + testListenOnSameProto(t, false) + }) +} + +func testListenOnSameProto(t *testing.T, enableReuseport bool) { + var opts []Option + if !enableReuseport { + opts = append(opts, DisableReuseport()) + } + cm, err := NewConnManager([32]byte{}, opts...) + require.NoError(t, err) + defer checkClosed(t, cm) + defer cm.Close() + + const alpn = "proto" + + var tlsConf tls.Config + tlsConf.NextProtos = []string{alpn} + ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil) + require.NoError(t, err) + defer ln1.Close() + + addr := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", ln1.Addr().(*net.UDPAddr).Port)) + _, err = cm.ListenQUIC(addr, &tls.Config{NextProtos: []string{alpn}}, nil) + require.EqualError(t, err, "already listening for protocol "+alpn) + + // listening on a different address works + ln2, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil) + require.NoError(t, err) + defer ln2.Close() +} + +// The conn passed to quic-go should be a conn that quic-go can be +// type-asserted to a UDPConn. That way, it can use all kinds of optimizations. +func TestConnectionPassedToQUICForListening(t *testing.T) { + origQuicListen := quicListen + t.Cleanup(func() { quicListen = origQuicListen }) + + var conn net.PacketConn + quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { + conn = c + return nil, errors.New("listen error") + } + + cm, err := NewConnManager([32]byte{}, DisableReuseport()) + require.NoError(t, err) + defer cm.Close() + + _, err = cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{"proto"}}, nil) + require.EqualError(t, err, "listen error") + require.NotNil(t, conn) + defer conn.Close() + if _, ok := conn.(quic.OOBCapablePacketConn); !ok { + t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") + } +} + +type mockFailAcceptListener struct { + addr net.Addr +} + +// Accept implements quic.Listener +func (l *mockFailAcceptListener) Accept(context.Context) (quic.Connection, error) { + return nil, fmt.Errorf("Some error") +} + +// Addr implements quic.Listener +func (l *mockFailAcceptListener) Addr() net.Addr { + return l.addr +} + +// Close implements quic.Listener +func (l *mockFailAcceptListener) Close() error { + return nil +} + +var _ quic.Listener = &mockFailAcceptListener{} + +func TestAcceptErrorGetCleanedUp(t *testing.T) { + origQuicListen := quicListen + t.Cleanup(func() { quicListen = origQuicListen }) + + quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { + return &mockFailAcceptListener{ + addr: c.LocalAddr(), + }, nil + } + + cm, err := NewConnManager([32]byte{}, DisableReuseport()) + require.NoError(t, err) + defer cm.Close() + + l, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{"proto"}}, nil) + require.NoError(t, err) + defer l.Close() + _, err = l.Accept(context.Background()) + require.EqualError(t, err, "accept goroutine finished") + +} + +// The connection passed to quic-go needs to be type-assertable to a net.UDPConn, +// in order to enable features like batch processing and ECN. +func TestConnectionPassedToQUICForDialing(t *testing.T) { + origQuicDialContext := quicDialContext + defer func() { quicDialContext = origQuicDialContext }() + + var conn net.PacketConn + quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Connection, error) { + conn = c + return nil, errors.New("dial error") + } + + cm, err := NewConnManager([32]byte{}, DisableReuseport()) + require.NoError(t, err) + defer cm.Close() + + _, err = cm.DialQUIC(context.Background(), ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1"), &tls.Config{}, nil) + require.EqualError(t, err, "dial error") + require.NotNil(t, conn) + defer conn.Close() + if _, ok := conn.(quic.OOBCapablePacketConn); !ok { + t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") + } +} + +func getTLSConfForProto(t *testing.T, alpn string) (peer.ID, *tls.Config) { + t.Helper() + priv, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + // We use the libp2p TLS certificate here, just because it's convenient. + identity, err := libp2ptls.NewIdentity(priv) + require.NoError(t, err) + var tlsConf tls.Config + tlsConf.NextProtos = []string{alpn} + tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + c, _ := identity.ConfigForPeer("") + c.NextProtos = tlsConf.NextProtos + return c, nil + } + return id, &tlsConf +} + +func connectWithProtocol(t *testing.T, addr net.Addr, alpn string) (peer.ID, error) { + t.Helper() + clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + require.NoError(t, err) + clientIdentity, err := libp2ptls.NewIdentity(clientKey) + require.NoError(t, err) + tlsConf, peerChan := clientIdentity.ConfigForPeer("") + cconn, err := net.ListenUDP("udp4", nil) + tlsConf.NextProtos = []string{alpn} + require.NoError(t, err) + c, err := quic.Dial(cconn, addr, "localhost", tlsConf, nil) + if err != nil { + return "", err + } + defer c.CloseWithError(0, "") + require.Equal(t, alpn, c.ConnectionState().TLS.NegotiatedProtocol) + serverID, err := peer.IDFromPublicKey(<-peerChan) + require.NoError(t, err) + return serverID, nil +} + +func TestListener(t *testing.T) { + t.Run("with reuseport", func(t *testing.T) { + testListener(t, true) + }) + + t.Run("without reuseport", func(t *testing.T) { + testListener(t, false) + }) +} + +func testListener(t *testing.T, enableReuseport bool) { + var opts []Option + if !enableReuseport { + opts = append(opts, DisableReuseport()) + } + cm, err := NewConnManager([32]byte{}, opts...) + require.NoError(t, err) + + id1, tlsConf1 := getTLSConfForProto(t, "proto1") + ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), tlsConf1, nil) + require.NoError(t, err) + + id2, tlsConf2 := getTLSConfForProto(t, "proto2") + ln2, err := cm.ListenQUIC( + ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", ln1.Addr().(*net.UDPAddr).Port)), + tlsConf2, + nil, + ) + require.NoError(t, err) + require.Equal(t, ln1.Addr(), ln2.Addr()) + + // Test that the right certificate is served. + id, err := connectWithProtocol(t, ln1.Addr(), "proto1") + require.NoError(t, err) + require.Equal(t, id1, id) + id, err = connectWithProtocol(t, ln1.Addr(), "proto2") + require.NoError(t, err) + require.Equal(t, id2, id) + // No such protocol registered. + _, err = connectWithProtocol(t, ln1.Addr(), "proto3") + require.Error(t, err) + + // Now close the first listener to test that it's properly deregistered. + require.NoError(t, ln1.Close()) + _, err = connectWithProtocol(t, ln1.Addr(), "proto1") + require.Error(t, err) + // connecting to the other listener should still be possible + id, err = connectWithProtocol(t, ln1.Addr(), "proto2") + require.NoError(t, err) + require.Equal(t, id2, id) + + ln2.Close() + cm.Close() + + checkClosed(t, cm) +} diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go new file mode 100644 index 0000000000..b71478fd50 --- /dev/null +++ b/p2p/transport/quicreuse/listener.go @@ -0,0 +1,222 @@ +package quicreuse + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "sync" + + "github.com/lucas-clemente/quic-go" + ma "github.com/multiformats/go-multiaddr" +) + +var quicListen = quic.Listen // so we can mock it in tests + +type Listener interface { + Accept(context.Context) (quic.Connection, error) + Addr() net.Addr + Multiaddrs() []ma.Multiaddr + io.Closer +} + +type protoConf struct { + ln *listener + tlsConf *tls.Config + allowWindowIncrease func(conn quic.Connection, delta uint64) bool +} + +type connListener struct { + l quic.Listener + conn pConn + running chan struct{} + addrs []ma.Multiaddr + + protocolsMu sync.Mutex + protocols map[string]protoConf +} + +func newConnListener(c pConn, quicConfig *quic.Config, enableDraft29 bool) (*connListener, error) { + localMultiaddrs := make([]ma.Multiaddr, 0, 2) + a, err := ToQuicMultiaddr(c.LocalAddr(), quic.Version1) + if err != nil { + return nil, err + } + localMultiaddrs = append(localMultiaddrs, a) + if enableDraft29 { + a, err := ToQuicMultiaddr(c.LocalAddr(), quic.VersionDraft29) + if err != nil { + return nil, err + } + localMultiaddrs = append(localMultiaddrs, a) + } + cl := &connListener{ + protocols: map[string]protoConf{}, + running: make(chan struct{}), + conn: c, + addrs: localMultiaddrs, + } + tlsConf := &tls.Config{ + GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { + cl.protocolsMu.Lock() + defer cl.protocolsMu.Unlock() + for _, proto := range info.SupportedProtos { + if entry, ok := cl.protocols[proto]; ok { + conf := entry.tlsConf + if conf.GetConfigForClient != nil { + return conf.GetConfigForClient(info) + } + return conf, nil + } + } + return nil, fmt.Errorf("no supported protocol found. offered: %+v", info.SupportedProtos) + }, + } + quicConf := quicConfig.Clone() + quicConf.AllowConnectionWindowIncrease = cl.allowWindowIncrease + ln, err := quicListen(c, tlsConf, quicConf) + if err != nil { + return nil, err + } + cl.l = ln + go cl.Run() // This go routine shuts down once the underlying quic.Listener is closed (or returns an error). + return cl, nil +} + +func (l *connListener) allowWindowIncrease(conn quic.Connection, delta uint64) bool { + l.protocolsMu.Lock() + defer l.protocolsMu.Unlock() + + conf, ok := l.protocols[conn.ConnectionState().TLS.ConnectionState.NegotiatedProtocol] + if !ok { + return false + } + return conf.allowWindowIncrease(conn, delta) +} + +func (l *connListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool, onRemove func()) (Listener, error) { + l.protocolsMu.Lock() + defer l.protocolsMu.Unlock() + + if len(tlsConf.NextProtos) == 0 { + return nil, errors.New("no ALPN found in tls.Config") + } + + for _, proto := range tlsConf.NextProtos { + if _, ok := l.protocols[proto]; ok { + return nil, fmt.Errorf("already listening for protocol %s", proto) + } + } + + ln := newSingleListener(l.l.Addr(), l.addrs, func() { + l.protocolsMu.Lock() + for _, proto := range tlsConf.NextProtos { + delete(l.protocols, proto) + } + l.protocolsMu.Unlock() + onRemove() + }, l.running) + for _, proto := range tlsConf.NextProtos { + l.protocols[proto] = protoConf{ + ln: ln, + tlsConf: tlsConf, + allowWindowIncrease: allowWindowIncrease, + } + } + return ln, nil +} + +func (l *connListener) Run() error { + defer close(l.running) + defer l.conn.DecreaseCount() + for { + conn, err := l.l.Accept(context.Background()) + if err != nil { + return err + } + proto := conn.ConnectionState().TLS.NegotiatedProtocol + + l.protocolsMu.Lock() + ln, ok := l.protocols[proto] + if !ok { + l.protocolsMu.Unlock() + return fmt.Errorf("negotiated unknown protocol: %s", proto) + } + ln.ln.add(conn) + l.protocolsMu.Unlock() + } +} + +func (l *connListener) Close() error { + err := l.l.Close() + <-l.running // wait for Run to return + return err +} + +const queueLen = 16 + +// A listener for a single ALPN protocol (set). +type listener struct { + queue chan quic.Connection + acceptLoopRunning chan struct{} + addr net.Addr + addrs []ma.Multiaddr + remove func() + closeOnce sync.Once +} + +var _ Listener = &listener{} + +func newSingleListener(addr net.Addr, addrs []ma.Multiaddr, remove func(), running chan struct{}) *listener { + return &listener{ + queue: make(chan quic.Connection, queueLen), + acceptLoopRunning: running, + remove: remove, + addr: addr, + addrs: addrs, + } +} + +func (l *listener) add(c quic.Connection) { + select { + case l.queue <- c: + default: + c.CloseWithError(1, "queue full") + } +} + +func (l *listener) Accept(ctx context.Context) (quic.Connection, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.acceptLoopRunning: + return nil, errors.New("accept goroutine finished") + case c, ok := <-l.queue: + if !ok { + return nil, errors.New("listener closed") + } + return c, nil + } +} + +func (l *listener) Addr() net.Addr { + return l.addr +} + +func (l *listener) Multiaddrs() []ma.Multiaddr { + return l.addrs +} + +func (l *listener) Close() error { + l.closeOnce.Do(func() { + l.remove() + close(l.queue) + // drain the queue + for conn := range l.queue { + conn.CloseWithError(1, "closing") + } + }) + return nil +} diff --git a/p2p/transport/quicreuse/options.go b/p2p/transport/quicreuse/options.go new file mode 100644 index 0000000000..a700a0544d --- /dev/null +++ b/p2p/transport/quicreuse/options.go @@ -0,0 +1,28 @@ +package quicreuse + +type Option func(*ConnManager) error + +func DisableReuseport() Option { + return func(m *ConnManager) error { + m.enableReuseport = false + return nil + } +} + +// DisableDraft29 disables support for QUIC draft-29. +// This option should be set, unless support for this legacy QUIC version is needed for backwards compatibility. +// Support for QUIC draft-29 is already deprecated and will be removed in the future, see https://github.com/libp2p/go-libp2p/issues/1841. +func DisableDraft29() Option { + return func(m *ConnManager) error { + m.enableDraft29 = false + return nil + } +} + +// EnableMetrics enables Prometheus metrics collection. +func EnableMetrics() Option { + return func(m *ConnManager) error { + m.enableMetrics = true + return nil + } +} diff --git a/p2p/transport/quic/quic_multiaddr.go b/p2p/transport/quicreuse/quic_multiaddr.go similarity index 74% rename from p2p/transport/quic/quic_multiaddr.go rename to p2p/transport/quicreuse/quic_multiaddr.go index 47dc7e905f..afd8fbb779 100644 --- a/p2p/transport/quic/quic_multiaddr.go +++ b/p2p/transport/quicreuse/quic_multiaddr.go @@ -1,4 +1,4 @@ -package libp2pquic +package quicreuse import ( "errors" @@ -9,10 +9,12 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -var quicV1MA ma.Multiaddr = ma.StringCast("/quic-v1") -var quicDraft29MA ma.Multiaddr = ma.StringCast("/quic") +var ( + quicV1MA = ma.StringCast("/quic-v1") + quicDraft29MA = ma.StringCast("/quic") +) -func toQuicMultiaddr(na net.Addr, version quic.VersionNumber) (ma.Multiaddr, error) { +func ToQuicMultiaddr(na net.Addr, version quic.VersionNumber) (ma.Multiaddr, error) { udpMA, err := manet.FromNetAddr(na) if err != nil { return nil, err @@ -27,7 +29,7 @@ func toQuicMultiaddr(na net.Addr, version quic.VersionNumber) (ma.Multiaddr, err } } -func fromQuicMultiaddr(addr ma.Multiaddr) (net.Addr, quic.VersionNumber, error) { +func FromQuicMultiaddr(addr ma.Multiaddr) (*net.UDPAddr, quic.VersionNumber, error) { var version quic.VersionNumber var partsBeforeQUIC []ma.Multiaddr ma.ForEach(addr, func(c ma.Component) bool { @@ -54,5 +56,9 @@ func fromQuicMultiaddr(addr ma.Multiaddr) (net.Addr, quic.VersionNumber, error) if err != nil { return nil, version, err } - return netAddr, version, err + udpAddr, ok := netAddr.(*net.UDPAddr) + if !ok { + return nil, 0, errors.New("not a *net.UDPAddr") + } + return udpAddr, version, nil } diff --git a/p2p/transport/quic/quic_multiaddr_test.go b/p2p/transport/quicreuse/quic_multiaddr_test.go similarity index 78% rename from p2p/transport/quic/quic_multiaddr_test.go rename to p2p/transport/quicreuse/quic_multiaddr_test.go index 34d971f812..fa18a72836 100644 --- a/p2p/transport/quic/quic_multiaddr_test.go +++ b/p2p/transport/quicreuse/quic_multiaddr_test.go @@ -1,4 +1,4 @@ -package libp2pquic +package quicreuse import ( "net" @@ -11,14 +11,14 @@ import ( func TestConvertToQuicMultiaddr(t *testing.T) { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337} - maddr, err := toQuicMultiaddr(addr, quic.VersionDraft29) + maddr, err := ToQuicMultiaddr(addr, quic.VersionDraft29) require.NoError(t, err) require.Equal(t, maddr.String(), "/ip4/192.168.0.42/udp/1337/quic") } func TestConvertToQuicV1Multiaddr(t *testing.T) { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337} - maddr, err := toQuicMultiaddr(addr, quic.Version1) + maddr, err := ToQuicMultiaddr(addr, quic.Version1) require.NoError(t, err) require.Equal(t, maddr.String(), "/ip4/192.168.0.42/udp/1337/quic-v1") } @@ -26,10 +26,8 @@ func TestConvertToQuicV1Multiaddr(t *testing.T) { func TestConvertFromQuicDraft29Multiaddr(t *testing.T) { maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic") require.NoError(t, err) - addr, v, err := fromQuicMultiaddr(maddr) + udpAddr, v, err := FromQuicMultiaddr(maddr) require.NoError(t, err) - udpAddr, ok := addr.(*net.UDPAddr) - require.True(t, ok) require.Equal(t, udpAddr.IP, net.IPv4(192, 168, 0, 42)) require.Equal(t, udpAddr.Port, 1337) require.Equal(t, v, quic.VersionDraft29) @@ -38,10 +36,8 @@ func TestConvertFromQuicDraft29Multiaddr(t *testing.T) { func TestConvertFromQuicV1Multiaddr(t *testing.T) { maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic-v1") require.NoError(t, err) - addr, v, err := fromQuicMultiaddr(maddr) + udpAddr, v, err := FromQuicMultiaddr(maddr) require.NoError(t, err) - udpAddr, ok := addr.(*net.UDPAddr) - require.True(t, ok) require.Equal(t, udpAddr.IP, net.IPv4(192, 168, 0, 42)) require.Equal(t, udpAddr.Port, 1337) require.Equal(t, v, quic.Version1) diff --git a/p2p/transport/quic/reuse.go b/p2p/transport/quicreuse/reuse.go similarity index 95% rename from p2p/transport/quic/reuse.go rename to p2p/transport/quicreuse/reuse.go index 43eb2cd361..4cb46f23d3 100644 --- a/p2p/transport/quic/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -1,4 +1,4 @@ -package libp2pquic +package quicreuse import ( "net" @@ -9,6 +9,23 @@ import ( "github.com/libp2p/go-netroute" ) +type pConn interface { + net.PacketConn + + // count conn reference + DecreaseCount() + IncreaseCount() +} + +type noreuseConn struct { + *net.UDPConn +} + +func (c *noreuseConn) IncreaseCount() {} +func (c *noreuseConn) DecreaseCount() { + c.UDPConn.Close() +} + // Constant. Defined as variables to simplify testing. var ( garbageCollectInterval = 30 * time.Second diff --git a/p2p/transport/quic/reuse_test.go b/p2p/transport/quicreuse/reuse_test.go similarity index 98% rename from p2p/transport/quic/reuse_test.go rename to p2p/transport/quicreuse/reuse_test.go index 473c125d36..36a109c80d 100644 --- a/p2p/transport/quic/reuse_test.go +++ b/p2p/transport/quicreuse/reuse_test.go @@ -1,4 +1,4 @@ -package libp2pquic +package quicreuse import ( "bytes" @@ -44,7 +44,7 @@ func platformHasRoutingTables() bool { func isGarbageCollectorRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic.(*reuse).gc") + return strings.Contains(b.String(), "quicreuse.(*reuse).gc") } func cleanup(t *testing.T, reuse *reuse) { diff --git a/p2p/transport/internal/quicutils/tracer.go b/p2p/transport/quicreuse/tracer.go similarity index 97% rename from p2p/transport/internal/quicutils/tracer.go rename to p2p/transport/quicreuse/tracer.go index 272b1be290..f7a8767ae4 100644 --- a/p2p/transport/internal/quicutils/tracer.go +++ b/p2p/transport/quicreuse/tracer.go @@ -1,4 +1,4 @@ -package quicutils +package quicreuse import ( "bufio" @@ -17,11 +17,11 @@ var log = golog.Logger("quic-utils") // QLOGTracer holds a qlog tracer, if qlogging is enabled (enabled using the QLOGDIR environment variable). // Otherwise it is nil. -var QLOGTracer logging.Tracer +var qlogTracer logging.Tracer func init() { if qlogDir := os.Getenv("QLOGDIR"); len(qlogDir) > 0 { - QLOGTracer = initQlogger(qlogDir) + qlogTracer = initQlogger(qlogDir) } } diff --git a/p2p/transport/quic/tracer_metrics.go b/p2p/transport/quicreuse/tracer_metrics.go similarity index 99% rename from p2p/transport/quic/tracer_metrics.go rename to p2p/transport/quicreuse/tracer_metrics.go index 0ba48e0fdc..3282b74715 100644 --- a/p2p/transport/quic/tracer_metrics.go +++ b/p2p/transport/quicreuse/tracer_metrics.go @@ -1,4 +1,4 @@ -package libp2pquic +package quicreuse import ( "context" diff --git a/p2p/transport/internal/quicutils/tracer_test.go b/p2p/transport/quicreuse/tracer_test.go similarity index 99% rename from p2p/transport/internal/quicutils/tracer_test.go rename to p2p/transport/quicreuse/tracer_test.go index 0d095d342d..6c8ee2ed1d 100644 --- a/p2p/transport/internal/quicutils/tracer_test.go +++ b/p2p/transport/quicreuse/tracer_test.go @@ -1,4 +1,4 @@ -package quicutils +package quicreuse import ( "bytes" diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index dc2c1f03a7..d988625669 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -17,6 +17,7 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/lucas-clemente/quic-go/http3" "github.com/multiformats/go-multihash" "golang.org/x/crypto/hkdf" ) @@ -34,6 +35,7 @@ func getTLSConf(key ic.PrivKey, start, end time.Time) (*tls.Config, error) { PrivateKey: priv, Leaf: cert, }}, + NextProtos: []string{http3.NextProtoH3}, }, nil } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 03ff72f81c..f82c0944bf 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -2,7 +2,6 @@ package libp2pwebtransport import ( "context" - "crypto/tls" "errors" "fmt" "net" @@ -13,11 +12,10 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/security/noise" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" - "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) var errClosed = errors.New("closed") @@ -27,8 +25,8 @@ const handshakeTimeout = 10 * time.Second type listener struct { transport *transport - tlsConf *tls.Config isStaticTLSConf bool + reuseListener quicreuse.Listener server webtransport.Server @@ -45,42 +43,21 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(laddr ma.Multiaddr, t *transport, tlsConf *tls.Config) (tpt.Listener, error) { - network, addr, err := manet.DialArgs(laddr) +func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf bool) (tpt.Listener, error) { + localMultiaddr, err := toWebtransportMultiaddr(reuseListener.Addr()) if err != nil { return nil, err } - udpAddr, err := net.ResolveUDPAddr(network, addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP(network, udpAddr) - if err != nil { - return nil, err - } - localMultiaddr, err := toWebtransportMultiaddr(udpConn.LocalAddr()) - if err != nil { - return nil, err - } - isStaticTLSConf := tlsConf != nil - if tlsConf == nil { - tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { - return t.certManager.GetConfig(), nil - }} - } + ln := &listener{ + reuseListener: reuseListener, transport: t, - tlsConf: tlsConf, isStaticTLSConf: isStaticTLSConf, queue: make(chan tpt.CapableConn, queueLen), serverClosed: make(chan struct{}), - addr: udpConn.LocalAddr(), + addr: reuseListener.Addr(), multiaddr: localMultiaddr, server: webtransport.Server{ - H3: http3.Server{ - QuicConfig: t.quicConfig, - TLSConfig: tlsConf, - }, CheckOrigin: func(r *http.Request) bool { return true }, }, } @@ -90,10 +67,13 @@ func newListener(laddr ma.Multiaddr, t *transport, tlsConf *tls.Config) (tpt.Lis ln.server.H3.Handler = mux go func() { defer close(ln.serverClosed) - defer func() { udpConn.Close() }() - if err := ln.server.Serve(udpConn); err != nil { - // TODO: only output if the server hasn't been closed - log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err) + for { + conn, err := ln.reuseListener.Accept(context.Background()) + if err != nil { + log.Debugw("serving failed", "addr", ln.Addr(), "error", err) + return + } + go ln.server.ServeQUICConn(conn) } }() return ln, nil @@ -227,6 +207,7 @@ func (l *listener) Multiaddrs() []ma.Multiaddr { func (l *listener) Close() error { l.ctxCancel() + l.reuseListener.Close() err := l.server.Close() <-l.serverClosed return err diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 0b133e08fd..f21d8f8743 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -18,7 +18,7 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/security/noise" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" - "github.com/libp2p/go-libp2p/p2p/transport/internal/quicutils" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/benbjohnson/clock" logging "github.com/ipfs/go-log/v2" @@ -73,9 +73,9 @@ type transport struct { pid peer.ID clock clock.Clock - quicConfig *quic.Config - rcmgr network.ResourceManager - gater connmgr.ConnectionGater + connManager *quicreuse.ConnManager + rcmgr network.ResourceManager + gater connmgr.ConnectionGater listenOnce sync.Once listenOnceErr error @@ -93,22 +93,20 @@ var _ tpt.Transport = &transport{} var _ tpt.Resolver = &transport{} var _ io.Closer = &transport{} -func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { +func New(key ic.PrivKey, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - conns: map[uint64]*conn{}, - } - t.quicConfig = &quic.Config{ - AllowConnectionWindowIncrease: t.allowWindowIncrease, - Versions: []quic.VersionNumber{quic.Version1}} + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[uint64]*conn{}, + } for _, opt := range opts { if err := opt(t); err != nil { return nil, err @@ -119,9 +117,6 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } t.noise = n - if qlogTracer := quicutils.QLOGTracer; qlogTracer != nil { - t.quicConfig.Tracer = qlogTracer - } return t, nil } @@ -130,6 +125,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } + url := fmt.Sprintf("https://%s%s?type=noise", addr, webtransportHTTPEndpoint) certHashes, err := extractCertHashes(raddr) if err != nil { return nil, err @@ -148,7 +144,8 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } - sess, err := t.dial(ctx, addr, sni, certHashes) + maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT }) + sess, err := t.dial(ctx, maddr, url, sni, certHashes) if err != nil { scope.Done() return nil, err @@ -169,14 +166,14 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return conn, nil } -func (t *transport) dial(ctx context.Context, addr string, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { - url := fmt.Sprintf("https://%s%s?type=noise", addr, webtransportHTTPEndpoint) +func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { var tlsConf *tls.Config if t.tlsClientConf != nil { tlsConf = t.tlsClientConf.Clone() } else { tlsConf = &tls.Config{} } + tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3) if sni != "" { tlsConf.ServerName = sni @@ -190,10 +187,15 @@ func (t *transport) dial(ctx context.Context, addr string, sni string, certHashe return verifyRawCerts(rawCerts, certHashes) } } + conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease) + if err != nil { + return nil, err + } dialer := webtransport.Dialer{ RoundTripper: &http3.RoundTripper{ - TLSClientConfig: tlsConf, - QuicConfig: t.quicConfig, + Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return conn.(quic.EarlyConnection), nil + }, }, } rsp, sess, err := dialer.Dial(ctx, url, nil) @@ -302,7 +304,19 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { return nil, t.listenOnceErr } } - return newListener(laddr, t, t.staticTLSConf) + tlsConf := t.staticTLSConf.Clone() + if tlsConf == nil { + tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return t.certManager.GetConfig(), nil + }} + } + tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3) + + ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease) + if err != nil { + return nil, err + } + return newListener(ln, t, t.staticTLSConf != nil) } func (t *transport) Protocols() []int { @@ -367,7 +381,7 @@ func extractSNI(maddr ma.Multiaddr) (sni string, foundSniComponent bool) { } // Resolve implements transport.Resolver -func (t *transport) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { +func (t *transport) Resolve(_ context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { sni, foundSniComponent := extractSNI(maddr) if foundSniComponent || sni == "" { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index e6415d9fff..96857564ce 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -23,6 +23,10 @@ import ( "testing/quick" "time" + "github.com/lucas-clemente/quic-go/http3" + + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/benbjohnson/clock" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -98,9 +102,17 @@ func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr { return ha } +func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManager { + t.Helper() + cm, err := quicreuse.NewConnManager([32]byte{}, opts...) + require.NoError(t, err) + t.Cleanup(func() { cm.Close() }) + return cm +} + func TestTransport(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -110,7 +122,7 @@ func TestTransport(t *testing.T) { addrChan := make(chan ma.Multiaddr) go func() { _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, nil, &network.NullResourceManager{}) + tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -146,7 +158,7 @@ func TestTransport(t *testing.T) { func TestHashVerification(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -159,7 +171,7 @@ func TestHashVerification(t *testing.T) { }() _, clientKey := newIdentity(t) - tr2, err := libp2pwebtransport.New(clientKey, nil, &network.NullResourceManager{}) + tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -197,7 +209,7 @@ func TestCanDial(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -223,7 +235,7 @@ func TestListenAddrValidity(t *testing.T) { } _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -240,7 +252,7 @@ func TestListenAddrValidity(t *testing.T) { func TestListenerAddrs(t *testing.T) { _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -263,7 +275,7 @@ func TestResourceManagerDialing(t *testing.T) { p := peer.ID("foobar") _, key := newIdentity(t) - tr, err := libp2pwebtransport.New(key, nil, rcmgr) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -278,7 +290,7 @@ func TestResourceManagerDialing(t *testing.T) { func TestResourceManagerListening(t *testing.T) { clientID, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -287,7 +299,7 @@ func TestResourceManagerListening(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tr, err := libp2pwebtransport.New(key, nil, rcmgr) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) @@ -313,7 +325,7 @@ func TestResourceManagerListening(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tr, err := libp2pwebtransport.New(key, nil, rcmgr) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr) require.NoError(t, err) ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) @@ -357,7 +369,7 @@ func TestConnectionGaterDialing(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -368,7 +380,7 @@ func TestConnectionGaterDialing(t *testing.T) { require.Equal(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.RemoteMultiaddr()) }) _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, connGater, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) @@ -381,7 +393,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, connGater, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -394,7 +406,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { }) _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) @@ -407,7 +419,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { connGater := NewMockConnectionGater(ctrl) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, connGater, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -415,7 +427,7 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { defer ln.Close() clientID, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -473,7 +485,7 @@ func TestStaticTLSConf(t *testing.T) { tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour)) serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -483,7 +495,7 @@ func TestStaticTLSConf(t *testing.T) { t.Run("fails when the certificate is invalid", func(t *testing.T) { _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -497,7 +509,7 @@ func TestStaticTLSConf(t *testing.T) { t.Run("fails when dialing with a wrong certhash", func(t *testing.T) { _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -512,7 +524,7 @@ func TestStaticTLSConf(t *testing.T) { store := x509.NewCertPool() store.AddCert(tlsConf.Certificates[0].Leaf) tlsConf := &tls.Config{RootCAs: store} - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf)) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf)) require.NoError(t, err) defer cl.(io.Closer).Close() @@ -525,7 +537,7 @@ func TestStaticTLSConf(t *testing.T) { func TestAcceptQueueFilledUp(t *testing.T) { serverID, serverKey := newIdentity(t) - tr, err := libp2pwebtransport.New(serverKey, nil, &network.NullResourceManager{}) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -535,7 +547,7 @@ func TestAcceptQueueFilledUp(t *testing.T) { newConn := func() (tpt.CapableConn, error) { t.Helper() _, key := newIdentity(t) - cl, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}) + cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() return cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) @@ -565,7 +577,7 @@ func TestSNIIsSent(t *testing.T) { return tlsConf, nil }, } - tr, err := libp2pwebtransport.New(key, nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) + tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf)) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -573,7 +585,7 @@ func TestSNIIsSent(t *testing.T) { require.NoError(t, err) _, key2 := newIdentity(t) - clientTr, err := libp2pwebtransport.New(key2, nil, &network.NullResourceManager{}) + clientTr, err := libp2pwebtransport.New(key2, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer tr.(io.Closer).Close() @@ -631,7 +643,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { serverID, serverKey := newIdentity(t) serverWindowIncreases := make(chan int, 100) serverRcmgr := &reportingRcmgr{report: serverWindowIncreases} - tr, err := libp2pwebtransport.New(serverKey, nil, serverRcmgr) + tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, serverRcmgr) require.NoError(t, err) defer tr.(io.Closer).Close() ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -658,7 +670,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { _, clientKey := newIdentity(t) clientWindowIncreases := make(chan int, 100) clientRcmgr := &reportingRcmgr{report: clientWindowIncreases} - tr2, err := libp2pwebtransport.New(clientKey, nil, clientRcmgr) + tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, clientRcmgr) require.NoError(t, err) defer tr2.(io.Closer).Close() @@ -724,7 +736,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { var errTimeout = errors.New("timeout") -func serverSendsBackValidCert(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) error { +func serverSendsBackValidCert(t *testing.T, timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) error { if timeSinceUnixEpoch < 0 { timeSinceUnixEpoch = -timeSinceUnixEpoch } @@ -741,21 +753,15 @@ func serverSendsBackValidCert(timeSinceUnixEpoch time.Duration, keySeed int64, r cl.Set(start) priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed) - if err != nil { - return err - } - tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) - if err != nil { - return err - } + require.NoError(t, err) + tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + require.NoError(t, err) l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) - if err != nil { - return err - } + require.NoError(t, err) defer l.Close() conn, err := quic.DialAddr(l.Addr().String(), &tls.Config{ - NextProtos: []string{"h3"}, + NextProtos: []string{http3.NextProtoH3}, InsecureSkipVerify: true, VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { for _, c := range rawCerts { @@ -790,7 +796,7 @@ func serverSendsBackValidCert(timeSinceUnixEpoch time.Duration, keySeed int64, r func TestServerSendsBackValidCert(t *testing.T) { var maxTimeoutErrors = 10 require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) bool { - err := serverSendsBackValidCert(timeSinceUnixEpoch, keySeed, randomClientSkew) + err := serverSendsBackValidCert(t, timeSinceUnixEpoch, keySeed, randomClientSkew) if err == errTimeout { maxTimeoutErrors -= 1 if maxTimeoutErrors <= 0 { @@ -827,7 +833,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) { if err != nil { return false } - tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) if err != nil { return false } @@ -841,7 +847,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) { // These two certificates together are valid for at most certValidity - (4*clockSkewAllowance) cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second) - tr, err = libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + tr, err = libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) if err != nil { return false } @@ -887,7 +893,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256) require.NoError(t, err) - tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) require.NoError(t, err) l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) @@ -900,7 +906,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { // e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D for i := 0; i < 200; i++ { cl.Add(24 * time.Hour) - tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) require.NoError(t, err) l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err)