Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: use BackendConnMgr as ConnContext #172

Merged
merged 5 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/manager/router/backend_selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ func (bs *BackendSelector) Reset() {
}

func (bs *BackendSelector) Next() string {
if len(bs.cur) > 0 {
bs.cur = bs.routeOnce(bs.excluded)
if bs.cur != "" {
bs.excluded = append(bs.excluded, bs.cur)
}
bs.cur = bs.routeOnce(bs.excluded)
return bs.cur
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/manager/router/router_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ package router
var _ Router = &StaticRouter{}

type StaticRouter struct {
addr []string
cnt int
addrs []string
cnt int
}

func NewStaticRouter(addr []string) *StaticRouter {
return &StaticRouter{addr: addr}
return &StaticRouter{addrs: addr}
}

func (r *StaticRouter) GetBackendSelector() BackendSelector {
return BackendSelector{
routeOnce: func(excluded []string) string {
for _, addr := range r.addr {
for _, addr := range r.addrs {
found := false
for _, e := range excluded {
if e == addr {
Expand Down
153 changes: 59 additions & 94 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"encoding/binary"
"fmt"
"net"
"sync"
"time"

"github.com/pingcap/TiProxy/lib/util/errors"
Expand Down Expand Up @@ -47,19 +46,14 @@ const SupportedServerCapabilities = pnet.ClientLongPassword | pnet.ClientFoundRo

// Authenticator handshakes with the client and the backend.
type Authenticator struct {
backendTLSConfig *tls.Config
ctxmap sync.Map
supportedServerCapabilities pnet.Capability
dbname string // default database name
serverAddr string
clientAddr string
user string
attrs map[string]string
salt []byte
capability uint32 // client capability
collation uint8
proxyProtocol bool
requireBackendTLS bool
dbname string // default database name
user string
attrs map[string]string
salt []byte
capability uint32 // client capability
collation uint8
proxyProtocol bool
requireBackendTLS bool
}

func (auth *Authenticator) String() string {
Expand Down Expand Up @@ -104,17 +98,16 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili

type backendIOGetter func(ctx ConnContext, auth *Authenticator, resp *pnet.HandshakeResp, timeout time.Duration) (*pnet.PacketIO, error)

func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnContext, clientIO *pnet.PacketIO, handshakeHandler HandshakeHandler,
getBackendIO backendIOGetter, frontendTLSConfig, backendTLSConfig *tls.Config) error {
clientIO.ResetSequence()
auth.clientAddr = clientIO.SourceAddr().String()

proxyCapability := auth.supportedServerCapabilities
proxyCapability := handshakeHandler.GetCapability()
if frontendTLSConfig == nil {
proxyCapability ^= pnet.ClientSSL
}

if err := clientIO.WriteInitialHandshake(proxyCapability.Uint32(), auth.salt, mysql.AuthNativePassword); err != nil {
if err := clientIO.WriteInitialHandshake(proxyCapability, auth.salt, mysql.AuthNativePassword); err != nil {
return err
}
pkt, isSSL, err := clientIO.ReadSSLRequestOrHandshakeResp()
Expand Down Expand Up @@ -151,17 +144,17 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
}
auth.capability = commonCaps.Uint32()

resp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(auth, resp); err != nil {
clientResp := pnet.ParseHandshakeResponse(pkt)
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
return err
}
auth.user = resp.User
auth.dbname = resp.DB
auth.collation = resp.Collation
auth.attrs = resp.Attrs
auth.user = clientResp.User
auth.dbname = clientResp.DB
auth.collation = clientResp.Collation
auth.attrs = clientResp.Attrs

// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(auth, auth, resp, 5*time.Second)
backendIO, err := getBackendIO(cctx, auth, clientResp, 5*time.Second)
if err != nil {
return err
}
Expand All @@ -173,11 +166,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
}

// read backend initial handshake
_, backendCapabilityU, err := auth.readInitialHandshake(backendIO)
_, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}
backendCapability := pnet.Capability(backendCapabilityU)

if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
return err
Expand All @@ -195,38 +187,12 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, clientIO *pnet
logger.Info("backend does not support capabilities from proxy", zap.Stringer("common", common), zap.Stringer("proxy", proxyCapability^common), zap.Stringer("backend", backendCapability^common))
}

// Send an unknown auth plugin so that the backend will request the auth data again.
resp.AuthPlugin = unknownAuthPlugin
resp.Capability = auth.capability

if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil {
resp.Capability |= mysql.ClientSSL
pkt = pnet.MakeHandshakeResponse(resp)
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
auth.backendTLSConfig = backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
if auth.serverAddr != "" {
// NOTE: should use DNS name as much as possible
// Usually certs are signed with domain instead of IP addrs
// And `RemoteAddr()` will return IP addr
addr = auth.serverAddr
}
host, _, err := net.SplitHostPort(addr)
if err == nil {
auth.backendTLSConfig.ServerName = host
}
if err = backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
return err
}
} else {
pkt = pnet.MakeHandshakeResponse(resp)
}

// forward client handshake resp
if err := backendIO.WritePacket(pkt, true); err != nil {
if err := auth.writeAuthHandshake(
backendIO, backendTLSConfig, backendCapability,
// send an unknown auth plugin so that the backend will request the auth data again.
unknownAuthPlugin, nil, 0,
); err != nil {
return err
}

Expand Down Expand Up @@ -258,7 +224,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {
return
}

func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, sessionToken string) error {
func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error {
if len(sessionToken) == 0 {
return errors.New("session token is empty")
}
Expand All @@ -268,24 +234,26 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
return err
}

_, serverCapability, err := auth.readInitialHandshake(backendIO)
_, backendCapability, err := auth.readInitialHandshake(backendIO)
if err != nil {
return err
}

if err := auth.verifyBackendCaps(logger, pnet.Capability(serverCapability)); err != nil {
if err := auth.verifyBackendCaps(logger, pnet.Capability(backendCapability)); err != nil {
return err
}

tokenBytes := hack.Slice(sessionToken)
if err = auth.writeAuthHandshake(backendIO, tokenBytes, serverCapability&mysql.ClientSSL != 0); err != nil {
if err = auth.writeAuthHandshake(
backendIO, backendTLSConfig, backendCapability,
mysql.AuthTiDBSessionToken, hack.Slice(sessionToken), mysql.ClientPluginAuth,
); err != nil {
return err
}

return auth.handleSecondAuthResult(backendIO)
}

func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability uint32, err error) {
func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
if serverPkt, err = backendIO.ReadPacket(); err != nil {
return
}
Expand All @@ -297,32 +265,49 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
return
}

func (auth *Authenticator) writeAuthHandshake(backendIO *pnet.PacketIO, authData []byte, tls bool) error {
func (auth *Authenticator) writeAuthHandshake(
backendIO *pnet.PacketIO,
backendTLSConfig *tls.Config,
backendCapability pnet.Capability,
authPlugin string,
authData []byte,
authCap uint32,
) error {
// Always handshake with SSL enabled and enable auth_plugin.
resp := &pnet.HandshakeResp{
User: auth.user,
DB: auth.dbname,
AuthPlugin: mysql.AuthTiDBSessionToken,
Attrs: auth.attrs,
AuthData: authData,
Capability: auth.capability | mysql.ClientSSL | mysql.ClientPluginAuth,
Collation: auth.collation,
AuthData: authData,
Capability: auth.capability | authCap,
AuthPlugin: authPlugin,
}
data := pnet.MakeHandshakeResponse(resp)

if tls && auth.backendTLSConfig != nil {
// write SSL req
if err := backendIO.WritePacket(data[:32], true); err != nil {
var pkt []byte
if backendCapability&pnet.ClientSSL != 0 && backendTLSConfig != nil {
pkt = pnet.MakeHandshakeResponse(resp)
resp.Capability |= mysql.ClientSSL
// write SSL Packet
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
return err
}
// Send TLS / SSL request packet. The server must have supported TLS.
if err := backendIO.ClientTLSHandshake(auth.backendTLSConfig); err != nil {
tcfg := backendTLSConfig.Clone()
addr := backendIO.RemoteAddr().String()
host, _, err := net.SplitHostPort(addr)
if err == nil {
tcfg.ServerName = host
}
if err := backendIO.ClientTLSHandshake(tcfg); err != nil {
return err
}
} else {
pkt = pnet.MakeHandshakeResponse(resp)
}

// write handshake resp
return backendIO.WritePacket(data, true)
return backendIO.WritePacket(pkt, true)
}

func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error {
Expand Down Expand Up @@ -353,23 +338,3 @@ func (auth *Authenticator) changeUser(username, db string) {
func (auth *Authenticator) updateCurrentDB(db string) {
auth.dbname = db
}

func (auth *Authenticator) ClientAddr() string {
return auth.clientAddr
}

func (auth *Authenticator) ServerAddr() string {
return auth.serverAddr
}

func (auth *Authenticator) SetValue(key, val any) {
auth.ctxmap.Store(key, val)
}

func (auth *Authenticator) Value(key any) any {
v, ok := auth.ctxmap.Load(key)
if !ok {
return nil
}
return v
}
Loading