Skip to content

Commit

Permalink
feat: lazy prefetch, fixes protocols where the server speaks first
Browse files Browse the repository at this point in the history
  • Loading branch information
ydylla committed Aug 5, 2024
1 parent 154bf6f commit d3998f5
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 71 deletions.
55 changes: 41 additions & 14 deletions layer4/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"net"
"os"
"sync"
"time"

Expand Down Expand Up @@ -67,10 +68,11 @@ type Connection struct {

Logger *zap.Logger

buf []byte // stores matching data
offset int
frozenOffset int
matching bool
buf []byte // stores matching data
offset int
frozenOffset int
matching bool
shouldPrefetchBeforeRead bool // indicates prefetch should be called before a Read

bytesRead, bytesWritten uint64
}
Expand All @@ -83,6 +85,16 @@ var ErrMatchingBufferFull = errors.New("matching buffer is full")
// and once depleted (or if there isn't one), it continues
// reading from the underlying connection.
func (cx *Connection) Read(p []byte) (n int, err error) {
// Lazy prefetch to support protocols where the server speaks first,
// see https://github.com/mholt/caddy-l4/issues/228 & https://github.com/mholt/caddy-l4/issues/212
if cx.matching && cx.shouldPrefetchBeforeRead {
err = cx.prefetch()
cx.shouldPrefetchBeforeRead = false
if err != nil {
return 0, err
}
}

// if we are matching and consumed the buffer exit with error
if cx.matching && (len(cx.buf) == 0 || len(cx.buf) == cx.offset) {
return 0, ErrConsumedAllPrefetchedBytes
Expand Down Expand Up @@ -122,14 +134,16 @@ func (cx *Connection) Write(p []byte) (n int, err error) {
// our Connection type (for example, `tls.Server()`).
func (cx *Connection) Wrap(conn net.Conn) *Connection {
return &Connection{
Conn: conn,
Context: cx.Context,
Logger: cx.Logger,
buf: cx.buf,
offset: cx.offset,
matching: cx.matching,
bytesRead: cx.bytesRead,
bytesWritten: cx.bytesWritten,
Conn: conn,
Context: cx.Context,
Logger: cx.Logger,
buf: cx.buf,
offset: cx.offset,
frozenOffset: cx.frozenOffset,
matching: cx.matching,
shouldPrefetchBeforeRead: cx.shouldPrefetchBeforeRead,
bytesRead: cx.bytesRead,
bytesWritten: cx.bytesWritten,
}
}

Expand All @@ -156,6 +170,9 @@ func (cx *Connection) prefetch() (err error) {
cx.bytesRead += uint64(n)

if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrMatchingTimeout
}
return err
}

Expand Down Expand Up @@ -215,8 +232,18 @@ func (cx *Connection) GetVar(key string) interface{} {

// MatchingBytes returns all bytes currently available for matching. This is only intended for reading.
// Do not write into the slice. It's a view of the internal buffer and you will likely mess up the connection.
func (cx *Connection) MatchingBytes() []byte {
return cx.buf[cx.offset:]
func (cx *Connection) MatchingBytes() ([]byte, error) {
// Lazy prefetch to support protocols where the server speaks first,
// see https://github.com/mholt/caddy-l4/issues/228 & https://github.com/mholt/caddy-l4/issues/212
if cx.matching && cx.shouldPrefetchBeforeRead {
err := cx.prefetch()
cx.shouldPrefetchBeforeRead = false
if err != nil {
return nil, err
}
}

return cx.buf[cx.offset:], nil
}

var (
Expand Down
6 changes: 5 additions & 1 deletion layer4/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ func (l *listener) handle(conn net.Conn) {
err = l.compiledRoute.Handle(cx)
duration := time.Since(start)
if err != nil && err != errHijacked {
l.logger.Error("handling connection", zap.Error(err))
logFunc := l.logger.Error
if errors.Is(err, ErrMatchingTimeout) {
logFunc = l.logger.Warn
}
logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
}

l.logger.Debug("connection stats",
Expand Down
42 changes: 12 additions & 30 deletions layer4/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -107,54 +106,34 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio
return err
}

routeIdx := -1 // init with -1 because before first use we increment it

notMatchingRoutes := make(map[int]struct{}, len(routes))

router:
for i := 0; i < 10000; i++ { // Limit number of tries to mitigate endless matching bugs.

// Do not call prefetch if this is the first loop iteration and there already is some data available,
// since this means we are at the start of a subroute handler and previous prefetch calls likely already fetched all bytes available from the client.
// Which means it would block the subroute handler. In the second iteration (if no subroute routes match) blocking is the correct behaviour.
if i != 0 || cx.buf == nil || len(cx.buf[cx.offset:]) == 0 {
err = cx.prefetch()
if err != nil {
logFunc := logger.Error
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrMatchingTimeout
logFunc = logger.Warn
}
logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
return nil // return nil so the error does not get logged again
}
cx.shouldPrefetchBeforeRead = true
}

// Use a wrapping routeIdx similar to a container/ring to try routes in a strictly circular fashion.
// After a match continue with the routes after the matched one, instead of starting at the beginning.
// This is done for backwards compatibility with configs written before the "Non blocking matchers & matching timeout" rewrite.
// See https://github.com/mholt/caddy-l4/pull/192 and https://github.com/mholt/caddy-l4/pull/192#issuecomment-2143681952.
for j := 0; j < len(routes); j++ {
routeIdx++
if routeIdx >= len(routes) {
routeIdx = 0
}

for routeIdx, route := range routes {
// Skip routes that signaled they definitely can not match
if _, ok := notMatchingRoutes[routeIdx]; ok {
continue
}

route := routes[routeIdx]

// A route must match at least one of the matcher sets
matched, err := route.matcherSets.AnyMatch(cx)
if errors.Is(err, ErrConsumedAllPrefetchedBytes) {
continue // ignore and try next route
}
if err != nil {
logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
return nil
return err
}
if matched {
// remove deadline after we matched
Expand All @@ -179,7 +158,8 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio
}
err = handler.Handle(cx)
if err != nil {
return err
logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
return nil // return nil so the error does not get logged again
}

// If handler is terminal we stop routing
Expand All @@ -196,9 +176,12 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio
// For example if the current route required multiple prefetch calls until it matched.
// Then routes with an index after the current one where also tried on this now old/consumed data.
clear(notMatchingRoutes)
// We jump back to the router loop to call prefetch again after the match,
// because the handler likely consumed all data.
continue router

// If all data was consumed by the handler
// enable prefetch and continue with next route
if len(cx.buf[cx.offset:]) == 0 {
cx.shouldPrefetchBeforeRead = true
}
} else {
// Remember to not try this route again
notMatchingRoutes[routeIdx] = struct{}{}
Expand All @@ -212,7 +195,6 @@ func (routes RouteList) Compile(logger *zap.Logger, matchingTimeout time.Duratio
}
}

logger.Error("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(errors.New("number of prefetch calls exhausted")))
return nil
return errors.New("number of matching tries exhausted")
})
}
50 changes: 28 additions & 22 deletions layer4/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,52 @@ import (
"context"
"encoding/json"
"errors"
"io"
"net"
"testing"
"time"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"go.uber.org/zap/zaptest/observer"
)

type testIoMatcher struct {
}

func (testIoMatcher) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "layer4.matchers.testIoMatcher",
New: func() caddy.Module { return new(testIoMatcher) },
}
}

func (m *testIoMatcher) Match(cx *Connection) (bool, error) {
buf := make([]byte, 1)
n, err := io.ReadFull(cx, buf)
return n > 0, err
}

func TestCompiledRouteTimeoutWorks(t *testing.T) {
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
defer cancel()

routes := RouteList{&Route{}}
caddy.RegisterModule(testIoMatcher{})

routes := RouteList{&Route{
MatcherSetsRaw: caddyhttp.RawMatcherSets{
caddy.ModuleMap{"testIoMatcher": json.RawMessage("{}")}, // any io using matcher
},
}}

err := routes.Provision(ctx)
if err != nil {
t.Fatalf("provision failed | %s", err)
}

matched := false
loggerCore, logs := observer.New(zapcore.WarnLevel)
compiledRoutes := routes.Compile(zap.New(loggerCore), 5*time.Millisecond,

compiledRoutes := routes.Compile(zap.NewNop(), 5*time.Millisecond,
NextHandlerFunc(func(con *Connection, next Handler) error {
matched = true
return next.Handle(con)
Expand All @@ -42,23 +63,8 @@ func TestCompiledRouteTimeoutWorks(t *testing.T) {
defer cx.Close()

err = compiledRoutes.Handle(cx)
if err != nil {
t.Fatalf("handle failed | %s", err)
}

// verify the matching aborted error was logged
if logs.Len() != 1 {
t.Fatalf("logs should contain 1 entry but has %d", logs.Len())
}
logEntry := logs.All()[0]
if logEntry.Level != zapcore.WarnLevel {
t.Fatalf("wrong log level | %s", logEntry.Level)
}
if logEntry.Message != "matching connection" {
t.Fatalf("wrong log message | %s", logEntry.Message)
}
if !(logEntry.Context[1].Key == "error" && errors.Is(logEntry.Context[1].Interface.(error), ErrMatchingTimeout)) {
t.Fatalf("wrong error | %v", logEntry.Context[1].Interface)
if !errors.Is(err, ErrMatchingTimeout) {
t.Fatalf("expected ErrMatchingTimeout but got %s", err)
}

// since matching failed no handler should be called
Expand Down
7 changes: 6 additions & 1 deletion layer4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package layer4

import (
"bytes"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -166,7 +167,11 @@ func (s Server) handle(conn net.Conn) {
err := s.compiledRoute.Handle(cx)
duration := time.Since(start)
if err != nil {
s.logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
logFunc := s.logger.Error
if errors.Is(err, ErrMatchingTimeout) {
logFunc = s.logger.Warn
}
logFunc("matching connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
}

s.logger.Debug("connection stats",
Expand Down
5 changes: 4 additions & 1 deletion modules/l4http/httpmatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ func (m MatchHTTP) Match(cx *layer4.Connection) (bool, error) {
if !ok {
var err error

data := cx.MatchingBytes()
data, err := cx.MatchingBytes()
if err != nil {
return false, err
}
match, err := m.isHttp(data)
if !match {
return match, err
Expand Down
9 changes: 7 additions & 2 deletions modules/l4http/httpmatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"net"
"testing"
"time"
Expand Down Expand Up @@ -52,7 +53,6 @@ func httpMatchTester(t *testing.T, matchers json.RawMessage, data []byte) (bool,
}))

err = compiledRoute.Handle(cx)
assertNoError(t, err)

return matched, err
}
Expand Down Expand Up @@ -234,17 +234,22 @@ func TestHttpMatchingGarbage(t *testing.T) {
matchers := json.RawMessage("[{\"host\":[\"localhost\"]}]")

matched, err := httpMatchTester(t, matchers, []byte("not a valid http request"))
assertNoError(t, err)
if matched {
t.Fatalf("matcher did match")
}
if !errors.Is(err, layer4.ErrMatchingTimeout) {
t.Fatalf("Unexpected error: %s\n", err)
}

validHttp2MagicWithoutHeadersFrame, err := base64.StdEncoding.DecodeString("UFJJICogSFRUUC8yLjANCg0KU00NCg0KAAASBAAAAAAAAAMAAABkAAQCAAAAAAIAAAAATm8gbG9uZ2VyIHZhbGlkIGh0dHAyIHJlcXVlc3QgZnJhbWVz")
assertNoError(t, err)
matched, err = httpMatchTester(t, matchers, validHttp2MagicWithoutHeadersFrame)
if matched {
t.Fatalf("matcher did match")
}
if !errors.Is(err, layer4.ErrMatchingTimeout) {
t.Fatalf("Unexpected error: %s\n", err)
}
}

func TestMatchHTTP_isHttp(t *testing.T) {
Expand Down

0 comments on commit d3998f5

Please sign in to comment.