diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 774c8493..bb354af0 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -27,50 +27,50 @@ func TestUnsupportedCapability(t *testing.T) { cfgs := [][]cfgOverrider{ { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientSSL + cfg.clientConfig.capability &= ^pnet.ClientSSL }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientSSL + cfg.clientConfig.capability |= pnet.ClientSSL }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability & ^pnet.ClientSSL + cfg.backendConfig.capability &= ^pnet.ClientSSL }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability | pnet.ClientSSL + cfg.backendConfig.capability |= pnet.ClientSSL }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability & ^pnet.ClientDeprecateEOF + cfg.backendConfig.capability &= ^pnet.ClientDeprecateEOF }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestBackendCapability | pnet.ClientDeprecateEOF + cfg.backendConfig.capability |= pnet.ClientDeprecateEOF }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientProtocol41 + cfg.clientConfig.capability &= ^pnet.ClientProtocol41 }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientProtocol41 + cfg.clientConfig.capability |= pnet.ClientProtocol41 }, }, { func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestClientCapability & ^pnet.ClientPSMultiResults + cfg.backendConfig.capability &= ^pnet.ClientPSMultiResults }, func(cfg *testConfig) { - cfg.backendConfig.capability = defaultTestClientCapability | pnet.ClientPSMultiResults + cfg.backendConfig.capability |= pnet.ClientPSMultiResults }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientPSMultiResults + cfg.clientConfig.capability &= ^pnet.ClientPSMultiResults }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientPSMultiResults + cfg.clientConfig.capability |= pnet.ClientPSMultiResults }, }, } @@ -151,27 +151,27 @@ func TestCapability(t *testing.T) { cfgs := [][]cfgOverrider{ { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientConnectWithDB + cfg.clientConfig.capability &= ^pnet.ClientConnectWithDB }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientConnectWithDB + cfg.clientConfig.capability |= pnet.ClientConnectWithDB }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientConnectAttrs + cfg.clientConfig.capability &= ^pnet.ClientConnectAttrs }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientConnectAttrs + cfg.clientConfig.capability |= pnet.ClientConnectAttrs cfg.clientConfig.attrs = map[string]string{"key": "value"} }, }, { func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability & ^pnet.ClientSecureConnection + cfg.clientConfig.capability &= ^pnet.ClientSecureConnection }, func(cfg *testConfig) { - cfg.clientConfig.capability = defaultTestClientCapability | pnet.ClientSecureConnection + cfg.clientConfig.capability |= pnet.ClientSecureConnection }, }, } diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index c55528a9..477ad15a 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -77,7 +77,10 @@ func getCfgCombinations(cfgs [][]cfgOverrider) [][]cfgOverrider { // Append the cfg to each of the existing overrider list. for _, cfg := range cfgList { for _, o := range cfgOverriders { - newOverriders = append(newOverriders, append(o, cfg)) + newOverrider := make([]cfgOverrider, 0, len(o)+1) + newOverrider = append(newOverrider, o...) + newOverrider = append(newOverrider, cfg) + newOverriders = append(newOverriders, newOverrider) } } cfgOverriders = newOverriders @@ -187,9 +190,13 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { // Check the data received by client equals to the data sent from the server and vice versa. require.Equal(t, ts.mb.authSucceed, ts.mc.authSucceed) require.Equal(t, ts.mc.username, ts.mb.username) - require.Equal(t, ts.mc.dbName, ts.mb.db) + if ts.mc.capability&pnet.ClientConnectWithDB > 0 { + require.Equal(t, ts.mc.dbName, ts.mb.db) + } require.Equal(t, ts.mc.authData, ts.mb.authData) - require.Equal(t, ts.mc.attrs, ts.mb.attrs) + if ts.mc.capability&pnet.ClientConnectAttrs > 0 { + require.Equal(t, ts.mc.attrs, ts.mb.attrs) + } } }