From b91d7e8fb9057bfebdbe68892fb167f9d8efb6ef Mon Sep 17 00:00:00 2001 From: Jason Tackaberry Date: Sun, 13 Aug 2023 19:18:10 -0400 Subject: [PATCH] Use single goroutine per UDP connection --- layer4/server.go | 160 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 140 insertions(+), 20 deletions(-) diff --git a/layer4/server.go b/layer4/server.go index 06cd1b9..8463e31 100644 --- a/layer4/server.go +++ b/layer4/server.go @@ -17,6 +17,7 @@ package layer4 import ( "bytes" "fmt" + "io" "net" "sync" "time" @@ -76,23 +77,69 @@ func (s Server) serve(ln net.Listener) error { } func (s Server) servePacket(pc net.PacketConn) error { + // Spawn a goroutine whose only job is to consume packets from the socket + // and send to the packets channel. + packets := make(chan packet, 10) + go func(packets chan packet) { + for { + buf := udpBufPool.Get().([]byte) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + packets <- packet{err: err} + return + } + packets <- packet{ + pooledBuf: buf, + n: n, + addr: addr, + } + } + }(packets) + + // udpConns tracks active packetConns by downstream address:port. They will + // be removed from this map after being closed. + udpConns := make(map[string]*packetConn) + // closeCh is used to receive notifications of socket closures from + // packetConn, which allows us to to remove stale connections (whose + // proxy handlers have completed) from the udpConns map. + closeCh := make(chan string, 10) for { - buf := udpBufPool.Get().([]byte) - n, addr, err := pc.ReadFrom(buf) - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - continue + select { + case addr := <-closeCh: + // UDP connection is closed (either implicitly through timeout or by + // explicit call to Close()). + delete(udpConns, addr) + + case pkt := <-packets: + if pkt.err != nil { + return pkt.err } - return err + conn, ok := udpConns[pkt.addr.String()] + if !ok { + // No existing proxy handler is running for this downstream. + // Create one now. + conn = &packetConn{ + PacketConn: pc, + readCh: make(chan *packet, 5), + addr: pkt.addr, + closeCh: closeCh, + } + udpConns[pkt.addr.String()] = conn + go func(conn *packetConn) { + s.handle(conn) + // It might seem cleaner to send to closeCh here rather than + // in packetConn, but doing it earlier in packetConn closes + // the gap between the proxy handler shutting down and new + // packets coming in from the same downstream. Should that + // happen, we'll just spin up a new handler concurrent to + // the old one shutting down. + }(conn) + } + conn.readCh <- &pkt } - go func(buf []byte, n int, addr net.Addr) { - defer udpBufPool.Put(buf) - s.handle(packetConn{ - PacketConn: pc, - buf: bytes.NewBuffer(buf[:n]), - addr: addr, - }) - }(buf, n, addr) } } @@ -120,22 +167,95 @@ func (s Server) handle(conn net.Conn) { ) } +type packet struct { + // The underlying bytes slice that was gotten from udpBufPool. It's up to + // packetConn to return it to udpBufPool once it's consumed. + pooledBuf []byte + // Number of bytes read from socket + n int + // Error that occurred while reading from socket + err error + // Address of downstream + addr net.Addr +} + type packetConn struct { net.PacketConn - buf *bytes.Buffer - addr net.Addr + addr net.Addr + readCh chan *packet + closeCh chan string + // If not nil, then the previous Read() call didn't consume all the data + // from the buffer, and this packet will be reused in the next Read() + // without waiting for readCh. + lastPacket *packet + lastBuf *bytes.Buffer } -func (pc packetConn) Read(b []byte) (n int, err error) { - return pc.buf.Read(b) +func (pc *packetConn) Read(b []byte) (n int, err error) { + if pc.lastPacket != nil { + // There is a partial buffer to continue reading from the previous + // packet. + n, err = pc.lastBuf.Read(b) + if pc.lastBuf.Len() == 0 { + udpBufPool.Put(pc.lastPacket.pooledBuf) + pc.lastPacket = nil + pc.lastBuf = nil + } + return + } + select { + case pkt := <-pc.readCh: + if pkt == nil { + // Channel is closed. Return EOF below. + break + } + buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n]) + n, err = buf.Read(b) + if buf.Len() == 0 { + // Buffer fully consumed, release it. + udpBufPool.Put(pkt.pooledBuf) + } else { + // Buffer only partially consumed. Keep track of it for + // next Read() call. + pc.lastPacket = pkt + pc.lastBuf = buf + } + return + // TODO: idle timeout should be configurable per server + case <-time.After(30 * time.Second): + break + } + // Idle timeout simulates socket closure. + // + // Although Close() also does this, we inform the server loop early about + // the closure to ensure that if any new packets are received from this + // connection in the meantime, a new handler will be started. + pc.closeCh <- pc.addr.String() + // Returning EOF here ensures that io.Copy() waiting on the downstream for + // reads will terminate. + return 0, io.EOF } func (pc packetConn) Write(b []byte) (n int, err error) { return pc.PacketConn.WriteTo(b, pc.addr) } -func (pc packetConn) Close() error { - // Do nothing, we don't want to close the UDP server +func (pc *packetConn) Close() error { + if pc.lastPacket != nil { + udpBufPool.Put(pc.lastPacket.pooledBuf) + pc.lastPacket = nil + } + // This will abort any active Read() from another goroutine and return EOF + close(pc.readCh) + // Drain pending packets to ensure we release buffers back to the pool + for pkt := range pc.readCh { + udpBufPool.Put(pkt.pooledBuf) + } + // We may have already done this earlier in Read(), but just in case + // Read() wasn't being called, (re-)notify server loop we're closed. + pc.closeCh <- pc.addr.String() + // We don't call net.PacketConn.Close() here as we would stop the UDP + // server. return nil }