Skip to content

Commit

Permalink
backend: refine the error message when require TLS (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Sep 5, 2023
1 parent 53a957f commit be8f39b
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 33 deletions.
11 changes: 4 additions & 7 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

var (
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
ErrTLSConfigRequired = errors.New("require TLS config on TiProxy when require-backend-tls=true")
)

const unknownAuthPlugin = "auth_unknown_plugin"
Expand Down Expand Up @@ -74,17 +73,15 @@ func (auth *Authenticator) writeProxyProtocol(clientIO, backendIO *pnet.PacketIO

func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapability pnet.Capability) error {
requiredBackendCaps := defRequiredBackendCaps & auth.capability
if auth.requireBackendTLS {
requiredBackendCaps |= pnet.ClientSSL
}

if commonCaps := backendCapability & requiredBackendCaps; commonCaps != requiredBackendCaps {
// The error cannot be sent to the client because the client only expects an initial handshake packet.
// The only way is to log it and disconnect.
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
}

if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) {
return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg)
}
return nil
}

Expand Down Expand Up @@ -315,7 +312,7 @@ func (auth *Authenticator) writeAuthHandshake(
var enableTLS bool
if auth.requireBackendTLS {
if backendTLSConfig == nil {
return ErrTLSConfigRequired
return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg)
}
enableTLS = true
} else {
Expand Down
50 changes: 49 additions & 1 deletion pkg/proxy/backend/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -220,7 +221,7 @@ func TestCustomAuth(t *testing.T) {
require.Equal(t, ts.mc.username, inUser)
require.Equal(t, reUser, ts.mb.username)
require.Equal(t, reAttrs, ts.mb.attrs)
require.Equal(t, reCap&pnet.ClientDeprecateEOF, pnet.Capability(ts.mb.capability)&pnet.ClientDeprecateEOF)
require.Equal(t, reCap&pnet.ClientDeprecateEOF, ts.mb.capability&pnet.ClientDeprecateEOF)
}
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {})
checker()
Expand Down Expand Up @@ -290,3 +291,50 @@ func TestAuthFail(t *testing.T) {
clean()
}
}

func TestRequireBackendTLS(t *testing.T) {
tests := []struct {
cfg cfgOverrider
errMsg string
}{
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.capability |= pnet.ClientSSL
},
errMsg: requireProxyTLSErrMsg,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
errMsg: requireTiDBTLSErrMsg,
},
{
cfg: func(cfg *testConfig) {
cfg.proxyConfig.bcConfig.RequireBackendTLS = false
cfg.proxyConfig.backendTLSConfig = nil
cfg.backendConfig.tlsConfig = nil
cfg.backendConfig.capability &= ^pnet.ClientSSL
},
},
}

tc := newTCPConnSuite(t)
for _, tt := range tests {
ts, clean := newTestSuite(t, tc, tt.cfg)
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
if len(tt.errMsg) > 0 {
var userError *pnet.UserError
require.True(t, errors.As(ts.mp.err, &userError))
require.Equal(t, tt.errMsg, userError.UserMsg())
} else {
require.NoError(t, ts.mp.err)
}
})
clean()
}
}
6 changes: 3 additions & 3 deletions pkg/proxy/backend/backend_conn_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ func TestOnTraffic(t *testing.T) {
0xce,
}
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
config.proxyConfig.handler.onTraffic = func(cc ConnContext) {
require.Equal(t, uint64(inbytes[i]), cc.ClientInBytes())
require.Equal(t, uint64(outbytes[i]), cc.ClientOutBytes())
Expand Down Expand Up @@ -873,7 +873,7 @@ func TestGetBackendIO(t *testing.T) {

func TestBackendInactive(t *testing.T) {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
})
runners := []runner{
// 1st handshake
Expand Down Expand Up @@ -957,7 +957,7 @@ func TestBackendInactive(t *testing.T) {

func TestKeepAlive(t *testing.T) {
ts := newBackendMgrTester(t, func(config *testConfig) {
config.proxyConfig.checkBackendInterval = 10 * time.Millisecond
config.proxyConfig.bcConfig.CheckBackendInterval = 10 * time.Millisecond
})
runners := []runner{
{
Expand Down
10 changes: 6 additions & 4 deletions pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (
)

const (
connectErrMsg = "No available TiDB instances, please check TiDB cluster"
parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP"
handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network"
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
connectErrMsg = "No available TiDB instances, please check TiDB cluster"
parsePktErrMsg = "TiProxy fails to parse the packet, please contact PingCAP"
handshakeErrMsg = "TiProxy fails to connect to TiDB, please check network"
capabilityErrMsg = "Verify TiDB capability failed, please upgrade TiDB"
requireProxyTLSErrMsg = "Require TLS config on TiProxy when require-backend-tls=true"
requireTiDBTLSErrMsg = "Require TLS config on TiDB when require-backend-tls=true"
)

var (
Expand Down
2 changes: 1 addition & 1 deletion pkg/proxy/backend/mock_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (mb *mockBackend) authenticate(packetIO *pnet.PacketIO) error {
mb.db = resp.DB
mb.authData = resp.AuthData
mb.attrs = resp.Attrs
mb.capability = pnet.Capability(resp.Capability)
mb.capability = resp.Capability
// verify password
return mb.verifyPassword(packetIO, resp)
}
Expand Down
32 changes: 15 additions & 17 deletions pkg/proxy/backend/mock_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ import (
)

type proxyConfig struct {
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
checkBackendInterval time.Duration
sessionToken string
capability pnet.Capability
waitRedirect bool
connectionID uint64
frontendTLSConfig *tls.Config
backendTLSConfig *tls.Config
handler *CustomHandshakeHandler
bcConfig *BCConfig
sessionToken string
capability pnet.Capability
waitRedirect bool
connectionID uint64
}

func newProxyConfig() *proxyConfig {
return &proxyConfig{
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
checkBackendInterval: CheckBackendInterval,
handler: &CustomHandshakeHandler{},
capability: defaultTestBackendCapability,
sessionToken: mockToken,
bcConfig: &BCConfig{},
}
}

Expand All @@ -49,11 +49,9 @@ type mockProxy struct {
func newMockProxy(t *testing.T, cfg *proxyConfig) *mockProxy {
lg, _ := logger.CreateLoggerForTest(t)
mp := &mockProxy{
proxyConfig: cfg,
logger: lg.Named("mockProxy"),
BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, &BCConfig{
CheckBackendInterval: cfg.checkBackendInterval,
}),
proxyConfig: cfg,
logger: lg.Named("mockProxy"),
BackendConnManager: NewBackendConnManager(lg, cfg.handler, cfg.connectionID, cfg.bcConfig),
}
mp.cmdProcessor.capability = cfg.capability
return mp
Expand Down

0 comments on commit be8f39b

Please sign in to comment.