From 03c24c2d766bb8e069dba41e07e6c272d345fbe1 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Sat, 18 May 2024 12:55:39 -0700 Subject: [PATCH] http2: use synthetic time in server tests Change newServerTester to return a server using fake time and a fake net.Conn. Change-Id: I9d5db0cbe75696aed6d99ff1cd2369c2dea426c3 Reviewed-on: https://go-review.googlesource.com/c/net/+/586247 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- http2/http2.go | 13 ++ http2/server.go | 87 ++++++++--- http2/server_push_test.go | 6 +- http2/server_test.go | 298 +++++++++++++++++++++++++------------- http2/sync_test.go | 13 +- http2/transport.go | 7 +- http2/transport_test.go | 50 ++++--- 7 files changed, 318 insertions(+), 156 deletions(-) diff --git a/http2/http2.go b/http2/http2.go index 6f90f98e4..003e649f3 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -17,6 +17,7 @@ package http2 // import "golang.org/x/net/http2" import ( "bufio" + "context" "crypto/tls" "fmt" "io" @@ -26,6 +27,7 @@ import ( "strconv" "strings" "sync" + "time" "golang.org/x/net/http/httpguts" ) @@ -377,3 +379,14 @@ func validPseudoPath(v string) bool { // makes that struct also non-comparable, and generally doesn't add // any size (as long as it's first). type incomparable [0]func() + +// synctestGroupInterface is the methods of synctestGroup used by Server and Transport. +// It's defined as an interface here to let us keep synctestGroup entirely test-only +// and not a part of non-test builds. +type synctestGroupInterface interface { + Join() + Now() time.Time + NewTimer(d time.Duration) timer + AfterFunc(d time.Duration, f func()) timer + ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) +} diff --git a/http2/server.go b/http2/server.go index 778ef636c..d23640d01 100644 --- a/http2/server.go +++ b/http2/server.go @@ -154,6 +154,39 @@ type Server struct { // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. state *serverInternalState + + // Synchronization group used for testing. + // Outside of tests, this is nil. + group synctestGroupInterface +} + +func (s *Server) markNewGoroutine() { + if s.group != nil { + s.group.Join() + } +} + +func (s *Server) now() time.Time { + if s.group != nil { + return s.group.Now() + } + return time.Now() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (s *Server) newTimer(d time.Duration) timer { + if s.group != nil { + return s.group.NewTimer(d) + } + return timeTimer{time.NewTimer(d)} +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (s *Server) afterFunc(d time.Duration, f func()) timer { + if s.group != nil { + return s.group.AfterFunc(d, f) + } + return timeTimer{time.AfterFunc(d, f)} } func (s *Server) initialConnRecvWindowSize() int32 { @@ -400,6 +433,10 @@ func (o *ServeConnOpts) handler() http.Handler { // // The opts parameter is optional. If nil, default values are used. func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { + s.serveConn(c, opts, nil) +} + +func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) { baseCtx, cancel := serverConnBaseContext(c, opts) defer cancel() @@ -426,6 +463,9 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { pushEnabled: true, sawClientPreface: opts.SawClientPreface, } + if newf != nil { + newf(sc) + } s.state.registerConn(sc) defer s.state.unregisterConn(sc) @@ -599,8 +639,8 @@ type serverConn struct { inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode ErrCode - shutdownTimer *time.Timer // nil until used - idleTimer *time.Timer // nil if unused + shutdownTimer timer // nil until used + idleTimer timer // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -649,12 +689,12 @@ type stream struct { flow outflow // limits writing from Handler to client inflow inflow // what the client is allowed to POST/etc to us state streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - readDeadline *time.Timer // nil if unused - writeDeadline *time.Timer // nil if unused - closeErr error // set before cw is closed + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + readDeadline timer // nil if unused + writeDeadline timer // nil if unused + closeErr error // set before cw is closed trailer http.Header // accumulated trailers reqTrailer http.Header // handler's Request.Trailer @@ -811,6 +851,7 @@ type readFrameResult struct { // consumer is done with the frame. // It's run on its own goroutine. func (sc *serverConn) readFrames() { + sc.srv.markNewGoroutine() gate := make(chan struct{}) gateDone := func() { gate <- struct{}{} } for { @@ -843,6 +884,7 @@ type frameWriteResult struct { // At most one goroutine can be running writeFrameAsync at a time per // serverConn. func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { + sc.srv.markNewGoroutine() var err error if wd == nil { err = wr.write.writeFrame(sc) @@ -922,13 +964,13 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateIdle) if sc.srv.IdleTimeout > 0 { - sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } go sc.readFrames() // closed by defer sc.conn.Close above - settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) + settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer) defer settingsTimer.Stop() loopNum := 0 @@ -1057,10 +1099,10 @@ func (sc *serverConn) readPreface() error { errc <- nil } }() - timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? + timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { - case <-timer.C: + case <-timer.C(): return errPrefaceTimeout case err := <-errc: if err == nil { @@ -1425,7 +1467,7 @@ func (sc *serverConn) goAway(code ErrCode) { func (sc *serverConn) shutDownIn(d time.Duration) { sc.serveG.check() - sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) + sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer) } func (sc *serverConn) resetStream(se StreamError) { @@ -2022,7 +2064,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) - st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -2120,7 +2162,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) if sc.hs.WriteTimeout > 0 { - st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } sc.streams[id] = st @@ -2344,6 +2386,7 @@ func (sc *serverConn) handlerDone() { // Run on its own goroutine. func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { + sc.srv.markNewGoroutine() defer sc.sendServeMsg(handlerDoneMsg) didPanic := true defer func() { @@ -2640,7 +2683,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { var date string if _, ok := rws.snapHeader["Date"]; !ok { // TODO(bradfitz): be faster here, like net/http? measure. - date = time.Now().UTC().Format(http.TimeFormat) + date = rws.conn.srv.now().UTC().Format(http.TimeFormat) } for _, v := range rws.snapHeader["Trailer"] { @@ -2762,7 +2805,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() { func (w *responseWriter) SetReadDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(time.Now()) { + if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onReadTimeout() @@ -2778,9 +2821,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { if deadline.IsZero() { st.readDeadline = nil } else if st.readDeadline == nil { - st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) + st.readDeadline = sc.srv.afterFunc(deadline.Sub(w.rws.conn.srv.now()), st.onReadTimeout) } else { - st.readDeadline.Reset(deadline.Sub(time.Now())) + st.readDeadline.Reset(deadline.Sub(w.rws.conn.srv.now())) } }) return nil @@ -2788,7 +2831,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(time.Now()) { + if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onWriteTimeout() @@ -2804,9 +2847,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { if deadline.IsZero() { st.writeDeadline = nil } else if st.writeDeadline == nil { - st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) + st.writeDeadline = sc.srv.afterFunc(deadline.Sub(w.rws.conn.srv.now()), st.onWriteTimeout) } else { - st.writeDeadline.Reset(deadline.Sub(time.Now())) + st.writeDeadline.Reset(deadline.Sub(w.rws.conn.srv.now())) } }) return nil diff --git a/http2/server_push_test.go b/http2/server_push_test.go index e90b28883..97b00e85a 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -105,7 +105,7 @@ func TestServer_Push_Success(t *testing.T) { errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI()) } }) - stURL = st.ts.URL + stURL = "https://" + st.authority() // Send one request, which should push two responses. st.greet() @@ -169,7 +169,7 @@ func TestServer_Push_Success(t *testing.T) { return checkPushPromise(f, 2, [][2]string{ {":method", "GET"}, {":scheme", "https"}, - {":authority", st.ts.Listener.Addr().String()}, + {":authority", st.authority()}, {":path", "/pushed?get"}, {"user-agent", userAgent}, }) @@ -178,7 +178,7 @@ func TestServer_Push_Success(t *testing.T) { return checkPushPromise(f, 4, [][2]string{ {":method", "HEAD"}, {":scheme", "https"}, - {":authority", st.ts.Listener.Addr().String()}, + {":authority", st.authority()}, {":path", "/pushed?head"}, {"cookie", cookie}, {"user-agent", userAgent}, diff --git a/http2/server_test.go b/http2/server_test.go index efa2b2207..506c51643 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -15,6 +15,7 @@ import ( "fmt" "io" "log" + "math" "net" "net/http" "net/http/httptest" @@ -66,7 +67,9 @@ func (sb *safeBuffer) Len() int { type serverTester struct { cc net.Conn // client conn t testing.TB - ts *httptest.Server + group *synctestGroup + h1server *http.Server + h2server *Server fr *Framer serverLogBuf safeBuffer // logger for httptest.Server logFilter []string // substrings to filter out @@ -109,6 +112,8 @@ func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) switch v := opt.(type) { case func(*httptest.Server): v(ts) + case func(*http.Server): + v(ts.Config) case func(*Server): v(h2server) default: @@ -140,14 +145,95 @@ type serverTesterOpt string var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames") -var optQuiet = func(ts *httptest.Server) { - ts.Config.ErrorLog = log.New(io.Discard, "", 0) +var optQuiet = func(server *http.Server) { + server.ErrorLog = log.New(io.Discard, "", 0) } func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { + t.Helper() + g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)) + h1server := &http.Server{} + h2server := &Server{ + group: g, + } + tlsState := tls.ConnectionState{ + Version: tls.VersionTLS13, + ServerName: "go.dev", + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + } + for _, opt := range opts { + switch v := opt.(type) { + case func(*Server): + v(h2server) + case func(*http.Server): + v(h1server) + case func(*tls.ConnectionState): + v(&tlsState) + default: + t.Fatalf("unknown newServerTester option type %T", v) + } + } + ConfigureServer(h1server, h2server) + + cli, srv := synctestNetPipe(g) + cli.SetReadDeadline(g.Now()) + cli.autoWait = true + + st := &serverTester{ + t: t, + cc: cli, + group: g, + h1server: h1server, + h2server: h2server, + } + st.hpackEnc = hpack.NewEncoder(&st.headerBuf) + st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) + if h1server.ErrorLog == nil { + h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) + } + + t.Cleanup(func() { + st.Close() + }) + + connc := make(chan *serverConn) + go func() { + g.Join() + h2server.serveConn(&netConnWithConnectionState{ + Conn: srv, + state: tlsState, + }, &ServeConnOpts{ + Handler: handler, + BaseConfig: h1server, + }, func(sc *serverConn) { + connc <- sc + }) + }() + st.sc = <-connc + + st.fr = NewFramer(st.cc, st.cc) + g.Wait() + return st +} + +type netConnWithConnectionState struct { + net.Conn + state tls.ConnectionState +} + +func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState { + return c.state +} + +// newServerTesterWithRealConn creates a test server listening on a localhost port. +// Mostly superseded by newServerTester, which creates a test server using a fake +// net.Conn and synthetic time. This function is still around because some benchmarks +// rely on it; new tests should use newServerTester. +func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { resetHooks() ts := httptest.NewUnstartedServer(handler) + t.Cleanup(ts.Close) tlsConfig := &tls.Config{ InsecureSkipVerify: true, @@ -162,6 +248,8 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} v(tlsConfig) case func(*httptest.Server): v(ts) + case func(*http.Server): + v(ts.Config) case func(*Server): v(h2server) case serverTesterOpt: @@ -185,8 +273,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} ts.Config.TLSConfig.MinVersion = tls.VersionTLS10 st := &serverTester{ - t: t, - ts: ts, + t: t, } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) @@ -234,6 +321,20 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} return st } +// sync waits for all goroutines to idle. +func (st *serverTester) sync() { + st.group.Wait() +} + +// advance advances synthetic time by a duration. +func (st *serverTester) advance(d time.Duration) { + st.group.AdvanceTime(d) +} + +func (st *serverTester) authority() string { + return "dummy.tld" +} + func (st *serverTester) closeConn() { st.scMu.Lock() defer st.scMu.Unlock() @@ -309,7 +410,6 @@ func (st *serverTester) Close() { st.cc.Close() } } - st.ts.Close() if st.cc != nil { st.cc.Close() } @@ -438,7 +538,7 @@ func (st *serverTester) encodeHeader(headers ...string) []byte { } st.headerBuf.Reset() - defaultAuthority := st.ts.Listener.Addr().String() + defaultAuthority := st.authority() if len(headers) == 0 { // Fast path, mostly for benchmarks, so test code doesn't pollute @@ -1245,38 +1345,32 @@ func (l *filterListener) Accept() (net.Conn, error) { } func TestServer_MaxQueuedControlFrames(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } + // Goroutine debugging makes this test very slow. + disableGoroutineTracking(t) - st := newServerTester(t, nil, func(ts *httptest.Server) { - // TCP buffer sizes on test systems aren't under our control and can be large. - // Create a conn that blocks after 10000 bytes written. - ts.Listener = &filterListener{ - Listener: ts.Listener, - accept: func(conn net.Conn) (net.Conn, error) { - return newBlockingWriteConn(conn, 10000), nil - }, - } - }) - defer st.Close() + st := newServerTester(t, nil) st.greet() - const extraPings = 500000 // enough to fill the TCP buffers + st.cc.(*synctestNetConn).SetReadBufferSize(0) // all writes block + st.cc.(*synctestNetConn).autoWait = false // don't sync after every write + // Send maxQueuedControlFrames pings, plus a few extra + // to account for ones that enter the server's write buffer. + const extraPings = 2 for i := 0; i < maxQueuedControlFrames+extraPings; i++ { pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} - if err := st.fr.WritePing(false, pingData); err != nil { - if i == 0 { - t.Fatal(err) - } - // We expect the connection to get closed by the server when the TCP - // buffer fills up and the write queue reaches MaxQueuedControlFrames. - t.Logf("sent %d PING frames", i) - return - } + st.fr.WritePing(false, pingData) + } + st.group.Wait() + + // Unblock the server. + // It should have closed the connection after exceeding the control frame limit. + st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt) + fr, err := st.readFrame() + if err != nil { + return } - t.Errorf("unexpected success sending all PING frames") + t.Errorf("unexpected frame after exceeding maxQueuedControlFrames; want closed conn\n%v", fr) } func TestServer_RejectsLargeFrames(t *testing.T) { @@ -1762,6 +1856,7 @@ func testServerRejectsConn(t *testing.T, writeReq func(*serverTester)) { writeReq(st) st.wantGoAway() + st.advance(goAwayTimeout) fr, err := st.fr.ReadFrame() if err == nil { @@ -2611,13 +2706,12 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) } func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) } -func testRejectTLS(t *testing.T, max uint16) { - st := newServerTester(t, nil, func(c *tls.Config) { +func testRejectTLS(t *testing.T, version uint16) { + st := newServerTester(t, nil, func(state *tls.ConnectionState) { // As of 1.18 the default minimum Go TLS version is // 1.2. In order to test rejection of lower versions, - // manually set the minimum version to 1.0 - c.MinVersion = tls.VersionTLS10 - c.MaxVersion = max + // manually set the version to 1.0 + state.Version = version }) defer st.Close() gf := st.wantGoAway() @@ -2627,24 +2721,9 @@ func testRejectTLS(t *testing.T, max uint16) { } func TestServer_Rejects_TLSBadCipher(t *testing.T) { - st := newServerTester(t, nil, func(c *tls.Config) { - // All TLS 1.3 ciphers are good. Test with TLS 1.2. - c.MaxVersion = tls.VersionTLS12 - // Only list bad ones: - c.CipherSuites = []uint16{ - tls.TLS_RSA_WITH_RC4_128_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, - } + st := newServerTester(t, nil, func(state *tls.ConnectionState) { + state.Version = tls.VersionTLS12 + state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA }) defer st.Close() gf := st.wantGoAway() @@ -2654,18 +2733,30 @@ func TestServer_Rejects_TLSBadCipher(t *testing.T) { } func TestServer_Advertises_Common_Cipher(t *testing.T) { - const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 - st := newServerTester(t, nil, func(c *tls.Config) { - // Have the client only support the one required by the spec. - c.CipherSuites = []uint16{requiredSuite} - }, func(ts *httptest.Server) { - var srv *http.Server = ts.Config + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + }, func(srv *http.Server) { // Have the server configured with no specific cipher suites. // This tests that Go's defaults include the required one. srv.TLSConfig = nil }) - defer st.Close() - st.greet() + + // Have the client only support the one required by the spec. + const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + tlsConfig := tlsConfigInsecure.Clone() + tlsConfig.MaxVersion = tls.VersionTLS12 + tlsConfig.CipherSuites = []uint16{requiredSuite} + tr := &Transport{TLSClientConfig: tlsConfig} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() } func (st *serverTester) onHeaderField(f hpack.HeaderField) { @@ -2867,8 +2958,8 @@ func TestCompressionErrorOnWrite(t *testing.T) { var serverConfig *http.Server st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. - }, func(ts *httptest.Server) { - serverConfig = ts.Config + }, func(s *http.Server) { + serverConfig = s serverConfig.MaxHeaderBytes = maxStrLen }) st.addLogFilter("connection error: COMPRESSION_ERROR") @@ -3141,11 +3232,11 @@ func TestServerDoesntWriteInvalidHeaders(t *testing.T) { } func BenchmarkServerGets(b *testing.B) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world" - st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, msg) }) defer st.Close() @@ -3173,11 +3264,11 @@ func BenchmarkServerGets(b *testing.B) { } func BenchmarkServerPosts(b *testing.B) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world" - st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { // Consume the (empty) body from th peer before replying, otherwise // the server will sometimes (depending on scheduling) send the peer a // a RST_STREAM with the CANCEL error code. @@ -3225,7 +3316,7 @@ func BenchmarkServerToClientStreamReuseFrames(b *testing.B) { } func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() const msgLen = 1 // default window size @@ -3241,7 +3332,7 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { return msg } - st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { // Consume the (empty) body from th peer before replying, otherwise // the server will sometimes (depending on scheduling) send the peer a // a RST_STREAM with the CANCEL error code. @@ -3515,17 +3606,17 @@ func TestServerContentLengthCanBeDisabled(t *testing.T) { } } -func disableGoroutineTracking() (restore func()) { +func disableGoroutineTracking(t testing.TB) { old := DebugGoroutines DebugGoroutines = false - return func() { DebugGoroutines = old } + t.Cleanup(func() { DebugGoroutines = old }) } func BenchmarkServer_GetRequest(b *testing.B) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world." - st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { n, err := io.Copy(io.Discard, r.Body) if err != nil || n > 0 { b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err) @@ -3554,10 +3645,10 @@ func BenchmarkServer_GetRequest(b *testing.B) { } func BenchmarkServer_PostRequest(b *testing.B) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() const msg = "Hello, world." - st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) { n, err := io.Copy(io.Discard, r.Body) if err != nil || n > 0 { b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err) @@ -3901,6 +3992,7 @@ func TestServerIdleTimeout(t *testing.T) { defer st.Close() st.greet() + st.advance(500 * time.Millisecond) ga := st.wantGoAway() if ga.ErrCode != ErrCodeNo { t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) @@ -3911,12 +4003,16 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } - const timeout = 250 * time.Millisecond + const ( + requestTimeout = 2 * time.Second + idleTimeout = 1 * time.Second + ) - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - time.Sleep(timeout * 2) + var st *serverTester + st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + st.group.Sleep(requestTimeout) }, func(h2s *Server) { - h2s.IdleTimeout = timeout + h2s.IdleTimeout = idleTimeout }) defer st.Close() @@ -3925,10 +4021,12 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { // Send a request which takes twice the timeout. Verifies the // idle timeout doesn't fire while we're in a request: st.bodylessReq1() + st.advance(requestTimeout) st.wantHeaders() // But the idle timeout should be rearmed after the request // is done: + st.advance(idleTimeout) ga := st.wantGoAway() if ga.ErrCode != ErrCodeNo { t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) @@ -4092,6 +4190,8 @@ func TestServerHandlerConnectionClose(t *testing.T) { } sawWindowUpdate = true unblockHandler <- true + st.sync() + st.advance(goAwayTimeout) default: t.Logf("unexpected frame: %v", summarizeFrame(f)) } @@ -4157,20 +4257,9 @@ func TestServer_Headers_HalfCloseRemote(t *testing.T) { } func TestServerGracefulShutdown(t *testing.T) { - var st *serverTester handlerDone := make(chan struct{}) - st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - defer close(handlerDone) - go st.ts.Config.Shutdown(context.Background()) - - ga := st.wantGoAway() - if ga.ErrCode != ErrCodeNo { - t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) - } - if ga.LastStreamID != 1 { - t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) - } - + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + <-handlerDone w.Header().Set("x-foo", "bar") }) defer st.Close() @@ -4178,7 +4267,20 @@ func TestServerGracefulShutdown(t *testing.T) { st.greet() st.bodylessReq1() - <-handlerDone + st.sync() + st.h1server.Shutdown(context.Background()) + + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } + if ga.LastStreamID != 1 { + t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID) + } + + close(handlerDone) + st.sync() + hf := st.wantHeaders() goth := st.decodeHeader(hf.HeaderBlockFragment()) wanth := [][2]string{ @@ -4396,7 +4498,6 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { } st.writeData(1, true, []byte(content)) - time.Sleep(200 * time.Millisecond) st.Close() if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) { @@ -4523,6 +4624,7 @@ func TestProtocolErrorAfterGoAway(t *testing.T) { t.Fatal(err) } + st.advance(goAwayTimeout) for { if _, err := st.readFrame(); err != nil { if err != io.EOF { @@ -4805,8 +4907,8 @@ Frames: func TestServerContinuationFlood(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Header) - }, func(ts *httptest.Server) { - ts.Config.MaxHeaderBytes = 4096 + }, func(s *http.Server) { + s.MaxHeaderBytes = 4096 }) defer st.Close() diff --git a/http2/sync_test.go b/http2/sync_test.go index bcbbe66ac..9e99a7a15 100644 --- a/http2/sync_test.go +++ b/http2/sync_test.go @@ -31,6 +31,9 @@ type goroutine struct { // newSynctest creates a new group with the synthetic clock set the provided time. func newSynctest(now time.Time) *synctestGroup { return &synctestGroup{ + gids: map[int]bool{ + currentGoroutine(): true, + }, now: now, } } @@ -39,9 +42,6 @@ func newSynctest(now time.Time) *synctestGroup { func (g *synctestGroup) Join() { g.mu.Lock() defer g.mu.Unlock() - if g.gids == nil { - g.gids = map[int]bool{} - } g.gids[currentGoroutine()] = true } @@ -154,6 +154,7 @@ func stacks(all bool) []goroutine { // AdvanceTime advances the synthetic clock by d. func (g *synctestGroup) AdvanceTime(d time.Duration) { + defer g.Wait() g.mu.Lock() defer g.mu.Unlock() g.now = g.now.Add(d) @@ -186,6 +187,12 @@ func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) { return d, scheduled } +// Sleep is time.Sleep, but using synthetic time. +func (g *synctestGroup) Sleep(d time.Duration) { + tm := g.NewTimer(d) + <-tm.C() +} + // NewTimer is time.NewTimer, but using synthetic time. func (g *synctestGroup) NewTimer(d time.Duration) Timer { return g.addTimer(d, &fakeTimer{ diff --git a/http2/transport.go b/http2/transport.go index 2cd3c6ec7..98a49c6b6 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -194,12 +194,7 @@ type Transport struct { type transportTestHooks struct { newclientconn func(*ClientConn) - group interface { - Join() - NewTimer(d time.Duration) timer - AfterFunc(d time.Duration, f func()) timer - ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) - } + group synctestGroupInterface } func (t *Transport) markNewGoroutine() { diff --git a/http2/transport_test.go b/http2/transport_test.go index d62407b47..d73f35e02 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3658,7 +3658,7 @@ func TestTransportNoBodyMeansNoDATA(t *testing.T) { } func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) b.ReportAllocs() ts := newTestServer(b, func(w http.ResponseWriter, r *http.Request) { @@ -3770,10 +3770,10 @@ func BenchmarkDownloadFrameSize(b *testing.B) { b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) }) } func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { - defer disableGoroutineTracking()() + disableGoroutineTracking(b) const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M b.ReportAllocs() - st := newServerTester(b, + ts := newTestServer(b, func(w http.ResponseWriter, r *http.Request) { // test 1GB transfer w.Header().Set("Content-Length", strconv.Itoa(transferSize)) @@ -3784,12 +3784,11 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { } }, optQuiet, ) - defer st.Close() tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize} defer tr.CloseIdleConnections() - req, err := http.NewRequest("GET", st.ts.URL, nil) + req, err := http.NewRequest("GET", ts.URL, nil) if err != nil { b.Fatal(err) } @@ -4869,33 +4868,36 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { } func TestClientConnReservations(t *testing.T) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - }, func(s *Server) { - s.MaxConcurrentStreams = initialMaxConcurrentStreams - }) - defer st.Close() - - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() + tc := newTestClientConn(t) + tc.greet( + Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams}, + ) - cc, err := tr.newClientConn(st.cc, false) - if err != nil { - t.Fatal(err) + doRoundTrip := func() { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) } - req, _ := http.NewRequest("GET", st.ts.URL, nil) n := 0 - for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { + for n <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { n++ } if n != initialMaxConcurrentStreams { t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams) } - if _, err := cc.RoundTrip(req); err != nil { - t.Fatalf("RoundTrip error = %v", err) - } + doRoundTrip() n2 := 0 - for n2 <= 5 && cc.ReserveNewRequest() { + for n2 <= 5 && tc.cc.ReserveNewRequest() { n2++ } if n2 != 1 { @@ -4904,11 +4906,11 @@ func TestClientConnReservations(t *testing.T) { // Use up all the reservations for i := 0; i < n; i++ { - cc.RoundTrip(req) + doRoundTrip() } n2 = 0 - for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() { + for n2 <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() { n2++ } if n2 != n {