From e00a5a705d28b48d320db916525706eb5b05d40a Mon Sep 17 00:00:00 2001 From: Jason Tackaberry Date: Sun, 13 Aug 2023 19:13:05 -0400 Subject: [PATCH 1/3] Close upstream UDP conns when downstream is closed --- modules/l4proxy/proxy.go | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/modules/l4proxy/proxy.go b/modules/l4proxy/proxy.go index 02160ce..d976e57 100644 --- a/modules/l4proxy/proxy.go +++ b/modules/l4proxy/proxy.go @@ -23,6 +23,7 @@ import ( "net" "runtime/debug" "sync" + "sync/atomic" "time" "github.com/caddyserver/caddy/v2" @@ -253,6 +254,7 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) { } var wg sync.WaitGroup + var downClosed atomic.Bool for _, up := range upConns { wg.Add(1) @@ -261,11 +263,16 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) { defer wg.Done() if _, err := io.Copy(down, up); err != nil { - h.logger.Error("upstream connection", - zap.String("local_address", up.LocalAddr().String()), - zap.String("remote_address", up.RemoteAddr().String()), - zap.Error(err), - ) + // If the downstream connection has been closed, we can assume this is + // the reason io.Copy() errored. That's normal operation for UDP + // connections after idle timeout, so don't log an error in that case. + if !downClosed.Load() { + h.logger.Error("upstream connection", + zap.String("local_address", up.LocalAddr().String()), + zap.String("remote_address", up.RemoteAddr().String()), + zap.Error(err), + ) + } } }(up) } @@ -280,9 +287,18 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) { // Shut down the writing side of all upstream connections, in case // that the downstream connection is half closed. (issue #40) + // + // UDP connections meanwhile don't implement CloseWrite(), but in order + // to ensure io.Copy() in the per-upstream goroutines (above) returns, + // we need to close the socket. This will cause io.Copy() return an + // error, which in this particular case is expected, so we signal the + // intentional closure by setting this flag. + downClosed.Store(true) for _, up := range upConns { if conn, ok := up.(closeWriter); ok { _ = conn.CloseWrite() + } else { + up.Close() } } }() From ca42e7ed472fa640d219aac54fb66606583e9495 Mon Sep 17 00:00:00 2001 From: Jason Tackaberry Date: Sun, 13 Aug 2023 19:13:51 -0400 Subject: [PATCH 2/3] Ensure UDP buffers are sufficiently sized --- layer4/server.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/layer4/server.go b/layer4/server.go index 6f1609f..06cd1b9 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -143,6 +143,10 @@ func (pc packetConn) RemoteAddr() net.Addr { return pc.addr } var udpBufPool = sync.Pool{ New: func() interface{} { - return make([]byte, 1024) + // Buffers need to be as large as the largest datagram we'll consume, because + // ReadFrom() can't resume partial reads. (This is standard for UDP + // sockets on *nix.) So our buffer sizes are 9000 bytes to accommodate + // networks with jumbo frames. See also https://github.com/golang/go/issues/18056 + return make([]byte, 9000) }, } From b91d7e8fb9057bfebdbe68892fb167f9d8efb6ef Mon Sep 17 00:00:00 2001 From: Jason Tackaberry Date: Sun, 13 Aug 2023 19:18:10 -0400 Subject: [PATCH 3/3] Use single goroutine per UDP connection --- layer4/server.go | 160 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 140 insertions(+), 20 deletions(-) diff --git a/layer4/server.go b/layer4/server.go index 06cd1b9..8463e31 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -17,6 +17,7 @@ package layer4 import ( "bytes" "fmt" + "io" "net" "sync" "time" @@ -76,23 +77,69 @@ func (s Server) serve(ln net.Listener) error { } func (s Server) servePacket(pc net.PacketConn) error { + // Spawn a goroutine whose only job is to consume packets from the socket + // and send to the packets channel. + packets := make(chan packet, 10) + go func(packets chan packet) { + for { + buf := udpBufPool.Get().([]byte) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + packets <- packet{err: err} + return + } + packets <- packet{ + pooledBuf: buf, + n: n, + addr: addr, + } + } + }(packets) + + // udpConns tracks active packetConns by downstream address:port. They will + // be removed from this map after being closed. + udpConns := make(map[string]*packetConn) + // closeCh is used to receive notifications of socket closures from + // packetConn, which allows us to to remove stale connections (whose + // proxy handlers have completed) from the udpConns map. + closeCh := make(chan string, 10) for { - buf := udpBufPool.Get().([]byte) - n, addr, err := pc.ReadFrom(buf) - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - continue + select { + case addr := <-closeCh: + // UDP connection is closed (either implicitly through timeout or by + // explicit call to Close()). + delete(udpConns, addr) + + case pkt := <-packets: + if pkt.err != nil { + return pkt.err } - return err + conn, ok := udpConns[pkt.addr.String()] + if !ok { + // No existing proxy handler is running for this downstream. + // Create one now. + conn = &packetConn{ + PacketConn: pc, + readCh: make(chan *packet, 5), + addr: pkt.addr, + closeCh: closeCh, + } + udpConns[pkt.addr.String()] = conn + go func(conn *packetConn) { + s.handle(conn) + // It might seem cleaner to send to closeCh here rather than + // in packetConn, but doing it earlier in packetConn closes + // the gap between the proxy handler shutting down and new + // packets coming in from the same downstream. Should that + // happen, we'll just spin up a new handler concurrent to + // the old one shutting down. + }(conn) + } + conn.readCh <- &pkt } - go func(buf []byte, n int, addr net.Addr) { - defer udpBufPool.Put(buf) - s.handle(packetConn{ - PacketConn: pc, - buf: bytes.NewBuffer(buf[:n]), - addr: addr, - }) - }(buf, n, addr) } } @@ -120,22 +167,95 @@ func (s Server) handle(conn net.Conn) { ) } +type packet struct { + // The underlying bytes slice that was gotten from udpBufPool. It's up to + // packetConn to return it to udpBufPool once it's consumed. + pooledBuf []byte + // Number of bytes read from socket + n int + // Error that occurred while reading from socket + err error + // Address of downstream + addr net.Addr +} + type packetConn struct { net.PacketConn - buf *bytes.Buffer - addr net.Addr + addr net.Addr + readCh chan *packet + closeCh chan string + // If not nil, then the previous Read() call didn't consume all the data + // from the buffer, and this packet will be reused in the next Read() + // without waiting for readCh. + lastPacket *packet + lastBuf *bytes.Buffer } -func (pc packetConn) Read(b []byte) (n int, err error) { - return pc.buf.Read(b) +func (pc *packetConn) Read(b []byte) (n int, err error) { + if pc.lastPacket != nil { + // There is a partial buffer to continue reading from the previous + // packet. + n, err = pc.lastBuf.Read(b) + if pc.lastBuf.Len() == 0 { + udpBufPool.Put(pc.lastPacket.pooledBuf) + pc.lastPacket = nil + pc.lastBuf = nil + } + return + } + select { + case pkt := <-pc.readCh: + if pkt == nil { + // Channel is closed. Return EOF below. + break + } + buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n]) + n, err = buf.Read(b) + if buf.Len() == 0 { + // Buffer fully consumed, release it. + udpBufPool.Put(pkt.pooledBuf) + } else { + // Buffer only partially consumed. Keep track of it for + // next Read() call. + pc.lastPacket = pkt + pc.lastBuf = buf + } + return + // TODO: idle timeout should be configurable per server + case <-time.After(30 * time.Second): + break + } + // Idle timeout simulates socket closure. + // + // Although Close() also does this, we inform the server loop early about + // the closure to ensure that if any new packets are received from this + // connection in the meantime, a new handler will be started. + pc.closeCh <- pc.addr.String() + // Returning EOF here ensures that io.Copy() waiting on the downstream for + // reads will terminate. + return 0, io.EOF } func (pc packetConn) Write(b []byte) (n int, err error) { return pc.PacketConn.WriteTo(b, pc.addr) } -func (pc packetConn) Close() error { - // Do nothing, we don't want to close the UDP server +func (pc *packetConn) Close() error { + if pc.lastPacket != nil { + udpBufPool.Put(pc.lastPacket.pooledBuf) + pc.lastPacket = nil + } + // This will abort any active Read() from another goroutine and return EOF + close(pc.readCh) + // Drain pending packets to ensure we release buffers back to the pool + for pkt := range pc.readCh { + udpBufPool.Put(pkt.pooledBuf) + } + // We may have already done this earlier in Read(), but just in case + // Read() wasn't being called, (re-)notify server loop we're closed. + pc.closeCh <- pc.addr.String() + // We don't call net.PacketConn.Close() here as we would stop the UDP + // server. return nil }