From e6bf2ad45806a0004654ce0f915183a312a870d2 Mon Sep 17 00:00:00 2001 From: xhe Date: Mon, 12 Dec 2022 10:27:34 +0800 Subject: [PATCH] server: support custom handler for Server (#145) --- cmd/tiproxy/main.go | 20 ++++++-------------- pkg/proxy/backend/authenticator.go | 21 +++++++++++++++++---- pkg/proxy/backend/backend_conn_mgr.go | 4 ++-- pkg/proxy/backend/backend_conn_mgr_test.go | 2 +- pkg/proxy/backend/handshake_handler.go | 17 ++++++++++------- pkg/proxy/backend/mock_proxy_test.go | 7 +++---- pkg/sctx/context.go | 22 +++++++++++----------- pkg/server/server.go | 14 +++++++++++++- 8 files changed, 63 insertions(+), 44 deletions(-) diff --git a/cmd/tiproxy/main.go b/cmd/tiproxy/main.go index a58f24ec..c1f5dc18 100644 --- a/cmd/tiproxy/main.go +++ b/cmd/tiproxy/main.go @@ -35,12 +35,12 @@ func main() { configFile := rootCmd.PersistentFlags().String("config", "conf/proxy.yaml", "proxy config file path") logEncoder := rootCmd.PersistentFlags().String("log_encoder", "tidb", "log in format of tidb, console, or json") logLevel := rootCmd.PersistentFlags().String("log_level", "", "log level") - clusterName := rootCmd.PersistentFlags().String("cluster_name", "tiproxy", "default cluster name, used to generate node name and differential clusters in dns discovery") - nodeName := rootCmd.PersistentFlags().String("node_name", "", "by default, it is generate prefixed by cluster-name") - pubAddr := rootCmd.PersistentFlags().String("pub_addr", "127.0.0.1", "IP or domain, will be used as the accessible addr for others") - bootstrapClusters := rootCmd.PersistentFlags().StringSlice("bootstrap_clusters", []string{}, "lists of other nodes in the cluster, e.g. 'n1=xxx,n2=xxx', where xx are IPs or domains") - bootstrapDiscoveryUrl := rootCmd.PersistentFlags().String("bootstrap_discovery_etcd", "", "etcd discovery service url") - bootstrapDiscoveryDNS := rootCmd.PersistentFlags().String("bootstrap_discovery_dns", "", "dns srv discovery") + _ = rootCmd.PersistentFlags().String("cluster_name", "tiproxy", "default cluster name, used to generate node name and differential clusters in dns discovery") + _ = rootCmd.PersistentFlags().String("node_name", "", "by default, it is generate prefixed by cluster-name") + _ = rootCmd.PersistentFlags().String("pub_addr", "127.0.0.1", "IP or domain, will be used as the accessible addr for others") + _ = rootCmd.PersistentFlags().StringSlice("bootstrap_clusters", []string{}, "lists of other nodes in the cluster, e.g. 'n1=xxx,n2=xxx', where xx are IPs or domains") + _ = rootCmd.PersistentFlags().String("bootstrap_discovery_etcd", "", "etcd discovery service url") + _ = rootCmd.PersistentFlags().String("bootstrap_discovery_dns", "", "dns srv discovery") rootCmd.RunE = func(cmd *cobra.Command, _ []string) error { proxyConfigData, err := os.ReadFile(*configFile) @@ -62,14 +62,6 @@ func main() { sctx := &sctx.Context{ Config: cfg, - Cluster: sctx.Cluster{ - PubAddr: *pubAddr, - ClusterName: *clusterName, - NodeName: *nodeName, - BootstrapDurl: *bootstrapDiscoveryUrl, - BootstrapDdns: *bootstrapDiscoveryDNS, - BootstrapClusters: *bootstrapClusters, - }, } srv, err := server.NewServer(cmd.Context(), sctx) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 367c6ba0..6954383a 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -15,11 +15,11 @@ package backend import ( - "context" "crypto/tls" "encoding/binary" "fmt" "net" + "sync" "github.com/pingcap/TiProxy/lib/util/errors" pnet "github.com/pingcap/TiProxy/pkg/proxy/net" @@ -47,6 +47,7 @@ const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRo // Authenticator handshakes with the client and the backend. type Authenticator struct { backendTLSConfig *tls.Config + ctxmap sync.Map supportedServerCapabilities pnet.Capability dbname string // default database name serverAddr string @@ -146,8 +147,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.capability = commonCaps.Uint32() resp := pnet.ParseHandshakeResponse(pkt) - ctx := context.WithValue(context.Background(), ContextKeyClientAddr, clientIO.SourceAddr().String()) - if err = handshakeHandler.HandleHandshakeResp(ctx, resp); err != nil { + auth.SetValue(ContextKeyClientAddr, clientIO.SourceAddr().String()) + if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil { return err } auth.user = resp.User @@ -155,7 +156,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet auth.collation = resp.Collation auth.attrs = resp.Attrs - backendIO, err := getBackend(ctx, auth, resp) + backendIO, err := getBackend(auth, auth, resp) if err != nil { return err } @@ -347,3 +348,15 @@ func (auth *Authenticator) changeUser(username, db string) { func (auth *Authenticator) updateCurrentDB(db string) { auth.dbname = db } + +func (auth *Authenticator) SetValue(key, val any) { + auth.ctxmap.Store(key, val) +} + +func (auth *Authenticator) Value(key any) any { + v, ok := auth.ctxmap.Load(key) + if !ok { + return nil + } + return v +} diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 8107256d..e2d66e41 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -54,7 +54,7 @@ type redirectResult struct { to string } -type backendIOGetter func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) +type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) // BackendConnManager migrates a session from one BackendConnection to another. // @@ -105,7 +105,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler signalReceived: make(chan struct{}, 1), redirectResCh: make(chan *redirectResult, 1), } - mgr.getBackendIO = func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { + mgr.getBackendIO = func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) { router, err := handshakeHandler.GetRouter(ctx, resp) if err != nil { return nil, err diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index b3a29e0a..8f5e99b8 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -116,7 +116,7 @@ func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester { return tester } -func (ts *backendMgrTester) getBackendIO(ctx context.Context, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { +func (ts *backendMgrTester) getBackendIO(ctx ConnContext, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) { addr := ts.tc.backendListener.Addr().String() ts.mp.backendConn = NewBackendConnection(addr) if err := ts.mp.backendConn.Connect(); err != nil { diff --git a/pkg/proxy/backend/handshake_handler.go b/pkg/proxy/backend/handshake_handler.go index 32be5dfe..028fdcef 100644 --- a/pkg/proxy/backend/handshake_handler.go +++ b/pkg/proxy/backend/handshake_handler.go @@ -15,8 +15,6 @@ package backend import ( - "context" - "github.com/pingcap/TiProxy/lib/util/errors" "github.com/pingcap/TiProxy/pkg/manager/namespace" "github.com/pingcap/TiProxy/pkg/manager/router" @@ -31,15 +29,20 @@ func (k contextKey) String() string { // Context keys. var ( - ContextKeyClientAddr = contextKey("client_addr") + ContextKeyClientAddr contextKey = "client_addr" ) var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil) +type ConnContext interface { + SetValue(key, val any) + Value(key any) any +} + type HandshakeHandler interface { - HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error + HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error GetCapability() pnet.Capability - GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) + GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) } type DefaultHandshakeHandler struct { @@ -52,7 +55,7 @@ func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultH } } -func (handler *DefaultHandshakeHandler) HandleHandshakeResp(context.Context, *pnet.HandshakeResp) error { +func (handler *DefaultHandshakeHandler) HandleHandshakeResp(ConnContext, *pnet.HandshakeResp) error { return nil } @@ -60,7 +63,7 @@ func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability { return SupportedServerCapabilities } -func (handler *DefaultHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) { +func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { ns, ok := handler.nsManager.GetNamespaceByUser(resp.User) if !ok { ns, ok = handler.nsManager.GetNamespace("default") diff --git a/pkg/proxy/backend/mock_proxy_test.go b/pkg/proxy/backend/mock_proxy_test.go index 706a2369..7dc07fd0 100644 --- a/pkg/proxy/backend/mock_proxy_test.go +++ b/pkg/proxy/backend/mock_proxy_test.go @@ -15,7 +15,6 @@ package backend import ( - "context" "crypto/tls" "testing" @@ -65,7 +64,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy { } func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error { - if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(context.Context, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { + if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) { return backendIO, nil }, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil { return err @@ -108,11 +107,11 @@ type CustomHandshakeHandler struct { outAttrs map[string]string } -func (handler *CustomHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) { +func (handler *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { return nil, nil } -func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error { +func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error { handler.inUsername = resp.User resp.User = handler.outUsername handler.inAddr = ctx.Value(ContextKeyClientAddr).(string) diff --git a/pkg/sctx/context.go b/pkg/sctx/context.go index 34c7f6b7..f9cd5e08 100644 --- a/pkg/sctx/context.go +++ b/pkg/sctx/context.go @@ -14,18 +14,18 @@ package sctx -import "github.com/pingcap/TiProxy/lib/config" - -type Cluster struct { - PubAddr string - ClusterName string - NodeName string - BootstrapDurl string - BootstrapDdns string - BootstrapClusters []string -} +import ( + "github.com/gin-gonic/gin" + "github.com/pingcap/TiProxy/lib/config" + "github.com/pingcap/TiProxy/pkg/proxy/backend" +) type Context struct { Config *config.Config - Cluster Cluster + Handler ServerHandler +} + +type ServerHandler interface { + backend.HandshakeHandler + RegisterHTTP(c *gin.Engine) error } diff --git a/pkg/server/server.go b/pkg/server/server.go index 27960e55..df117b23 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -68,6 +68,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) } cfg := sctx.Config + handler := sctx.Handler // set up logger var lg *zap.Logger @@ -114,6 +115,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) srv.HTTPListener = tls.NewListener(srv.HTTPListener, tlscfg) } + if handler != nil { + if err := handler.RegisterHTTP(engine); err != nil { + return nil, errors.WithStack(err) + } + } + srv.wg.Run(func() { slogger.Info("HTTP closed", zap.Error(engine.RunListener(srv.HTTPListener))) }) @@ -182,7 +189,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { - hsHandler := backend.NewDefaultHandshakeHandler(srv.NamespaceManager) + var hsHandler backend.HandshakeHandler + if handler != nil { + hsHandler = handler + } else { + hsHandler = backend.NewDefaultHandshakeHandler(srv.NamespaceManager) + } srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, hsHandler) if err != nil { err = errors.WithStack(err)