diff --git a/layer4/connection.go b/layer4/connection.go index faae99fa..1ae78513 100644 --- a/layer4/connection.go +++ b/layer4/connection.go @@ -18,6 +18,7 @@ import ( "context" "errors" "net" + "os" "sync" "time" @@ -67,10 +68,11 @@ type Connection struct { Logger *zap.Logger - buf []byte // stores matching data - offset int - frozenOffset int - matching bool + buf []byte // stores matching data + offset int + frozenOffset int + matching bool + shouldPrefetchBeforeRead bool // indicates prefetch should be called before a Read bytesRead, bytesWritten uint64 } @@ -83,6 +85,16 @@ var ErrMatchingBufferFull = errors.New("matching buffer is full") // and once depleted (or if there isn't one), it continues // reading from the underlying connection. func (cx *Connection) Read(p []byte) (n int, err error) { + // Lazy prefetch to support protocols where the server speaks first, + // see https://github.com/mholt/caddy-l4/issues/228 & https://github.com/mholt/caddy-l4/issues/212 + if cx.matching && cx.shouldPrefetchBeforeRead { + err = cx.prefetch() + cx.shouldPrefetchBeforeRead = false + if err != nil { + return 0, err + } + } + // if we are matching and consumed the buffer exit with error if cx.matching && (len(cx.buf) == 0 || len(cx.buf) == cx.offset) { return 0, ErrConsumedAllPrefetchedBytes @@ -122,14 +134,16 @@ func (cx *Connection) Write(p []byte) (n int, err error) { // our Connection type (for example, `tls.Server()`). func (cx *Connection) Wrap(conn net.Conn) *Connection { return &Connection{ - Conn: conn, - Context: cx.Context, - Logger: cx.Logger, - buf: cx.buf, - offset: cx.offset, - matching: cx.matching, - bytesRead: cx.bytesRead, - bytesWritten: cx.bytesWritten, + Conn: conn, + Context: cx.Context, + Logger: cx.Logger, + buf: cx.buf, + offset: cx.offset, + frozenOffset: cx.frozenOffset, + matching: cx.matching, + shouldPrefetchBeforeRead: cx.shouldPrefetchBeforeRead, + bytesRead: cx.bytesRead, + bytesWritten: cx.bytesWritten, } } @@ -156,6 +170,9 @@ func (cx *Connection) prefetch() (err error) { cx.bytesRead += uint64(n) if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + err = ErrMatchingTimeout + } return err } @@ -215,8 +232,18 @@ func (cx *Connection) GetVar(key string) interface{} { // MatchingBytes returns all bytes currently available for matching. This is only intended for reading. // Do not write into the slice. It's a view of the internal buffer and you will likely mess up the connection. -func (cx *Connection) MatchingBytes() []byte { - return cx.buf[cx.offset:] +func (cx *Connection) MatchingBytes() ([]byte, error) { + // Lazy prefetch to support protocols where the server speaks first, + // see https://github.com/mholt/caddy-l4/issues/228 & https://github.com/mholt/caddy-l4/issues/212 + if cx.matching && cx.shouldPrefetchBeforeRead { + err := cx.prefetch() + cx.shouldPrefetchBeforeRead = false + if err != nil { + return nil, err + } + } + + return cx.buf[cx.offset:], nil } var ( diff --git a/layer4/listener.go b/layer4/listener.go index 451e749f..34f45988 100644 --- a/layer4/listener.go +++ b/layer4/listener.go @@ -172,7 +172,11 @@ func (l *listener) handle(conn net.Conn) { err = l.compiledRoute.Handle(cx) duration := time.Since(start) if err != nil && err != errHijacked { - l.logger.Error("handling connection", zap.Error(err)) + logFunc := l.logger.Error + if errors.Is(err, ErrMatchingTimeout) { + logFunc = l.logger.Warn + } + logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) } l.logger.Debug("connection stats", diff --git a/layer4/routes.go b/layer4/routes.go index 575d91b3..e811464c 100644 --- a/layer4/routes.go +++ b/layer4/routes.go @@ -18,7 +18,6 @@ import ( "encoding/json" "errors" "fmt" - "os" "time" "github.com/caddyserver/caddy/v2" @@ -107,54 +106,34 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio return err } - routeIdx := -1 // init with -1 because before first use we increment it - notMatchingRoutes := make(map[int]struct{}, len(routes)) - router: for i := 0; i < 10000; i++ { // Limit number of tries to mitigate endless matching bugs. // Do not call prefetch if this is the first loop iteration and there already is some data available, // since this means we are at the start of a subroute handler and previous prefetch calls likely already fetched all bytes available from the client. // Which means it would block the subroute handler. In the second iteration (if no subroute routes match) blocking is the correct behaviour. if i != 0 || cx.buf == nil || len(cx.buf[cx.offset:]) == 0 { - err = cx.prefetch() - if err != nil { - logFunc := logger.Error - if errors.Is(err, os.ErrDeadlineExceeded) { - err = ErrMatchingTimeout - logFunc = logger.Warn - } - logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) - return nil // return nil so the error does not get logged again - } + cx.shouldPrefetchBeforeRead = true } // Use a wrapping routeIdx similar to a container/ring to try routes in a strictly circular fashion. // After a match continue with the routes after the matched one, instead of starting at the beginning. // This is done for backwards compatibility with configs written before the "Non blocking matchers & matching timeout" rewrite. // See https://github.com/mholt/caddy-l4/pull/192 and https://github.com/mholt/caddy-l4/pull/192#issuecomment-2143681952. - for j := 0; j < len(routes); j++ { - routeIdx++ - if routeIdx >= len(routes) { - routeIdx = 0 - } - + for routeIdx, route := range routes { // Skip routes that signaled they definitely can not match if _, ok := notMatchingRoutes[routeIdx]; ok { continue } - route := routes[routeIdx] - // A route must match at least one of the matcher sets matched, err := route.matcherSets.AnyMatch(cx) if errors.Is(err, ErrConsumedAllPrefetchedBytes) { continue // ignore and try next route } if err != nil { - logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) - return nil + return err } if matched { // remove deadline after we matched @@ -179,7 +158,8 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio } err = handler.Handle(cx) if err != nil { - return err + logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) + return nil // return nil so the error does not get logged again } // If handler is terminal we stop routing @@ -196,9 +176,12 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio // For example if the current route required multiple prefetch calls until it matched. // Then routes with an index after the current one where also tried on this now old/consumed data. clear(notMatchingRoutes) - // We jump back to the router loop to call prefetch again after the match, - // because the handler likely consumed all data. - continue router + + // If all data was consumed by the handler + // enable prefetch and continue with next route + if len(cx.buf[cx.offset:]) == 0 { + cx.shouldPrefetchBeforeRead = true + } } else { // Remember to not try this route again notMatchingRoutes[routeIdx] = struct{}{} @@ -212,7 +195,6 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio } } - logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(errors.New("number of prefetch calls exhausted"))) - return nil + return errors.New("number of matching tries exhausted") }) } diff --git a/layer4/routes_test.go b/layer4/routes_test.go index c2324a95..a9020f98 100644 --- a/layer4/routes_test.go +++ b/layer4/routes_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "io" "net" "testing" "time" @@ -11,15 +12,35 @@ import ( "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "go.uber.org/zap/zaptest/observer" ) +type testIoMatcher struct { +} + +func (testIoMatcher) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.matchers.testIoMatcher", + New: func() caddy.Module { return new(testIoMatcher) }, + } +} + +func (m *testIoMatcher) Match(cx *Connection) (bool, error) { + buf := make([]byte, 1) + n, err := io.ReadFull(cx, buf) + return n > 0, err +} + func TestCompiledRouteTimeoutWorks(t *testing.T) { ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()}) defer cancel() - routes := RouteList{&Route{}} + caddy.RegisterModule(testIoMatcher{}) + + routes := RouteList{&Route{ + MatcherSetsRaw: caddyhttp.RawMatcherSets{ + caddy.ModuleMap{"testIoMatcher": json.RawMessage("{}")}, // any io using matcher + }, + }} err := routes.Provision(ctx) if err != nil { @@ -27,8 +48,8 @@ func TestCompiledRouteTimeoutWorks(t *testing.T) { } matched := false - loggerCore, logs := observer.New(zapcore.WarnLevel) - compiledRoutes := routes.Compile(zap.New(loggerCore), 5*time.Millisecond, + + compiledRoutes := routes.Compile(zap.NewNop(), 5*time.Millisecond, NextHandlerFunc(func(con *Connection, next Handler) error { matched = true return next.Handle(con) @@ -42,23 +63,8 @@ func TestCompiledRouteTimeoutWorks(t *testing.T) { defer cx.Close() err = compiledRoutes.Handle(cx) - if err != nil { - t.Fatalf("handle failed | %s", err) - } - - // verify the matching aborted error was logged - if logs.Len() != 1 { - t.Fatalf("logs should contain 1 entry but has %d", logs.Len()) - } - logEntry := logs.All()[0] - if logEntry.Level != zapcore.WarnLevel { - t.Fatalf("wrong log level | %s", logEntry.Level) - } - if logEntry.Message != "matching connection" { - t.Fatalf("wrong log message | %s", logEntry.Message) - } - if !(logEntry.Context[1].Key == "error" && errors.Is(logEntry.Context[1].Interface.(error), ErrMatchingTimeout)) { - t.Fatalf("wrong error | %v", logEntry.Context[1].Interface) + if !errors.Is(err, ErrMatchingTimeout) { + t.Fatalf("expected ErrMatchingTimeout but got %s", err) } // since matching failed no handler should be called diff --git a/layer4/server.go b/layer4/server.go index 6c5fd2c5..f54ae9d4 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -16,6 +16,7 @@ package layer4 import ( "bytes" + "errors" "fmt" "io" "net" @@ -166,7 +167,11 @@ func (s Server) handle(conn net.Conn) { err := s.compiledRoute.Handle(cx) duration := time.Since(start) if err != nil { - s.logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) + logFunc := s.logger.Error + if errors.Is(err, ErrMatchingTimeout) { + logFunc = s.logger.Warn + } + logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err)) } s.logger.Debug("connection stats", diff --git a/modules/l4http/httpmatcher.go b/modules/l4http/httpmatcher.go index f59f7997..93972f7f 100644 --- a/modules/l4http/httpmatcher.go +++ b/modules/l4http/httpmatcher.go @@ -82,7 +82,10 @@ func (m MatchHTTP) Match(cx *layer4.Connection) (bool, error) { if !ok { var err error - data := cx.MatchingBytes() + data, err := cx.MatchingBytes() + if err != nil { + return false, err + } match, err := m.isHttp(data) if !match { return match, err diff --git a/modules/l4http/httpmatcher_test.go b/modules/l4http/httpmatcher_test.go index 8c4efb51..c26a2585 100644 --- a/modules/l4http/httpmatcher_test.go +++ b/modules/l4http/httpmatcher_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/json" + "errors" "net" "testing" "time" @@ -52,7 +53,6 @@ func httpMatchTester(t *testing.T, matchers json.RawMessage, data []byte) (bool, })) err = compiledRoute.Handle(cx) - assertNoError(t, err) return matched, err } @@ -234,10 +234,12 @@ func TestHttpMatchingGarbage(t *testing.T) { matchers := json.RawMessage("[{\"host\":[\"localhost\"]}]") matched, err := httpMatchTester(t, matchers, []byte("not a valid http request")) - assertNoError(t, err) if matched { t.Fatalf("matcher did match") } + if !errors.Is(err, layer4.ErrMatchingTimeout) { + t.Fatalf("Unexpected error: %s\n", err) + } validHttp2MagicWithoutHeadersFrame, err := base64.StdEncoding.DecodeString("UFJJICogSFRUUC8yLjANCg0KU00NCg0KAAASBAAAAAAAAAMAAABkAAQCAAAAAAIAAAAATm8gbG9uZ2VyIHZhbGlkIGh0dHAyIHJlcXVlc3QgZnJhbWVz") assertNoError(t, err) @@ -245,6 +247,9 @@ func TestHttpMatchingGarbage(t *testing.T) { if matched { t.Fatalf("matcher did match") } + if !errors.Is(err, layer4.ErrMatchingTimeout) { + t.Fatalf("Unexpected error: %s\n", err) + } } func TestMatchHTTP_isHttp(t *testing.T) {