Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend, net: return handshake errors to the client #294

Merged
merged 3 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,19 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword, handshakeHandler.GetServerVersion()); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp(logger)
if err != nil {
return err
}
frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt))
if isSSL {
if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil {
return err
return pnet.WrapUserError(err, err.Error())
}
pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp()
pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp(logger)
if err != nil {
return err
}
if len(pkt) <= 32 {
return errors.WithStack(errors.New("expect handshake resp"))
}
Comment on lines -127 to -129
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already checked in ReadSSLRequestOrHandshakeResp.

frontendCapabilityResponse := pnet.Capability(binary.LittleEndian.Uint32(pkt))
if frontendCapability != frontendCapabilityResponse {
common := frontendCapability & frontendCapabilityResponse
Expand All @@ -137,7 +134,11 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}
if commonCaps := frontendCapability & requiredFrontendCaps; commonCaps != requiredFrontendCaps {
logger.Error("require frontend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredFrontendCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
mysqlErr := mysql.NewErr(mysql.ErrNotSupportedAuthMode)
if writeErr := clientIO.WriteErrPacket(mysqlErr); writeErr != nil {
return writeErr
}
return mysqlErr
}
commonCaps := frontendCapability & proxyCapability
if frontendCapability^commonCaps != 0 {
Expand All @@ -159,10 +160,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
if errors.As(err, &warning) {
logger.Warn("parse handshake response encounters error", zap.Error(err))
} else if err != nil {
return WrapUserError(err, parsePktErrMsg)
return pnet.WrapUserError(err, parsePktErrMsg)
}
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return WrapUserError(err, err.Error())
return pnet.WrapUserError(err, err.Error())
}
auth.user = clientResp.User
auth.dbname = clientResp.DB
Expand All @@ -172,13 +173,13 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
if err != nil {
return WrapUserError(err, connectErrMsg)
return pnet.WrapUserError(err, connectErrMsg)
}
backendIO.ResetSequence()

// write proxy header
if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil {
return WrapUserError(err, handshakeErrMsg)
return pnet.WrapUserError(err, handshakeErrMsg)
}

// read backend initial handshake
Expand All @@ -190,11 +191,11 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
}
return err
}
return WrapUserError(err, handshakeErrMsg)
return pnet.WrapUserError(err, handshakeErrMsg)
}

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return WrapUserError(err, capabilityErrMsg)
return pnet.WrapUserError(err, capabilityErrMsg)
}

if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
Expand All @@ -215,7 +216,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
// send an unknown auth plugin so that the backend will request the auth data again.
unknownAuthPlugin, nil, 0,
); err != nil {
return WrapUserError(err, handshakeErrMsg)
return pnet.WrapUserError(err, handshakeErrMsg)
}

