From 34a6ec9a0bb9c4ead4a318e88d35d485278a156e Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Fri, 26 Nov 2021 14:58:46 -0700 Subject: [PATCH 1/4] server, session: fix socket auth --- server/conn.go | 12 +++++++++++- server/conn_test.go | 7 +++++-- session/session.go | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/server/conn.go b/server/conn.go index f74c00a5550cc..e703cc0a4e205 100644 --- a/server/conn.go +++ b/server/conn.go @@ -712,6 +712,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo newAuth, err := cc.checkAuthPlugin(ctx, resp) if err != nil { logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + return err } if len(newAuth) > 0 { resp.Auth = newAuth @@ -850,11 +851,20 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon if err != nil { return nil, err } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: host}) + // Find the identity of the user based on username and peer host. + identity, err := cc.ctx.MatchIdentity(cc.user, host) + if err != nil { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + // Get the plugin for the identity. + userplugin, err := cc.ctx.AuthPluginForUser(identity) if err != nil { return nil, err } if userplugin == mysql.AuthSocket { + if !cc.isUnixSocket { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } resp.AuthPlugin = mysql.AuthSocket user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) if err != nil { diff --git a/server/conn_test.go b/server/conn_test.go index 34ec75124a828..7bdf588f0af73 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -930,14 +930,17 @@ func TestHandleAuthPlugin(t *testing.T) { pkt: &packetIO{ bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), }, - server: srv, + server: srv, + collation: mysql.DefaultCollationID, + peerHost: "127.0.0.1", } ctx := context.Background() resp := handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, } err = cc.handleAuthPlugin(ctx, &resp) - require.NoError(t, err) + require.Error(t, err) // Does not have username or password + require.Equal(t, errAccessDenied.FastGenByArgs("", "127.0.0.1", "NO"), err) resp.Capability = mysql.ClientProtocol41 err = cc.handleAuthPlugin(ctx, &resp) diff --git a/session/session.go b/session/session.go index 2228431007048..ab44f7144cd49 100644 --- a/session/session.go +++ b/session/session.go @@ -146,6 +146,7 @@ type Session interface { Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool AuthWithoutVerification(user *auth.UserIdentity) bool AuthPluginForUser(user *auth.UserIdentity) (string, error) + MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) ShowProcess() *util.ProcessInfo // Return the information of the txn current running TxnInfo() *txninfo.TxnInfo @@ -2254,6 +2255,27 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by return false } +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var user = &auth.UserIdentity{} + user.Username, user.Hostname, success = pm.GetAuthWithoutVerification(username, remoteHost) + if success { + return user, nil + } + // Check Hosts + for _, addr := range s.getHostByIP(remoteHost) { + user.Username, user.Hostname, success = pm.GetAuthWithoutVerification(username, addr) + if success { + return user, nil + } + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) +} + // AuthWithoutVerification is required by the ResetConnection RPC func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { pm := privilege.GetPrivilegeManager(s) From a72b73cd7abe280c687a70ba915a14bf6735b165 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Fri, 26 Nov 2021 15:21:38 -0700 Subject: [PATCH 2/4] Add small refactor --- session/session.go | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/session/session.go b/session/session.go index ab44f7144cd49..42a55e9d2f44e 100644 --- a/session/session.go +++ b/session/session.go @@ -2224,6 +2224,9 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { return authplugin, nil } +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) @@ -2279,33 +2282,15 @@ func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity // AuthWithoutVerification is required by the ResetConnection RPC func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname) - if success { - s.sessionVars.User = user - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) - return true - } else if user.Hostname == variable.DefHostname { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { return false } - - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.GetAuthWithoutVerification(user.Username, addr) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } - } - return false + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return true } func (s *session) getHostByIP(ip string) []string { From 8ce2c8013e021e9f4554026b7bfe59e7e8ca39b1 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Fri, 26 Nov 2021 19:50:47 -0700 Subject: [PATCH 3/4] Add test for MatchIdentity --- session/session_test.go | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/session/session_test.go b/session/session_test.go index 3057ef6ae9ab4..7ead8c1ed926e 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -18,6 +18,7 @@ import ( "context" "flag" "fmt" + "net" "os" "path" "runtime" @@ -700,6 +701,51 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(v, Equals, "OFF") } +func (s *testSessionSuite) TestMatchIdentity(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("CREATE USER `useridentity`@`%`") + tk.MustExec("CREATE USER `useridentity`@`localhost`") + tk.MustExec("CREATE USER `useridentity`@`192.168.1.1`") + tk.MustExec("CREATE USER `useridentity`@`example.com`") + + // The MySQL matching rule is most specific to least specific. + // So if I log in from 192.168.1.1 I should match that entry always. + identity, err := tk.Se.MatchIdentity("useridentity", "192.168.1.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "192.168.1.1") + + // If I log in from localhost, I should match localhost + identity, err = tk.Se.MatchIdentity("useridentity", "localhost") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // If I log in from 192.168.1.2 I should match wildcard. + identity, err = tk.Se.MatchIdentity("useridentity", "192.168.1.2") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "%") + + identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + // FIXME: we *should* match localhost instead + c.Assert(identity.Hostname, Equals, "%") + + // This uses the lookup of example.com to get an IP address. + // We then login with that IP address, but expect it to match the example.com + // entry in the privileges table (by reverse lookup). + ips, err := net.LookupHost("example.com") + c.Assert(err, IsNil) + identity, err = tk.Se.MatchIdentity("useridentity", ips[0]) + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + // FIXME: we *should* match example.com instead + // as long as skip-name-resolve is not set (DEFAULT) + c.Assert(identity.Hostname, Equals, "%") +} + func (s *testSessionSuite) TestGetSysVariables(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) From 454a6f52286f8e06a156aa5318f37dc2aa449d7c Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Fri, 26 Nov 2021 22:07:29 -0700 Subject: [PATCH 4/4] Refactor code to improve MatchIdentity consistency --- executor/coprocessor.go | 4 +- privilege/privilege.go | 9 +++- privilege/privileges/cache.go | 48 ++++++++++++++++++- privilege/privileges/privileges.go | 28 +++++++---- session/session.go | 77 ++++++++---------------------- session/session_test.go | 3 +- 6 files changed, 94 insertions(+), 75 deletions(-) diff --git a/executor/coprocessor.go b/executor/coprocessor.go index 4d8595640d6fd..8970629c9f4b8 100644 --- a/executor/coprocessor.go +++ b/executor/coprocessor.go @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec Username: dagReq.User.UserName, Hostname: dagReq.User.UserHost, } - authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost) - if success { + authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false) + if success && pm.GetAuthWithoutVerification(authName, authHost) { h.sctx.GetSessionVars().User.AuthUsername = authName h.sctx.GetSessionVars().User.AuthHostname = authHost h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost) diff --git a/privilege/privilege.go b/privilege/privilege.go index e0b9d41f41b1d..af5ff9924ffe9 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -59,10 +59,15 @@ type Manager interface { RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) + // Requires exact match on user name and host name. + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool // GetAuthWithoutVerification uses to get auth name without verification. - GetAuthWithoutVerification(user, host string) (string, string, bool) + // Requires exact match on user name and host name. + GetAuthWithoutVerification(user, host string) bool + + // MatchIdentity matches an identity + MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index c306443457caa..26e06ec85769b 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" @@ -860,6 +861,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { // See https://dev.mysql.com/doc/refman/5.7/en/account-names.html func (record *baseRecord) hostMatch(s string) bool { if record.hostIPNet == nil { + if record.Host == "localhost" && net.ParseIP(s).IsLoopback() { + return true + } return false } ip := net.ParseIP(s).To4() @@ -902,14 +906,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool { return stringutil.DoMatchBytes(str, patChars, patTypes) } -// connectionVerification verifies the connection have access to TiDB server. -func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { +// matchIdentity finds an identity to match a user + host +// using the correct rules according to MySQL. +func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord { for i := 0; i < len(p.User); i++ { record := &p.User[i] if record.match(user, host) { return record } } + + // If skip-name resolve is not enabled, and the host is not localhost + // we can fallback and try to resolve with all addrs that match. + // TODO: this is imported from previous code in session.Auth(), and can be improved in future. + if !skipNameResolve && host != variable.DefHostname { + addrs, err := net.LookupAddr(host) + if err != nil { + logutil.BgLogger().Warn( + "net.LookupAddr returned an error during auth check", + zap.String("host", host), + zap.Error(err), + ) + return nil + } + for _, addr := range addrs { + for i := 0; i < len(p.User); i++ { + record := &p.User[i] + if record.match(user, addr) { + return record + } + } + } + } + return nil +} + +// connectionVerification verifies the username + hostname according to exact +// match from the mysql.user privilege table. call matchIdentity() first if you +// do not have an exact match yet. +func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { + records, exists := p.UserMap[user] + if exists { + for i := 0; i < len(records); i++ { + record := &records[i] + if record.Host == host { // exact match + return record + } + } + } return nil } diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index ea7059d72804e..7b499bdd64100 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { return "", errors.New("Failed to get plugin for user") } +// MatchIdentity implements the Manager interface. +func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) { + if SkipWithGrant { + return user, host, true + } + mysqlPriv := p.Handle.Get() + record := mysqlPriv.matchIdentity(user, host, skipNameResolve) + if record != nil { + return record.User, record.Host, true + } + return "", "", false +} + // GetAuthWithoutVerification implements the Manager interface. -func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) { +func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string return } - u = record.User - h = record.Host p.user = user - p.host = h + p.host = record.Host success = true return } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } - u = record.User - h = record.Host - globalPriv := mysqlPriv.matchGlobalPriv(user, host) if globalPriv != nil { if !p.checkSSL(globalPriv, tlsState) { @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio // empty password if len(pwd) == 0 && len(authentication) == 0 { p.user = user - p.host = h + p.host = record.Host success = true return } @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio } p.user = user - p.host = h + p.host = record.Host success = true return } diff --git a/session/session.go b/session/session.go index 42a55e9d2f44e..2f7ea477c13f7 100644 --- a/session/session.go +++ b/session/session.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "encoding/json" "fmt" - "net" "runtime/pprof" "runtime/trace" "strconv" @@ -2229,31 +2228,16 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { // This means it can not be refactored to use MatchIdentity yet. func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.ConnectionVerification(authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false - } - - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } } return false } @@ -2263,18 +2247,16 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { pm := privilege.GetPrivilegeManager(s) var success bool + var skipNameResolve bool var user = &auth.UserIdentity{} - user.Username, user.Hostname, success = pm.GetAuthWithoutVerification(username, remoteHost) + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true + } + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) if success { return user, nil } - // Check Hosts - for _, addr := range s.getHostByIP(remoteHost) { - user.Username, user.Hostname, success = pm.GetAuthWithoutVerification(username, addr) - if success { - return user, nil - } - } // This error will not be returned to the user, access denied will be instead return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) } @@ -2286,33 +2268,14 @@ func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { if err != nil { return false } - user.AuthUsername = authUser.Username - user.AuthHostname = authUser.Hostname - s.sessionVars.User = user - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) - return true -} - -func (s *session) getHostByIP(ip string) []string { - if ip == "127.0.0.1" { - return []string{variable.DefHostname} - } - skipNameResolve, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) - if err == nil && variable.TiDBOptOn(skipNameResolve) { - return []string{ip} // user wants to skip name resolution - } - addrs, err := net.LookupAddr(ip) - if err != nil { - // These messages can be noisy. - // See: https://github.com/pingcap/tidb/pull/13989 - logutil.BgLogger().Debug( - "net.LookupAddr returned an error during auth check", - zap.String("ip", ip), - zap.Error(err), - ) - return []string{ip} + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return true } - return addrs + return false } // RefreshVars implements the sessionctx.Context interface. diff --git a/session/session_test.go b/session/session_test.go index 7ead8c1ed926e..f8897e22f8e2a 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -730,8 +730,7 @@ func (s *testSessionSuite) TestMatchIdentity(c *C) { identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1") c.Assert(err, IsNil) c.Assert(identity.Username, Equals, "useridentity") - // FIXME: we *should* match localhost instead - c.Assert(identity.Hostname, Equals, "%") + c.Assert(identity.Hostname, Equals, "localhost") // This uses the lookup of example.com to get an IP address. // We then login with that IP address, but expect it to match the example.com