diff --git a/internal/http2/transport.go b/internal/http2/transport.go index 115eab6f..0007cdee 100644 --- a/internal/http2/transport.go +++ b/internal/http2/transport.go @@ -15,14 +15,6 @@ import ( "crypto/tls" "errors" "fmt" - "github.com/imroc/req/v3/http2" - "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/common" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/netutil" - "github.com/imroc/req/v3/internal/transport" - reqtls "github.com/imroc/req/v3/pkg/tls" "io" "io/fs" "log" @@ -43,6 +35,15 @@ import ( "golang.org/x/net/http/httpguts" "golang.org/x/net/http2/hpack" "golang.org/x/net/idna" + + "github.com/imroc/req/v3/http2" + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/common" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/netutil" + "github.com/imroc/req/v3/internal/transport" + reqtls "github.com/imroc/req/v3/pkg/tls" ) const ( @@ -157,7 +158,6 @@ func (t *Transport) pingTimeout() time.Duration { return 15 * time.Second } return t.PingTimeout - } func (t *Transport) connPool() ClientConnPool { @@ -585,18 +585,72 @@ func (t *Transport) newTLSConfig(host string) *tls.Config { return cfg } +var zeroDialer net.Dialer + +type tlsHandshakeTimeoutError struct{} + +func (tlsHandshakeTimeoutError) Timeout() bool { return true } +func (tlsHandshakeTimeoutError) Temporary() bool { return true } +func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } + // dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS // connection. func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (reqtls.Conn, error) { - dialer := &tls.Dialer{ - Config: cfg, - } - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err + if t.TLSHandshakeContext != nil { + conn, err := zeroDialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + var firstTLSHost string + if firstTLSHost, _, err = net.SplitHostPort(addr); err != nil { + return nil, err + } + trace := httptrace.ContextClientTrace(ctx) + errc := make(chan error, 2) + var timer *time.Timer // for canceling TLS handshake + if d := t.TLSHandshakeTimeout; d != 0 { + timer = time.AfterFunc(d, func() { + errc <- tlsHandshakeTimeoutError{} + }) + } + go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + tlsCn, tlsState, err := t.TLSHandshakeContext(ctx, firstTLSHost, conn) + if err != nil { + if timer != nil { + timer.Stop() + } + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } + } else { + conn = tlsCn + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(*tlsState, nil) + } + } + errc <- err + }() + if err := <-errc; err != nil { + conn.Close() + return nil, err + } else { + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil + } + } else { + dialer := &tls.Dialer{ + Config: cfg, + } + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + tlsCn := conn.(reqtls.Conn) + return tlsCn, nil } - tlsCn := conn.(reqtls.Conn) - return tlsCn, nil } func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) { @@ -1771,7 +1825,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if a := cs.flow.available(); a > 0 { take := a if int(take) > maxBytes { - take = int32(maxBytes) // can't truncate int; take is int32 } if take > int32(cc.maxFrameSize) { @@ -1928,7 +1981,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail break } vals = append(vals, v[:p]) - //writeHeader("cookie", v[:p]) + // writeHeader("cookie", v[:p]) p++ // strip space after semicolon if any. for p+1 <= len(v) && v[p] == ' ' { @@ -1938,7 +1991,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } if len(v) > 0 { vals = append(vals, v) - //writeHeader("cookie", v) + // writeHeader("cookie", v) } } writeHeader("cookie", vals...) @@ -2641,6 +2694,7 @@ func (b transportResponseBody) Close() error { } return nil } + func (rl *clientConnReadLoop) processData(f *DataFrame) error { cc := rl.cc cs := rl.streamByID(f.StreamID)