Skip to content

Commit

Permalink
server: support custom handler for Server (pingcap#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox committed Mar 7, 2023
1 parent e097ad1 commit e6bf2ad
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 44 deletions.
20 changes: 6 additions & 14 deletions cmd/tiproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ func main() {
configFile := rootCmd.PersistentFlags().String("config", "conf/proxy.yaml", "proxy config file path")
logEncoder := rootCmd.PersistentFlags().String("log_encoder", "tidb", "log in format of tidb, console, or json")
logLevel := rootCmd.PersistentFlags().String("log_level", "", "log level")
clusterName := rootCmd.PersistentFlags().String("cluster_name", "tiproxy", "default cluster name, used to generate node name and differential clusters in dns discovery")
nodeName := rootCmd.PersistentFlags().String("node_name", "", "by default, it is generate prefixed by cluster-name")
pubAddr := rootCmd.PersistentFlags().String("pub_addr", "127.0.0.1", "IP or domain, will be used as the accessible addr for others")
bootstrapClusters := rootCmd.PersistentFlags().StringSlice("bootstrap_clusters", []string{}, "lists of other nodes in the cluster, e.g. 'n1=xxx,n2=xxx', where xx are IPs or domains")
bootstrapDiscoveryUrl := rootCmd.PersistentFlags().String("bootstrap_discovery_etcd", "", "etcd discovery service url")
bootstrapDiscoveryDNS := rootCmd.PersistentFlags().String("bootstrap_discovery_dns", "", "dns srv discovery")
_ = rootCmd.PersistentFlags().String("cluster_name", "tiproxy", "default cluster name, used to generate node name and differential clusters in dns discovery")
_ = rootCmd.PersistentFlags().String("node_name", "", "by default, it is generate prefixed by cluster-name")
_ = rootCmd.PersistentFlags().String("pub_addr", "127.0.0.1", "IP or domain, will be used as the accessible addr for others")
_ = rootCmd.PersistentFlags().StringSlice("bootstrap_clusters", []string{}, "lists of other nodes in the cluster, e.g. 'n1=xxx,n2=xxx', where xx are IPs or domains")
_ = rootCmd.PersistentFlags().String("bootstrap_discovery_etcd", "", "etcd discovery service url")
_ = rootCmd.PersistentFlags().String("bootstrap_discovery_dns", "", "dns srv discovery")

rootCmd.RunE = func(cmd *cobra.Command, _ []string) error {
proxyConfigData, err := os.ReadFile(*configFile)
Expand All @@ -62,14 +62,6 @@ func main() {

sctx := &sctx.Context{
Config: cfg,
Cluster: sctx.Cluster{
PubAddr: *pubAddr,
ClusterName: *clusterName,
NodeName: *nodeName,
BootstrapDurl: *bootstrapDiscoveryUrl,
BootstrapDdns: *bootstrapDiscoveryDNS,
BootstrapClusters: *bootstrapClusters,
},
}

srv, err := server.NewServer(cmd.Context(), sctx)
Expand Down
21 changes: 17 additions & 4 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
package backend

import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"net"
"sync"

"github.com/pingcap/TiProxy/lib/util/errors"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
Expand Down Expand Up @@ -47,6 +47,7 @@ const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRo
// Authenticator handshakes with the client and the backend.
type Authenticator struct {
backendTLSConfig *tls.Config
ctxmap sync.Map
supportedServerCapabilities pnet.Capability
dbname string // default database name
serverAddr string
Expand Down Expand Up @@ -146,16 +147,16 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
auth.capability = commonCaps.Uint32()

resp := pnet.ParseHandshakeResponse(pkt)
ctx := context.WithValue(context.Background(), ContextKeyClientAddr, clientIO.SourceAddr().String())
if err = handshakeHandler.HandleHandshakeResp(ctx, resp); err != nil {
auth.SetValue(ContextKeyClientAddr, clientIO.SourceAddr().String())
if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil {
return err
}
auth.user = resp.User
auth.dbname = resp.DB
auth.collation = resp.Collation
auth.attrs = resp.Attrs

backendIO, err := getBackend(ctx, auth, resp)
backendIO, err := getBackend(auth, auth, resp)
if err != nil {
return err
}
Expand Down Expand Up @@ -347,3 +348,15 @@ func (auth *Authenticator) changeUser(username, db string) {
func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func (auth *Authenticator) SetValue(key, val any) {
auth.ctxmap.Store(key, val)
}

func (auth *Authenticator) Value(key any) any {
v, ok := auth.ctxmap.Load(key)
if !ok {
return nil
}
return v
}
4 changes: 2 additions & 2 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type redirectResult struct {
to string
}

type backendIOGetter func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error)
type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error)

// BackendConnManager migrates a session from one BackendConnection to another.
//
Expand Down Expand Up @@ -105,7 +105,7 @@ func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler
signalReceived: make(chan struct{}, 1),
redirectResCh: make(chan *redirectResult, 1),
}
mgr.getBackendIO = func(ctx context.Context, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) {
mgr.getBackendIO = func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) {
router, err := handshakeHandler.GetRouter(ctx, resp)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func newBackendMgrTester(t *testing.T, cfg ...cfgOverrider) *backendMgrTester {
return tester
}

func (ts *backendMgrTester) getBackendIO(ctx context.Context, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) {
func (ts *backendMgrTester) getBackendIO(ctx ConnContext, auth *Authenticator, _ *pnet.HandshakeResp) (*pnet.PacketIO, error) {
addr := ts.tc.backendListener.Addr().String()
ts.mp.backendConn = NewBackendConnection(addr)
if err := ts.mp.backendConn.Connect(); err != nil {
Expand Down
17 changes: 10 additions & 7 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package backend

import (
"context"

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/pkg/manager/namespace"
"github.com/pingcap/TiProxy/pkg/manager/router"
Expand All @@ -31,15 +29,20 @@ func (k contextKey) String() string {

// Context keys.
var (
ContextKeyClientAddr = contextKey("client_addr")
ContextKeyClientAddr contextKey = "client_addr"
)

var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil)

type ConnContext interface {
SetValue(key, val any)
Value(key any) any
}

type HandshakeHandler interface {
HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error
HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error
GetCapability() pnet.Capability
GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error)
GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error)
}

type DefaultHandshakeHandler struct {
Expand All @@ -52,15 +55,15 @@ func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultH
}
}

func (handler *DefaultHandshakeHandler) HandleHandshakeResp(context.Context, *pnet.HandshakeResp) error {
func (handler *DefaultHandshakeHandler) HandleHandshakeResp(ConnContext, *pnet.HandshakeResp) error {
return nil
}

func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability {
return SupportedServerCapabilities
}

func (handler *DefaultHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) {
func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
ns, ok := handler.nsManager.GetNamespaceByUser(resp.User)
if !ok {
ns, ok = handler.nsManager.GetNamespace("default")
Expand Down
7 changes: 3 additions & 4 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package backend

import (
"context"
"crypto/tls"
"testing"

Expand Down Expand Up @@ -65,7 +64,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
}

func (mp *mockProxy) authenticateFirstTime(clientIO, backendIO *pnet.PacketIO) error {
if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(context.Context, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) {
if err := mp.authenticator.handshakeFirstTime(mp.logger, clientIO, mp.handshakeHandler, func(ConnContext, *Authenticator, *pnet.HandshakeResp) (*pnet.PacketIO, error) {
return backendIO, nil
}, mp.frontendTLSConfig, mp.backendTLSConfig); err != nil {
return err
Expand Down Expand Up @@ -108,11 +107,11 @@ type CustomHandshakeHandler struct {
outAttrs map[string]string
}

func (handler *CustomHandshakeHandler) GetRouter(ctx context.Context, resp *pnet.HandshakeResp) (router.Router, error) {
func (handler *CustomHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
return nil, nil
}

func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx context.Context, resp *pnet.HandshakeResp) error {
func (handler *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error {
handler.inUsername = resp.User
resp.User = handler.outUsername
handler.inAddr = ctx.Value(ContextKeyClientAddr).(string)
Expand Down
22 changes: 11 additions & 11 deletions pkg/sctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

package sctx

import "github.com/pingcap/TiProxy/lib/config"

type Cluster struct {
PubAddr string
ClusterName string
NodeName string
BootstrapDurl string
BootstrapDdns string
BootstrapClusters []string
}
import (
"github.com/gin-gonic/gin"
"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/pkg/proxy/backend"
)

type Context struct {
Config *config.Config
Cluster Cluster
Handler ServerHandler
}

type ServerHandler interface {
backend.HandshakeHandler
RegisterHTTP(c *gin.Engine) error
}
14 changes: 13 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error)
}

cfg := sctx.Config
handler := sctx.Handler

// set up logger
var lg *zap.Logger
Expand Down Expand Up @@ -114,6 +115,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error)
srv.HTTPListener = tls.NewListener(srv.HTTPListener, tlscfg)
}

if handler != nil {
if err := handler.RegisterHTTP(engine); err != nil {
return nil, errors.WithStack(err)
}
}

srv.wg.Run(func() {
slogger.Info("HTTP closed", zap.Error(engine.RunListener(srv.HTTPListener)))
})
Expand Down Expand Up @@ -182,7 +189,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error)

// setup proxy server
{
hsHandler := backend.NewDefaultHandshakeHandler(srv.NamespaceManager)
var hsHandler backend.HandshakeHandler
if handler != nil {
hsHandler = handler
} else {
hsHandler = backend.NewDefaultHandshakeHandler(srv.NamespaceManager)
}
srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, hsHandler)
if err != nil {
err = errors.WithStack(err)
Expand Down

0 comments on commit e6bf2ad

Please sign in to comment.