Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege, session, server: consistently map user login to identity #30204

Merged
merged 7 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions executor/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec
Username: dagReq.User.UserName,
Hostname: dagReq.User.UserHost,
}
authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost)
if success {
authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false)
if success && pm.GetAuthWithoutVerification(authName, authHost) {
h.sctx.GetSessionVars().User.AuthUsername = authName
h.sctx.GetSessionVars().User.AuthHostname = authHost
h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost)
Expand Down
9 changes: 7 additions & 2 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ type Manager interface {
RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool

// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool)
// Requires exact match on user name and host name.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool

// GetAuthWithoutVerification uses to get auth name without verification.
GetAuthWithoutVerification(user, host string) (string, string, bool)
// Requires exact match on user name and host name.
GetAuthWithoutVerification(user, host string) bool

// MatchIdentity matches an identity
MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool
Expand Down
48 changes: 46 additions & 2 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -860,6 +861,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType {
// See https://dev.mysql.com/doc/refman/5.7/en/account-names.html
func (record *baseRecord) hostMatch(s string) bool {
if record.hostIPNet == nil {
if record.Host == "localhost" && net.ParseIP(s).IsLoopback() {
return true
}
return false
}
ip := net.ParseIP(s).To4()
Expand Down Expand Up @@ -902,14 +906,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool {
return stringutil.DoMatchBytes(str, patChars, patTypes)
}

// connectionVerification verifies the connection have access to TiDB server.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
// matchIdentity finds an identity to match a user + host
// using the correct rules according to MySQL.
func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, host) {
return record
}
}

// If skip-name resolve is not enabled, and the host is not localhost
// we can fallback and try to resolve with all addrs that match.
// TODO: this is imported from previous code in session.Auth(), and can be improved in future.
if !skipNameResolve && host != variable.DefHostname {
addrs, err := net.LookupAddr(host)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code branch may have some performance impact if matchIdentity is frequently called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's just used in connection verification, the performance impact is OK
But if it's used in query privilege verification, I'm afraid the impact is not negligible

if err != nil {
logutil.BgLogger().Warn(
"net.LookupAddr returned an error during auth check",
zap.String("host", host),
zap.Error(err),
)
return nil
}
for _, addr := range addrs {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, addr) {
return record
}
}
}
}
return nil
}

// connectionVerification verifies the username + hostname according to exact
// match from the mysql.user privilege table. call matchIdentity() first if you
// do not have an exact match yet.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
records, exists := p.UserMap[user]
if exists {
for i := 0; i < len(records); i++ {
record := &records[i]
if record.Host == host { // exact match
return record
}
}
}
return nil
}

Expand Down
28 changes: 18 additions & 10 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
return "", errors.New("Failed to get plugin for user")
}

// MatchIdentity implements the Manager interface.
func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) {
if SkipWithGrant {
return user, host, true
}
mysqlPriv := p.Handle.Get()
record := mysqlPriv.matchIdentity(user, host, skipNameResolve)
if record != nil {
return record.User, record.Host, true
}
return "", "", false
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string
return
}

u = record.User
h = record.Host
p.user = user
p.host = h
p.host = record.Host
success = true
return
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

u = record.User
h = record.Host

globalPriv := mysqlPriv.matchGlobalPriv(user, host)
if globalPriv != nil {
if !p.checkSSL(globalPriv, tlsState) {
Expand Down Expand Up @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
// empty password
if len(pwd) == 0 && len(authentication) == 0 {
p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down Expand Up @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down
19 changes: 14 additions & 5 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err))
return err
}
if len(newAuth) > 0 {
resp.Auth = newAuth
Expand Down Expand Up @@ -858,16 +859,24 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeRespon
if err != nil {
return nil, err
}
userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: host})
failpoint.Inject("FakeUser", func(val failpoint.Value) {
userplugin = val.(string)
})
// Find the identity of the user based on username and peer host.
identity, err := cc.ctx.MatchIdentity(cc.user, host)
if err != nil {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
// Get the plugin for the identity.
userplugin, err := cc.ctx.AuthPluginForUser(identity)
if err != nil {
// This happens if the account doesn't exist
logutil.Logger(ctx).Warn("Failed to get authentication method for user",
zap.String("user", cc.user), zap.String("host", host))
}
failpoint.Inject("FakeUser", func(val failpoint.Value) {
userplugin = val.(string)
})
if userplugin == mysql.AuthSocket {
if !cc.isUnixSocket {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
resp.AuthPlugin = mysql.AuthSocket
user, err := user.LookupId(fmt.Sprint(cc.socketCredUID))
if err != nil {
Expand Down
23 changes: 19 additions & 4 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,12 @@ func TestHandleAuthPlugin(t *testing.T) {
require.NoError(t, err)
ctx := context.Background()

tk := testkit.NewTestKit(t, store)
tk.MustExec("CREATE USER unativepassword")
defer func() {
tk.MustExec("DROP USER unativepassword")
}()

// 5.7 or newer client trying to authenticate with mysql_native_password
cc := &clientConn{
connectionID: 1,
Expand All @@ -935,6 +941,7 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp := handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
Expand All @@ -955,6 +962,7 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
Expand All @@ -976,6 +984,7 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
Expand All @@ -998,14 +1007,15 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthNativePassword,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword))
require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// 8.0 or newer client trying to authenticate with caching_sha2_password
Expand All @@ -1020,14 +1030,15 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthCachingSha2Password,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword))
require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// MySQL 5.1 or older client, without authplugin support
Expand All @@ -1041,6 +1052,7 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
Expand All @@ -1064,14 +1076,15 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthNativePassword,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, resp.Auth, []byte(mysql.AuthCachingSha2Password))
require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// 8.0 or newer client trying to authenticate with caching_sha2_password
Expand All @@ -1086,14 +1099,15 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
AuthPlugin: mysql.AuthCachingSha2Password,
}
err = cc.handleAuthPlugin(ctx, &resp)
require.NoError(t, err)
require.Equal(t, resp.Auth, []byte(mysql.AuthCachingSha2Password))
require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))

// MySQL 5.1 or older client, without authplugin support
Expand All @@ -1107,6 +1121,7 @@ func TestHandleAuthPlugin(t *testing.T) {
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
},
server: srv,
user: "unativepassword",
}
resp = handshakeResponse41{
Capability: mysql.ClientProtocol41,
Expand Down
Loading