diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 26dc7eb6..ed236e95 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -106,11 +106,11 @@ type BackendConnManager struct { closeStatus atomic.Int32 // cancelFunc is used to cancel the signal processing goroutine. cancelFunc context.CancelFunc + clientIO *pnet.PacketIO backendIO *pnet.PacketIO backendTLS *tls.Config handshakeHandler HandshakeHandler ctxmap sync.Map - clientAddr string connectionID uint64 } @@ -147,7 +147,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe mgr.backendTLS = backendTLSConfig - mgr.clientAddr = clientIO.RemoteAddr().String() + mgr.clientIO = clientIO err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) if err != nil { @@ -158,7 +158,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe childCtx, cancelFunc := context.WithCancel(ctx) mgr.cancelFunc = cancelFunc mgr.wg.Run(func() { - mgr.processSignals(childCtx, clientIO) + mgr.processSignals(childCtx) }) return nil } @@ -219,7 +219,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato // ExecuteCmd forwards messages between the client and the backend. // If it finds that the session is ready for redirection, it migrates the session. -func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, clientIO *pnet.PacketIO) error { +func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) error { if len(request) < 1 { return mysql.ErrMalformPacket } @@ -233,7 +233,7 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c return nil } waitingRedirect := atomic.LoadPointer(&mgr.signal) != nil - holdRequest, err := mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, waitingRedirect) + holdRequest, err := mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, waitingRedirect) if !holdRequest { addCmdMetrics(cmd, mgr.ServerAddr(), startTime) } @@ -269,17 +269,17 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte, c // Even if it meets an MySQL error, it may have changed the status, such as when executing multi-statements. if mgr.cmdProcessor.finishedTxn() { if waitingRedirect && holdRequest { - mgr.tryRedirect(ctx, clientIO) + mgr.tryRedirect(ctx) // Execute the held request no matter redirection succeeds or not. - _, err = mgr.cmdProcessor.executeCmd(request, clientIO, mgr.backendIO, false) + _, err = mgr.cmdProcessor.executeCmd(request, mgr.clientIO, mgr.backendIO, false) addCmdMetrics(cmd, mgr.ServerAddr(), startTime) if err != nil && !IsMySQLError(err) { return err } } else if mgr.closeStatus.Load() == statusNotifyClose { - mgr.tryGracefulClose(ctx, clientIO) + mgr.tryGracefulClose(ctx) } else if waitingRedirect { - mgr.tryRedirect(ctx, clientIO) + mgr.tryRedirect(ctx) } } // Ignore MySQL errors, only return unexpected errors. @@ -325,7 +325,7 @@ func (mgr *BackendConnManager) querySessionStates() (sessionStates, sessionToken // processSignals runs in a goroutine to: // - Receive redirection signals and then try to migrate the session. // - Send redirection results to the event receiver. -func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pnet.PacketIO) { +func (mgr *BackendConnManager) processSignals(ctx context.Context) { for { select { case s := <-mgr.signalReceived: @@ -333,9 +333,9 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pne mgr.processLock.Lock() switch s { case signalTypeGracefulClose: - mgr.tryGracefulClose(ctx, clientIO) + mgr.tryGracefulClose(ctx) case signalTypeRedirect: - mgr.tryRedirect(ctx, clientIO) + mgr.tryRedirect(ctx) } mgr.processLock.Unlock() case rs := <-mgr.redirectResCh: @@ -348,7 +348,7 @@ func (mgr *BackendConnManager) processSignals(ctx context.Context, clientIO *pne // tryRedirect tries to migrate the session if the session is redirect-able. // NOTE: processLock should be held before calling this function. -func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.PacketIO) { +func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { switch mgr.closeStatus.Load() { case statusNotifyClose, statusClosing, statusClosed: return @@ -389,8 +389,7 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context, clientIO *pnet.P } newBackendIO := pnet.NewPacketIO(cn, pnet.WithRemoteAddr(rs.to)) - mgr.clientAddr = clientIO.RemoteAddr().String() - if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { + if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) } else { mgr.handshakeHandler.OnHandshake(mgr, newBackendIO.RemoteAddr().String(), rs.err) @@ -475,7 +474,7 @@ func (mgr *BackendConnManager) GracefulClose() { mgr.signalReceived <- signalTypeGracefulClose } -func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *pnet.PacketIO) { +func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context) { if mgr.closeStatus.Load() != statusNotifyClose { return } @@ -483,20 +482,40 @@ func (mgr *BackendConnManager) tryGracefulClose(ctx context.Context, clientIO *p return } // Closing clientIO will cause the whole connection to be closed. - if err := clientIO.GracefulClose(); err != nil { - mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", clientIO.RemoteAddr()), zap.Error(err)) + if err := mgr.clientIO.GracefulClose(); err != nil { + mgr.logger.Warn("graceful close client IO error", zap.Stringer("addr", mgr.clientIO.RemoteAddr()), zap.Error(err)) } mgr.closeStatus.Store(statusClosing) } func (mgr *BackendConnManager) ClientAddr() string { - return mgr.clientAddr + if mgr.clientIO == nil { + return "" + } + return mgr.clientIO.RemoteAddr().String() } func (mgr *BackendConnManager) ServerAddr() string { + if mgr.backendIO == nil { + return "" + } return mgr.backendIO.RemoteAddr().String() } +func (mgr *BackendConnManager) ClientInBytes() uint64 { + if mgr.clientIO == nil { + return 0 + } + return mgr.clientIO.InBytes() +} + +func (mgr *BackendConnManager) ClientOutBytes() uint64 { + if mgr.clientIO == nil { + return 0 + } + return mgr.clientIO.OutBytes() +} + func (mgr *BackendConnManager) SetValue(key, val any) { mgr.ctxmap.Store(key, val) } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index fc74469f..5329e41f 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -159,7 +159,7 @@ func (ts *backendMgrTester) forwardCmd4Proxy(clientIO, backendIO *pnet.PacketIO) require.NoError(ts.t, err) prevCounter, err := readCmdCounter(request[0], ts.tc.backendListener.Addr().String()) require.NoError(ts.t, err) - rsErr := ts.mp.ExecuteCmd(context.Background(), request, clientIO) + rsErr := ts.mp.ExecuteCmd(context.Background(), request) curCounter, err := readCmdCounter(request[0], ts.tc.backendListener.Addr().String()) require.NoError(ts.t, err) require.Equal(ts.t, prevCounter+1, curCounter) @@ -225,6 +225,8 @@ func (ts *backendMgrTester) checkConnClosed(_, _ *pnet.PacketIO) error { func (ts *backendMgrTester) runTests(runners []runner) { for _, runner := range runners { ts.runAndCheck(ts.t, nil, runner.client, runner.backend, runner.proxy) + require.Equal(ts.t, ts.tc.clientIO.InBytes(), ts.mp.ClientOutBytes()) + require.Equal(ts.t, ts.tc.clientIO.OutBytes(), ts.mp.ClientInBytes()) } } diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 84407907..07b12db9 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -27,6 +27,8 @@ var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) type ConnContext interface { ClientAddr() string ServerAddr() string + ClientInBytes() uint64 + ClientOutBytes() uint64 SetValue(key, val any) Value(key any) any } diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 01621fe2..609b9f9a 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -171,6 +171,8 @@ func (ts *testSuite) runAndCheck(t *testing.T, c checker, clientRunner, backendR // Ensure all the packets are forwarded. msg := fmt.Sprintf("cmd:%d responseType:%d", ts.mc.cmd, ts.mb.respondType) require.Equal(t, ts.tc.backendIO.GetSequence(), ts.tc.clientIO.GetSequence(), msg) + require.Equal(t, ts.tc.clientIO.OutBytes(), ts.tc.proxyCIO.InBytes(), msg) + require.Equal(t, ts.tc.clientIO.InBytes(), ts.tc.proxyCIO.OutBytes(), msg) } } else { c(t, ts) diff --git a/pkg/proxy/client/client_conn.go b/pkg/proxy/client/client_conn.go index bb23cdfd..d9760c04 100644 --- a/pkg/proxy/client/client_conn.go +++ b/pkg/proxy/client/client_conn.go @@ -87,7 +87,7 @@ func (cc *ClientConnection) processMsg(ctx context.Context) error { if err != nil { return err } - err = cc.connMgr.ExecuteCmd(ctx, clientPkt, cc.pkt) + err = cc.connMgr.ExecuteCmd(ctx, clientPkt) if err != nil { return err } diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index bedf3fcf..8f11a4ec 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -74,6 +74,8 @@ func (f *rdbufConn) Read(b []byte) (int, error) { // PacketIO is a helper to read and write sql and proxy protocol. type PacketIO struct { + inBytes uint64 + outBytes uint64 conn net.Conn buf *bufio.ReadWriter proxyInited *atomic.Bool @@ -142,6 +144,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { if _, err := io.ReadFull(p.conn, header[:]); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } + p.inBytes += 4 // probe proxy V2 refill := false @@ -164,6 +167,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { if _, err := io.ReadFull(p.conn, header[:]); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } + p.inBytes += 4 } sequence := header[3] @@ -177,6 +181,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { if _, err := io.ReadFull(p.conn, data); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } + p.inBytes += uint64(length) return data, length == mysql.MaxPayloadLen, nil } @@ -214,10 +219,12 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) { if _, err := io.Copy(p.buf, bytes.NewReader(header[:])); err != nil { return 0, more, errors.Wrap(ErrWriteConn, err) } + p.outBytes += 4 if _, err := io.Copy(p.buf, bytes.NewReader(data[:length])); err != nil { return 0, more, errors.Wrap(ErrWriteConn, err) } + p.outBytes += uint64(length) return length, more, nil } @@ -239,6 +246,14 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { return nil } +func (p *PacketIO) InBytes() uint64 { + return p.inBytes +} + +func (p *PacketIO) OutBytes() uint64 { + return p.outBytes +} + func (p *PacketIO) Flush() error { if err := p.buf.Flush(); err != nil { return p.wrapErr(errors.Wrap(ErrFlushConn, err)) diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 2c9c8924..54e78741 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -88,8 +88,11 @@ func TestPacketIO(t *testing.T) { // send anything require.NoError(t, cli.WritePacket(expectMsg, true)) + outBytes := len(expectMsg) + 4 for _, l := range pktLengths { require.NoError(t, cli.WritePacket(make([]byte, l), true)) + outBytes += l + (l/(mysql.MaxPayloadLen)+1)*4 + require.Equal(t, uint64(outBytes), cli.OutBytes()) } // skip handshake @@ -116,10 +119,13 @@ func TestPacketIO(t *testing.T) { require.NoError(t, err) require.Equal(t, expectMsg, msg) + inBytes := len(expectMsg) + 4 for _, l := range pktLengths { msg, err = srv.ReadPacket() require.NoError(t, err) require.Equal(t, l, len(msg)) + inBytes += l + (l/(mysql.MaxPayloadLen)+1)*4 + require.Equal(t, uint64(inBytes), srv.InBytes()) } // send handshake diff --git a/pkg/proxy/net/proxy.go b/pkg/proxy/net/proxy.go index cc228b7d..b672913a 100644 --- a/pkg/proxy/net/proxy.go +++ b/pkg/proxy/net/proxy.go @@ -167,12 +167,14 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) { if err != nil { return nil, errors.WithStack(errors.Wrap(ErrReadConn, err)) } + p.inBytes += 8 var hdr [4]byte if _, err := io.ReadFull(p.buf, hdr[:]); err != nil { return nil, errors.WithStack(err) } + p.inBytes += 4 m := &Proxy{} m.Version = ProxyVersion(hdr[0] >> 4) @@ -182,6 +184,7 @@ func (p *PacketIO) parseProxyV2() (*Proxy, error) { if _, err := io.ReadFull(p.buf, buf); err != nil { return nil, errors.WithStack(err) } + p.inBytes += uint64(len(buf)) addressFamily := ProxyAddressFamily(hdr[1] >> 4) network := ProxyNetwork(hdr[1] & 0xF) @@ -285,6 +288,7 @@ func (p *PacketIO) WriteProxyV2(m *Proxy) error { if _, err := io.Copy(p.buf, bytes.NewReader(buf)); err != nil { return errors.Wrap(ErrWriteConn, err) } + p.outBytes += uint64(len(buf)) // according to the spec, we better flush to avoid server hanging return p.Flush() }