From eeb0187588f62271c77215de01722978a6b06c3b Mon Sep 17 00:00:00 2001 From: djshow832 Date: Mon, 16 Jan 2023 11:51:39 +0800 Subject: [PATCH] backend: send the error to the client when the handler encounters an error (#187) --- pkg/proxy/backend/authenticator.go | 20 ++++-- pkg/proxy/backend/backend_conn_mgr.go | 34 ++++++---- pkg/proxy/backend/backend_conn_mgr_test.go | 53 ++++++++++++++- pkg/proxy/backend/cmd_processor_test.go | 5 +- pkg/proxy/backend/common_test.go | 4 +- pkg/proxy/backend/error.go | 75 ++++++++++++++++++++++ pkg/proxy/backend/mock_client_test.go | 9 ++- pkg/proxy/backend/testsuite_test.go | 5 +- 8 files changed, 177 insertions(+), 28 deletions(-) create mode 100644 pkg/proxy/backend/error.go diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index dfc6cc25..b26593f1 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -146,7 +146,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte clientResp := pnet.ParseHandshakeResponse(pkt) if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil { - return err + return WrapUserError(err, err.Error()) } auth.user = clientResp.User auth.dbname = clientResp.DB @@ -156,23 +156,29 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte // In case of testing, backendIO is passed manually that we don't want to bother with the routing logic. backendIO, err := getBackendIO(cctx, auth, clientResp, 5*time.Second) if err != nil { - return err + return WrapUserError(err, connectErrMsg) } backendIO.ResetSequence() // write proxy header if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil { - return err + return WrapUserError(err, handshakeErrMsg) } // read backend initial handshake - _, backendCapability, err := auth.readInitialHandshake(backendIO) + serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO) if err != nil { - return err + if IsMySQLError(err) { + if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil { + err = writeErr + } + return err + } + return WrapUserError(err, handshakeErrMsg) } if err := auth.verifyBackendCaps(logger, backendCapability); err != nil { - return err + return WrapUserError(err, capabilityErrMsg) } if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 { @@ -193,7 +199,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte // send an unknown auth plugin so that the backend will request the auth data again. unknownAuthPlugin, nil, 0, ); err != nil { - return err + return WrapUserError(err, handshakeErrMsg) } // forward other packets diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 2da424f8..156b5e1d 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -151,6 +151,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe err := mgr.authenticator.handshakeFirstTime(mgr.logger.Named("authenticator"), mgr, clientIO, mgr.handshakeHandler, mgr.getBackendIO, frontendTLSConfig, backendTLSConfig) mgr.handshakeHandler.OnHandshake(mgr, mgr.ServerAddr(), err) if err != nil { + WriteUserError(clientIO, err, mgr.logger) return err } @@ -166,7 +167,7 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO *pnet.Packe func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error) { r, err := mgr.handshakeHandler.GetRouter(cctx, resp) if err != nil { - return nil, err + return nil, WrapUserError(err, err.Error()) } // Reasons to wait: // - The TiDB instances may not be initialized yet @@ -174,31 +175,30 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato bctx, cancel := context.WithTimeout(context.Background(), timeout) selector := r.GetBackendSelector() var addr string + var origErr error io, err := backoff.RetryNotifyWithData( func() (*pnet.PacketIO, error) { // Try to connect to all backup backends one by one. - addr, err := selector.Next() + addr, err = selector.Next() + // If all addrs are enumerated, reset and try again. + if err == nil && addr == "" { + selector.Reset() + addr, err = selector.Next() + } if err != nil { - return nil, backoff.Permanent(err) + return nil, backoff.Permanent(WrapUserError(err, err.Error())) } - - // if all addrs are enumerated, reset and try again if addr == "" { - selector.Reset() - if addr, err = selector.Next(); err != nil { - return nil, backoff.Permanent(err) - } - if addr == "" { - return nil, router.ErrNoInstanceToSelect - } + return nil, router.ErrNoInstanceToSelect } - cn, err := net.DialTimeout("tcp", addr, DialTimeout) + var cn net.Conn + cn, err = net.DialTimeout("tcp", addr, DialTimeout) if err != nil { return nil, errors.Wrapf(err, "dial backend %s error", addr) } - if err := selector.Succeed(mgr); err != nil { + if err = selector.Succeed(mgr); err != nil { // Bad luck: the backend has been recycled or shut down just after the selector returns it. if ignoredErr := cn.Close(); ignoredErr != nil { mgr.logger.Error("close backend connection failed", zap.String("addr", addr), zap.Error(ignoredErr)) @@ -215,10 +215,16 @@ func (mgr *BackendConnManager) getBackendIO(cctx ConnContext, auth *Authenticato }, backoff.WithContext(backoff.NewConstantBackOff(200*time.Millisecond), bctx), func(err error, d time.Duration) { + origErr = err mgr.handshakeHandler.OnHandshake(cctx, addr, err) }, ) cancel() + if err != nil && errors.Is(err, context.DeadlineExceeded) { + if origErr != nil { + err = origErr + } + } return io, err } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 5329e41f..e3569503 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -689,6 +689,57 @@ func TestGracefulCloseBeforeHandshake(t *testing.T) { ts.runTests(runners) } +func TestHandlerReturnError(t *testing.T) { + tests := []struct { + cfg cfgOverrider + errMsg string + }{ + { + cfg: func(config *testConfig) { + config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return nil, errors.New("mocked error") + } + }, + errMsg: "mocked error", + }, + { + cfg: func(config *testConfig) { + config.proxyConfig.handler.handleHandshakeResp = func(ctx ConnContext, resp *pnet.HandshakeResp) error { + return errors.New("mocked error") + } + }, + errMsg: "mocked error", + }, + { + // TODO: make it fail faster. + cfg: func(config *testConfig) { + config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return router.NewStaticRouter(nil), nil + } + }, + errMsg: connectErrMsg, + }, + } + for _, test := range tests { + ts := newBackendMgrTester(t, test.cfg) + rn := runner{ + client: func(packetIO *pnet.PacketIO) error { + err := ts.mc.authenticate(packetIO) + require.NoError(t, err) + require.ErrorContains(t, ts.mc.mysqlErr, test.errMsg) + return nil + }, + proxy: func(clientIO, backendIO *pnet.PacketIO) error { + err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig) + require.Error(t, err) + return nil + }, + backend: nil, + } + ts.runAndCheck(ts.t, func(t *testing.T, ts *testSuite) {}, rn.client, rn.backend, rn.proxy) + } +} + func TestGetBackendIO(t *testing.T) { addrs := make([]string, 0, 3) listeners := make([]net.Listener, 0, cap(addrs)) @@ -732,7 +783,7 @@ func TestGetBackendIO(t *testing.T) { err = listeners[i].Close() require.NoError(t, err, message) } else { - require.ErrorIs(t, err, context.DeadlineExceeded, message) + require.Error(t, err, message) } require.True(t, len(badAddrs) <= i, message) badAddrs = make(map[string]struct{}, 3) diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index c5a8bd02..cd859e04 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -1036,11 +1036,12 @@ func TestNetworkError(t *testing.T) { } clientErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) - require.True(t, pnet.IsDisconnectError(ts.mp.err)) + require.True(t, pnet.IsDisconnectError(ts.mc.err)) + require.NotNil(t, ts.mp.err.(*UserError)) } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) - require.True(t, pnet.IsDisconnectError(ts.mp.err)) + require.True(t, pnet.IsDisconnectError(ts.mb.err)) } proxyErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) diff --git a/pkg/proxy/backend/common_test.go b/pkg/proxy/backend/common_test.go index 17c69e31..73d71963 100644 --- a/pkg/proxy/backend/common_test.go +++ b/pkg/proxy/backend/common_test.go @@ -87,7 +87,9 @@ func (tc *tcpConnSuite) newConn(t *testing.T, enableRoute bool) func() { if tc.proxyBIO != nil { _ = tc.proxyBIO.Close() } - _ = tc.backendIO.Close() + if tc.backendIO != nil { + _ = tc.backendIO.Close() + } } } diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go new file mode 100644 index 00000000..6d8afb3a --- /dev/null +++ b/pkg/proxy/backend/error.go @@ -0,0 +1,75 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "github.com/pingcap/TiProxy/lib/util/errors" + pnet "github.com/pingcap/TiProxy/pkg/proxy/net" + "github.com/pingcap/tidb/parser/mysql" + "go.uber.org/zap" +) + +const ( + connectErrMsg = "No available TiDB instances, please check TiDB cluster" + handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network" + capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB" +) + +// UserError is returned to the client. +// err is used to log and userMsg is used to report to the user. +type UserError struct { + err error + userMsg string +} + +func WrapUserError(err error, userMsg string) *UserError { + if err == nil { + return nil + } + if ue, ok := err.(*UserError); ok { + return ue + } + return &UserError{ + err: err, + userMsg: userMsg, + } +} + +func (ue *UserError) UserMsg() string { + return ue.userMsg +} + +func (ue *UserError) Unwrap() error { + return ue.err +} + +func (ue *UserError) Error() string { + return ue.err.Error() +} + +// WriteUserError writes an unknown error to the client. +func WriteUserError(clientIO *pnet.PacketIO, err error, lg *zap.Logger) { + if err == nil { + return + } + var ue *UserError + if !errors.As(err, &ue) { + return + } + myErr := mysql.NewErrf(mysql.ErrUnknown, "%s", nil, ue.UserMsg()) + if writeErr := clientIO.WriteErrPacket(myErr); writeErr != nil { + lg.Error("writing error to client failed", zap.NamedError("mysql_err", err), zap.NamedError("write_err", writeErr)) + } +} diff --git a/pkg/proxy/backend/mock_client_test.go b/pkg/proxy/backend/mock_client_test.go index 322f4675..e11865c4 100644 --- a/pkg/proxy/backend/mock_client_test.go +++ b/pkg/proxy/backend/mock_client_test.go @@ -62,6 +62,7 @@ type mockClient struct { *clientConfig // Outputs that received from the server and will be checked by the test. authSucceed bool + mysqlErr error } func newMockClient(cfg *clientConfig) *mockClient { @@ -117,6 +118,7 @@ func (mc *mockClient) writePassword(packetIO *pnet.PacketIO) error { return nil case mysql.ErrHeader: mc.authSucceed = false + mc.mysqlErr = pnet.ParseErrorPacket(serverPkt) return nil case mysql.AuthSwitchRequest, pnet.ShaCommand: if err := packetIO.WritePacket(mc.authData, true); err != nil { @@ -182,7 +184,10 @@ func (mc *mockClient) requestChangeUser(packetIO *pnet.PacketIO) error { return err } switch resp[0] { - case mysql.OKHeader, mysql.ErrHeader: + case mysql.OKHeader: + return nil + case mysql.ErrHeader: + mc.mysqlErr = pnet.ParseErrorPacket(resp) return nil default: if err := packetIO.WritePacket(mc.authData, true); err != nil { @@ -268,6 +273,7 @@ func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, e return } if pkt[0] == mysql.ErrHeader { + mc.mysqlErr = pnet.ParseErrorPacket(pkt) return } if mc.capability&pnet.ClientDeprecateEOF == 0 { @@ -311,6 +317,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error { case mysql.OKHeader: serverStatus = binary.LittleEndian.Uint16(pkt[3:]) case mysql.ErrHeader: + mc.mysqlErr = pnet.ParseErrorPacket(pkt) return nil case mysql.LocalInFileHeader: for i := 0; i < mc.filePkts; i++ { diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index 609b9f9a..c55528a9 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -122,7 +122,7 @@ type checker func(t *testing.T, ts *testSuite) func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (*testSuite, func()) { ts := &testSuite{} - cfg := newTestConfig(append(overriders, func(config *testConfig) { + overriders = append([]cfgOverrider{func(config *testConfig) { config.backendConfig.tlsConfig = tc.backendTLSConfig config.proxyConfig.backendTLSConfig = tc.clientTLSConfig config.proxyConfig.frontendTLSConfig = tc.backendTLSConfig @@ -130,7 +130,8 @@ func newTestSuite(t *testing.T, tc *tcpConnSuite, overriders ...cfgOverrider) (* config.proxyConfig.handler.getRouter = func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { return router.NewStaticRouter([]string{ts.tc.backendListener.Addr().String()}), nil } - })...) + }}, overriders...) + cfg := newTestConfig(overriders...) ts.mb = newMockBackend(cfg.backendConfig) ts.mp = newMockProxy(t, cfg.proxyConfig) ts.mc = newMockClient(cfg.clientConfig)