Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net, backend: fix wrong format in error packet #357

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"net"
"time"

"github.com/pingcap/tidb/parser/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
Expand Down Expand Up @@ -73,7 +73,7 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO
}

func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error {
requiredBackendCaps := defRequiredBackendCaps & pnet.Capability(auth.capability)
requiredBackendCaps := defRequiredBackendCaps & auth.capability
if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}
Expand All @@ -100,7 +100,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}

cid, _ := cctx.Value(ConnContextKeyConnID).(uint64)
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil {
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, pnet.AuthNativePassword, handshakeHandler.GetServerVersion(), cid); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand All @@ -126,7 +126,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}
if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps {
logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps))
if writeErr := clientIO.WriteErrPacket(mysql.ErrNotSupportedAuthMode); writeErr != nil {
if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil {
return writeErr
}
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
Expand Down Expand Up @@ -220,14 +220,14 @@ loop:
return err
}
switch serverPkt[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(serverPkt)
default: // mysql.AuthSwitchRequest, ShaCommand
if serverPkt[0] == mysql.AuthSwitchRequest {
if serverPkt[0] == pnet.AuthSwitchHeader.Byte() {
pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1])
} else if serverPkt[0] == 1 && pluginName == mysql.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 {
} else if serverPkt[0] == 1 && pluginName == pnet.AuthCachingSha2Password && len(serverPkt) == 2 && serverPkt[1] == 3 {
// caching_sha2_password fast path
continue loop
}
Expand Down Expand Up @@ -262,7 +262,7 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil {
if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return err
}

Expand Down Expand Up @@ -320,7 +320,7 @@ func (auth *Authenticator) writeAuthHandshake(
enableTLS = true
} else {
// When client TLS is disabled, also disables proxy TLS.
enableTLS = pnet.Capability(auth.capability)&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil
enableTLS = auth.capability&pnet.ClientSSL != 0 && backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil
}
if enableTLS {
resp.Capability |= pnet.ClientSSL
Expand Down Expand Up @@ -355,9 +355,9 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
}

switch data[0] {
case mysql.OKHeader:
case pnet.OKHeader.Byte():
return nil
case mysql.ErrHeader:
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(data)
default: // mysql.AuthSwitchRequest, ShaCommand:
return errors.Errorf("read unexpected command: %#x", data[0])
Expand Down
35 changes: 17 additions & 18 deletions pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import (
"crypto/tls"
"encoding/binary"

gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/parser/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
)

Expand Down Expand Up @@ -74,7 +73,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}
// upgrade to TLS
capability := binary.LittleEndian.Uint16(clientPkt[:2])
sslEnabled := uint32(capability)&mysql.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0
sslEnabled := pnet.Capability(capability)&pnet.ClientSSL > 0 && mb.capability&pnet.ClientSSL > 0
if sslEnabled {
if _, err = packetIO.ServerTLSHandshake(mb.tlsConfig); err != nil {
return err
Expand All @@ -98,7 +97,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
}

func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.HandshakeResp) error {
if resp.AuthPlugin != mysql.AuthTiDBSessionToken {
if resp.AuthPlugin != pnet.AuthTiDBSessionToken {
var err error
if err = packetIO.WriteSwitchRequest(mb.authPlugin, mb.salt); err != nil {
return err
Expand All @@ -107,7 +106,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
return err
}
switch mb.authPlugin {
case mysql.AuthCachingSha2Password:
case pnet.AuthCachingSha2Password:
if err = packetIO.WriteShaCommand(); err != nil {
return err
}
Expand All @@ -121,7 +120,7 @@ func (mb *mockBackend) verifyPassword(packetIO *pnet.PacketIO, resp *pnet.Handsh
return err
}
} else {
if err := packetIO.WriteErrPacket(mysql.ErrAccessDenied); err != nil {
if err := packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR)); err != nil {
return err
}
}
Expand Down Expand Up @@ -150,7 +149,7 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
case responseTypeOK:
return mb.respondOK(packetIO)
case responseTypeErr:
return packetIO.WriteErrPacket(mysql.ErrUnknown)
return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR))
case responseTypeResultSet:
if pnet.Command(pkt[0]) == pnet.ComQuery && string(pkt[1:]) == sqlQueryState {
return mb.respondSessionStates(packetIO)
Expand Down Expand Up @@ -179,16 +178,16 @@ func (mb *mockBackend) respondOnce(packetIO *pnet.PacketIO) error {
case responseTypeNone:
return nil
}
return packetIO.WriteErrPacket(mysql.ErrUnknown)
return packetIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR))
}

