Skip to content

Commit

Permalink
backend, net: track traffic for each client connection (pingcap#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored and xhebox committed Mar 7, 2023
1 parent ff86891 commit c928982
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 21 deletions.
57 changes: 38 additions & 19 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ type BackendConnManager struct {
closeStatus atomic.Int32
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
clientIO *pnet.PacketIO
backendIO *pnet.PacketIO
backendTLS *tls.Config
handshakeHandler HandshakeHandler
ctxmap sync.Map
clientAddr string
connectionID uint64
}

Expand Down Expand Up @@ -147,7 +147,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe

mgr.backendTLS = backendTLSConfig

mgr.clientAddr = clientIO.RemoteAddr().String()
mgr.clientIO = clientIO
err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err)
if err != nil {
Expand All @@ -158,7 +158,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
childCtx, cancelFunc := context.WithCancel(ctx)
mgr.cancelFunc = cancelFunc
mgr.wg.Run(func() {
mgr.processSignals(childCtx, clientIO)
mgr.processSignals(childCtx)
})
return nil
}
Expand Down Expand Up @@ -219,7 +219,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato

// ExecuteCmd forwards messages between the client and the backend.
// If it finds that the session is ready for redirection, it migrates the session.
func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, clientIO *pnet.PacketIO) error {
func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) error {
if len(request) < 1 {
return mysql.ErrMalformPacket
}
Expand All @@ -233,7 +233,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
return nil
}
waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil
holdRequest, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, waitingRedirect)
holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect)
if !holdRequest {
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
}
Expand Down Expand Up @@ -269,17 +269,17 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c
// Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements.
if mgr.cmdProcessor.finishedTxn() {
if waitingRedirect && holdRequest {
mgr.tryRedirect(ctx, clientIO)
mgr.tryRedirect(ctx)
// Execute the held request no matter redirection succeeds or not.
_, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, false)
_, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, false)
addCmdMetrics(cmd, mgr.ServerAddr(), startTime)
if err != nil && !IsMySQLError(err) {
return err
}
} else if mgr.closeStatus.Load() == statusNotifyClose {
mgr.tryGracefulClose(ctx, clientIO)
mgr.tryGracefulClose(ctx)
} else if waitingRedirect {
mgr.tryRedirect(ctx, clientIO)
mgr.tryRedirect(ctx)
}
}
// Ignore MySQL errors, only return unexpected errors.
Expand Down Expand Up @@ -325,17 +325,17 @@ func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken
// processSignals runs in a goroutine to:
// - Receive redirection signals and then try to migrate the session.
// - Send redirection results to the event receiver.
func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pnet.PacketIO) {
func (mgr *BackendConnManager) processSignals(ctx context.Context) {
for {
select {
case s := <-mgr.signalReceived:
// Redirect the session immediately just in case the session is finishedTxn.
mgr.processLock.Lock()
switch s {
case signalTypeGracefulClose:
mgr.tryGracefulClose(ctx, clientIO)
mgr.tryGracefulClose(ctx)
case signalTypeRedirect:
mgr.tryRedirect(ctx, clientIO)
mgr.tryRedirect(ctx)
}
mgr.processLock.Unlock()
case rs := <-mgr.redirectResCh:
Expand All @@ -348,7 +348,7 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pne

// tryRedirect tries to migrate the session if the session is redirect-able.
// NOTE: processLock should be held before calling this function.
func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.PacketIO) {
func (mgr *BackendConnManager) tryRedirect(ctx context.Context) {
switch mgr.closeStatus.Load() {
case statusNotifyClose, statusClosing, statusClosed:
return
Expand Down Expand Up @@ -389,8 +389,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P
}
newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to))

mgr.clientAddr = clientIO.RemoteAddr().String()
if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil {
if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil {
rs.err = mgr.initSessionStates(newBackendIO, sessionStates)
} else {
mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err)
Expand Down Expand Up @@ -475,28 +474,48 @@ func (mgr *BackendConnManager) GracefulClose() {
mgr.signalReceived <- signalTypeGracefulClose
}

func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *pnet.PacketIO) {
func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) {
if mgr.closeStatus.Load() != statusNotifyClose {
return
}
if !mgr.cmdProcessor.finishedTxn() {
return
}
// Closing clientIO will cause the whole connection to be closed.
if err := clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.RemoteAddr()), zap.Error(err))
if err := mgr.clientIO.GracefulClose(); err != nil {
mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err))
}
mgr.closeStatus.Store(statusClosing)
}

func (mgr *BackendConnManager) ClientAddr() string {
return mgr.clientAddr
if mgr.clientIO == nil {
return ""
}
return mgr.clientIO.RemoteAddr().String()
}

func (mgr *BackendConnManager) ServerAddr() string {
if mgr.backendIO == nil {
return ""
}
return mgr.backendIO.RemoteAddr().String()
}

func (mgr *BackendConnManager) ClientInBytes() uint64 {
if mgr.clientIO == nil {
return 0
}
return mgr.clientIO.InBytes()
}

func (mgr *BackendConnManager) ClientOutBytes() uint64 {
if mgr.clientIO == nil {
return 0
}
return mgr.clientIO.OutBytes()
}