// forward other packets
Expand Down
6 changes: 3 additions & 3 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
if err != nil {
mgr.setQuitSourceByErr(err)
mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err)
WriteUserError(clientIO, err, mgr.logger)
clientIO.WriteUserError(err, mgr.logger)
return err
}
mgr.resetQuitSource()
Expand All @@ -193,7 +193,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe
func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) {
r, err := mgr.handshakeHandler.GetRouter(cctx, resp)
if err != nil {
return nil, WrapUserError(err, err.Error())
return nil, pnet.WrapUserError(err, err.Error())
}
// Reasons to wait:
// - The TiDB instances may not be initialized yet
Expand All @@ -213,7 +213,7 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato
addr, err = selector.Next()
}
if err != nil {
return nil, backoff.Permanent(WrapUserError(err, err.Error()))
return nil, backoff.Permanent(pnet.WrapUserError(err, err.Error()))
}
if addr == "" {
return nil, router.ErrNoInstanceToSelect
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/cmd_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ func TestNetworkError(t *testing.T) {
clientErrChecker := func(t *testing.T, ts *testSuite) {
require.True(t, pnet.IsDisconnectError(ts.mp.err))
require.True(t, pnet.IsDisconnectError(ts.mc.err))
require.NotNil(t, ts.mp.err.(*UserError))
require.NotNil(t, ts.mp.err.(*pnet.UserError))
}
backendErrChecker := func(t *testing.T, ts *testSuite) {
require.True(t, pnet.IsDisconnectError(ts.mp.err))
Expand Down
50 changes: 0 additions & 50 deletions pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ package backend

import (
"github.com/pingcap/TiProxy/lib/util/errors"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/parser/mysql"
"go.uber.org/zap"
)

const (
Expand All @@ -32,50 +29,3 @@ var (
ErrClientConn = errors.New("this is an error from client")
ErrBackendConn = errors.New("this is an error from backend")
)

// UserError is returned to the client.
// err is used to log and userMsg is used to report to the user.
type UserError struct {
err error
userMsg string
}

func WrapUserError(err error, userMsg string) *UserError {
if err == nil {
return nil
}
if ue, ok := err.(*UserError); ok {
return ue
}
return &UserError{
err: err,
userMsg: userMsg,
}
}

func (ue *UserError) UserMsg() string {
return ue.userMsg
}

func (ue *UserError) Unwrap() error {
return ue.err
}

func (ue *UserError) Error() string {
return ue.err.Error()
}

// WriteUserError writes an unknown error to the client.
func WriteUserError(clientIO *pnet.PacketIO, err error, lg *zap.Logger) {
if err == nil {
return
}
var ue *UserError
if !errors.As(err, &ue) {
return
}
myErr := mysql.NewErrf(mysql.ErrUnknown, "%s", nil, ue.UserMsg())
if writeErr := clientIO.WriteErrPacket(myErr); writeErr != nil {
lg.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr))
}
}
36 changes: 35 additions & 1 deletion pkg/proxy/net/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

package net

import "github.com/pingcap/TiProxy/lib/util/errors"
import (
"github.com/pingcap/TiProxy/lib/util/errors"
)

var (
ErrExpectSSLRequest = errors.New("expect a SSLRequest packet")
Expand All @@ -24,3 +26,35 @@ var (
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
)

// UserError is returned to the client.
// err is used to log and userMsg is used to report to the user.
type UserError struct {
err error
userMsg string
}

func WrapUserError(err error, userMsg string) *UserError {
if err == nil {
return nil
}
if ue, ok := err.(*UserError); ok {
return ue
}
return &UserError{
err: err,
userMsg: userMsg,
}
}

func (ue *UserError) UserMsg() string {
return ue.userMsg
}

func (ue *UserError) Unwrap() error {
return ue.err
}

func (ue *UserError) Error() string {
return ue.err.Error()
}
21 changes: 19 additions & 2 deletions pkg/proxy/net/packetio_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/tidb/parser/mysql"
"go.uber.org/zap"
)

var (
Expand Down Expand Up @@ -86,14 +87,15 @@ func (p *PacketIO) WriteShaCommand() error {
return p.WritePacket([]byte{ShaCommand, FastAuthFail}, true)
}

func (p *PacketIO) ReadSSLRequestOrHandshakeResp() (pkt []byte, isSSL bool, err error) {
func (p *PacketIO) ReadSSLRequestOrHandshakeResp(lg *zap.Logger) (pkt []byte, isSSL bool, err error) {
pkt, err = p.ReadPacket()
if err != nil {
return
}

if len(pkt) < 32 {
err = errors.WithStack(errors.Errorf("%w: but got less than 32 bytes", ErrExpectSSLRequest))
lg.Error("got malformed handshake response", zap.ByteString("packetData", pkt))
err = WrapUserError(mysql.ErrMalformPacket, mysql.ErrMalformPacket.Error())
return
}

Expand Down Expand Up @@ -134,3 +136,18 @@ func (p *PacketIO) WriteEOFPacket(status uint16) error {
data = DumpUint16(data, status)
return p.WritePacket(data, true)
}

// WriteUserError writes an unknown error to the client.
func (p *PacketIO) WriteUserError(err error, lg *zap.Logger) {
djshow832 marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
return
}
var ue *UserError
if !errors.As(err, &ue) {
return
}
myErr := mysql.NewErrf(mysql.ErrUnknown, "%s", nil, ue.UserMsg())
if writeErr := p.WriteErrPacket(myErr); writeErr != nil {
lg.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr))
}
}
6 changes: 4 additions & 2 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/logger"
"github.com/pingcap/TiProxy/lib/util/security"
"github.com/pingcap/TiProxy/pkg/testkit"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -54,6 +55,7 @@ func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T,
func TestPacketIO(t *testing.T) {
expectMsg := []byte("test")
pktLengths := []int{0, mysql.MaxPayloadLen + 212, mysql.MaxPayloadLen, mysql.MaxPayloadLen * 2}
lg := logger.CreateLoggerForTest(t).Named("TestPacketIO")
testPipeConn(t,
func(t *testing.T, cli *PacketIO) {
var err error
Expand Down Expand Up @@ -107,10 +109,10 @@ func TestPacketIO(t *testing.T) {
require.ErrorIs(t, srv.WriteInitialHandshake(0, make([]byte, 4), mysql.AuthNativePassword, ServerVersion), ErrSaltNotLongEnough)

// expect correct and wrong capability flags
_, isSSL, err := srv.ReadSSLRequestOrHandshakeResp()
_, isSSL, err := srv.ReadSSLRequestOrHandshakeResp(lg)
require.NoError(t, err)
require.True(t, isSSL)
_, isSSL, err = srv.ReadSSLRequestOrHandshakeResp()
_, isSSL, err = srv.ReadSSLRequestOrHandshakeResp(lg)
require.NoError(t, err)
require.False(t, isSSL)
},
Expand Down