func (mb *mockBackend) respondOK(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
if err := packetIO.WriteOKPacket(status, pnet.OKHeader); err != nil {
return err
Expand Down Expand Up @@ -242,16 +241,16 @@ func (mb *mockBackend) respondResultSet(packetIO *pnet.PacketIO) error {
}

func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, values [][]any) error {
rs, err := gomysql.BuildSimpleTextResultset(names, values)
rs, err := mysql.BuildSimpleTextResultset(names, values)
if err != nil {
return err
}
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
data := pnet.DumpLengthEncodedInt(nil, uint64(len(names)))
if err := packetIO.WritePacket(data, false); err != nil {
Expand All @@ -263,7 +262,7 @@ func (mb *mockBackend) writeResultSet(packetIO *pnet.PacketIO, names []string, v
}
}

if status&mysql.ServerStatusCursorExists == 0 {
if status&mysql.SERVER_STATUS_CURSOR_EXISTS == 0 {
if mb.capability&pnet.ClientDeprecateEOF == 0 {
if err := packetIO.WriteEOFPacket(status); err != nil {
return err
Expand Down Expand Up @@ -291,12 +290,12 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error {
for i := 0; i < mb.stmtNum; i++ {
status := mb.status
if i < mb.stmtNum-1 {
status |= mysql.ServerMoreResultsExists
status |= mysql.SERVER_MORE_RESULTS_EXISTS
} else {
status &= ^mysql.ServerMoreResultsExists
status &= ^mysql.SERVER_MORE_RESULTS_EXISTS
}
data := make([]byte, 0, 1+len(mockCmdStr))
data = append(data, mysql.LocalInFileHeader)
data = append(data, pnet.LocalInFileHeader.Byte())
data = append(data, []byte(mockCmdStr)...)
if err := packetIO.WritePacket(data, true); err != nil {
return err
Expand All @@ -321,7 +320,7 @@ func (mb *mockBackend) respondLoadFile(packetIO *pnet.PacketIO) error {

// respond to Prepare
func (mb *mockBackend) respondPrepare(packetIO *pnet.PacketIO) error {
data := []byte{mysql.OKHeader}
data := []byte{pnet.OKHeader.Byte()}
data = pnet.DumpUint32(data, uint32(mockCmdInt))
data = pnet.DumpUint16(data, uint16(mb.columns))
data = pnet.DumpUint16(data, uint16(mb.params))
Expand Down
32 changes: 10 additions & 22 deletions pkg/proxy/net/packetio_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package net

import (
"encoding/binary"
"fmt"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
Expand Down Expand Up @@ -96,26 +95,14 @@ func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err
}

// WriteErrPacket writes an Error packet.
func (p *PacketIO) WriteErrPacket(code uint16, message ...any) error {
data := make([]byte, 0, 9+len(message))
func (p *PacketIO) WriteErrPacket(merr *mysql.MyError) error {
data := make([]byte, 0, 9+len(merr.Message))
data = append(data, ErrHeader.Byte())
data = append(data, byte(code), byte(code>>8))

// TODO: ClientProtocol41 must be enabled for state
data = append(data, byte(merr.Code), byte(merr.Code>>8))
// ClientProtocol41 is always enabled.
data = append(data, '#')
s, ok := mysql.MySQLState[code]
if !ok {
s = mysql.DEFAULT_MYSQL_STATE
}
data = append(data, s...)

var msg string
if format, ok := mysql.MySQLErrName[code]; ok {
msg = fmt.Sprintf(format, message...)
} else {
msg = fmt.Sprint(message...)
}
data = append(data, msg...)
data = append(data, merr.State...)
data = append(data, merr.Message...)
return p.WritePacket(data, true)
}

Expand All @@ -124,7 +111,7 @@ func (p *PacketIO) WriteOKPacket(status uint16, header Header) error {
data := make([]byte, 0, 7)
data = append(data, header.Byte())
data = append(data, 0, 0)
// ClientProtocol41 must be enabled.
// ClientProtocol41 is always enabled.
data = DumpUint16(data, status)
data = append(data, 0, 0)
return p.WritePacket(data, true)
Expand All @@ -135,7 +122,7 @@ func (p *PacketIO) WriteEOFPacket(status uint16) error {
data := make([]byte, 0, 5)
data = append(data, EOFHeader.Byte())
data = append(data, 0, 0)
// ClientProtocol41 must be enabled.
// ClientProtocol41 is always enabled.
data = DumpUint16(data, status)
return p.WritePacket(data, true)
}
Expand All @@ -149,7 +136,8 @@ func (p *PacketIO) WriteUserError(err error) {
if !errors.As(err, &ue) {
return
}
if writeErr := p.WriteErrPacket(mysql.ER_UNKNOWN_ERROR, ue.UserMsg()); writeErr != nil {
myErr := mysql.NewError(mysql.ER_UNKNOWN_ERROR, ue.UserMsg())
if writeErr := p.WriteErrPacket(myErr); writeErr != nil {
p.logger.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr))
}
}
30 changes: 30 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/config"
"github.com/pingcap/tiproxy/lib/util/logger"
"github.com/pingcap/tiproxy/lib/util/security"
Expand Down Expand Up @@ -257,3 +258,32 @@ func TestKeepAlive(t *testing.T) {
1,
)
}

func TestPredefinedPacket(t *testing.T) {
testTCPConn(t,
func(t *testing.T, cli *PacketIO) {
data, err := cli.ReadPacket()
require.NoError(t, err)
merr := ParseErrorPacket(data).(*mysql.MyError)
require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code)
require.Equal(t, "Unknown error", merr.Message)

data, err = cli.ReadPacket()
require.NoError(t, err)
merr = ParseErrorPacket(data).(*mysql.MyError)
require.Equal(t, uint16(mysql.ER_UNKNOWN_ERROR), merr.Code)
require.Equal(t, "test error", merr.Message)

data, err = cli.ReadPacket()
require.NoError(t, err)
res := ParseOKPacket(data)
require.Equal(t, uint16(100), res.Status)
},
func(t *testing.T, srv *PacketIO) {
require.NoError(t, srv.WriteErrPacket(mysql.NewDefaultError(mysql.ER_UNKNOWN_ERROR)))
require.NoError(t, srv.WriteErrPacket(mysql.NewError(mysql.ER_UNKNOWN_ERROR, "test error")))
require.NoError(t, srv.WriteOKPacket(100, OKHeader))
},
1,
)
}