diff --git a/tcp.go b/tcp.go index a8c6d76..1405c6c 100644 --- a/tcp.go +++ b/tcp.go @@ -21,9 +21,7 @@ import ( manet "github.com/multiformats/go-multiaddr/net" ) -// DefaultConnectTimeout is the (default) maximum amount of time the TCP -// transport will spend on the initial TCP connect before giving up. -var DefaultConnectTimeout = 5 * time.Second +const defaultConnectTimeout = 5 * time.Second var log = logging.Logger("tcp-tpt") @@ -97,6 +95,12 @@ func DisableReuseport() Option { return nil } } +func WithConnectionTimeout(d time.Duration) Option { + return func(tr *TcpTransport) error { + tr.connectTimeout = d + return nil + } +} // TcpTransport is the TCP transport. type TcpTransport struct { @@ -108,7 +112,7 @@ type TcpTransport struct { disableReuseport bool // TCP connect timeout - ConnectTimeout time.Duration + connectTimeout time.Duration reuse rtpt.Transport } @@ -118,7 +122,10 @@ var _ transport.Transport = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners // created. It represents an entire TCP stack (though it might not necessarily be). func NewTCPTransport(upgrader *tptu.Upgrader, opts ...Option) (*TcpTransport, error) { - tr := &TcpTransport{Upgrader: upgrader, ConnectTimeout: DefaultConnectTimeout} + tr := &TcpTransport{ + Upgrader: upgrader, + connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option + } for _, o := range opts { if err := o(tr); err != nil { return nil, err @@ -137,9 +144,9 @@ func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool { func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { // Apply the deadline iff applicable - if t.ConnectTimeout > 0 { + if t.connectTimeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, t.ConnectTimeout) + ctx, cancel = context.WithTimeout(ctx, t.connectTimeout) defer cancel() }