diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index e4dcee1e..c4ea467f 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -31,6 +31,7 @@ var ( ErrCapabilityNegotiation = errors.New("capability negotiation failed") ) +const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF @@ -180,7 +181,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet } // Send an unknown auth plugin so that the backend will request the auth data again. - resp.AuthPlugin = "auth_unknown_plugin" + resp.AuthPlugin = unknownAuthPlugin resp.Capability = auth.capability if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil { diff --git a/pkg/proxy/backend/cmd_processor_exec.go b/pkg/proxy/backend/cmd_processor_exec.go index 1695aab8..9639fea9 100644 --- a/pkg/proxy/backend/cmd_processor_exec.go +++ b/pkg/proxy/backend/cmd_processor_exec.go @@ -44,15 +44,21 @@ func (cp *CmdProcessor) executeCmd(request []byte, clientIO, backendIO *pnet.Pac } return true, err } - - if err = backendIO.WritePacket(request, true); err != nil { - return false, err - } return false, cp.forwardCommand(clientIO, backendIO, request) } func (cp *CmdProcessor) forwardCommand(clientIO, backendIO *pnet.PacketIO, request []byte) error { cmd := request[0] + if cmd != mysql.ComChangeUser { + if err := backendIO.WritePacket(request, true); err != nil { + return err + } + } else { + user, db := pnet.ParseChangeUser(request) + if err := backendIO.WritePacket(pnet.MakeChangeUser(user, db, unknownAuthPlugin, nil), true); err != nil { + return err + } + } switch cmd { case mysql.ComStmtPrepare: return cp.forwardPrepareCmd(clientIO, backendIO) @@ -276,8 +282,6 @@ func (cp *CmdProcessor) forwardSendLongDataCmd(request []byte) error { } func (cp *CmdProcessor) forwardChangeUserCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error { - // Currently, TiDB responses with an OK or Err packet. But according to the MySQL doc, the server may send a - // switch auth request. for { response, err := forwardOnePacket(clientIO, backendIO, true) if err != nil { diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index ddb0f813..32733a83 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -169,7 +169,7 @@ func (mc *mockClient) request(packetIO *pnet.PacketIO) error { } func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error { - data := pnet.MakeChangeUser(mc.username, mc.dbName, mc.authData) + data := pnet.MakeChangeUser(mc.username, mc.dbName, mysql.AuthNativePassword, mc.authData) if err := packetIO.WritePacket(data, true); err != nil { return err } diff --git a/pkg/proxy/net/mysql.go b/pkg/proxy/net/mysql.go index 62492a2b..f9a4944f 100644 --- a/pkg/proxy/net/mysql.go +++ b/pkg/proxy/net/mysql.go @@ -216,7 +216,7 @@ func MakeHandshakeResponse(resp *HandshakeResp) []byte { } // MakeChangeUser creates the data of COM_CHANGE_USER. It's only used for testing. -func MakeChangeUser(username, db string, authData []byte) []byte { +func MakeChangeUser(username, db, authPlugin string, authData []byte) []byte { length := 1 + len(username) + 1 + len(authData) + 1 + len(db) + 1 data := make([]byte, 0, length) data = append(data, mysql.ComChangeUser) @@ -226,6 +226,13 @@ func MakeChangeUser(username, db string, authData []byte) []byte { data = append(data, authData...) data = append(data, []byte(db)...) data = append(data, 0x00) + + // character set + data = append(data, 0x00) + data = append(data, 0x00) + + data = append(data, []byte(authPlugin)...) + data = append(data, 0x00) return data }