diff --git a/executor/coprocessor.go b/executor/coprocessor.go index 6eb438d5aaeb5..3811475fa9212 100644 --- a/executor/coprocessor.go +++ b/executor/coprocessor.go @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec Username: dagReq.User.UserName, Hostname: dagReq.User.UserHost, } - authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost) - if success { + authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false) + if success && pm.GetAuthWithoutVerification(authName, authHost) { h.sctx.GetSessionVars().User.AuthUsername = authName h.sctx.GetSessionVars().User.AuthHostname = authHost h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost) diff --git a/privilege/privilege.go b/privilege/privilege.go index e0b9d41f41b1d..af5ff9924ffe9 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -59,10 +59,15 @@ type Manager interface { RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) + // Requires exact match on user name and host name. + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool // GetAuthWithoutVerification uses to get auth name without verification. - GetAuthWithoutVerification(user, host string) (string, string, bool) + // Requires exact match on user name and host name. + GetAuthWithoutVerification(user, host string) bool + + // MatchIdentity matches an identity + MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index dc90170a500ef..4d388e073b205 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" @@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { // See https://dev.mysql.com/doc/refman/5.7/en/account-names.html func (record *baseRecord) hostMatch(s string) bool { if record.hostIPNet == nil { + if record.Host == "localhost" && net.ParseIP(s).IsLoopback() { + return true + } return false } ip := net.ParseIP(s).To4() @@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool { return stringutil.DoMatchBytes(str, patChars, patTypes) } -// connectionVerification verifies the connection have access to TiDB server. -func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { +// matchIdentity finds an identity to match a user + host +// using the correct rules according to MySQL. +func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord { for i := 0; i < len(p.User); i++ { record := &p.User[i] if record.match(user, host) { return record } } + + // If skip-name resolve is not enabled, and the host is not localhost + // we can fallback and try to resolve with all addrs that match. + // TODO: this is imported from previous code in session.Auth(), and can be improved in future. + if !skipNameResolve && host != variable.DefHostname { + addrs, err := net.LookupAddr(host) + if err != nil { + logutil.BgLogger().Warn( + "net.LookupAddr returned an error during auth check", + zap.String("host", host), + zap.Error(err), + ) + return nil + } + for _, addr := range addrs { + for i := 0; i < len(p.User); i++ { + record := &p.User[i] + if record.match(user, addr) { + return record + } + } + } + } + return nil +} + +// connectionVerification verifies the username + hostname according to exact +// match from the mysql.user privilege table. call matchIdentity() first if you +// do not have an exact match yet. +func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { + records, exists := p.UserMap[user] + if exists { + for i := 0; i < len(records); i++ { + record := &records[i] + if record.Host == host { // exact match + return record + } + } + } return nil } diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 104c2c3782387..fea22acef641c 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { return "", errors.New("Failed to get plugin for user") } +// MatchIdentity implements the Manager interface. +func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) { + if SkipWithGrant { + return user, host, true + } + mysqlPriv := p.Handle.Get() + record := mysqlPriv.matchIdentity(user, host, skipNameResolve) + if record != nil { + return record.User, record.Host, true + } + return "", "", false +} + // GetAuthWithoutVerification implements the Manager interface. -func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) { +func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string return } - u = record.User - h = record.Host p.user = user - p.host = h + p.host = record.Host success = true return } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } - u = record.User - h = record.Host - globalPriv := mysqlPriv.matchGlobalPriv(user, host) if globalPriv != nil { if !p.checkSSL(globalPriv, tlsState) { @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio // empty password if len(pwd) == 0 && len(authentication) == 0 { p.user = user - p.host = h + p.host = record.Host success = true return } @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio } p.user = user - p.host = h + p.host = record.Host success = true return } diff --git a/server/conn.go b/server/conn.go index 49a42fe54bb94..26e80dc6b8656 100644 --- a/server/conn.go +++ b/server/conn.go @@ -711,6 +711,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin) if err != nil { logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + return err } if len(newAuth) > 0 { resp.Auth = newAuth @@ -858,8 +859,30 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( if err != nil { return nil, err } +<<<<<<< HEAD if userplugin == mysql.AuthSocket { *authPlugin = mysql.AuthSocket +======= + // Find the identity of the user based on username and peer host. + identity, err := cc.ctx.MatchIdentity(cc.user, host) + if err != nil { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + // Get the plugin for the identity. + userplugin, err := cc.ctx.AuthPluginForUser(identity) + if err != nil { + logutil.Logger(ctx).Warn("Failed to get authentication method for user", + zap.String("user", cc.user), zap.String("host", host)) + } + failpoint.Inject("FakeUser", func(val failpoint.Value) { + userplugin = val.(string) + }) + if userplugin == mysql.AuthSocket { + if !cc.isUnixSocket { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + resp.AuthPlugin = mysql.AuthSocket +>>>>>>> 7fc6ebbda... privilege, session, server: consistently map user login to identity (#30204) user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) if err != nil { return nil, err diff --git a/server/conn_test.go b/server/conn_test.go index dc50900e41624..ae63ca72e94a1 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -906,19 +906,35 @@ func TestHandleAuthPlugin(t *testing.T) { srv, err := NewServer(cfg, drv) require.NoError(t, err) +<<<<<<< HEAD +======= + tk := testkit.NewTestKit(t, store) + tk.MustExec("CREATE USER unativepassword") + defer func() { + tk.MustExec("DROP USER unativepassword") + }() + + // 5.7 or newer client trying to authenticate with mysql_native_password +>>>>>>> 7fc6ebbda... privilege, session, server: consistently map user login to identity (#30204) cc := &clientConn{ connectionID: 1, alloc: arena.NewAllocator(1024), pkt: &packetIO{ bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), }, +<<<<<<< HEAD collation: mysql.DefaultCollationID, server: srv, user: "root", +======= + server: srv, + user: "unativepassword", +>>>>>>> 7fc6ebbda... privilege, session, server: consistently map user login to identity (#30204) } ctx := context.Background() resp := handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, +<<<<<<< HEAD } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) @@ -926,4 +942,237 @@ func TestHandleAuthPlugin(t *testing.T) { resp.Capability = mysql.ClientProtocol41 err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) +======= + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword)) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // === Target account has mysql_native_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"mysql_native_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) + + // === Target account has caching_sha2_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"caching_sha2_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.Error(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) +} + +func TestAuthPlugin2(t *testing.T) { + + t.Parallel() + + store, clean := testkit.CreateMockStore(t) + defer clean() + + cfg := newTestConfig() + cfg.Socket = "" + cfg.Port = 0 + cfg.Status.StatusPort = 0 + + drv := NewTiDBDriver(store) + srv, err := NewServer(cfg, drv) + require.NoError(t, err) + + cc := &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + chunkAlloc: chunk.NewAllocator(), + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "root", + } + ctx := context.Background() + se, _ := session.CreateSession4Test(store) + tc := &TiDBContext{ + Session: se, + stmts: make(map[int]*TiDBStatement), + } + cc.ctx = tc + + resp := handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + } + + cc.isUnixSocket = true + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + respAuthSwitch, err := cc.checkAuthPlugin(ctx, &resp) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + require.Equal(t, respAuthSwitch, []byte(mysql.AuthNativePassword)) + require.NoError(t, err) + +>>>>>>> 7fc6ebbda... privilege, session, server: consistently map user login to identity (#30204) } diff --git a/session/session.go b/session/session.go index aa08d554c9f13..913faa3591aa9 100644 --- a/session/session.go +++ b/session/session.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "encoding/json" "fmt" - "net" "runtime/pprof" "runtime/trace" "strconv" @@ -146,6 +145,7 @@ type Session interface { Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool AuthWithoutVerification(user *auth.UserIdentity) bool AuthPluginForUser(user *auth.UserIdentity) (string, error) + MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) ShowProcess() *util.ProcessInfo // Return the information of the txn current running TxnInfo() *txninfo.TxnInfo @@ -2211,91 +2211,61 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { return authplugin, nil } +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.ConnectionVerification(authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false } + return false +} - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var skipNameResolve bool + var user = &auth.UserIdentity{} + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true } - return false + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) + if success { + return user, nil + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) } // AuthWithoutVerification is required by the ResetConnection RPC func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false - } - - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.GetAuthWithoutVerification(user.Username, addr) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } } return false } -func (s *session) getHostByIP(ip string) []string { - if ip == "127.0.0.1" { - return []string{variable.DefHostname} - } - skipNameResolve, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) - if err == nil && variable.TiDBOptOn(skipNameResolve) { - return []string{ip} // user wants to skip name resolution - } - addrs, err := net.LookupAddr(ip) - if err != nil { - // These messages can be noisy. - // See: https://github.com/pingcap/tidb/pull/13989 - logutil.BgLogger().Debug( - "net.LookupAddr returned an error during auth check", - zap.String("ip", ip), - zap.Error(err), - ) - return []string{ip} - } - return addrs -} - // RefreshVars implements the sessionctx.Context interface. func (s *session) RefreshVars(ctx context.Context) error { pruneMode, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBPartitionPruneMode) diff --git a/session/session_test.go b/session/session_test.go index b2b386fefbd61..b4aadd0306bc1 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -18,6 +18,7 @@ import ( "context" "flag" "fmt" + "net" "os" "path" "runtime" @@ -691,6 +692,50 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(terror.ErrorEqual(err, variable.ErrUnknownTimeZone), IsTrue) } +func (s *testSessionSuite) TestMatchIdentity(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("CREATE USER `useridentity`@`%`") + tk.MustExec("CREATE USER `useridentity`@`localhost`") + tk.MustExec("CREATE USER `useridentity`@`192.168.1.1`") + tk.MustExec("CREATE USER `useridentity`@`example.com`") + + // The MySQL matching rule is most specific to least specific. + // So if I log in from 192.168.1.1 I should match that entry always. + identity, err := tk.Se.MatchIdentity("useridentity", "192.168.1.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "192.168.1.1") + + // If I log in from localhost, I should match localhost + identity, err = tk.Se.MatchIdentity("useridentity", "localhost") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // If I log in from 192.168.1.2 I should match wildcard. + identity, err = tk.Se.MatchIdentity("useridentity", "192.168.1.2") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "%") + + identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // This uses the lookup of example.com to get an IP address. + // We then login with that IP address, but expect it to match the example.com + // entry in the privileges table (by reverse lookup). + ips, err := net.LookupHost("example.com") + c.Assert(err, IsNil) + identity, err = tk.Se.MatchIdentity("useridentity", ips[0]) + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + // FIXME: we *should* match example.com instead + // as long as skip-name-resolve is not set (DEFAULT) + c.Assert(identity.Hostname, Equals, "%") +} + func (s *testSessionSuite) TestGetSysVariables(c *C) { tk := testkit.NewTestKitWithInit(c, s.store)