func (mgr *BackendConnManager) SetValue(key, val any) {
mgr.ctxmap.Store(key, val)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (ts *backendMgrTester) forwardCmd4Proxy(clientIO, backendIO *pnet.PacketIO)
require.NoError(ts.t, err)
prevCounter, err := readCmdCounter(request[0], ts.tc.backendListener.Addr().String())
require.NoError(ts.t, err)
rsErr := ts.mp.ExecuteCmd(context.Background(), request, clientIO)
rsErr := ts.mp.ExecuteCmd(context.Background(), request)
curCounter, err := readCmdCounter(request[0], ts.tc.backendListener.Addr().String())
require.NoError(ts.t, err)
require.Equal(ts.t, prevCounter+1, curCounter)
Expand Down Expand Up @@ -225,6 +225,8 @@ func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error {
func (ts *backendMgrTester) runTests(runners []runner) {
for _, runner := range runners {
ts.runAndCheck(ts.t, nil, runner.client, runner.backend, runner.proxy)
require.Equal(ts.t, ts.tc.clientIO.InBytes(), ts.mp.ClientOutBytes())
require.Equal(ts.t, ts.tc.clientIO.OutBytes(), ts.mp.ClientInBytes())
}
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil)
type ConnContext interface {
ClientAddr() string
ServerAddr() string
ClientInBytes() uint64
ClientOutBytes() uint64
SetValue(key, val any)
Value(key any) any
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/proxy/backend/testsuite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendR
// Ensure all the packets are forwarded.
msg := fmt.Sprintf("cmd:%d responseType:%d", ts.mc.cmd, ts.mb.respondType)
require.Equal(t, ts.tc.backendIO.GetSequence(), ts.tc.clientIO.GetSequence(), msg)
require.Equal(t, ts.tc.clientIO.OutBytes(), ts.tc.proxyCIO.InBytes(), msg)
require.Equal(t, ts.tc.clientIO.InBytes(), ts.tc.proxyCIO.OutBytes(), msg)
}
} else {
c(t, ts)
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (cc *ClientConnection) processMsg(ctx context.Context) error {
if err != nil {
return err
}
err = cc.connMgr.ExecuteCmd(ctx, clientPkt, cc.pkt)
err = cc.connMgr.ExecuteCmd(ctx, clientPkt)
if err != nil {
return err
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func (f *rdbufConn) Read(b []byte) (int, error) {

// PacketIO is a helper to read and write sql and proxy protocol.
type PacketIO struct {
inBytes uint64
outBytes uint64
conn net.Conn
buf *bufio.ReadWriter
proxyInited *atomic.Bool
Expand Down Expand Up @@ -142,6 +144,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += 4

// probe proxy V2
refill := false
Expand All @@ -164,6 +167,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
if _, err := io.ReadFull(p.conn, header[:]); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += 4
}

sequence := header[3]
Expand All @@ -177,6 +181,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
if _, err := io.ReadFull(p.conn, data); err != nil {
return nil, false, errors.Wrap(ErrReadConn, err)
}
p.inBytes += uint64(length)
return data, length == mysql.MaxPayloadLen, nil
}

Expand Down Expand Up @@ -214,10 +219,12 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) {
if _, err := io.Copy(p.buf, bytes.NewReader(header[:])); err != nil {
return 0, more, errors.Wrap(ErrWriteConn, err)
}
p.outBytes += 4

if _, err := io.Copy(p.buf, bytes.NewReader(data[:length])); err != nil {
return 0, more, errors.Wrap(ErrWriteConn, err)
}
p.outBytes += uint64(length)

return length, more, nil
}
Expand All @@ -239,6 +246,14 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
return nil
}

func (p *PacketIO) InBytes() uint64 {
return p.inBytes
}

func (p *PacketIO) OutBytes() uint64 {
return p.outBytes
}

func (p *PacketIO) Flush() error {
if err := p.buf.Flush(); err != nil {
return p.wrapErr(errors.Wrap(ErrFlushConn, err))
Expand Down
6 changes: 6 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ func TestPacketIO(t *testing.T) {
// send anything
require.NoError(t, cli.WritePacket(expectMsg, true))

outBytes := len(expectMsg) + 4
for _, l := range pktLengths {
require.NoError(t, cli.WritePacket(make([]byte, l), true))
outBytes += l + (l/(mysql.MaxPayloadLen)+1)*4
require.Equal(t, uint64(outBytes), cli.OutBytes())
}

// skip handshake
Expand All @@ -116,10 +119,13 @@ func TestPacketIO(t *testing.T) {
require.NoError(t, err)
require.Equal(t, expectMsg, msg)

inBytes := len(expectMsg) + 4
for _, l := range pktLengths {
msg, err = srv.ReadPacket()
require.NoError(t, err)
require.Equal(t, l, len(msg))
inBytes += l + (l/(mysql.MaxPayloadLen)+1)*4
require.Equal(t, uint64(inBytes), srv.InBytes())
}

// send handshake
Expand Down
4 changes: 4 additions & 0 deletions pkg/proxy/net/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) {
if err != nil {
return nil, errors.WithStack(errors.Wrap(ErrReadConn, err))
}
p.inBytes += 8

var hdr [4]byte

if _, err := io.ReadFull(p.buf, hdr[:]); err != nil {
return nil, errors.WithStack(err)
}
p.inBytes += 4

m := &Proxy{}
m.Version = ProxyVersion(hdr[0] >> 4)
Expand All @@ -182,6 +184,7 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) {
if _, err := io.ReadFull(p.buf, buf); err != nil {
return nil, errors.WithStack(err)
}
p.inBytes += uint64(len(buf))

addressFamily := ProxyAddressFamily(hdr[1] >> 4)
network := ProxyNetwork(hdr[1] & 0xF)
Expand Down Expand Up @@ -285,6 +288,7 @@ func (p *PacketIO) WriteProxyV2(m *Proxy) error {
if _, err := io.Copy(p.buf, bytes.NewReader(buf)); err != nil {
return errors.Wrap(ErrWriteConn, err)
}
p.outBytes += uint64(len(buf))
// according to the spec, we better flush to avoid server hanging
return p.Flush()
}

0 comments on commit c928982

Please sign in to comment.