Skip to content

Commit

Permalink
privilege, session, server: consistently map user login to identity (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Feb 21, 2022
1 parent 7175897 commit 27ffd11
Show file tree
Hide file tree
Showing 11 changed files with 395 additions and 116 deletions.
4 changes: 2 additions & 2 deletions executor/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec
Username: dagReq.User.UserName,
Hostname: dagReq.User.UserHost,
}
authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost)
if success {
authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false)
if success && pm.GetAuthWithoutVerification(authName, authHost) {
h.sctx.GetSessionVars().User.AuthUsername = authName
h.sctx.GetSessionVars().User.AuthHostname = authHost
h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost)
Expand Down
9 changes: 7 additions & 2 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ type Manager interface {
RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool

// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool)
// Requires exact match on user name and host name.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool

// GetAuthWithoutVerification uses to get auth name without verification.
GetAuthWithoutVerification(user, host string) (string, string, bool)
// Requires exact match on user name and host name.
GetAuthWithoutVerification(user, host string) bool

// MatchIdentity matches an identity
MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool
Expand Down
48 changes: 46 additions & 2 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType {
// See https://dev.mysql.com/doc/refman/5.7/en/account-names.html
func (record *baseRecord) hostMatch(s string) bool {
if record.hostIPNet == nil {
if record.Host == "localhost" && net.ParseIP(s).IsLoopback() {
return true
}
return false
}
ip := net.ParseIP(s).To4()
Expand Down Expand Up @@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool {
return stringutil.DoMatchBytes(str, patChars, patTypes)
}

// connectionVerification verifies the connection have access to TiDB server.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
// matchIdentity finds an identity to match a user + host
// using the correct rules according to MySQL.
func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, host) {
return record
}
}

// If skip-name resolve is not enabled, and the host is not localhost
// we can fallback and try to resolve with all addrs that match.
// TODO: this is imported from previous code in session.Auth(), and can be improved in future.
if !skipNameResolve && host != variable.DefHostname {
addrs, err := net.LookupAddr(host)
if err != nil {
logutil.BgLogger().Warn(
"net.LookupAddr returned an error during auth check",
zap.String("host", host),
zap.Error(err),
)
return nil
}
for _, addr := range addrs {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, addr) {
return record
}
}
}
}
return nil
}

// connectionVerification verifies the username + hostname according to exact
// match from the mysql.user privilege table. call matchIdentity() first if you
// do not have an exact match yet.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
records, exists := p.UserMap[user]
if exists {
for i := 0; i < len(records); i++ {
record := &records[i]
if record.Host == host { // exact match
return record
}
}
}
return nil
}

Expand Down
28 changes: 18 additions & 10 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
return "", errors.New("Failed to get plugin for user")
}

// MatchIdentity implements the Manager interface.
func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) {
if SkipWithGrant {
return user, host, true
}
mysqlPriv := p.Handle.Get()
record := mysqlPriv.matchIdentity(user, host, skipNameResolve)
if record != nil {
return record.User, record.Host, true
}
return "", "", false
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string
return
}

u = record.User
h = record.Host
p.user = user
p.host = h
p.host = record.Host
success = true
return
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

u = record.User
h = record.Host

globalPriv := mysqlPriv.matchGlobalPriv(user, host)
if globalPriv != nil {
if !p.checkSSL(globalPriv, tlsState) {
Expand Down Expand Up @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
// empty password
if len(pwd) == 0 && len(authentication) == 0 {
p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down Expand Up @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down
83 changes: 56 additions & 27 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ func (cc *clientConn) String() string {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
// https://bugs.mysql.com/bug.php?id=93044
func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) {
failpoint.Inject("FakeAuthSwitch", func() {
failpoint.Return([]byte(plugin), nil)
})
enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1
data := cc.alloc.AllocWithLen(4, enclen)
data = append(data, mysql.AuthSwitchRequest) // switch request
Expand Down Expand Up @@ -708,40 +711,29 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con

func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error {
if resp.Capability&mysql.ClientPluginAuth > 0 {
newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin)
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err))
return err
}
if len(newAuth) > 0 {
resp.Auth = newAuth
}

switch resp.AuthPlugin {
case mysql.AuthCachingSha2Password:
resp.Auth, err = cc.authSha(ctx)
if err != nil {
return err
}
case mysql.AuthNativePassword:
case mysql.AuthSocket:
default:
logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin))
}
} else {
// MySQL 5.1 and older clients don't support authentication plugins.
logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client")
if cc.ctx == nil {
err := cc.openSession()
if err != nil {
return err
}
}
userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
_, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
return err
}
if userplugin != mysql.AuthNativePassword && userplugin != "" {
return errNotSupportedAuthMode
}
resp.AuthPlugin = mysql.AuthNativePassword
}
return nil
Expand Down Expand Up @@ -845,7 +837,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
}

// Check if the Authentication Plugin of the server, client and user configuration matches
func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) {
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) {
// Open a context unless this was done before.
if cc.ctx == nil {
err := cc.openSession()
Expand All @@ -854,22 +846,54 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
}
}

userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
authData := resp.Auth
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, _, err := cc.PeerHost(hasPassword)
if err != nil {
return nil, err
}
// Find the identity of the user based on username and peer host.
identity, err := cc.ctx.MatchIdentity(cc.user, host)
if err != nil {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
// Get the plugin for the identity.
userplugin, err := cc.ctx.AuthPluginForUser(identity)
if err != nil {
logutil.Logger(ctx).Warn("Failed to get authentication method for user",
zap.String("user", cc.user), zap.String("host", host))
}
failpoint.Inject("FakeUser", func(val failpoint.Value) {
userplugin = val.(string)
})
if userplugin == mysql.AuthSocket {
*authPlugin = mysql.AuthSocket
if !cc.isUnixSocket {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
resp.AuthPlugin = mysql.AuthSocket
user, err := user.LookupId(fmt.Sprint(cc.socketCredUID))
if err != nil {
return nil, err
}
return []byte(user.Username), nil
}
if len(userplugin) == 0 {
logutil.Logger(ctx).Warn("No user plugin set, assuming MySQL Native Password",
zap.String("user", cc.user), zap.String("host", cc.peerHost))
*authPlugin = mysql.AuthNativePassword
// No user plugin set, assuming MySQL Native Password
// This happens if the account doesn't exist or if the account doesn't have
// a password set.
if resp.AuthPlugin != mysql.AuthNativePassword {
if resp.Capability&mysql.ClientPluginAuth > 0 {
resp.AuthPlugin = mysql.AuthNativePassword
authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword)
if err != nil {
return nil, err
}
return authData, nil
}
}
return nil, nil
}

Expand All @@ -878,13 +902,18 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
// or if the authentication method send by the server doesn't match the authentication
// method send by the client (*authPlugin) then we need to switch the authentication
// method to match the one configured for that specific user.
if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) {
if resp.Capability&mysql.ClientPluginAuth > 0 {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
}
resp.AuthPlugin = userplugin
return authData, nil
} else if userplugin != mysql.AuthNativePassword {
// MySQL 5.1 and older don't support authentication plugins yet
return nil, errNotSupportedAuthMode
}
*authPlugin = userplugin
return authData, nil
}

return nil, nil
Expand Down
Loading

0 comments on commit 27ffd11

Please sign in to comment.