Skip to content

Commit

Permalink
proxy: remove NamespaceManager from HandshakeHandler interface (pingc…
Browse files Browse the repository at this point in the history
  • Loading branch information
disksing authored and xhebox committed Mar 13, 2023
1 parent 5100029 commit a0b5eb3
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 23 deletions.
7 changes: 2 additions & 5 deletions pkg/proxy/backend/backend_conn_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
"github.com/pingcap/TiProxy/pkg/manager/namespace"
"github.com/pingcap/TiProxy/pkg/manager/router"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/parser/mysql"
Expand Down Expand Up @@ -84,20 +83,18 @@ type BackendConnManager struct {
// cancelFunc is used to cancel the signal processing goroutine.
cancelFunc context.CancelFunc
backendConn *BackendConnection
nsmgr *namespace.NamespaceManager
handshakeHandler HandshakeHandler
getBackendIO backendIOGetter
connectionID uint64
}

// NewBackendConnManager creates a BackendConnManager.
func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager, handshakeHandler HandshakeHandler,
func NewBackendConnManager(logger *zap.Logger, handshakeHandler HandshakeHandler,
connectionID uint64, proxyProtocol, requireBackendTLS bool) *BackendConnManager {
mgr := &BackendConnManager{
logger: logger,
connectionID: connectionID,
cmdProcessor: NewCmdProcessor(),
nsmgr: nsmgr,
handshakeHandler: handshakeHandler,
authenticator: &Authenticator{
supportedServerCapabilities: handshakeHandler.GetCapability(),
Expand All @@ -109,7 +106,7 @@ func NewBackendConnManager(logger *zap.Logger, nsmgr *namespace.NamespaceManager
redirectResCh: make(chan *redirectResult, 1),
}
mgr.getBackendIO = func(auth *Authenticator, resp *pnet.HandshakeResp) (*pnet.PacketIO, error) {
ns, err := handshakeHandler.GetNamespace(nsmgr, resp)
ns, err := handshakeHandler.GetNamespace(resp)
if err != nil {
return nil, err
}
Expand Down
15 changes: 9 additions & 6 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ var _ HandshakeHandler = (*DefaultHandshakeHandler)(nil)
type HandshakeHandler interface {
HandleHandshakeResp(resp *pnet.HandshakeResp, sourceAddr string) error
GetCapability() pnet.Capability
GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error)
GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error)
}

type DefaultHandshakeHandler struct {
nsManager *namespace.NamespaceManager
}

func NewDefaultHandshakeHandler() *DefaultHandshakeHandler {
return &DefaultHandshakeHandler{}
func NewDefaultHandshakeHandler(nsManager *namespace.NamespaceManager) *DefaultHandshakeHandler {
return &DefaultHandshakeHandler{
nsManager: nsManager,
}
}

func (handler *DefaultHandshakeHandler) HandleHandshakeResp(*pnet.HandshakeResp, string) error {
Expand All @@ -43,10 +46,10 @@ func (handler *DefaultHandshakeHandler) GetCapability() pnet.Capability {
return SupportedServerCapabilities
}

func (handler *DefaultHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) {
ns, ok := nsMgr.GetNamespaceByUser(resp.User)
func (handler *DefaultHandshakeHandler) GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error) {
ns, ok := handler.nsManager.GetNamespaceByUser(resp.User)
if !ok {
ns, ok = nsMgr.GetNamespace("default")
ns, ok = handler.nsManager.GetNamespace("default")
}
if !ok {
return nil, errors.New("failed to find a namespace")
Expand Down
6 changes: 3 additions & 3 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type proxyConfig struct {

func newProxyConfig() *proxyConfig {
return &proxyConfig{
handler: NewDefaultHandshakeHandler(),
handler: NewDefaultHandshakeHandler(nil),
capability: defaultTestBackendCapability,
sessionToken: mockToken,
}
Expand All @@ -57,7 +57,7 @@ func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
mp := &mockProxy{
proxyConfig: cfg,
logger: logger.CreateLoggerForTest(t).Named("mockProxy"),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), nil, cfg.handler, 0, false, false),
BackendConnManager: NewBackendConnManager(logger.CreateLoggerForTest(t), cfg.handler, 0, false, false),
}
mp.cmdProcessor.capability = cfg.capability
return mp
Expand Down Expand Up @@ -107,7 +107,7 @@ type CustomHandshakeHandler struct {
outAttrs map[string]string
}

func (handler *CustomHandshakeHandler) GetNamespace(nsMgr *namespace.NamespaceManager, resp *pnet.HandshakeResp) (*namespace.Namespace, error) {
func (handler *CustomHandshakeHandler) GetNamespace(resp *pnet.HandshakeResp) (*namespace.Namespace, error) {
return &namespace.Namespace{}, nil
}

Expand Down
5 changes: 2 additions & 3 deletions pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"net"

"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/pkg/manager/namespace"
"github.com/pingcap/TiProxy/pkg/proxy/backend"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"github.com/pingcap/tidb/parser/mysql"
Expand All @@ -41,8 +40,8 @@ type ClientConnection struct {
}

func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config,
nsmgr *namespace.NamespaceManager, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), nsmgr, backend.NewDefaultHandshakeHandler(), connID, proxyProtocol, requireBackendTLS)
hsHandler backend.HandshakeHandler, connID uint64, proxyProtocol, requireBackendTLS bool) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, proxyProtocol, requireBackendTLS)
opts := make([]pnet.PacketIOption, 0, 2)
opts = append(opts, pnet.WithWrapError(ErrClientConn))
if proxyProtocol {
Expand Down
10 changes: 5 additions & 5 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/waitgroup"
"github.com/pingcap/TiProxy/pkg/manager/cert"
mgrns "github.com/pingcap/TiProxy/pkg/manager/namespace"
"github.com/pingcap/TiProxy/pkg/metrics"
"github.com/pingcap/TiProxy/pkg/proxy/backend"
"github.com/pingcap/TiProxy/pkg/proxy/client"
pnet "github.com/pingcap/TiProxy/pkg/proxy/net"
"go.uber.org/zap"
Expand All @@ -43,21 +43,21 @@ type SQLServer struct {
listener net.Listener
logger *zap.Logger
certMgr *cert.CertManager
nsmgr *mgrns.NamespaceManager
hsHandler backend.HandshakeHandler
requireBackendTLS bool
wg waitgroup.WaitGroup

mu serverState
}

// NewSQLServer creates a new SQLServer.
func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, nsmgr *mgrns.NamespaceManager) (*SQLServer, error) {
func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.CertManager, hsHandler backend.HandshakeHandler) (*SQLServer, error) {
var err error

s := &SQLServer{
logger: logger,
certMgr: certMgr,
nsmgr: nsmgr,
hsHandler: hsHandler,
requireBackendTLS: cfg.RequireBackendTLS,
mu: serverState{
connID: 0,
Expand Down Expand Up @@ -124,7 +124,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
connID := s.mu.connID
s.mu.connID++
logger := s.logger.With(zap.Uint64("connID", connID), zap.String("remoteAddr", conn.RemoteAddr().String()))
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.nsmgr, connID, s.mu.proxyProtocol, s.requireBackendTLS)
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(), s.hsHandler, connID, s.mu.proxyProtocol, s.requireBackendTLS)
s.mu.clients[connID] = clientConn
s.mu.Unlock()

Expand Down
4 changes: 3 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/pingcap/TiProxy/pkg/manager/router"
"github.com/pingcap/TiProxy/pkg/metrics"
"github.com/pingcap/TiProxy/pkg/proxy"
"github.com/pingcap/TiProxy/pkg/proxy/backend"
"github.com/pingcap/TiProxy/pkg/sctx"
"github.com/pingcap/TiProxy/pkg/server/api"
clientv3 "go.etcd.io/etcd/client/v3"
Expand Down Expand Up @@ -194,7 +195,8 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error)

// setup proxy server
{
srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, srv.NamespaceManager)
hsHandler := backend.NewDefaultHandshakeHandler(srv.NamespaceManager)
srv.Proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg.Proxy, srv.CertManager, hsHandler)
if err != nil {
err = errors.WithStack(err)
return
Expand Down

0 comments on commit a0b5eb3

Please sign in to comment.