diff --git a/.golangci.yaml b/.golangci.yaml index 25f88ec2..446419a3 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -22,6 +22,10 @@ issues: linters: - gosec text: "G402:" + - path: pkg/proxy/net/auth.go + linters: + - gosec + text: "G101:" linters: enable: diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 6b2390bb..28432187 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -42,7 +42,7 @@ type Authenticator struct { user string attrs map[string]string salt []byte - capability uint32 // client capability + capability pnet.Capability collation uint8 proxyProtocol bool requireBackendTLS bool @@ -125,7 +125,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte } if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps { logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps)) - if writeErr := clientIO.WriteErrPacket(mysql.NewErr(mysql.ErrNotSupportedAuthMode)); writeErr != nil { + if writeErr := clientIO.WriteErrPacket(mysql.ErrNotSupportedAuthMode); writeErr != nil { return writeErr } return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps) @@ -134,12 +134,12 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte if frontendCapability^commonCaps != 0 { logger.Debug("frontend send capabilities unsupported by proxy", zap.Stringer("common", commonCaps), zap.Stringer("frontend", frontendCapability^commonCaps), zap.Stringer("proxy", proxyCapability^commonCaps)) } - auth.capability = commonCaps.Uint32() - if auth.capability&mysql.ClientPluginAuth == 0 { + auth.capability = commonCaps + if auth.capability&pnet.ClientPluginAuth == 0 { logger.Warn("frontend may not support plugin auth", zap.Stringer("capability", commonCaps)) // Some clients (e.g. node/mysql) support ClientAuthPlugin but don't have the capability set correctly. // Always set it to ensure capability. - auth.capability |= mysql.ClientPluginAuth + auth.capability |= pnet.ClientPluginAuth } if isSSL { @@ -267,7 +267,7 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac if err = auth.writeAuthHandshake( backendIO, backendTLSConfig, backendCapability, - mysql.AuthTiDBSessionToken, hack.Slice(sessionToken), mysql.ClientPluginAuth, + pnet.AuthTiDBSessionToken, hack.Slice(sessionToken), pnet.ClientPluginAuth, ); err != nil { return err } @@ -293,7 +293,7 @@ func (auth *Authenticator) writeAuthHandshake( backendCapability pnet.Capability, authPlugin string, authData []byte, - authCap uint32, + authCap pnet.Capability, ) error { // Always handshake with SSL enabled and enable auth_plugin. resp := &pnet.HandshakeResp{ @@ -307,7 +307,7 @@ func (auth *Authenticator) writeAuthHandshake( } if len(resp.Attrs) > 0 { - resp.Capability |= mysql.ClientConnectAtts + resp.Capability |= pnet.ClientConnectAttrs } var pkt []byte @@ -322,7 +322,7 @@ func (auth *Authenticator) writeAuthHandshake( enableTLS = pnet.Capability(auth.capability)&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil } if enableTLS { - resp.Capability |= mysql.ClientSSL + resp.Capability |= pnet.ClientSSL pkt = pnet.MakeHandshakeResponse(resp) // write SSL Packet if err := backendIO.WritePacket(pkt[:32], true); err != nil { @@ -339,7 +339,7 @@ func (auth *Authenticator) writeAuthHandshake( return err } } else { - resp.Capability &= ^mysql.ClientSSL + resp.Capability &= ^pnet.ClientSSL pkt = pnet.MakeHandshakeResponse(resp) } diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 185ab812..3d9f4d01 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -291,11 +291,11 @@ func (mgr *BackendConnManager) ExecuteCmd(ctx context.Context, request []byte) ( val := binary.LittleEndian.Uint16(request[1:]) switch val { case 0: - mgr.authenticator.capability |= mysql.ClientMultiStatements - mgr.cmdProcessor.capability |= mysql.ClientMultiStatements + mgr.authenticator.capability |= pnet.ClientMultiStatements + mgr.cmdProcessor.capability |= pnet.ClientMultiStatements case 1: - mgr.authenticator.capability &^= mysql.ClientMultiStatements - mgr.cmdProcessor.capability &^= mysql.ClientMultiStatements + mgr.authenticator.capability &^= pnet.ClientMultiStatements + mgr.cmdProcessor.capability &^= pnet.ClientMultiStatements default: err = errors.Errorf("unrecognized set_option value:%d", val) return diff --git a/pkg/proxy/backend/cmd_processor.go b/pkg/proxy/backend/cmd_processor.go index c224dfa5..52b18b0c 100644 --- a/pkg/proxy/backend/cmd_processor.go +++ b/pkg/proxy/backend/cmd_processor.go @@ -22,7 +22,7 @@ const ( type CmdProcessor struct { // Each prepared statement has an independent status. preparedStmtStatus map[int]uint32 - capability uint32 + capability pnet.Capability // Only includes in_trans or quit status. serverStatus uint32 } diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 040fb369..7f9edd38 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -81,7 +81,7 @@ func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, reque case mysql.ErrHeader: return cp.handleErrorPacket(response) case mysql.EOFHeader: - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { cp.handleEOFPacket(request, response) } else { cp.handleOKPacket(request, response) @@ -111,7 +111,7 @@ func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO } return 0, cp.handleErrorPacket(response) } - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { if pnet.IsEOFPacket(response) { return cp.handleEOFPacket(request, response), clientIO.Flush() } @@ -136,7 +136,7 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er numColumns := binary.LittleEndian.Uint16(response[5:]) numParams := binary.LittleEndian.Uint16(response[7:]) expectedPackets := int(numColumns) + int(numParams) - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { if numColumns > 0 { expectedPackets++ } @@ -235,7 +235,7 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re } func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) { - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { var response []byte // read columns for { diff --git a/pkg/proxy/backend/cmd_processor_query.go b/pkg/proxy/backend/cmd_processor_query.go index 87792a45..c148c86e 100644 --- a/pkg/proxy/backend/cmd_processor_query.go +++ b/pkg/proxy/backend/cmd_processor_query.go @@ -69,7 +69,7 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys for { if fieldIndex == len(result.Fields) { - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { if data, err = packetIO.ReadPacket(); err != nil { return err } @@ -102,7 +102,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql. if data, err = packetIO.ReadPacket(); err != nil { return err } - if cp.capability&mysql.ClientDeprecateEOF == 0 { + if cp.capability&pnet.ClientDeprecateEOF == 0 { if pnet.IsEOFPacket(data) { result.Status = binary.LittleEndian.Uint16(data[3:]) break diff --git a/pkg/proxy/backend/mock_backend_test.go b/pkg/proxy/backend/mock_backend_test.go index 53b39ffd..f074d2ed 100644 --- a/pkg/proxy/backend/mock_backend_test.go +++ b/pkg/proxy/backend/mock_backend_test.go @@ -33,7 +33,7 @@ func newBackendConfig() *backendConfig { return &backendConfig{ capability: defaultTestBackendCapability, salt: mockSalt, - authPlugin: mysql.AuthCachingSha2Password, + authPlugin: pnet.AuthCachingSha2Password, authSucceed: true, loops: 1, stmtNum: 1, @@ -117,11 +117,11 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh } } if mb.authSucceed { - if err := packetIO.WriteOKPacket(mb.status, mysql.OKHeader); err != nil { + if err := packetIO.WriteOKPacket(mb.status, pnet.OKHeader); err != nil { return err } } else { - if err := packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrAccessDenied)); err != nil { + if err := packetIO.WriteErrPacket(mysql.ErrAccessDenied); err != nil { return err } } @@ -150,7 +150,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { case responseTypeOK: return mb.respondOK(packetIO) case responseTypeErr: - return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown)) + return packetIO.WriteErrPacket(mysql.ErrUnknown) case responseTypeResultSet: if pnet.Command(pkt[0]) == pnet.ComQuery && string(pkt[1:]) == sqlQueryState { return mb.respondSessionStates(packetIO) @@ -171,7 +171,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { if _, err := packetIO.ReadPacket(); err != nil { return err } - return packetIO.WriteOKPacket(mb.status, mysql.OKHeader) + return packetIO.WriteOKPacket(mb.status, pnet.OKHeader) case responseTypePrepareOK: return mb.respondPrepare(packetIO) case responseTypeRow: @@ -179,7 +179,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error { case responseTypeNone: return nil } - return packetIO.WriteErrPacket(mysql.NewErr(mysql.ErrUnknown)) + return packetIO.WriteErrPacket(mysql.ErrUnknown) } func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error { @@ -190,7 +190,7 @@ func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error { } else { status &= ^mysql.ServerMoreResultsExists } - if err := packetIO.WriteOKPacket(status, mysql.OKHeader); err != nil { + if err := packetIO.WriteOKPacket(status, pnet.OKHeader); err != nil { return err } } @@ -209,7 +209,7 @@ func (mb *mockBackend) respondColumns(packetIO *pnet.PacketIO) error { func (mb *mockBackend) writeResultEndPacket(packetIO *pnet.PacketIO, status uint16) error { if mb.capability&pnet.ClientDeprecateEOF > 0 { - return packetIO.WriteOKPacket(status, mysql.EOFHeader) + return packetIO.WriteOKPacket(status, pnet.EOFHeader) } return packetIO.WriteEOFPacket(status) } @@ -312,7 +312,7 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error { break } } - if err := packetIO.WriteOKPacket(status, mysql.OKHeader); err != nil { + if err := packetIO.WriteOKPacket(status, pnet.OKHeader); err != nil { return err } } diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index ef0ff03f..4774b7ca 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -78,7 +78,7 @@ func (mc *mockClient) authenticate(packetIO *pnet.PacketIO) error { AuthPlugin: mc.authPlugin, Attrs: mc.attrs, AuthData: mc.authData, - Capability: mc.capability.Uint32(), + Capability: mc.capability, Collation: mc.collation, } pkt = pnet.MakeHandshakeResponse(resp) diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 3ae4db73..9fcac1a0 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -54,7 +54,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { CheckBackendInterval: cfg.checkBackendInterval, }), } - mp.cmdProcessor.capability = cfg.capability.Uint32() + mp.cmdProcessor.capability = cfg.capability return mp } diff --git a/pkg/proxy/net/auth.go b/pkg/proxy/net/auth.go new file mode 100644 index 00000000..8c557591 --- /dev/null +++ b/pkg/proxy/net/auth.go @@ -0,0 +1,14 @@ +// Copyright 2023 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package net + +const ( + AuthNativePassword = "mysql_native_password" + AuthCachingSha2Password = "caching_sha2_password" + AuthTiDBSM3Password = "tidb_sm3_password" + AuthMySQLClearPassword = "mysql_clear_password" + AuthSocket = "auth_socket" + AuthTiDBSessionToken = "tidb_session_token" + AuthTiDBAuthToken = "tidb_auth_token" +) diff --git a/pkg/proxy/net/header.go b/pkg/proxy/net/header.go new file mode 100644 index 00000000..7f3912a6 --- /dev/null +++ b/pkg/proxy/net/header.go @@ -0,0 +1,30 @@ +// Copyright 2023 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package net + +type Header byte + +// Header information. +const ( + OKHeader Header = 0x00 + ErrHeader Header = 0xff + EOFHeader Header = 0xfe + AuthSwitchHeader Header = 0xfe + LocalInFileHeader Header = 0xfb +) + +var headerStrings = map[Header]string{ + OKHeader: "OK", + ErrHeader: "ERR", + EOFHeader: "EOF/AuthSwitch", + LocalInFileHeader: "LOCAL_IN_FILE", +} + +func (f Header) Byte() byte { + return byte(f) +} + +func (f Header) String() string { + return headerStrings[f] +} diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 781ca17c..9e7246a6 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -62,7 +62,7 @@ type HandshakeResp struct { DB string AuthPlugin string AuthData []byte - Capability uint32 + Capability Capability Collation uint8 } @@ -70,7 +70,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { resp := new(HandshakeResp) pos := 0 // capability - resp.Capability = binary.LittleEndian.Uint32(data[:4]) + resp.Capability = Capability(binary.LittleEndian.Uint32(data[:4])) pos += 4 // skip max packet size pos += 4 @@ -85,7 +85,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { pos += len(resp.User) + 1 // password - if resp.Capability&mysql.ClientPluginAuthLenencClientData > 0 { + if resp.Capability&ClientPluginAuthLenencClientData > 0 { if data[pos] == 0x1 { // No auth data pos += 2 } else { @@ -96,7 +96,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { pos += int(num) } } - } else if resp.Capability&mysql.ClientSecureConnection > 0 { + } else if resp.Capability&ClientSecureConnection > 0 { authLen := int(data[pos]) pos++ resp.AuthData = data[pos : pos+authLen] @@ -107,7 +107,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { } // dbname - if resp.Capability&mysql.ClientConnectWithDB > 0 { + if resp.Capability&ClientConnectWithDB > 0 { if len(data[pos:]) > 0 { idx := bytes.IndexByte(data[pos:], 0) resp.DB = string(data[pos : pos+idx]) @@ -116,7 +116,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { } // auth plugin - if resp.Capability&mysql.ClientPluginAuth > 0 { + if resp.Capability&ClientPluginAuth > 0 { idx := bytes.IndexByte(data[pos:], 0) s := pos f := pos + idx @@ -128,7 +128,7 @@ func ParseHandshakeResponse(data []byte) (*HandshakeResp, error) { // attrs var err error - if resp.Capability&mysql.ClientConnectAtts > 0 { + if resp.Capability&ClientConnectAttrs > 0 { if num, null, off := ParseLengthEncodedInt(data[pos:]); !null { pos += off row := data[pos : pos+int(num)] @@ -184,11 +184,11 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { authResp = DumpLengthEncodedInt(authRespBuf[:0], uint64(len(resp.AuthData))) capability := resp.Capability if len(authResp) > 1 { - capability |= mysql.ClientPluginAuthLenencClientData + capability |= ClientPluginAuthLenencClientData } else { - capability &= ^mysql.ClientPluginAuthLenencClientData + capability &= ^ClientPluginAuthLenencClientData } - if capability&mysql.ClientConnectAtts > 0 { + if capability&ClientConnectAttrs > 0 { attrs = dumpAttrs(resp.Attrs) attrBuf = DumpLengthEncodedInt(attrLenBuf[:0], uint64(len(attrs))) } @@ -197,7 +197,7 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { data := make([]byte, length) pos := 0 // capability [32 bit] - DumpUint32(data[:0], capability) + DumpUint32(data[:0], capability.Uint32()) pos += 4 // MaxPacketSize [32 bit] pos += 4 @@ -213,10 +213,10 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { pos++ // auth data - if capability&mysql.ClientPluginAuthLenencClientData > 0 { + if capability&ClientPluginAuthLenencClientData > 0 { pos += copy(data[pos:], authResp) pos += copy(data[pos:], resp.AuthData) - } else if capability&mysql.ClientSecureConnection > 0 { + } else if capability&ClientSecureConnection > 0 { data[pos] = byte(len(resp.AuthData)) pos++ pos += copy(data[pos:], resp.AuthData) @@ -227,21 +227,21 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { } // db [null terminated string] - if capability&mysql.ClientConnectWithDB > 0 { + if capability&ClientConnectWithDB > 0 { pos += copy(data[pos:], resp.DB) data[pos] = 0x00 pos++ } // auth_plugin [null terminated string] - if capability&mysql.ClientPluginAuth > 0 { + if capability&ClientPluginAuth > 0 { pos += copy(data[pos:], resp.AuthPlugin) data[pos] = 0x00 pos++ } // attrs - if capability&mysql.ClientConnectAtts > 0 { + if capability&ClientConnectAttrs > 0 { pos += copy(data[pos:], attrBuf) pos += copy(data[pos:], attrs) } @@ -337,12 +337,12 @@ func ParseErrorPacket(data []byte) error { // IsOKPacket returns true if it's an OK packet (but not ResultSet OK). func IsOKPacket(data []byte) bool { - return data[0] == mysql.OKHeader + return data[0] == OKHeader.Byte() } // IsEOFPacket returns true if it's an EOF packet. func IsEOFPacket(data []byte) bool { - return data[0] == mysql.EOFHeader && len(data) <= 5 + return data[0] == EOFHeader.Byte() && len(data) <= 5 } // IsResultSetOKPacket returns true if it's an OK packet after the result set when CLIENT_DEPRECATE_EOF is enabled. @@ -350,10 +350,10 @@ func IsEOFPacket(data []byte) bool { // See https://mariadb.com/kb/en/result-set-packets/ func IsResultSetOKPacket(data []byte) bool { // With CLIENT_PROTOCOL_41 enabled, the least length is 7. - return data[0] == mysql.EOFHeader && len(data) >= 7 && len(data) < 0xFFFFFF + return data[0] == EOFHeader.Byte() && len(data) >= 7 && len(data) < 0xFFFFFF } // IsErrorPacket returns true if it's an error packet. func IsErrorPacket(data []byte) bool { - return data[0] == mysql.ErrHeader + return data[0] == ErrHeader.Byte() } diff --git a/pkg/proxy/net/mysql_test.go b/pkg/proxy/net/mysql_test.go index a82fdaee..f44d7afd 100644 --- a/pkg/proxy/net/mysql_test.go +++ b/pkg/proxy/net/mysql_test.go @@ -6,7 +6,6 @@ package net import ( "testing" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -17,7 +16,7 @@ func TestHandshakeResp(t *testing.T) { DB: "db", AuthPlugin: "plugin", AuthData: []byte("1234567890"), - Capability: ^mysql.ClientPluginAuthLenencClientData, + Capability: ^ClientPluginAuthLenencClientData, Collation: 0, } b := MakeHandshakeResponse(resp1) diff --git a/pkg/proxy/net/packetio_mysql.go b/pkg/proxy/net/packetio_mysql.go index de2a504b..2132d658 100644 --- a/pkg/proxy/net/packetio_mysql.go +++ b/pkg/proxy/net/packetio_mysql.go @@ -5,9 +5,10 @@ package net import ( "encoding/binary" + "fmt" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/TiProxy/lib/util/errors" - "github.com/pingcap/tidb/parser/mysql" "go.uber.org/zap" ) @@ -64,8 +65,9 @@ func (p *PacketIO) WriteInitialHandshake(capability Capability, salt []byte, aut func (p *PacketIO) WriteSwitchRequest(authPlugin string, salt []byte) error { length := 1 + len(authPlugin) + 1 + len(salt) + 1 data := make([]byte, 0, length) - data = append(data, mysql.AuthSwitchRequest) - data = append(data, []byte(authPlugin)...) + // check https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + data = append(data, byte(AuthSwitchHeader)) + data = append(data, authPlugin...) data = append(data, 0x00) data = append(data, salt...) data = append(data, 0x00) @@ -94,21 +96,33 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err } // WriteErrPacket writes an Error packet. -func (p *PacketIO) WriteErrPacket(merr *mysql.SQLError) error { - data := make([]byte, 0, 4+len(merr.Message)+len(merr.State)) - data = append(data, mysql.ErrHeader) - data = append(data, byte(merr.Code), byte(merr.Code>>8)) - // ClientProtocol41 must be enabled. +func (p *PacketIO) WriteErrPacket(code uint16, message ...any) error { + data := make([]byte, 0, 9+len(message)) + data = append(data, ErrHeader.Byte()) + data = append(data, byte(code), byte(code>>8)) + + // TODO: ClientProtocol41 must be enabled for state data = append(data, '#') - data = append(data, merr.State...) - data = append(data, merr.Message...) + s, ok := mysql.MySQLState[code] + if !ok { + s = mysql.DEFAULT_MYSQL_STATE + } + data = append(data, s...) + + var msg string + if format, ok := mysql.MySQLErrName[code]; ok { + msg = fmt.Sprintf(format, message...) + } else { + msg = fmt.Sprint(message...) + } + data = append(data, msg...) return p.WritePacket(data, true) } // WriteOKPacket writes an OK packet. It's only for testing. -func (p *PacketIO) WriteOKPacket(status uint16, header byte) error { +func (p *PacketIO) WriteOKPacket(status uint16, header Header) error { data := make([]byte, 0, 7) - data = append(data, header) + data = append(data, header.Byte()) data = append(data, 0, 0) // ClientProtocol41 must be enabled. data = DumpUint16(data, status) @@ -119,7 +133,7 @@ func (p *PacketIO) WriteOKPacket(status uint16, header byte) error { // WriteEOFPacket writes an EOF packet. It's only for testing. func (p *PacketIO) WriteEOFPacket(status uint16) error { data := make([]byte, 0, 5) - data = append(data, mysql.EOFHeader) + data = append(data, EOFHeader.Byte()) data = append(data, 0, 0) // ClientProtocol41 must be enabled. data = DumpUint16(data, status) @@ -135,8 +149,7 @@ func (p *PacketIO) WriteUserError(err error) { if !errors.As(err, &ue) { return } - myErr := mysql.NewErrf(mysql.ErrUnknown, "%s", nil, ue.UserMsg()) - if writeErr := p.WriteErrPacket(myErr); writeErr != nil { + if writeErr := p.WriteErrPacket(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()); writeErr != nil { p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr)) } } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 9d2945fd..701941ed 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -13,7 +13,6 @@ import ( "github.com/pingcap/TiProxy/lib/util/logger" "github.com/pingcap/TiProxy/lib/util/security" "github.com/pingcap/TiProxy/pkg/testkit" - "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) @@ -45,7 +44,7 @@ func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, func TestPacketIO(t *testing.T) { expectMsg := []byte("test") - pktLengths := []int{0, mysql.MaxPayloadLen + 212, mysql.MaxPayloadLen, mysql.MaxPayloadLen * 2} + pktLengths := []int{0, MaxPayloadLen + 212, MaxPayloadLen, MaxPayloadLen * 2} testPipeConn(t, func(t *testing.T, cli *PacketIO) { var err error @@ -56,7 +55,7 @@ func TestPacketIO(t *testing.T) { outBytes := len(expectMsg) + 4 for _, l := range pktLengths { require.NoError(t, cli.WritePacket(make([]byte, l), true)) - outBytes += l + (l/(mysql.MaxPayloadLen)+1)*4 + outBytes += l + (l/(MaxPayloadLen)+1)*4 require.Equal(t, uint64(outBytes), cli.OutBytes()) } @@ -66,7 +65,7 @@ func TestPacketIO(t *testing.T) { // send correct and wrong capability flags var hdr [32]byte - binary.LittleEndian.PutUint32(hdr[:], mysql.ClientSSL) + binary.LittleEndian.PutUint32(hdr[:], ClientSSL.Uint32()) err = cli.WritePacket(hdr[:], true) require.NoError(t, err) @@ -89,14 +88,14 @@ func TestPacketIO(t *testing.T) { msg, err = srv.ReadPacket() require.NoError(t, err) require.Equal(t, l, len(msg)) - inBytes += l + (l/(mysql.MaxPayloadLen)+1)*4 + inBytes += l + (l/(MaxPayloadLen)+1)*4 require.Equal(t, uint64(inBytes), srv.InBytes()) } // send handshake - require.NoError(t, srv.WriteInitialHandshake(0, salt[:], mysql.AuthNativePassword, ServerVersion)) + require.NoError(t, srv.WriteInitialHandshake(0, salt[:], AuthNativePassword, ServerVersion)) // salt should not be long enough - require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), mysql.AuthNativePassword, ServerVersion), ErrSaltNotLongEnough) + require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), AuthNativePassword, ServerVersion), ErrSaltNotLongEnough) // expect correct and wrong capability flags _, isSSL, err := srv.ReadSSLRequestOrHandshakeResp()