From 17906417f5985a75eab7a987f82056b9681454ed Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 7 Nov 2022 23:18:37 +0800 Subject: [PATCH 01/26] Add system variables --- executor/set_test.go | 11 +++---- sessionctx/variable/noop.go | 5 ---- sessionctx/variable/sysvar.go | 56 +++++++++++++++++++++++++++++++---- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index a8d4a0b059246..4413d986597d5 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -1407,14 +1407,11 @@ func TestValidateSetVar(t *testing.T) { tk.MustExec("set @@innodb_lock_wait_timeout = 1073741825") tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect innodb_lock_wait_timeout value: '1073741825'")) - tk.MustExec("set @@global.validate_password_number_count=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_number_count value: '-1'")) + tk.MustExec("set @@global.validate_password.number_count=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.number_count value: '-1'")) - tk.MustExec("set @@global.validate_password_length=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_length value: '-1'")) - - tk.MustExec("set @@global.validate_password_length=8") - tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustExec("set @@global.validate_password.length=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.length value: '-1'")) err = tk.ExecToErr("set @@tx_isolation=''") require.True(t, terror.ErrorEqual(err, variable.ErrWrongValueForVar), fmt.Sprintf("err %v", err)) diff --git a/sessionctx/variable/noop.go b/sessionctx/variable/noop.go index 398ea09f3ec92..5019ab90af115 100644 --- a/sessionctx/variable/noop.go +++ b/sessionctx/variable/noop.go @@ -58,8 +58,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: BigTables, Value: Off, Type: TypeBool}, {Scope: ScopeNone, Name: "skip_external_locking", Value: "1"}, {Scope: ScopeNone, Name: "innodb_sync_array_size", Value: "1"}, - {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeSession, Name: "gtid_next", Value: ""}, {Scope: ScopeGlobal, Name: "ndb_show_foreign_key_mock_tables", Value: ""}, {Scope: ScopeNone, Name: "multi_range_count", Value: "256"}, @@ -117,7 +115,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeNone, Name: "innodb_log_group_home_dir", Value: "./"}, {Scope: ScopeNone, Name: "performance_schema_events_statements_history_size", Value: "10"}, {Scope: ScopeGlobal, Name: GeneralLog, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: "validate_password_dictionary_file", Value: ""}, {Scope: ScopeGlobal, Name: BinlogOrderCommits, Value: On, Type: TypeBool}, {Scope: ScopeGlobal, Name: "key_cache_division_limit", Value: "100"}, {Scope: ScopeGlobal | ScopeSession, Name: "max_insert_delayed_threads", Value: "20"}, @@ -463,7 +460,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: "eq_range_index_dive_limit", Value: "200", IsHintUpdatable: true}, {Scope: ScopeNone, Name: "performance_schema_events_stages_history_size", Value: "10"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndb_join_pushdown", Value: ""}, - {Scope: ScopeGlobal, Name: "validate_password_special_char_count", Value: "1"}, {Scope: ScopeNone, Name: "performance_schema_max_thread_instances", Value: "402"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndbinfo_show_hidden", Value: ""}, {Scope: ScopeGlobal | ScopeSession, Name: "net_read_timeout", Value: "30"}, @@ -472,7 +468,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: "sync_relay_log_info", Value: "10000"}, {Scope: ScopeGlobal | ScopeSession, Name: "optimizer_trace_limit", Value: "1"}, {Scope: ScopeNone, Name: "innodb_ft_max_token_size", Value: "84"}, - {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeGlobal, Name: "ndb_log_binlog_index", Value: ""}, {Scope: ScopeGlobal, Name: "innodb_api_bk_commit_interval", Value: "5"}, {Scope: ScopeNone, Name: "innodb_undo_directory", Value: "."}, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 580a8deca7b7e..0ea82baf40c78 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -486,6 +486,39 @@ var defaultSysVars = []*SysVar{ } return normalizedValue, nil }}, + {Scope: ScopeGlobal, Name: ValidatePasswordEnable, Value: Off, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordPolicy, Value: "MEDIUM", Type: TypeEnum, PossibleValues: []string{"LOW", "MEDIUM", "STRONG"}}, + {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: On, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + var numberCount, specialCharCount, mixedCaseCount int64 + if numberCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + return "", err + } else if numberCount, err = strconv.ParseInt(numberCountStr, 10, 64); err != nil { + return "", err + } + if specialCharCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + return "", err + } else if specialCharCount, err = strconv.ParseInt(specialCharCountStr, 10, 64); err != nil { + return "", err + } + if mixedCaseCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + return "", err + } else if mixedCaseCount, err = strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return "", err + } + if length, err := strconv.ParseInt(normalizedValue, 10, 64); err != nil { + return "", err + } else if length < numberCount+specialCharCount+2*mixedCaseCount { + return "", ErrWrongValueForVar.GenWithStackByArgs(ValidatePasswordLength, normalizedValue) + } + return normalizedValue, nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, + {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, + {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, + {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr}, /* TiDB specific variables */ {Scope: ScopeGlobal, Name: TiDBTSOClientBatchMaxWaitTime, Value: strconv.FormatFloat(DefTiDBTSOClientBatchMaxWaitTime, 'f', -1, 64), Type: TypeFloat, MinValue: 0, MaxValue: 10, @@ -2104,10 +2137,6 @@ const ( BlockEncryptionMode = "block_encryption_mode" // WaitTimeout is the name for 'wait_timeout' system variable. WaitTimeout = "wait_timeout" - // ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable. - ValidatePasswordNumberCount = "validate_password_number_count" - // ValidatePasswordLength is the name of 'validate_password_length' system variable. - ValidatePasswordLength = "validate_password_length" // Version is the name of 'version' system variable. Version = "version" // VersionComment is the name of 'version_comment' system variable. @@ -2130,8 +2159,6 @@ const ( BinlogOrderCommits = "binlog_order_commits" // MasterVerifyChecksum is the name for 'master_verify_checksum' system variable. MasterVerifyChecksum = "master_verify_checksum" - // ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable. - ValidatePasswordCheckUserName = "validate_password_check_user_name" // SuperReadOnly is the name for 'super_read_only' system variable. SuperReadOnly = "super_read_only" // SQLNotes is the name for 'sql_notes' system variable. @@ -2298,4 +2325,21 @@ const ( RandSeed2 = "rand_seed2" // SQLRequirePrimaryKey is the name of `sql_require_primary_key` system variable. SQLRequirePrimaryKey = "sql_require_primary_key" + // ValidatePasswordEnable turns on/off the validation of password. + ValidatePasswordEnable = "validate_password.enable" + // ValidatePasswordPolicy specifies the password policy enforced by validate_password. + ValidatePasswordPolicy = "validate_password.policy" + // ValidatePasswordCheckUserName controls whether validate_password compares passwords to the user name part of + // the effective user account for the current session + ValidatePasswordCheckUserName = "validate_password.check_user_name" + // ValidatePasswordLength specified the minimum number of characters that validate_password requires passwords to have + ValidatePasswordLength = "validate_password.length" + // ValidatePasswordMixedCaseCount specified the minimum number of lowercase and uppercase characters that validate_password requires + ValidatePasswordMixedCaseCount = "validate_password.mixed_case_count" + // ValidatePasswordNumberCount specified the minimum number of numeric (digit) characters that validate_password requires + ValidatePasswordNumberCount = "validate_password.number_count" + // ValidatePasswordSpecialCharCount specified the minimum number of nonalphanumeric characters that validate_password requires + ValidatePasswordSpecialCharCount = "validate_password.special_char_count" + // ValidatePasswordDictionaryFile specified the path name of the dictionary file that validate_password uses for checking passwords + ValidatePasswordDictionaryFile = "validate_password.dictionary_file" ) From 9a3f17790bbefd96b4c2900fbbb62af11a321eb8 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 8 Nov 2022 17:06:10 +0800 Subject: [PATCH 02/26] TODO: checkDictionary --- executor/errors.go | 1 + executor/simple.go | 123 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/executor/errors.go b/executor/errors.go index 4a0c7f9215875..3bb8935b7ec3a 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -69,6 +69,7 @@ var ( ErrFuncNotEnabled = dbterror.ClassExecutor.NewStdErr(mysql.ErrNotSupportedYet, parser_mysql.Message("%-.32s is not supported. To enable this experimental feature, set '%-.32s' in the configuration file.", nil)) errSavepointNotExists = dbterror.ClassExecutor.NewStd(mysql.ErrSpDoesNotExist) ErrForeignKeyCascadeDepthExceeded = dbterror.ClassExecutor.NewStd(mysql.ErrForeignKeyCascadeDepthExceeded) + ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) ErrWrongStringLength = dbterror.ClassDDL.NewStd(mysql.ErrWrongStringLength) errUnsupportedFlashbackTmpTable = dbterror.ClassDDL.NewStdErr(mysql.ErrUnsupportedDDLOperation, parser_mysql.Message("Recover/flashback table is not supported on temporary tables", nil)) diff --git a/executor/simple.go b/executor/simple.go index 29d40775fa06e..4ed4584185110 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -20,9 +20,11 @@ import ( "encoding/json" "fmt" "os" + "strconv" "strings" "syscall" "time" + "unicode" "github.com/ngaut/pools" "github.com/pingcap/errors" @@ -783,6 +785,114 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } +func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption) bool { + if authOpt == nil || !authOpt.ByAuthString { + return false + } + if authOpt.AuthString == mysql.AuthNativePassword || + authOpt.AuthString == mysql.AuthTiDBSM3Password || + authOpt.AuthString == mysql.AuthCachingSha2Password { + return true + } + return false +} + +func (e *SimpleExec) checkUserNameInPassword(pwd string) error { + pwdBytes := hack.Slice(pwd) + userName := hack.Slice(e.ctx.GetSessionVars().User.AuthUsername) + userNameLen := len(userName) + if userNameLen == 0 { + return nil + } + if bytes.Contains(pwdBytes, userName) { + return ErrNotValidPassword.GenWithStack("Password Contains User Name") + } + var reverseUserName []byte + for i := range userName { + reverseUserName = append(reverseUserName, userName[userNameLen-1-i]) + } + if bytes.Contains(pwdBytes, reverseUserName) { + return ErrNotValidPassword.GenWithStack("Password Contains User Name") + } + return nil +} + +func (e *SimpleExec) checkDictionary(pwd string) error { + // TODO: dictionary_file + return nil +} + +func (e *SimpleExec) validatePassword(pwd string) error { + if validatePwd, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + return err + } else if !variable.TiDBOptOn(validatePwd) { + return nil + } + + runes := []rune(pwd) + // LOW + if validateLengthStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { + return err + } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { + return err + } else if (int64)(len(runes)) < validateLength { + return ErrNotValidPassword.GenWithStack("Require Password Length: %d", validateLength) + } + validatePolicy, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordPolicy) + if err != nil { + return err + } + if err = e.checkUserNameInPassword(pwd); err != nil { + return err + } + if validatePolicy == "LOW" { + return nil + } + + // MEDIUM + var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 + for _, r := range runes { + if unicode.IsUpper(r) { + upperCaseCount++ + } else if unicode.IsLower(r) { + lowerCaseCount++ + } else if unicode.IsDigit(r) { + numberCount++ + } else { + specialCharCount++ + } + } + if mixedCaseCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { + return err + } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return err + } else if lowerCaseCount < mixedCaseCount { + return ErrNotValidPassword.GenWithStack("Require Password Lowercase Count: %d", mixedCaseCount) + } else if upperCaseCount < mixedCaseCount { + return ErrNotValidPassword.GenWithStack("Require Password Uppercase Count: %d", mixedCaseCount) + } + if requireNumberCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { + return err + } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { + return err + } else if numberCount < requireNumberCount { + return ErrNotValidPassword.GenWithStack("Require Password Digit Count: %d", requireNumberCount) + } + if requireSpecialCharCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { + return err + } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { + return err + } else if specialCharCount < requireSpecialCharCount { + return ErrNotValidPassword.GenWithStack("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount) + } + if validatePolicy == "MEDIUM" { + return nil + } + + // STRONG + return e.checkDictionary(pwd) +} + func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) // Check `CREATE USER` privilege. @@ -874,6 +984,11 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm e.ctx.GetSessionVars().StmtCtx.AppendNote(err) continue } + if e.authUsingCleartextPwd(spec.AuthOpt) { + if err := e.validatePassword(spec.AuthOpt.AuthString); err != nil { + return err + } + } pwd, ok := spec.EncodedPassword() if !ok { @@ -1087,6 +1202,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } + if e.authUsingCleartextPwd(spec.AuthOpt) { + if err := e.validatePassword(spec.AuthOpt.AuthString); err != nil { + return err + } + } pwd, ok := spec.EncodedPassword() if !ok { return errors.Trace(ErrPasswordFormat) @@ -1603,6 +1723,9 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } + if err := e.validatePassword(s.Password); err != nil { + return err + } var pwd string switch authplugin { case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: From 20ff355b35fbb9bfeeca395194f457b9f39c1fb0 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 8 Nov 2022 19:49:22 +0800 Subject: [PATCH 03/26] Fix --- errors.toml | 5 +++++ executor/simple.go | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/errors.toml b/errors.toml index 31a56ef6b1d17..a50e484a5a29b 100644 --- a/errors.toml +++ b/errors.toml @@ -1451,6 +1451,11 @@ error = ''' SET PASSWORD has no significance for user '%-.48s'@'%-.255s' as authentication plugin does not support it. ''' +["executor:1819"] +error = ''' +Your password does not satisfy the current policy requirements +''' + ["executor:1827"] error = ''' The password hash doesn't have the expected format. Check if the correct password algorithm is being used with the PASSWORD() function. diff --git a/executor/simple.go b/executor/simple.go index 4ed4584185110..dd002ccf804b2 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -807,9 +807,9 @@ func (e *SimpleExec) checkUserNameInPassword(pwd string) error { if bytes.Contains(pwdBytes, userName) { return ErrNotValidPassword.GenWithStack("Password Contains User Name") } - var reverseUserName []byte + reverseUserName := make([]byte, userNameLen) for i := range userName { - reverseUserName = append(reverseUserName, userName[userNameLen-1-i]) + reverseUserName[i] = userName[userNameLen-1-i] } if bytes.Contains(pwdBytes, reverseUserName) { return ErrNotValidPassword.GenWithStack("Password Contains User Name") From 1779510f619b15b9ca01c2157eed2b23ca4b4f62 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 9 Nov 2022 00:45:21 +0800 Subject: [PATCH 04/26] TODO: add UT --- executor/BUILD.bazel | 1 + executor/simple.go | 37 ++++++------- sessionctx/variable/BUILD.bazel | 1 + sessionctx/variable/sysvar.go | 5 +- util/validate-password/BUILD.bazel | 13 +++++ util/validate-password/dictionary.go | 78 ++++++++++++++++++++++++++++ 6 files changed, 116 insertions(+), 19 deletions(-) create mode 100644 util/validate-password/BUILD.bazel create mode 100644 util/validate-password/dictionary.go diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 8d9eb3af53211..c810b579ea876 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -195,6 +195,7 @@ go_library( "//util/tls", "//util/topsql", "//util/topsql/state", + "//util/validate-password", "@com_github_burntsushi_toml//:toml", "@com_github_gogo_protobuf//proto", "@com_github_ngaut_pools//:pools", diff --git a/executor/simple.go b/executor/simple.go index dd002ccf804b2..3950e71c87c1b 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -57,6 +57,7 @@ import ( "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tidb/util/tls" + validatePwd "github.com/pingcap/tidb/util/validate-password" "github.com/pingcap/tipb/go-tipb" tikvutil "github.com/tikv/client-go/v2/util" "go.uber.org/zap" @@ -797,35 +798,30 @@ func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption) bool { return false } -func (e *SimpleExec) checkUserNameInPassword(pwd string) error { +func (e *SimpleExec) validateUserNameInPassword(pwd, username string) error { pwdBytes := hack.Slice(pwd) - userName := hack.Slice(e.ctx.GetSessionVars().User.AuthUsername) - userNameLen := len(userName) + usernameBytes := hack.Slice(username) + userNameLen := len(usernameBytes) if userNameLen == 0 { return nil } - if bytes.Contains(pwdBytes, userName) { + if bytes.Contains(pwdBytes, usernameBytes) { return ErrNotValidPassword.GenWithStack("Password Contains User Name") } - reverseUserName := make([]byte, userNameLen) - for i := range userName { - reverseUserName[i] = userName[userNameLen-1-i] + usernameReversedBytes := make([]byte, userNameLen) + for i := range usernameBytes { + usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] } - if bytes.Contains(pwdBytes, reverseUserName) { - return ErrNotValidPassword.GenWithStack("Password Contains User Name") + if bytes.Contains(pwdBytes, usernameReversedBytes) { + return ErrNotValidPassword.GenWithStack("Password Contains Reversed User Name") } return nil } -func (e *SimpleExec) checkDictionary(pwd string) error { - // TODO: dictionary_file - return nil -} - func (e *SimpleExec) validatePassword(pwd string) error { - if validatePwd, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + if validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { return err - } else if !variable.TiDBOptOn(validatePwd) { + } else if !variable.TiDBOptOn(validatePwdEnable) { return nil } @@ -842,7 +838,9 @@ func (e *SimpleExec) validatePassword(pwd string) error { if err != nil { return err } - if err = e.checkUserNameInPassword(pwd); err != nil { + if err = e.validateUserNameInPassword(pwd, e.ctx.GetSessionVars().User.AuthUsername); err != nil { + return err + } else if err = e.validateUserNameInPassword(pwd, e.ctx.GetSessionVars().User.Username); err != nil { return err } if validatePolicy == "LOW" { @@ -890,7 +888,10 @@ func (e *SimpleExec) validatePassword(pwd string) error { } // STRONG - return e.checkDictionary(pwd) + if !validatePwd.ValidateDictionaryPassword(pwd) { + return ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") + } + return nil } func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index fa4865079e8bf..60375b1e92336 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -57,6 +57,7 @@ go_library( "//util/timeutil", "//util/tls", "//util/topsql/state", + "//util/validate-password", "//util/versioninfo", "@com_github_pingcap_errors//:errors", "@com_github_tikv_client_go_v2//config", diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 0ea82baf40c78..652c0e986ea09 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -44,6 +44,7 @@ import ( "github.com/pingcap/tidb/util/tikvutil" "github.com/pingcap/tidb/util/tls" topsqlstate "github.com/pingcap/tidb/util/topsql/state" + validatePwd "github.com/pingcap/tidb/util/validate-password" "github.com/pingcap/tidb/util/versioninfo" tikvcfg "github.com/tikv/client-go/v2/config" tikvstore "github.com/tikv/client-go/v2/kv" @@ -518,7 +519,9 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, - {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr}, + {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + return validatePwd.UpdateDictionaryFile(val) + }}, /* TiDB specific variables */ {Scope: ScopeGlobal, Name: TiDBTSOClientBatchMaxWaitTime, Value: strconv.FormatFloat(DefTiDBTSOClientBatchMaxWaitTime, 'f', -1, 64), Type: TypeFloat, MinValue: 0, MaxValue: 10, diff --git a/util/validate-password/BUILD.bazel b/util/validate-password/BUILD.bazel new file mode 100644 index 0000000000000..498bea77f5b68 --- /dev/null +++ b/util/validate-password/BUILD.bazel @@ -0,0 +1,13 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "validate-password", + srcs = ["dictionary.go"], + importpath = "github.com/pingcap/tidb/util/validate-password", + visibility = ["//visibility:public"], + deps = [ + "//util/hack", + "//util/mathutil", + "@com_github_pingcap_errors//:errors", + ], +) diff --git a/util/validate-password/dictionary.go b/util/validate-password/dictionary.go new file mode 100644 index 0000000000000..397572874882c --- /dev/null +++ b/util/validate-password/dictionary.go @@ -0,0 +1,78 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validate_password + +import ( + "bufio" + "os" + "strings" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/mathutil" +) + +type dictionaryImpl struct { + cache map[string]struct{} + m sync.RWMutex +} + +var dictionary = dictionaryImpl{cache: make(map[string]struct{})} + +func UpdateDictionaryFile(filePath string) error { + dictionary.m.Lock() + defer dictionary.m.Unlock() + newDictionary := make(map[string]struct{}) + file, err := os.Open(filePath) + if err != nil { + return err + } + if fileInfo, err := file.Stat(); err != nil { + return err + } else if fileInfo.Size() > 1*1024*1024 { + return errors.New("Too Large Dictionary. The maximum permitted file size is 1MB") + } + s := bufio.NewScanner(file) + for s.Scan() { + line := strings.ToLower(string(hack.String(s.Bytes()))) + if len(line) >= 4 && len(line) <= 100 { + newDictionary[line] = struct{}{} + } + } + if err := s.Err(); err != nil { + return err + } + dictionary.cache = newDictionary + return file.Close() +} + +func ValidateDictionaryPassword(pwd string) bool { + dictionary.m.RLock() + dictionary.m.RUnlock() + if len(dictionary.cache) == 0 { + return true + } + pwdLen := len(pwd) + for subStrLen := mathutil.Min(100, pwdLen); subStrLen >= 4; subStrLen-- { + for subStrPos := 0; subStrPos+subStrLen <= pwdLen; subStrPos++ { + subStr := pwd[subStrPos : subStrPos+subStrLen] + if _, ok := dictionary.cache[subStr]; ok { + return false + } + } + } + return true +} From 983023ae24823ac3f17f445ec9a6e3389061ad4b Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 9 Nov 2022 01:01:39 +0800 Subject: [PATCH 05/26] Fix --- util/validate-password/dictionary.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/util/validate-password/dictionary.go b/util/validate-password/dictionary.go index 397572874882c..1b20b392fbd6e 100644 --- a/util/validate-password/dictionary.go +++ b/util/validate-password/dictionary.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package validate_password +package validator import ( "bufio" "os" + "path/filepath" "strings" "sync" @@ -32,11 +33,12 @@ type dictionaryImpl struct { var dictionary = dictionaryImpl{cache: make(map[string]struct{})} +// UpdateDictionaryFile update the dictionary for validating password. func UpdateDictionaryFile(filePath string) error { dictionary.m.Lock() defer dictionary.m.Unlock() newDictionary := make(map[string]struct{}) - file, err := os.Open(filePath) + file, err := os.Open(filepath.Clean(filePath)) if err != nil { return err } @@ -59,9 +61,10 @@ func UpdateDictionaryFile(filePath string) error { return file.Close() } +// ValidateDictionaryPassword checks if the password contains words in the dictionary. func ValidateDictionaryPassword(pwd string) bool { dictionary.m.RLock() - dictionary.m.RUnlock() + defer dictionary.m.RUnlock() if len(dictionary.cache) == 0 { return true } From 21ad50ccaefe41312b44eea840fed1df81d37c82 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 9 Nov 2022 13:56:09 +0800 Subject: [PATCH 06/26] Add basic UT --- executor/BUILD.bazel | 1 + executor/simple.go | 46 ++++++++++++---------- executor/simple_test.go | 41 +++++++++++++++++++ util/validate-password/BUILD.bazel | 9 ++++- util/validate-password/dictionary.go | 47 ++++++++++++++++++++-- util/validate-password/dictionary_test.go | 48 +++++++++++++++++++++++ 6 files changed, 166 insertions(+), 26 deletions(-) create mode 100644 util/validate-password/dictionary_test.go diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index c810b579ea876..06386aca9f168 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -416,6 +416,7 @@ go_test( "//util/tableutil", "//util/timeutil", "//util/topsql/state", + "//util/validate-password", "@com_github_golang_protobuf//proto", "@com_github_gorilla_mux//:mux", "@com_github_jarcoal_httpmock//:httpmock", diff --git a/executor/simple.go b/executor/simple.go index 3950e71c87c1b..23cd726baabb8 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -790,9 +790,10 @@ func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption) bool { if authOpt == nil || !authOpt.ByAuthString { return false } - if authOpt.AuthString == mysql.AuthNativePassword || - authOpt.AuthString == mysql.AuthTiDBSM3Password || - authOpt.AuthString == mysql.AuthCachingSha2Password { + if authOpt.AuthPlugin == mysql.AuthNativePassword || + authOpt.AuthPlugin == mysql.AuthTiDBSM3Password || + authOpt.AuthPlugin == mysql.AuthCachingSha2Password || + authOpt.AuthPlugin == "" { return true } return false @@ -818,31 +819,34 @@ func (e *SimpleExec) validateUserNameInPassword(pwd, username string) error { return nil } -func (e *SimpleExec) validatePassword(pwd string) error { - if validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { +func (e *SimpleExec) validatePassword(pwd string, currentUser *auth.UserIdentity) error { + globalVars := e.ctx.GetSessionVars().GlobalVarsAccessor + if validatePwdEnable, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { return err } else if !variable.TiDBOptOn(validatePwdEnable) { return nil } runes := []rune(pwd) - // LOW - if validateLengthStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { + validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) + if err != nil { + return err + } + if validateLengthStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { return err } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { return err } else if (int64)(len(runes)) < validateLength { return ErrNotValidPassword.GenWithStack("Require Password Length: %d", validateLength) } - validatePolicy, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordPolicy) - if err != nil { - return err - } - if err = e.validateUserNameInPassword(pwd, e.ctx.GetSessionVars().User.AuthUsername); err != nil { - return err - } else if err = e.validateUserNameInPassword(pwd, e.ctx.GetSessionVars().User.Username); err != nil { - return err + if currentUser != nil { + if err = e.validateUserNameInPassword(pwd, currentUser.AuthUsername); err != nil { + return err + } else if err = e.validateUserNameInPassword(pwd, currentUser.Username); err != nil { + return err + } } + // LOW if validatePolicy == "LOW" { return nil } @@ -860,7 +864,7 @@ func (e *SimpleExec) validatePassword(pwd string) error { specialCharCount++ } } - if mixedCaseCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { + if mixedCaseCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { return err } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { return err @@ -869,14 +873,14 @@ func (e *SimpleExec) validatePassword(pwd string) error { } else if upperCaseCount < mixedCaseCount { return ErrNotValidPassword.GenWithStack("Require Password Uppercase Count: %d", mixedCaseCount) } - if requireNumberCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { + if requireNumberCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { return err } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { return err } else if numberCount < requireNumberCount { return ErrNotValidPassword.GenWithStack("Require Password Digit Count: %d", requireNumberCount) } - if requireSpecialCharCountStr, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { + if requireSpecialCharCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { return err } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { return err @@ -986,7 +990,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm continue } if e.authUsingCleartextPwd(spec.AuthOpt) { - if err := e.validatePassword(spec.AuthOpt.AuthString); err != nil { + if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { return err } } @@ -1204,7 +1208,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } if e.authUsingCleartextPwd(spec.AuthOpt) { - if err := e.validatePassword(spec.AuthOpt.AuthString); err != nil { + if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { return err } } @@ -1724,7 +1728,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } - if err := e.validatePassword(s.Password); err != nil { + if err := e.validatePassword(s.Password, e.ctx.GetSessionVars().User); err != nil { return err } var pwd string diff --git a/executor/simple_test.go b/executor/simple_test.go index 8b284fb9b42e5..ffe76d4dfdde6 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/server" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" + validator "github.com/pingcap/tidb/util/validate-password" "github.com/stretchr/testify/require" tikvutil "github.com/tikv/client-go/v2/util" ) @@ -122,3 +123,43 @@ func TestUserAttributes(t *testing.T) { testkit.Rows("root % ", "testuser % {\"comment\": \"1234\"}", "testuser1 % {\"age\": 20, \"name\": \"Tom\"}", "testuser2 % ")) tk.MustGetErrCode(`SELECT user, host, user_attributes FROM mysql.user ORDER BY user`, mysql.ErrTableaccessDenied) } + +func TestValidatePassword(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) + + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) + tk.MustExec("DROP USER IF EXISTS testuser") + tk.MustExec("CREATE USER testuser IDENTIFIED BY '12345678'") + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + + tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") + + // check user name + tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'abcdroot1234'", "Password Contains User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'abcdtoor1234'", "Password Contains Reversed User Name") + + // LOW: Length + tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") + + // MEDIUM: Length; numeric, lowercase/uppercase, and special characters + tk.MustExec("SET GLOBAL validate_password.policy = 'MEDIUM'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABC1234567'", "Require Password Lowercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!abc1234567'", "Require Password Uppercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABCDabcd'", "Require Password Digit Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'Abc1234567'", "Require Password Non-alphanumeric Count: 1") + + // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file + tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") + dictFile, err := validator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + require.NoError(t, err) + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file = '%s'", dictFile)) + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") +} diff --git a/util/validate-password/BUILD.bazel b/util/validate-password/BUILD.bazel index 498bea77f5b68..6ea767148ebb7 100644 --- a/util/validate-password/BUILD.bazel +++ b/util/validate-password/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "validate-password", @@ -11,3 +11,10 @@ go_library( "@com_github_pingcap_errors//:errors", ], ) + +go_test( + name = "validate-password_test", + srcs = ["dictionary_test.go"], + embed = [":validate-password"], + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/util/validate-password/dictionary.go b/util/validate-password/dictionary.go index 1b20b392fbd6e..437fdbb3e04b6 100644 --- a/util/validate-password/dictionary.go +++ b/util/validate-password/dictionary.go @@ -31,6 +31,9 @@ type dictionaryImpl struct { m sync.RWMutex } +const maxPwdLength int = 100 +const minPwdLength int = 4 + var dictionary = dictionaryImpl{cache: make(map[string]struct{})} // UpdateDictionaryFile update the dictionary for validating password. @@ -50,7 +53,7 @@ func UpdateDictionaryFile(filePath string) error { s := bufio.NewScanner(file) for s.Scan() { line := strings.ToLower(string(hack.String(s.Bytes()))) - if len(line) >= 4 && len(line) <= 100 { + if len(line) >= minPwdLength && len(line) <= maxPwdLength { newDictionary[line] = struct{}{} } } @@ -68,9 +71,9 @@ func ValidateDictionaryPassword(pwd string) bool { if len(dictionary.cache) == 0 { return true } - pwdLen := len(pwd) - for subStrLen := mathutil.Min(100, pwdLen); subStrLen >= 4; subStrLen-- { - for subStrPos := 0; subStrPos+subStrLen <= pwdLen; subStrPos++ { + pwdLength := len(pwd) + for subStrLen := mathutil.Min(maxPwdLength, pwdLength); subStrLen >= minPwdLength; subStrLen-- { + for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { subStr := pwd[subStrPos : subStrPos+subStrLen] if _, ok := dictionary.cache[subStr]; ok { return false @@ -79,3 +82,39 @@ func ValidateDictionaryPassword(pwd string) bool { } return true } + +// CreateTmpDictWithSize is only used for test. +func CreateTmpDictWithSize(filename string, size int) (string, error) { + filename = filepath.Join(os.TempDir(), filename) + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return "", err + } + if size > 0 { + n, err := file.Write(make([]byte, size)) + if err != nil { + return "", err + } else if n != size { + return "", errors.New("") + } + } + return filename, file.Close() +} + +// CreateTmpDictWithContent is only used for test. +func CreateTmpDictWithContent(filename string, content []byte) (string, error) { + filename = filepath.Join(os.TempDir(), filename) + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return "", err + } + if len(content) > 0 { + n, err := file.Write(content) + if err != nil { + return "", err + } else if n != len(content) { + return "", errors.New("") + } + } + return filename, file.Close() +} diff --git a/util/validate-password/dictionary_test.go b/util/validate-password/dictionary_test.go new file mode 100644 index 0000000000000..5b06ffa183b05 --- /dev/null +++ b/util/validate-password/dictionary_test.go @@ -0,0 +1,48 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateDictionaryFile(t *testing.T) { + tooLargeDict, err := CreateTmpDictWithSize("1.dict", 2*1024*1024) + require.NoError(t, err) + err = UpdateDictionaryFile(tooLargeDict) + require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") + + dict, err := CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) + require.NoError(t, err) + require.NoError(t, UpdateDictionaryFile(dict)) + _, ok := dictionary.cache["1234"] + require.True(t, ok) + _, ok = dictionary.cache["5678"] + require.True(t, ok) + _, ok = dictionary.cache["abc"] + require.False(t, ok) +} + +func TestValidateDictionaryPassword(t *testing.T) { + dict, err := CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + require.NoError(t, err) + require.NoError(t, UpdateDictionaryFile(dict)) + require.True(t, ValidateDictionaryPassword("abcdefg")) + require.True(t, ValidateDictionaryPassword("abcd123efg")) + require.False(t, ValidateDictionaryPassword("abcd1234efg")) + require.False(t, ValidateDictionaryPassword("abcd12345efg")) +} From 58460d7b892096db97886e7554f9378f59aaad78 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 9 Nov 2022 16:35:42 +0800 Subject: [PATCH 07/26] Finish UT --- executor/simple.go | 44 +++++++------ executor/simple_test.go | 98 +++++++++++++++++++--------- sessionctx/variable/sysvar.go | 5 ++ util/validate-password/dictionary.go | 7 ++ 4 files changed, 103 insertions(+), 51 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index 23cd726baabb8..a3c9c4cbc36ad 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -786,14 +786,13 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } -func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption) bool { +func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin string) bool { if authOpt == nil || !authOpt.ByAuthString { return false } - if authOpt.AuthPlugin == mysql.AuthNativePassword || - authOpt.AuthPlugin == mysql.AuthTiDBSM3Password || - authOpt.AuthPlugin == mysql.AuthCachingSha2Password || - authOpt.AuthPlugin == "" { + if authPlugin == mysql.AuthNativePassword || + authPlugin == mysql.AuthTiDBSM3Password || + authPlugin == mysql.AuthCachingSha2Password { return true } return false @@ -819,13 +818,16 @@ func (e *SimpleExec) validateUserNameInPassword(pwd, username string) error { return nil } +func (e *SimpleExec) enableValidatePassword() bool { + validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { + return false + } + return variable.TiDBOptOn(validatePwdEnable) +} + func (e *SimpleExec) validatePassword(pwd string, currentUser *auth.UserIdentity) error { globalVars := e.ctx.GetSessionVars().GlobalVarsAccessor - if validatePwdEnable, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { - return err - } else if !variable.TiDBOptOn(validatePwdEnable) { - return nil - } runes := []rune(pwd) validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) @@ -989,7 +991,11 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm e.ctx.GetSessionVars().StmtCtx.AppendNote(err) continue } - if e.authUsingCleartextPwd(spec.AuthOpt) { + authPlugin := mysql.AuthNativePassword + if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { + authPlugin = spec.AuthOpt.AuthPlugin + } + if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { return err } @@ -999,10 +1005,6 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if !ok { return errors.Trace(ErrPasswordFormat) } - authPlugin := mysql.AuthNativePassword - if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { - authPlugin = spec.AuthOpt.AuthPlugin - } switch authPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthTiDBAuthToken: @@ -1191,11 +1193,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) var fields []alterField if spec.AuthOpt != nil { if spec.AuthOpt.AuthPlugin == "" { - authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) + curAuthplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) if err != nil { return err } - spec.AuthOpt.AuthPlugin = authplugin + spec.AuthOpt.AuthPlugin = curAuthplugin } switch spec.AuthOpt.AuthPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, "": @@ -1207,7 +1209,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } - if e.authUsingCleartextPwd(spec.AuthOpt) { + if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { return err } @@ -1728,8 +1730,10 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } - if err := e.validatePassword(s.Password, e.ctx.GetSessionVars().User); err != nil { - return err + if e.enableValidatePassword() { + if err := e.validatePassword(s.Password, e.ctx.GetSessionVars().User); err != nil { + return err + } } var pwd string switch authplugin { diff --git a/executor/simple_test.go b/executor/simple_test.go index ffe76d4dfdde6..b72fe0b9542ef 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -125,41 +125,77 @@ func TestUserAttributes(t *testing.T) { } func TestValidatePassword(t *testing.T) { + // Some test cases come from mysql-server/mysql-test: + // t/validate_password_component.test + // t/validate_password_component_check_user.test + store, _ := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) - tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) - tk.MustExec("DROP USER IF EXISTS testuser") - tk.MustExec("CREATE USER testuser IDENTIFIED BY '12345678'") - tk.MustExec("SET GLOBAL validate_password.enable = 1") - tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) - - tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") - - // check user name - tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'abcdroot1234'", "Password Contains User Name") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'abcdtoor1234'", "Password Contains Reversed User Name") - - // LOW: Length - tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") - - // MEDIUM: Length; numeric, lowercase/uppercase, and special characters - tk.MustExec("SET GLOBAL validate_password.policy = 'MEDIUM'") - tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABC1234567'", "Require Password Lowercase Count: 1") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!abc1234567'", "Require Password Uppercase Count: 1") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABCDabcd'", "Require Password Digit Count: 1") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'Abc1234567'", "Require Password Non-alphanumeric Count: 1") - - // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file - tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") + authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} dictFile, err := validator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) - tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file = '%s'", dictFile)) - tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") - tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") + tk.MustExec("CREATE USER \"\"@localhost") + tk.MustExec("GRANT ALL PRIVILEGES ON test.* TO \"\"@localhost;") + + for _, authPlugin := range authPlugins { + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) + tk.MustExec("DROP USER IF EXISTS testuser") + tk.MustExec(fmt.Sprintf("CREATE USER testuser IDENTIFIED WITH %s BY '12345678'", authPlugin)) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + + tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") + + // check user name + tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains Reversed User Name") + tk.MustExec("SET PASSWORD FOR 'testuser' = 'testuser'") // password the same as the user name, but run by root + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'testuser'") + + // LOW: Length + tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") + tk.MustExec("SET GLOBAL validate_password.length = 12") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length: 12") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdefg1234'") + tk.MustExec("SET GLOBAL validate_password.length = 8") + + // MEDIUM: Length; numeric, lowercase/uppercase, and special characters + tk.MustExec("SET GLOBAL validate_password.policy = 'MEDIUM'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABC1234567'", "Require Password Lowercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!abc1234567'", "Require Password Uppercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABCDabcd'", "Require Password Digit Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'Abc1234567'", "Require Password Non-alphanumeric Count: 1") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'Abc1234567'") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 1") + tk.MustContainErrMsg("SET GLOBAL validate_password.length = 3", "Variable 'validate_password.length' can't be set to the value of '3'") + + // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file + tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file = '%s'", dictFile)) + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") + tk.MustExec("SET GLOBAL validate_password.dictionary_file = ''") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + + // "IDENTIFIED AS 'xxx'" is not affected by validation + tk.MustExec(fmt.Sprintf("ALTER USER testuser IDENTIFIED WITH '%s' AS ''", authPlugin)) + + // if the username is '', all password can pass the check_user_name + subtk := testkit.NewTestKit(t, store) + require.NoError(t, subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil)) + subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) + subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + subtk.MustExec("ALTER USER ''@localhost IDENTIFIED BY ''") + + tk.MustExec("SET GLOBAL validate_password.enable = 0") + } } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 652c0e986ea09..95cfc5ea52d67 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -520,6 +520,11 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + // Use 'SET @@global.validate_password.dictionary_file = ""' to clean the dictionary. + if len(val) == 0 { + validatePwd.Clean() + return nil + } return validatePwd.UpdateDictionaryFile(val) }}, diff --git a/util/validate-password/dictionary.go b/util/validate-password/dictionary.go index 437fdbb3e04b6..d9f049457263b 100644 --- a/util/validate-password/dictionary.go +++ b/util/validate-password/dictionary.go @@ -36,6 +36,13 @@ const minPwdLength int = 4 var dictionary = dictionaryImpl{cache: make(map[string]struct{})} +// Clean removes all the words in the dictionary. +func Clean() { + dictionary.m.Lock() + defer dictionary.m.Unlock() + dictionary.cache = make(map[string]struct{}) +} + // UpdateDictionaryFile update the dictionary for validating password. func UpdateDictionaryFile(filePath string) error { dictionary.m.Lock() From 44ec82e7e3130b5002ce6d4e032e34ab8dca47c2 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 9 Nov 2022 17:42:42 +0800 Subject: [PATCH 08/26] Update UT --- executor/simple_test.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/executor/simple_test.go b/executor/simple_test.go index b72fe0b9542ef..c072953110c99 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -136,18 +136,15 @@ func TestValidatePassword(t *testing.T) { authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} dictFile, err := validator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) - tk.MustExec("CREATE USER \"\"@localhost") - tk.MustExec("GRANT ALL PRIVILEGES ON test.* TO \"\"@localhost;") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) for _, authPlugin := range authPlugins { - tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("DROP USER IF EXISTS testuser") - tk.MustExec(fmt.Sprintf("CREATE USER testuser IDENTIFIED WITH %s BY '12345678'", authPlugin)) - tk.MustExec("SET GLOBAL validate_password.enable = 1") - tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + tk.MustExec(fmt.Sprintf("CREATE USER testuser IDENTIFIED WITH %s BY '!Abc12345678'", authPlugin)) tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") - // check user name tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") @@ -187,15 +184,18 @@ func TestValidatePassword(t *testing.T) { // "IDENTIFIED AS 'xxx'" is not affected by validation tk.MustExec(fmt.Sprintf("ALTER USER testuser IDENTIFIED WITH '%s' AS ''", authPlugin)) - - // if the username is '', all password can pass the check_user_name - subtk := testkit.NewTestKit(t, store) - require.NoError(t, subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil)) - subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) - subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) - subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) - subtk.MustExec("ALTER USER ''@localhost IDENTIFIED BY ''") - - tk.MustExec("SET GLOBAL validate_password.enable = 0") } + + // if the username is '', all password can pass the check_user_name + tk.MustExec("CREATE USER ''@'localhost'") + tk.MustExec("GRANT ALL PRIVILEGES ON mysql.* TO ''@'localhost';") + subtk := testkit.NewTestKit(t, store) + require.NoError(t, subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil)) + subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) + subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY ''") + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY 'abcd'") + + tk.MustExec("SET GLOBAL validate_password.enable = 1") } From d05d9ca116181bc588e47929db87129b41a8530b Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Fri, 11 Nov 2022 18:36:48 +0800 Subject: [PATCH 09/26] TODO --- executor/simple.go | 111 ++------------------ executor/simple_test.go | 12 ++- expression/builtin_encryption.go | 63 +++++++++++- expression/builtin_encryption_vec.go | 9 ++ util/validate-password/dictionary.go | 118 +++++++++++++++++++++- util/validate-password/dictionary_test.go | 8 +- util/validate-password/errors.go | 25 +++++ 7 files changed, 231 insertions(+), 115 deletions(-) create mode 100644 util/validate-password/errors.go diff --git a/executor/simple.go b/executor/simple.go index a3c9c4cbc36ad..69b6925db89ea 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -19,13 +19,6 @@ import ( "context" "encoding/json" "fmt" - "os" - "strconv" - "strings" - "syscall" - "time" - "unicode" - "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/tidb/config" @@ -61,6 +54,10 @@ import ( "github.com/pingcap/tipb/go-tipb" tikvutil "github.com/tikv/client-go/v2/util" "go.uber.org/zap" + "os" + "strings" + "syscall" + "time" ) var ( @@ -798,26 +795,6 @@ func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin s return false } -func (e *SimpleExec) validateUserNameInPassword(pwd, username string) error { - pwdBytes := hack.Slice(pwd) - usernameBytes := hack.Slice(username) - userNameLen := len(usernameBytes) - if userNameLen == 0 { - return nil - } - if bytes.Contains(pwdBytes, usernameBytes) { - return ErrNotValidPassword.GenWithStack("Password Contains User Name") - } - usernameReversedBytes := make([]byte, userNameLen) - for i := range usernameBytes { - usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] - } - if bytes.Contains(pwdBytes, usernameReversedBytes) { - return ErrNotValidPassword.GenWithStack("Password Contains Reversed User Name") - } - return nil -} - func (e *SimpleExec) enableValidatePassword() bool { validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable) if err != nil { @@ -826,80 +803,6 @@ func (e *SimpleExec) enableValidatePassword() bool { return variable.TiDBOptOn(validatePwdEnable) } -func (e *SimpleExec) validatePassword(pwd string, currentUser *auth.UserIdentity) error { - globalVars := e.ctx.GetSessionVars().GlobalVarsAccessor - - runes := []rune(pwd) - validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) - if err != nil { - return err - } - if validateLengthStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { - return err - } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { - return err - } else if (int64)(len(runes)) < validateLength { - return ErrNotValidPassword.GenWithStack("Require Password Length: %d", validateLength) - } - if currentUser != nil { - if err = e.validateUserNameInPassword(pwd, currentUser.AuthUsername); err != nil { - return err - } else if err = e.validateUserNameInPassword(pwd, currentUser.Username); err != nil { - return err - } - } - // LOW - if validatePolicy == "LOW" { - return nil - } - - // MEDIUM - var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 - for _, r := range runes { - if unicode.IsUpper(r) { - upperCaseCount++ - } else if unicode.IsLower(r) { - lowerCaseCount++ - } else if unicode.IsDigit(r) { - numberCount++ - } else { - specialCharCount++ - } - } - if mixedCaseCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { - return err - } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { - return err - } else if lowerCaseCount < mixedCaseCount { - return ErrNotValidPassword.GenWithStack("Require Password Lowercase Count: %d", mixedCaseCount) - } else if upperCaseCount < mixedCaseCount { - return ErrNotValidPassword.GenWithStack("Require Password Uppercase Count: %d", mixedCaseCount) - } - if requireNumberCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { - return err - } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { - return err - } else if numberCount < requireNumberCount { - return ErrNotValidPassword.GenWithStack("Require Password Digit Count: %d", requireNumberCount) - } - if requireSpecialCharCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { - return err - } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { - return err - } else if specialCharCount < requireSpecialCharCount { - return ErrNotValidPassword.GenWithStack("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount) - } - if validatePolicy == "MEDIUM" { - return nil - } - - // STRONG - if !validatePwd.ValidateDictionaryPassword(pwd) { - return ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") - } - return nil -} - func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) // Check `CREATE USER` privilege. @@ -996,7 +899,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm authPlugin = spec.AuthOpt.AuthPlugin } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { - if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { + if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1210,7 +1113,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { - if err := e.validatePassword(spec.AuthOpt.AuthString, e.ctx.GetSessionVars().User); err != nil { + if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1731,7 +1634,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return err } if e.enableValidatePassword() { - if err := e.validatePassword(s.Password, e.ctx.GetSessionVars().User); err != nil { + if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { return err } } diff --git a/executor/simple_test.go b/executor/simple_test.go index c072953110c99..f4df110938f05 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -147,16 +147,20 @@ func TestValidatePassword(t *testing.T) { tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") // check user name tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains Reversed User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains (Reversed) User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains (Reversed) User Name") tk.MustExec("SET PASSWORD FOR 'testuser' = 'testuser'") // password the same as the user name, but run by root tk.MustExec("ALTER USER testuser IDENTIFIED BY 'testuser'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 1") // LOW: Length tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length") tk.MustExec("SET GLOBAL validate_password.length = 12") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length: 12") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length") tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdefg1234'") tk.MustExec("SET GLOBAL validate_password.length = 8") diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index a206a9d4970bb..aca16c718b8ae 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -25,6 +25,7 @@ import ( "crypto/sha512" "encoding/binary" "fmt" + validatePwd "github.com/pingcap/tidb/util/validate-password" "hash" "io" "strings" @@ -73,6 +74,7 @@ var ( _ builtinFunc = &builtinSHA2Sig{} _ builtinFunc = &builtinUncompressSig{} _ builtinFunc = &builtinUncompressedLengthSig{} + _ builtinFunc = &builtinValidatePasswordStrengthSig{} ) // aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. @@ -1010,5 +1012,64 @@ type validatePasswordStrengthFunctionClass struct { } func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString) + if err != nil { + return nil, err + } + charset, collate := ctx.GetSessionVars().GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + bf.tp.SetFlen(args[0].GetType().GetFlen()) + sig := &builtinValidatePasswordStrengthSig{bf} + //sig.setPbCode(tipb.ScalarFuncSig_ValidatePasswordStrength) + return sig, nil +} + +type builtinValidatePasswordStrengthSig struct { + baseBuiltinFunc +} + +func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc { + newSig := &builtinValidatePasswordStrengthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals VALIDATE_PASSWORD_STRENGTH(str). +// See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength +func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) { + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + return 0, true, err + } else if !variable.TiDBOptOn(validation) { + return 0, false, nil + } + + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if err != nil { + return 0, true, err + } else if isNull { + return 0, true, nil + } else if len(str) < 4 { + return 0, false, nil + } + + if ok, err := validatePwd.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { + return 0, true, err + } else if !ok { + return 0, false, nil + } + + if ok, err := validatePwd.ValidateLow(str, &globalVars); err != nil { + return 0, false, err + } else if !ok { + return 25, false, nil + } + + // TODO + + return 100, false, nil } diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index e9a1d45ae67be..f5bac5660ad99 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -863,3 +863,12 @@ func (b *builtinUncompressedLengthSig) vecEvalInt(input *chunk.Chunk, result *ch } return nil } + +func (b *builtinValidatePasswordStrengthSig) vectorized() bool { + return true +} + +func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + // TODO + return nil +} diff --git a/util/validate-password/dictionary.go b/util/validate-password/dictionary.go index d9f049457263b..e5f03becbe180 100644 --- a/util/validate-password/dictionary.go +++ b/util/validate-password/dictionary.go @@ -16,10 +16,14 @@ package validator import ( "bufio" + "bytes" + "github.com/pingcap/tidb/sessionctx/variable" "os" "path/filepath" + "strconv" "strings" "sync" + "unicode" "github.com/pingcap/errors" "github.com/pingcap/tidb/util/hack" @@ -71,8 +75,8 @@ func UpdateDictionaryFile(filePath string) error { return file.Close() } -// ValidateDictionaryPassword checks if the password contains words in the dictionary. -func ValidateDictionaryPassword(pwd string) bool { +// validateDictionaryPassword checks if the password contains words in the dictionary. +func validateDictionaryPassword(pwd string) bool { dictionary.m.RLock() defer dictionary.m.RUnlock() if len(dictionary.cache) == 0 { @@ -125,3 +129,113 @@ func CreateTmpDictWithContent(filename string, content []byte) (string, error) { } return filename, file.Close() } + +func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (bool, error) { + currentUser := sessionVars.User + globalVars := sessionVars.GlobalVarsAccessor + pwdBytes := hack.Slice(pwd) + if checkUserName, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordCheckUserName); err != nil { + return false, err + } else if currentUser != nil && variable.TiDBOptOn(checkUserName) { + for _, username := range []string{currentUser.AuthUsername, currentUser.Username} { + usernameBytes := hack.Slice(username) + userNameLen := len(usernameBytes) + if userNameLen == 0 { + continue + } + if bytes.Contains(pwdBytes, usernameBytes) { + return false, nil + } + usernameReversedBytes := make([]byte, userNameLen) + for i := range usernameBytes { + usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] + } + if bytes.Contains(pwdBytes, usernameReversedBytes) { + return false, nil + } + } + } + return true, nil +} + +func ValidateLow(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { + if validateLengthStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { + return false, err + } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { + return false, err + } else if (int64)(len([]rune(pwd))) < validateLength { + return false, nil + } + return true, nil +} + +func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { + globalVars := sessionVars.GlobalVarsAccessor + + runes := []rune(pwd) + validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) + if err != nil { + return err + } + if ok, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { + return err + } else if !ok { + return ErrNotValidPassword.GenWithStack("Password Contains (Reversed) User Name") + } + if ok, err := ValidateLow(pwd, &globalVars); err != nil { + return err + } else if !ok { + return ErrNotValidPassword.GenWithStack("Require Password Length") + } + + // LOW + if validatePolicy == "LOW" { + return nil + } + + // MEDIUM + var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 + for _, r := range runes { + if unicode.IsUpper(r) { + upperCaseCount++ + } else if unicode.IsLower(r) { + lowerCaseCount++ + } else if unicode.IsDigit(r) { + numberCount++ + } else { + specialCharCount++ + } + } + if mixedCaseCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { + return err + } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return err + } else if lowerCaseCount < mixedCaseCount { + return ErrNotValidPassword.GenWithStack("Require Password Lowercase Count: %d", mixedCaseCount) + } else if upperCaseCount < mixedCaseCount { + return ErrNotValidPassword.GenWithStack("Require Password Uppercase Count: %d", mixedCaseCount) + } + if requireNumberCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { + return err + } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { + return err + } else if numberCount < requireNumberCount { + return ErrNotValidPassword.GenWithStack("Require Password Digit Count: %d", requireNumberCount) + } + if requireSpecialCharCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { + return err + } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { + return err + } else if specialCharCount < requireSpecialCharCount { + return ErrNotValidPassword.GenWithStack("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount) + } + if validatePolicy == "MEDIUM" { + return nil + } + + // STRONG + if !validateDictionaryPassword(pwd) { + return ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") + } + return nil +} diff --git a/util/validate-password/dictionary_test.go b/util/validate-password/dictionary_test.go index 5b06ffa183b05..87ada4ef30c9c 100644 --- a/util/validate-password/dictionary_test.go +++ b/util/validate-password/dictionary_test.go @@ -41,8 +41,8 @@ func TestValidateDictionaryPassword(t *testing.T) { dict, err := CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) require.NoError(t, UpdateDictionaryFile(dict)) - require.True(t, ValidateDictionaryPassword("abcdefg")) - require.True(t, ValidateDictionaryPassword("abcd123efg")) - require.False(t, ValidateDictionaryPassword("abcd1234efg")) - require.False(t, ValidateDictionaryPassword("abcd12345efg")) + require.True(t, validateDictionaryPassword("abcdefg")) + require.True(t, validateDictionaryPassword("abcd123efg")) + require.False(t, validateDictionaryPassword("abcd1234efg")) + require.False(t, validateDictionaryPassword("abcd12345efg")) } diff --git a/util/validate-password/errors.go b/util/validate-password/errors.go new file mode 100644 index 0000000000000..bd301aa8bf02e --- /dev/null +++ b/util/validate-password/errors.go @@ -0,0 +1,25 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + mysql "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/util/dbterror" +) + +// Error instances. +var ( + ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) +) From 37b9aabcc3a4a7e2e38aa76f819354057a42c8cf Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 14 Nov 2022 21:03:42 +0800 Subject: [PATCH 10/26] TODO: add UT in builtin_encryption_vec.go and integration_test.go --- ddl/ddl_api.go | 2 +- ddl/schematracker/dm_tracker.go | 2 +- executor/BUILD.bazel | 2 - executor/adapter.go | 2 +- executor/errors.go | 1 - executor/partition_table_test.go | 2 +- executor/simple.go | 7 +- executor/simple_test.go | 12 +- expression/builtin_encryption.go | 44 ++--- expression/builtin_encryption_test.go | 53 ++++++ expression/builtin_encryption_vec.go | 33 +++- expression/integration_test.go | 2 + sessionctx/binloginfo/binloginfo_test.go | 2 +- sessionctx/variable/BUILD.bazel | 4 +- sessionctx/variable/error.go | 1 + .../variable/password_dictionary.go | 153 +++++++++--------- .../variable/password_dictionary_test.go | 22 +-- sessionctx/variable/session.go | 2 +- sessionctx/variable/sysvar.go | 5 +- util/validate-password/BUILD.bazel | 20 --- util/validate-password/errors.go | 25 --- 21 files changed, 221 insertions(+), 175 deletions(-) rename util/validate-password/dictionary.go => sessionctx/variable/password_dictionary.go (59%) rename util/validate-password/dictionary_test.go => sessionctx/variable/password_dictionary_test.go (69%) delete mode 100644 util/validate-password/BUILD.bazel delete mode 100644 util/validate-password/errors.go diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index f60a80ece0d06..a1ee5728202f9 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -5071,7 +5071,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt } col := table.ToColumn(oldCol.Clone()) - // Clean the NoDefaultValueFlag value. + // CleanPasswordDictionary the NoDefaultValueFlag value. col.DelFlag(mysql.NoDefaultValueFlag) if len(specNewColumn.Options) == 0 { col.DefaultIsExpr = false diff --git a/ddl/schematracker/dm_tracker.go b/ddl/schematracker/dm_tracker.go index afb3a75c1974b..813503a9b8c45 100644 --- a/ddl/schematracker/dm_tracker.go +++ b/ddl/schematracker/dm_tracker.go @@ -610,7 +610,7 @@ func (d SchemaTracker) alterColumn(ctx sessionctx.Context, ident ast.Ident, spec return dbterror.ErrBadField.GenWithStackByArgs(colName, ident.Name) } - // Clean the NoDefaultValueFlag value. + // CleanPasswordDictionary the NoDefaultValueFlag value. oldCol.DelFlag(mysql.NoDefaultValueFlag) if len(specNewColumn.Options) == 0 { oldCol.DefaultIsExpr = false diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 06386aca9f168..8d9eb3af53211 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -195,7 +195,6 @@ go_library( "//util/tls", "//util/topsql", "//util/topsql/state", - "//util/validate-password", "@com_github_burntsushi_toml//:toml", "@com_github_gogo_protobuf//proto", "@com_github_ngaut_pools//:pools", @@ -416,7 +415,6 @@ go_test( "//util/tableutil", "//util/timeutil", "//util/topsql/state", - "//util/validate-password", "@com_github_golang_protobuf//proto", "@com_github_gorilla_mux//:mux", "@com_github_jarcoal_httpmock//:httpmock", diff --git a/executor/adapter.go b/executor/adapter.go index db9fbbaa929e0..2a6e78443c414 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -1381,7 +1381,7 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults boo } // Reset DurationParse due to the next statement may not need to be parsed (not a text protocol query). sessVars.DurationParse = 0 - // Clean the stale read flag when statement execution finish + // CleanPasswordDictionary the stale read flag when statement execution finish sessVars.StmtCtx.IsStaleness = false if sessVars.StmtCtx.ReadFromTableCache { diff --git a/executor/errors.go b/executor/errors.go index 3bb8935b7ec3a..4a0c7f9215875 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -69,7 +69,6 @@ var ( ErrFuncNotEnabled = dbterror.ClassExecutor.NewStdErr(mysql.ErrNotSupportedYet, parser_mysql.Message("%-.32s is not supported. To enable this experimental feature, set '%-.32s' in the configuration file.", nil)) errSavepointNotExists = dbterror.ClassExecutor.NewStd(mysql.ErrSpDoesNotExist) ErrForeignKeyCascadeDepthExceeded = dbterror.ClassExecutor.NewStd(mysql.ErrForeignKeyCascadeDepthExceeded) - ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) ErrWrongStringLength = dbterror.ClassDDL.NewStd(mysql.ErrWrongStringLength) errUnsupportedFlashbackTmpTable = dbterror.ClassDDL.NewStdErr(mysql.ErrUnsupportedDDLOperation, parser_mysql.Message("Recover/flashback table is not supported on temporary tables", nil)) diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index 3640b4e155097..a8cbcf3e05114 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -3268,7 +3268,7 @@ func TestIssue26251(t *testing.T) { t.Fail() } - // Clean up + // CleanPasswordDictionary up <-ch tk2.MustExec("rollback") } diff --git a/executor/simple.go b/executor/simple.go index 69b6925db89ea..f06d41e19a587 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -50,7 +50,6 @@ import ( "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" "github.com/pingcap/tidb/util/tls" - validatePwd "github.com/pingcap/tidb/util/validate-password" "github.com/pingcap/tipb/go-tipb" tikvutil "github.com/tikv/client-go/v2/util" "go.uber.org/zap" @@ -899,7 +898,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm authPlugin = spec.AuthOpt.AuthPlugin } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { - if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := variable.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1113,7 +1112,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { - if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := variable.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1634,7 +1633,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return err } if e.enableValidatePassword() { - if err := validatePwd.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { + if err := variable.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { return err } } diff --git a/executor/simple_test.go b/executor/simple_test.go index f4df110938f05..46759202282ae 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -24,9 +24,9 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/server" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" - validator "github.com/pingcap/tidb/util/validate-password" "github.com/stretchr/testify/require" tikvutil "github.com/tikv/client-go/v2/util" ) @@ -134,7 +134,7 @@ func TestValidatePassword(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} - dictFile, err := validator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + dictFile, err := variable.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("SET GLOBAL validate_password.enable = 1") @@ -147,8 +147,8 @@ func TestValidatePassword(t *testing.T) { tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") // check user name tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains (Reversed) User Name") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains (Reversed) User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains Reversed User Name") tk.MustExec("SET PASSWORD FOR 'testuser' = 'testuser'") // password the same as the user name, but run by root tk.MustExec("ALTER USER testuser IDENTIFIED BY 'testuser'") tk.MustExec("SET GLOBAL validate_password.check_user_name = 0") @@ -158,9 +158,9 @@ func TestValidatePassword(t *testing.T) { // LOW: Length tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") tk.MustExec("SET GLOBAL validate_password.length = 12") - tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length: 12") tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdefg1234'") tk.MustExec("SET GLOBAL validate_password.length = 8") diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index aca16c718b8ae..891b8ff050518 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -25,7 +25,6 @@ import ( "crypto/sha512" "encoding/binary" "fmt" - validatePwd "github.com/pingcap/tidb/util/validate-password" "hash" "io" "strings" @@ -1019,10 +1018,7 @@ func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Conte if err != nil { return nil, err } - charset, collate := ctx.GetSessionVars().GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - bf.tp.SetFlen(args[0].GetType().GetFlen()) + bf.tp.SetFlen(21) sig := &builtinValidatePasswordStrengthSig{bf} //sig.setPbCode(tipb.ScalarFuncSig_ValidatePasswordStrength) return sig, nil @@ -1042,34 +1038,38 @@ func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc { // See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) { globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor - if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if err != nil || isNull { return 0, true, err - } else if !variable.TiDBOptOn(validation) { + } else if len([]rune(str)) < 4 { return 0, false, nil } - - str, isNull, err := b.args[0].EvalString(b.ctx, row) - if err != nil { + if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { return 0, true, err - } else if isNull { - return 0, true, nil - } else if len(str) < 4 { + } else if !variable.TiDBOptOn(validation) { return 0, false, nil } + return b.validateStr(str, &globalVars) +} - if ok, err := validatePwd.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { +func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) { + if warn, err := variable.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { return 0, true, err - } else if !ok { + } else if len(warn) > 0 { return 0, false, nil } - - if ok, err := validatePwd.ValidateLow(str, &globalVars); err != nil { - return 0, false, err - } else if !ok { + if warn, err := variable.ValidatePasswordLowPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { return 25, false, nil } - - // TODO - + if warn, err := variable.ValidatePasswordMediumPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 50, false, nil + } + if ok := variable.ValidateDictionaryPassword(str); !ok { + return 75, false, nil + } return 100, false, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 0f74ab611aa48..1da7a0b638469 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -15,12 +15,14 @@ package expression import ( + "context" "encoding/hex" "fmt" "strings" "testing" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -631,6 +633,57 @@ func TestUncompressLength(t *testing.T) { } } +func TestValidatePasswordStrength(t *testing.T) { + ctx := createContext(t) + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} + tempDict, err := variable.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) + require.NoError(t, err) + globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() + ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor + err = globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionaryFile, tempDict) + require.NoError(t, err) + + tests := []struct { + in interface{} + expect interface{} + }{ + {nil, nil}, + {"123", 0}, + {"testuser123", 0}, + {"resutset123", 0}, + {"12345", 25}, + {"12345678", 50}, + {"!Abc12345678", 75}, + {"!Abc87654321", 100}, + } + + fc := funcs[ast.ValidatePasswordStrength] + // disable password validation + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + if test.expect == nil { + require.Equal(t, types.NewDatum(nil), out) + } else { + require.Equalf(t, types.NewDatum(0), out, "%v", test) + } + } + // enable password validation + err = globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordEnable, "ON") + require.NoError(t, err) + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + require.Equalf(t, types.NewDatum(test.expect), out, "%v", test) + } +} + func TestPassword(t *testing.T) { ctx := createContext(t) cases := []struct { diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index f5bac5660ad99..aa71078491d24 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" @@ -869,6 +870,36 @@ func (b *builtinValidatePasswordStrengthSig) vectorized() bool { } func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { - // TODO + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { + return err + } + + result.ResizeInt64(n, true) + i64s := result.Int64s() + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + enableValidation := false + if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + return err + } else { + enableValidation = variable.TiDBOptOn(validation) + } + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + if !enableValidation { + i64s[i] = 0 + } else if score, isNull, err := b.validateStr(buf.GetString(i), &globalVars); err != nil { + return err + } else if !isNull { + i64s[i] = score + } + } return nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index d3b307fb9653b..09c716df232b0 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1143,6 +1143,8 @@ func TestEncryptionBuiltin(t *testing.T) { tk.MustQuery("SELECT RANDOM_BYTES(1024);") result = tk.MustQuery("SELECT RANDOM_BYTES(NULL);") result.Check(testkit.Rows("")) + + // TODO: for VALIDATE_PASSWORD_STRENGTH } func TestOpBuiltin(t *testing.T) { diff --git a/sessionctx/binloginfo/binloginfo_test.go b/sessionctx/binloginfo/binloginfo_test.go index 3c777a9436234..8b2b0746f33aa 100644 --- a/sessionctx/binloginfo/binloginfo_test.go +++ b/sessionctx/binloginfo/binloginfo_test.go @@ -479,7 +479,7 @@ func TestZIgnoreError(t *testing.T) { tk.MustExec("insert into t values (1)") tk.MustExec("insert into t values (1)") - // Clean up. + // CleanPasswordDictionary up. s.pump.mu.Lock() s.pump.mu.mockFail = false s.pump.mu.Unlock() diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index b6bf2ce552bc3..18500687d1dac 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "error.go", "mock_globalaccessor.go", "noop.go", + "password_dictionary.go", "removed.go", "sequence_state.go", "session.go", @@ -43,6 +44,7 @@ go_library( "//util/disk", "//util/execdetails", "//util/gctuner", + "//util/hack", "//util/kvcache", "//util/logutil", "//util/mathutil", @@ -57,7 +59,6 @@ go_library( "//util/timeutil", "//util/tls", "//util/topsql/state", - "//util/validate-password", "//util/versioninfo", "@com_github_pingcap_errors//:errors", "@com_github_tikv_client_go_v2//config", @@ -77,6 +78,7 @@ go_test( srcs = [ "main_test.go", "mock_globalaccessor_test.go", + "password_dictionary_test.go", "removed_test.go", "session_test.go", "statusvar_test.go", diff --git a/sessionctx/variable/error.go b/sessionctx/variable/error.go index 60928932f0f06..f760cba8bfcd5 100644 --- a/sessionctx/variable/error.go +++ b/sessionctx/variable/error.go @@ -39,6 +39,7 @@ var ( errLocalVariable = dbterror.ClassVariable.NewStd(mysql.ErrLocalVariable) errValueNotSupportedWhen = dbterror.ClassVariable.NewStdErr(mysql.ErrNotSupportedYet, pmysql.Message("%s = OFF is not supported when %s = ON", nil)) ErrStmtNotFound = dbterror.ClassOptimizer.NewStd(mysql.ErrPreparedStmtNotFound) + ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) // ErrFunctionsNoopImpl is an error to say the behavior is protected by the tidb_enable_noop_functions sysvar. // This is copied from expression.ErrFunctionsNoopImpl to prevent circular dependencies. // It needs to be public for tests. diff --git a/util/validate-password/dictionary.go b/sessionctx/variable/password_dictionary.go similarity index 59% rename from util/validate-password/dictionary.go rename to sessionctx/variable/password_dictionary.go index e5f03becbe180..7b83e4246fbf7 100644 --- a/util/validate-password/dictionary.go +++ b/sessionctx/variable/password_dictionary.go @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package validator +package variable import ( "bufio" "bytes" - "github.com/pingcap/tidb/sessionctx/variable" + "fmt" "os" "path/filepath" "strconv" @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/util/mathutil" ) -type dictionaryImpl struct { +type PasswordDictionaryImpl struct { cache map[string]struct{} m sync.RWMutex } @@ -38,19 +38,19 @@ type dictionaryImpl struct { const maxPwdLength int = 100 const minPwdLength int = 4 -var dictionary = dictionaryImpl{cache: make(map[string]struct{})} +var passwordDictionary = PasswordDictionaryImpl{cache: make(map[string]struct{})} -// Clean removes all the words in the dictionary. -func Clean() { - dictionary.m.Lock() - defer dictionary.m.Unlock() - dictionary.cache = make(map[string]struct{}) +// CleanPasswordDictionary removes all the words in the dictionary. +func CleanPasswordDictionary() { + passwordDictionary.m.Lock() + defer passwordDictionary.m.Unlock() + passwordDictionary.cache = make(map[string]struct{}) } -// UpdateDictionaryFile update the dictionary for validating password. -func UpdateDictionaryFile(filePath string) error { - dictionary.m.Lock() - defer dictionary.m.Unlock() +// UpdatePasswordDictionary update the dictionary for validating password. +func UpdatePasswordDictionary(filePath string) error { + passwordDictionary.m.Lock() + defer passwordDictionary.m.Unlock() newDictionary := make(map[string]struct{}) file, err := os.Open(filepath.Clean(filePath)) if err != nil { @@ -71,22 +71,22 @@ func UpdateDictionaryFile(filePath string) error { if err := s.Err(); err != nil { return err } - dictionary.cache = newDictionary + passwordDictionary.cache = newDictionary return file.Close() } -// validateDictionaryPassword checks if the password contains words in the dictionary. -func validateDictionaryPassword(pwd string) bool { - dictionary.m.RLock() - defer dictionary.m.RUnlock() - if len(dictionary.cache) == 0 { +// ValidateDictionaryPassword checks if the password contains words in the dictionary. +func ValidateDictionaryPassword(pwd string) bool { + passwordDictionary.m.RLock() + defer passwordDictionary.m.RUnlock() + if len(passwordDictionary.cache) == 0 { return true } pwdLength := len(pwd) for subStrLen := mathutil.Min(maxPwdLength, pwdLength); subStrLen >= minPwdLength; subStrLen-- { for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { subStr := pwd[subStrPos : subStrPos+subStrLen] - if _, ok := dictionary.cache[subStr]; ok { + if _, ok := passwordDictionary.cache[subStr]; ok { return false } } @@ -130,13 +130,13 @@ func CreateTmpDictWithContent(filename string, content []byte) (string, error) { return filename, file.Close() } -func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (bool, error) { +func ValidateUserNameInPassword(pwd string, sessionVars *SessionVars) (string, error) { currentUser := sessionVars.User globalVars := sessionVars.GlobalVarsAccessor pwdBytes := hack.Slice(pwd) - if checkUserName, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordCheckUserName); err != nil { - return false, err - } else if currentUser != nil && variable.TiDBOptOn(checkUserName) { + if checkUserName, err := globalVars.GetGlobalSysVar(ValidatePasswordCheckUserName); err != nil { + return "", err + } else if currentUser != nil && TiDBOptOn(checkUserName) { for _, username := range []string{currentUser.AuthUsername, currentUser.Username} { usernameBytes := hack.Slice(username) userNameLen := len(usernameBytes) @@ -144,58 +144,34 @@ func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) ( continue } if bytes.Contains(pwdBytes, usernameBytes) { - return false, nil + return "Password Contains User Name", nil } usernameReversedBytes := make([]byte, userNameLen) for i := range usernameBytes { usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] } if bytes.Contains(pwdBytes, usernameReversedBytes) { - return false, nil + return "Password Contains Reversed User Name", nil } } } - return true, nil + return "", nil } -func ValidateLow(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { - if validateLengthStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { - return false, err +func ValidatePasswordLowPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { + if validateLengthStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordLength); err != nil { + return "", err } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { - return false, err + return "", err } else if (int64)(len([]rune(pwd))) < validateLength { - return false, nil + return fmt.Sprintf("Require Password Length: %d", validateLength), nil } - return true, nil + return "", nil } -func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { - globalVars := sessionVars.GlobalVarsAccessor - - runes := []rune(pwd) - validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) - if err != nil { - return err - } - if ok, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { - return err - } else if !ok { - return ErrNotValidPassword.GenWithStack("Password Contains (Reversed) User Name") - } - if ok, err := ValidateLow(pwd, &globalVars); err != nil { - return err - } else if !ok { - return ErrNotValidPassword.GenWithStack("Require Password Length") - } - - // LOW - if validatePolicy == "LOW" { - return nil - } - - // MEDIUM +func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 - for _, r := range runes { + for _, r := range []rune(pwd) { if unicode.IsUpper(r) { upperCaseCount++ } else if unicode.IsLower(r) { @@ -206,35 +182,66 @@ func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { specialCharCount++ } } - if mixedCaseCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { - return err + if mixedCaseCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordMixedCaseCount); err != nil { + return "", err } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { - return err + return "", err } else if lowerCaseCount < mixedCaseCount { - return ErrNotValidPassword.GenWithStack("Require Password Lowercase Count: %d", mixedCaseCount) + return fmt.Sprintf("Require Password Lowercase Count: %d", mixedCaseCount), nil } else if upperCaseCount < mixedCaseCount { - return ErrNotValidPassword.GenWithStack("Require Password Uppercase Count: %d", mixedCaseCount) + return fmt.Sprintf("Require Password Uppercase Count: %d", mixedCaseCount), nil } - if requireNumberCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { - return err + if requireNumberCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + return "", err } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { - return err + return "", err } else if numberCount < requireNumberCount { - return ErrNotValidPassword.GenWithStack("Require Password Digit Count: %d", requireNumberCount) + return fmt.Sprintf("Require Password Digit Count: %d", requireNumberCount), nil } - if requireSpecialCharCountStr, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { - return err + if requireSpecialCharCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordSpecialCharCount); err != nil { + return "", err } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { - return err + return "", err } else if specialCharCount < requireSpecialCharCount { - return ErrNotValidPassword.GenWithStack("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount) + return fmt.Sprintf("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount), nil + } + return "", nil +} + +func ValidatePassword(sessionVars *SessionVars, pwd string) error { + globalVars := sessionVars.GlobalVarsAccessor + + validatePolicy, err := globalVars.GetGlobalSysVar(ValidatePasswordPolicy) + if err != nil { + return err + } + if warn, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { + return err + } else if len(warn) > 0 { + return ErrNotValidPassword.GenWithStack(warn) + } + if warn, err := ValidatePasswordLowPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return ErrNotValidPassword.GenWithStack(warn) + } + // LOW + if validatePolicy == "LOW" { + return nil + } + + // MEDIUM + if warn, err := ValidatePasswordMediumPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return ErrNotValidPassword.GenWithStack(warn) } if validatePolicy == "MEDIUM" { return nil } // STRONG - if !validateDictionaryPassword(pwd) { + if !ValidateDictionaryPassword(pwd) { return ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") } return nil diff --git a/util/validate-password/dictionary_test.go b/sessionctx/variable/password_dictionary_test.go similarity index 69% rename from util/validate-password/dictionary_test.go rename to sessionctx/variable/password_dictionary_test.go index 87ada4ef30c9c..fa9c85ab5f8e8 100644 --- a/util/validate-password/dictionary_test.go +++ b/sessionctx/variable/password_dictionary_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package validator +package variable import ( "testing" @@ -23,26 +23,26 @@ import ( func TestUpdateDictionaryFile(t *testing.T) { tooLargeDict, err := CreateTmpDictWithSize("1.dict", 2*1024*1024) require.NoError(t, err) - err = UpdateDictionaryFile(tooLargeDict) + err = UpdatePasswordDictionary(tooLargeDict) require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") dict, err := CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) require.NoError(t, err) - require.NoError(t, UpdateDictionaryFile(dict)) - _, ok := dictionary.cache["1234"] + require.NoError(t, UpdatePasswordDictionary(dict)) + _, ok := passwordDictionary.cache["1234"] require.True(t, ok) - _, ok = dictionary.cache["5678"] + _, ok = passwordDictionary.cache["5678"] require.True(t, ok) - _, ok = dictionary.cache["abc"] + _, ok = passwordDictionary.cache["abc"] require.False(t, ok) } func TestValidateDictionaryPassword(t *testing.T) { dict, err := CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) - require.NoError(t, UpdateDictionaryFile(dict)) - require.True(t, validateDictionaryPassword("abcdefg")) - require.True(t, validateDictionaryPassword("abcd123efg")) - require.False(t, validateDictionaryPassword("abcd1234efg")) - require.False(t, validateDictionaryPassword("abcd12345efg")) + require.NoError(t, UpdatePasswordDictionary(dict)) + require.True(t, ValidateDictionaryPassword("abcdefg")) + require.True(t, ValidateDictionaryPassword("abcd123efg")) + require.False(t, ValidateDictionaryPassword("abcd1234efg")) + require.False(t, ValidateDictionaryPassword("abcd12345efg")) } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index b11b727079630..249d6f5f23fc0 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -102,7 +102,7 @@ type ReuseChunkPool struct { Alloc chunk.Allocator } -// Clean does some clean work. +// CleanPasswordDictionary does some clean work. func (r *RetryInfo) Clean() { r.autoIncrementIDs.clean() r.autoRandomIDs.clean() diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 69a3d3ca1c252..58d6c67b04d99 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -44,7 +44,6 @@ import ( "github.com/pingcap/tidb/util/tikvutil" "github.com/pingcap/tidb/util/tls" topsqlstate "github.com/pingcap/tidb/util/topsql/state" - validatePwd "github.com/pingcap/tidb/util/validate-password" "github.com/pingcap/tidb/util/versioninfo" tikvcfg "github.com/tikv/client-go/v2/config" tikvstore "github.com/tikv/client-go/v2/kv" @@ -522,10 +521,10 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { // Use 'SET @@global.validate_password.dictionary_file = ""' to clean the dictionary. if len(val) == 0 { - validatePwd.Clean() + CleanPasswordDictionary() return nil } - return validatePwd.UpdateDictionaryFile(val) + return UpdatePasswordDictionary(val) }}, /* TiDB specific variables */ diff --git a/util/validate-password/BUILD.bazel b/util/validate-password/BUILD.bazel deleted file mode 100644 index 6ea767148ebb7..0000000000000 --- a/util/validate-password/BUILD.bazel +++ /dev/null @@ -1,20 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") - -go_library( - name = "validate-password", - srcs = ["dictionary.go"], - importpath = "github.com/pingcap/tidb/util/validate-password", - visibility = ["//visibility:public"], - deps = [ - "//util/hack", - "//util/mathutil", - "@com_github_pingcap_errors//:errors", - ], -) - -go_test( - name = "validate-password_test", - srcs = ["dictionary_test.go"], - embed = [":validate-password"], - deps = ["@com_github_stretchr_testify//require"], -) diff --git a/util/validate-password/errors.go b/util/validate-password/errors.go deleted file mode 100644 index bd301aa8bf02e..0000000000000 --- a/util/validate-password/errors.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package validator - -import ( - mysql "github.com/pingcap/tidb/errno" - "github.com/pingcap/tidb/util/dbterror" -) - -// Error instances. -var ( - ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) -) From e611e4a8e5ed2009dad8923cff9b904c64dba2ea Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 14 Nov 2022 21:30:14 +0800 Subject: [PATCH 11/26] Fix --- expression/builtin_encryption_vec.go | 6 +++--- sessionctx/variable/password_dictionary.go | 14 ++++++++++---- sessionctx/variable/session.go | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index aa71078491d24..7ba71f7329c7f 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -884,11 +884,11 @@ func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, resu i64s := result.Int64s() globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor enableValidation := false - if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { return err - } else { - enableValidation = variable.TiDBOptOn(validation) } + enableValidation = variable.TiDBOptOn(validation) for i := 0; i < n; i++ { if result.IsNull(i) { continue diff --git a/sessionctx/variable/password_dictionary.go b/sessionctx/variable/password_dictionary.go index 7b83e4246fbf7..06f475932d634 100644 --- a/sessionctx/variable/password_dictionary.go +++ b/sessionctx/variable/password_dictionary.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/mathutil" ) +// PasswordDictionaryImpl is the dictionary for validating password. type PasswordDictionaryImpl struct { cache map[string]struct{} m sync.RWMutex @@ -130,6 +131,7 @@ func CreateTmpDictWithContent(filename string, content []byte) (string, error) { return filename, file.Close() } +// ValidateUserNameInPassword checks whether pwd exists in the dictionary. func ValidateUserNameInPassword(pwd string, sessionVars *SessionVars) (string, error) { currentUser := sessionVars.User globalVars := sessionVars.GlobalVarsAccessor @@ -158,6 +160,7 @@ func ValidateUserNameInPassword(pwd string, sessionVars *SessionVars) (string, e return "", nil } +// ValidatePasswordLowPolicy checks whether pwd satisfies the low policy of password validation. func ValidatePasswordLowPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { if validateLengthStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordLength); err != nil { return "", err @@ -169,14 +172,16 @@ func ValidatePasswordLowPolicy(pwd string, globalVars *GlobalVarAccessor) (strin return "", nil } +// ValidatePasswordMediumPolicy checks whether pwd satisfies the medium policy of password validation. func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 - for _, r := range []rune(pwd) { - if unicode.IsUpper(r) { + runes := []rune(pwd) + for i := 0; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { upperCaseCount++ - } else if unicode.IsLower(r) { + } else if unicode.IsLower(runes[i]) { lowerCaseCount++ - } else if unicode.IsDigit(r) { + } else if unicode.IsDigit(runes[i]) { numberCount++ } else { specialCharCount++ @@ -208,6 +213,7 @@ func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (st return "", nil } +// ValidatePassword checks whether the pwd can be used. func ValidatePassword(sessionVars *SessionVars, pwd string) error { globalVars := sessionVars.GlobalVarsAccessor diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 249d6f5f23fc0..b11b727079630 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -102,7 +102,7 @@ type ReuseChunkPool struct { Alloc chunk.Allocator } -// CleanPasswordDictionary does some clean work. +// Clean does some clean work. func (r *RetryInfo) Clean() { r.autoIncrementIDs.clean() r.autoRandomIDs.clean() From 0406cfbff81ae2c21a9d9944312722bb24cefa56 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Mon, 14 Nov 2022 22:35:38 +0800 Subject: [PATCH 12/26] Fix --- executor/simple.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index f06d41e19a587..942367593cde5 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -19,6 +19,11 @@ import ( "context" "encoding/json" "fmt" + "os" + "strings" + "syscall" + "time" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/tidb/config" @@ -53,10 +58,6 @@ import ( "github.com/pingcap/tipb/go-tipb" tikvutil "github.com/tikv/client-go/v2/util" "go.uber.org/zap" - "os" - "strings" - "syscall" - "time" ) var ( From cec6fb11255dc5236ccd839799486c514b97a0a3 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 15 Nov 2022 16:18:07 +0800 Subject: [PATCH 13/26] Update --- executor/BUILD.bazel | 2 + executor/simple.go | 7 +- executor/simple_test.go | 4 +- expression/BUILD.bazel | 2 + expression/builtin_encryption.go | 7 +- expression/builtin_encryption_test.go | 3 +- expression/builtin_encryption_vec.go | 5 +- expression/builtin_encryption_vec_test.go | 3 + expression/integration_test.go | 23 +++- sessionctx/variable/BUILD.bazel | 2 - sessionctx/variable/varsutil.go | 72 ++++++++++++ util/password-validation/BUILD.bazel | 23 ++++ .../password_validation.go | 108 ++++-------------- .../password_validation_test.go | 25 ++-- 14 files changed, 173 insertions(+), 113 deletions(-) create mode 100644 util/password-validation/BUILD.bazel rename sessionctx/variable/password_dictionary.go => util/password-validation/password_validation.go (59%) rename sessionctx/variable/password_dictionary_test.go => util/password-validation/password_validation_test.go (60%) diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 6a300dbeaf654..1f65c981a4801 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -177,6 +177,7 @@ go_library( "//util/mathutil", "//util/memory", "//util/mvmap", + "//util/password-validation", "//util/pdapi", "//util/plancodec", "//util/printer", @@ -406,6 +407,7 @@ go_test( "//util/memory", "//util/mock", "//util/paging", + "//util/password-validation", "//util/pdapi", "//util/plancodec", "//util/ranger", diff --git a/executor/simple.go b/executor/simple.go index c0fd9c12dd931..102dbc1947e6d 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -51,6 +51,7 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" + pwdValidate "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" @@ -899,7 +900,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm authPlugin = spec.AuthOpt.AuthPlugin } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { - if err := variable.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1113,7 +1114,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { - if err := variable.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1634,7 +1635,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return err } if e.enableValidatePassword() { - if err := variable.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { + if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { return err } } diff --git a/executor/simple_test.go b/executor/simple_test.go index 46759202282ae..35fe979dbb28e 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -24,9 +24,9 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/server" - "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" + pwdValidate "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" tikvutil "github.com/tikv/client-go/v2/util" ) @@ -134,7 +134,7 @@ func TestValidatePassword(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} - dictFile, err := variable.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + dictFile, err := pwdValidate.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("SET GLOBAL validate_password.enable = 1") diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 032c44054dba2..471874910cb27 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -97,6 +97,7 @@ go_library( "//util/mathutil", "//util/mock", "//util/parser", + "//util/password-validation", "//util/plancodec", "//util/printer", "//util/sem", @@ -220,6 +221,7 @@ go_test( "//util/hack", "//util/mathutil", "//util/mock", + "//util/password-validation", "//util/printer", "//util/sem", "//util/sqlexec", diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 891b8ff050518..d5bde2e900836 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" + pwdValidate "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tipb/go-tipb" ) @@ -1053,17 +1054,17 @@ func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool } func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) { - if warn, err := variable.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { + if warn, err := pwdValidate.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { return 0, true, err } else if len(warn) > 0 { return 0, false, nil } - if warn, err := variable.ValidatePasswordLowPolicy(str, globalVars); err != nil { + if warn, err := pwdValidate.ValidatePasswordLowPolicy(str, globalVars); err != nil { return 0, true, err } else if len(warn) > 0 { return 25, false, nil } - if warn, err := variable.ValidatePasswordMediumPolicy(str, globalVars); err != nil { + if warn, err := pwdValidate.ValidatePasswordMediumPolicy(str, globalVars); err != nil { return 0, true, err } else if len(warn) > 0 { return 50, false, nil diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 1da7a0b638469..f508837c706d2 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" + pwdValidate "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" ) @@ -636,7 +637,7 @@ func TestUncompressLength(t *testing.T) { func TestValidatePasswordStrength(t *testing.T) { ctx := createContext(t) ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} - tempDict, err := variable.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) + tempDict, err := pwdValidate.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) require.NoError(t, err) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index 7ba71f7329c7f..ff71913f8d70b 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -880,7 +880,8 @@ func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, resu return err } - result.ResizeInt64(n, true) + result.ResizeInt64(n, false) + result.MergeNulls(buf) i64s := result.Int64s() globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor enableValidation := false @@ -899,6 +900,8 @@ func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, resu return err } else if !isNull { i64s[i] = score + } else { + result.SetNull(i, true) } } return nil diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index c6caa1eb60d51..46395e51bcb6b 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -75,6 +75,9 @@ var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ ast.Decode: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}, geners: []dataGenerator{newRandLenStrGener(10, 20)}}, }, + ast.ValidatePasswordStrength: { + {retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ETString}}, + }, } func TestVectorizedBuiltinEncryptionFunc(t *testing.T) { diff --git a/expression/integration_test.go b/expression/integration_test.go index 09c716df232b0..b748907f91552 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -45,6 +45,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" + pwdValidate "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/versioninfo" @@ -968,6 +969,7 @@ func TestEncryptionBuiltin(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root"} ctx := context.Background() // for password @@ -1144,7 +1146,26 @@ func TestEncryptionBuiltin(t *testing.T) { result = tk.MustQuery("SELECT RANDOM_BYTES(NULL);") result.Check(testkit.Rows("")) - // TODO: for VALIDATE_PASSWORD_STRENGTH + // for VALIDATE_PASSWORD_STRENGTH + tempDict, err := pwdValidate.CreateTmpDictWithContent("4.txt", []byte("password\n")) + require.NoError(t, err) + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file='%s'", tempDict)) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT validate_password_strength('root')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('toor')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('ROOT')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('TOOR')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('fooHoHo%1')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength('pass')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('password')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password0000')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password1A#')").Check(testkit.Rows("75")) + tk.MustQuery("SELECT validate_password_strength('PA12wrd!#')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(REPEAT(\"aA1#\", 26))").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength(null)").Check(testkit.Rows("")) + tk.MustQuery("SELECT validate_password_strength('null')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH( 0x6E616E646F73617135234552 )").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(CAST(0xd2 AS BINARY(10)))").Check(testkit.Rows("50")) } func TestOpBuiltin(t *testing.T) { diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index 18500687d1dac..857bbc8e07437 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -6,7 +6,6 @@ go_library( "error.go", "mock_globalaccessor.go", "noop.go", - "password_dictionary.go", "removed.go", "sequence_state.go", "session.go", @@ -78,7 +77,6 @@ go_test( srcs = [ "main_test.go", "mock_globalaccessor_test.go", - "password_dictionary_test.go", "removed_test.go", "session_test.go", "statusvar_test.go", diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index b64500b91d208..e3832c9253cc3 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -15,10 +15,14 @@ package variable import ( + "bufio" "fmt" "io" + "os" + "path/filepath" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -28,6 +32,8 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/timeutil" "github.com/tikv/client-go/v2/oracle" @@ -531,6 +537,72 @@ func collectAllowFuncName4ExpressionIndex() string { return strings.Join(str, ", ") } +// PasswordDictionaryImpl is the dictionary for validating password. +type PasswordDictionaryImpl struct { + Cache map[string]struct{} + m sync.RWMutex +} + +const MaxPwdValidationLength int = 100 +const MinPwdValidationLength int = 4 + +// PasswordDictionary is the dictionary for validating password. +var PasswordDictionary = PasswordDictionaryImpl{Cache: make(map[string]struct{})} + +// CleanPasswordDictionary removes all the words in the dictionary. +func CleanPasswordDictionary() { + PasswordDictionary.m.Lock() + defer PasswordDictionary.m.Unlock() + PasswordDictionary.Cache = make(map[string]struct{}) +} + +// UpdatePasswordDictionary update the dictionary for validating password. +func UpdatePasswordDictionary(filePath string) error { + PasswordDictionary.m.Lock() + defer PasswordDictionary.m.Unlock() + newDictionary := make(map[string]struct{}) + file, err := os.Open(filepath.Clean(filePath)) + if err != nil { + return err + } + if fileInfo, err := file.Stat(); err != nil { + return err + } else if fileInfo.Size() > 1*1024*1024 { + return errors.New("Too Large Dictionary. The maximum permitted file size is 1MB") + } + s := bufio.NewScanner(file) + for s.Scan() { + line := strings.ToLower(string(hack.String(s.Bytes()))) + if len(line) >= MinPwdValidationLength && len(line) <= MaxPwdValidationLength { + newDictionary[line] = struct{}{} + } + } + if err := s.Err(); err != nil { + return err + } + PasswordDictionary.Cache = newDictionary + return file.Close() +} + +// ValidateDictionaryPassword checks if the password contains words in the dictionary. +func ValidateDictionaryPassword(pwd string) bool { + PasswordDictionary.m.RLock() + defer PasswordDictionary.m.RUnlock() + if len(PasswordDictionary.Cache) == 0 { + return true + } + pwdLength := len(pwd) + for subStrLen := mathutil.Min(MaxPwdValidationLength, pwdLength); subStrLen >= MinPwdValidationLength; subStrLen-- { + for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { + subStr := pwd[subStrPos : subStrPos+subStrLen] + if _, ok := PasswordDictionary.Cache[subStr]; ok { + return false + } + } + } + return true +} + // GAFunction4ExpressionIndex stores functions GA for expression index. var GAFunction4ExpressionIndex = map[string]struct{}{ ast.Lower: {}, diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel new file mode 100644 index 0000000000000..b37c1ffddbe0f --- /dev/null +++ b/util/password-validation/BUILD.bazel @@ -0,0 +1,23 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "password-validation", + srcs = ["password_validation.go"], + importpath = "github.com/pingcap/tidb/util/password-validation", + visibility = ["//visibility:public"], + deps = [ + "//sessionctx/variable", + "//util/hack", + "@com_github_pingcap_errors//:errors", + ], +) + +go_test( + name = "password-validation_test", + srcs = ["password_validation_test.go"], + embed = [":password-validation"], + deps = [ + "//sessionctx/variable", + "@com_github_stretchr_testify//require", + ], +) diff --git a/sessionctx/variable/password_dictionary.go b/util/password-validation/password_validation.go similarity index 59% rename from sessionctx/variable/password_dictionary.go rename to util/password-validation/password_validation.go index 06f475932d634..793b36c3d302a 100644 --- a/sessionctx/variable/password_dictionary.go +++ b/util/password-validation/password_validation.go @@ -12,91 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -package variable +package password_validation import ( - "bufio" "bytes" "fmt" "os" "path/filepath" "strconv" - "strings" - "sync" "unicode" "github.com/pingcap/errors" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/hack" - "github.com/pingcap/tidb/util/mathutil" ) -// PasswordDictionaryImpl is the dictionary for validating password. -type PasswordDictionaryImpl struct { - cache map[string]struct{} - m sync.RWMutex -} - -const maxPwdLength int = 100 -const minPwdLength int = 4 - -var passwordDictionary = PasswordDictionaryImpl{cache: make(map[string]struct{})} - -// CleanPasswordDictionary removes all the words in the dictionary. -func CleanPasswordDictionary() { - passwordDictionary.m.Lock() - defer passwordDictionary.m.Unlock() - passwordDictionary.cache = make(map[string]struct{}) -} - -// UpdatePasswordDictionary update the dictionary for validating password. -func UpdatePasswordDictionary(filePath string) error { - passwordDictionary.m.Lock() - defer passwordDictionary.m.Unlock() - newDictionary := make(map[string]struct{}) - file, err := os.Open(filepath.Clean(filePath)) - if err != nil { - return err - } - if fileInfo, err := file.Stat(); err != nil { - return err - } else if fileInfo.Size() > 1*1024*1024 { - return errors.New("Too Large Dictionary. The maximum permitted file size is 1MB") - } - s := bufio.NewScanner(file) - for s.Scan() { - line := strings.ToLower(string(hack.String(s.Bytes()))) - if len(line) >= minPwdLength && len(line) <= maxPwdLength { - newDictionary[line] = struct{}{} - } - } - if err := s.Err(); err != nil { - return err - } - passwordDictionary.cache = newDictionary - return file.Close() -} - -// ValidateDictionaryPassword checks if the password contains words in the dictionary. -func ValidateDictionaryPassword(pwd string) bool { - passwordDictionary.m.RLock() - defer passwordDictionary.m.RUnlock() - if len(passwordDictionary.cache) == 0 { - return true - } - pwdLength := len(pwd) - for subStrLen := mathutil.Min(maxPwdLength, pwdLength); subStrLen >= minPwdLength; subStrLen-- { - for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { - subStr := pwd[subStrPos : subStrPos+subStrLen] - if _, ok := passwordDictionary.cache[subStr]; ok { - return false - } - } - } - return true -} - -// CreateTmpDictWithSize is only used for test. -func CreateTmpDictWithSize(filename string, size int) (string, error) { +// createTmpDictWithSize is only used for test. +func createTmpDictWithSize(filename string, size int) (string, error) { filename = filepath.Join(os.TempDir(), filename) file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) if err != nil { @@ -132,13 +64,13 @@ func CreateTmpDictWithContent(filename string, content []byte) (string, error) { } // ValidateUserNameInPassword checks whether pwd exists in the dictionary. -func ValidateUserNameInPassword(pwd string, sessionVars *SessionVars) (string, error) { +func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (string, error) { currentUser := sessionVars.User globalVars := sessionVars.GlobalVarsAccessor pwdBytes := hack.Slice(pwd) - if checkUserName, err := globalVars.GetGlobalSysVar(ValidatePasswordCheckUserName); err != nil { + if checkUserName, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordCheckUserName); err != nil { return "", err - } else if currentUser != nil && TiDBOptOn(checkUserName) { + } else if currentUser != nil && variable.TiDBOptOn(checkUserName) { for _, username := range []string{currentUser.AuthUsername, currentUser.Username} { usernameBytes := hack.Slice(username) userNameLen := len(usernameBytes) @@ -161,8 +93,8 @@ func ValidateUserNameInPassword(pwd string, sessionVars *SessionVars) (string, e } // ValidatePasswordLowPolicy checks whether pwd satisfies the low policy of password validation. -func ValidatePasswordLowPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { - if validateLengthStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordLength); err != nil { +func ValidatePasswordLowPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { + if validateLengthStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { return "", err } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { return "", err @@ -173,7 +105,7 @@ func ValidatePasswordLowPolicy(pwd string, globalVars *GlobalVarAccessor) (strin } // ValidatePasswordMediumPolicy checks whether pwd satisfies the medium policy of password validation. -func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (string, error) { +func ValidatePasswordMediumPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 runes := []rune(pwd) for i := 0; i < len(runes); i++ { @@ -187,7 +119,7 @@ func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (st specialCharCount++ } } - if mixedCaseCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordMixedCaseCount); err != nil { + if mixedCaseCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { return "", err } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { return "", err @@ -196,14 +128,14 @@ func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (st } else if upperCaseCount < mixedCaseCount { return fmt.Sprintf("Require Password Uppercase Count: %d", mixedCaseCount), nil } - if requireNumberCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + if requireNumberCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { return "", err } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { return "", err } else if numberCount < requireNumberCount { return fmt.Sprintf("Require Password Digit Count: %d", requireNumberCount), nil } - if requireSpecialCharCountStr, err := (*globalVars).GetGlobalSysVar(ValidatePasswordSpecialCharCount); err != nil { + if requireSpecialCharCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { return "", err } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { return "", err @@ -214,22 +146,22 @@ func ValidatePasswordMediumPolicy(pwd string, globalVars *GlobalVarAccessor) (st } // ValidatePassword checks whether the pwd can be used. -func ValidatePassword(sessionVars *SessionVars, pwd string) error { +func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { globalVars := sessionVars.GlobalVarsAccessor - validatePolicy, err := globalVars.GetGlobalSysVar(ValidatePasswordPolicy) + validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) if err != nil { return err } if warn, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { return err } else if len(warn) > 0 { - return ErrNotValidPassword.GenWithStack(warn) + return variable.ErrNotValidPassword.GenWithStack(warn) } if warn, err := ValidatePasswordLowPolicy(pwd, &globalVars); err != nil { return err } else if len(warn) > 0 { - return ErrNotValidPassword.GenWithStack(warn) + return variable.ErrNotValidPassword.GenWithStack(warn) } // LOW if validatePolicy == "LOW" { @@ -240,15 +172,15 @@ func ValidatePassword(sessionVars *SessionVars, pwd string) error { if warn, err := ValidatePasswordMediumPolicy(pwd, &globalVars); err != nil { return err } else if len(warn) > 0 { - return ErrNotValidPassword.GenWithStack(warn) + return variable.ErrNotValidPassword.GenWithStack(warn) } if validatePolicy == "MEDIUM" { return nil } // STRONG - if !ValidateDictionaryPassword(pwd) { - return ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") + if !variable.ValidateDictionaryPassword(pwd) { + return variable.ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") } return nil } diff --git a/sessionctx/variable/password_dictionary_test.go b/util/password-validation/password_validation_test.go similarity index 60% rename from sessionctx/variable/password_dictionary_test.go rename to util/password-validation/password_validation_test.go index fa9c85ab5f8e8..55921577d4e21 100644 --- a/sessionctx/variable/password_dictionary_test.go +++ b/util/password-validation/password_validation_test.go @@ -12,37 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package variable +package password_validation import ( "testing" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/stretchr/testify/require" ) func TestUpdateDictionaryFile(t *testing.T) { - tooLargeDict, err := CreateTmpDictWithSize("1.dict", 2*1024*1024) + tooLargeDict, err := createTmpDictWithSize("1.dict", 2*1024*1024) require.NoError(t, err) - err = UpdatePasswordDictionary(tooLargeDict) + err = variable.UpdatePasswordDictionary(tooLargeDict) require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") dict, err := CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) require.NoError(t, err) - require.NoError(t, UpdatePasswordDictionary(dict)) - _, ok := passwordDictionary.cache["1234"] + require.NoError(t, variable.UpdatePasswordDictionary(dict)) + _, ok := variable.PasswordDictionary.Cache["1234"] require.True(t, ok) - _, ok = passwordDictionary.cache["5678"] + _, ok = variable.PasswordDictionary.Cache["5678"] require.True(t, ok) - _, ok = passwordDictionary.cache["abc"] + _, ok = variable.PasswordDictionary.Cache["abc"] require.False(t, ok) } func TestValidateDictionaryPassword(t *testing.T) { dict, err := CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) - require.NoError(t, UpdatePasswordDictionary(dict)) - require.True(t, ValidateDictionaryPassword("abcdefg")) - require.True(t, ValidateDictionaryPassword("abcd123efg")) - require.False(t, ValidateDictionaryPassword("abcd1234efg")) - require.False(t, ValidateDictionaryPassword("abcd12345efg")) + require.NoError(t, variable.UpdatePasswordDictionary(dict)) + require.True(t, variable.ValidateDictionaryPassword("abcdefg")) + require.True(t, variable.ValidateDictionaryPassword("abcd123efg")) + require.False(t, variable.ValidateDictionaryPassword("abcd1234efg")) + require.False(t, variable.ValidateDictionaryPassword("abcd12345efg")) } From 319442356452b638df1573302a11213529187885 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Tue, 15 Nov 2022 17:07:58 +0800 Subject: [PATCH 14/26] Fix --- executor/simple.go | 8 ++++---- executor/simple_test.go | 4 ++-- expression/builtin_encryption.go | 8 ++++---- expression/builtin_encryption_test.go | 4 ++-- expression/integration_test.go | 4 ++-- sessionctx/variable/varsutil.go | 3 +++ util/password-validation/password_validation.go | 2 +- util/password-validation/password_validation_test.go | 2 +- 8 files changed, 19 insertions(+), 16 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index 102dbc1947e6d..e083fd918bff3 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -51,7 +51,7 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" - pwdValidate "github.com/pingcap/tidb/util/password-validation" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" @@ -900,7 +900,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm authPlugin = spec.AuthOpt.AuthPlugin } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { - if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1114,7 +1114,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { - if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } } @@ -1635,7 +1635,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return err } if e.enableValidatePassword() { - if err := pwdValidate.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { return err } } diff --git a/executor/simple_test.go b/executor/simple_test.go index 35fe979dbb28e..db7cec831d0c7 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -26,7 +26,7 @@ import ( "github.com/pingcap/tidb/server" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" - pwdValidate "github.com/pingcap/tidb/util/password-validation" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" tikvutil "github.com/tikv/client-go/v2/util" ) @@ -134,7 +134,7 @@ func TestValidatePassword(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} - dictFile, err := pwdValidate.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + dictFile, err := pwdValidator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("SET GLOBAL validate_password.enable = 1") diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index d5bde2e900836..4ca541ee9a6f6 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -37,7 +37,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" - pwdValidate "github.com/pingcap/tidb/util/password-validation" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tipb/go-tipb" ) @@ -1054,17 +1054,17 @@ func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool } func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) { - if warn, err := pwdValidate.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { + if warn, err := pwdValidator.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { return 0, true, err } else if len(warn) > 0 { return 0, false, nil } - if warn, err := pwdValidate.ValidatePasswordLowPolicy(str, globalVars); err != nil { + if warn, err := pwdValidator.ValidatePasswordLowPolicy(str, globalVars); err != nil { return 0, true, err } else if len(warn) > 0 { return 25, false, nil } - if warn, err := pwdValidate.ValidatePasswordMediumPolicy(str, globalVars); err != nil { + if warn, err := pwdValidator.ValidatePasswordMediumPolicy(str, globalVars); err != nil { return 0, true, err } else if len(warn) > 0 { return 50, false, nil diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index f508837c706d2..ad64bb9a86bd0 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -31,7 +31,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" - pwdValidate "github.com/pingcap/tidb/util/password-validation" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" ) @@ -637,7 +637,7 @@ func TestUncompressLength(t *testing.T) { func TestValidatePasswordStrength(t *testing.T) { ctx := createContext(t) ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} - tempDict, err := pwdValidate.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) + tempDict, err := pwdValidator.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) require.NoError(t, err) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor diff --git a/expression/integration_test.go b/expression/integration_test.go index b748907f91552..15301a8bb0f7e 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -45,7 +45,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" - pwdValidate "github.com/pingcap/tidb/util/password-validation" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/versioninfo" @@ -1147,7 +1147,7 @@ func TestEncryptionBuiltin(t *testing.T) { result.Check(testkit.Rows("")) // for VALIDATE_PASSWORD_STRENGTH - tempDict, err := pwdValidate.CreateTmpDictWithContent("4.txt", []byte("password\n")) + tempDict, err := pwdValidator.CreateTmpDictWithContent("4.txt", []byte("password\n")) require.NoError(t, err) tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file='%s'", tempDict)) tk.MustExec("SET GLOBAL validate_password.enable = 1") diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index e3832c9253cc3..83226cc1a662c 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -543,7 +543,10 @@ type PasswordDictionaryImpl struct { m sync.RWMutex } +// MaxPwdValidationLength is the max length of word in dictionary. const MaxPwdValidationLength int = 100 + +// MinPwdValidationLength is the min length of word in dictionary. const MinPwdValidationLength int = 4 // PasswordDictionary is the dictionary for validating password. diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go index 793b36c3d302a..f3750417ba7dc 100644 --- a/util/password-validation/password_validation.go +++ b/util/password-validation/password_validation.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package password_validation +package validator import ( "bytes" diff --git a/util/password-validation/password_validation_test.go b/util/password-validation/password_validation_test.go index 55921577d4e21..5161331c8dafd 100644 --- a/util/password-validation/password_validation_test.go +++ b/util/password-validation/password_validation_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package password_validation +package validator import ( "testing" From abed795b7a14852b1b3e1b7c35404b68d3b39096 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 16 Nov 2022 12:04:07 +0800 Subject: [PATCH 15/26] Update --- ddl/ddl_api.go | 2 +- ddl/schematracker/dm_tracker.go | 2 +- executor/BUILD.bazel | 1 - executor/adapter.go | 2 +- executor/partition_table_test.go | 2 +- executor/simple_test.go | 3 +- expression/BUILD.bazel | 1 - expression/builtin_encryption_test.go | 4 +- expression/integration_test.go | 4 +- sessionctx/binloginfo/binloginfo_test.go | 2 +- sessionctx/variable/BUILD.bazel | 2 + sessionctx/variable/varsutil.go | 36 +++---- sessionctx/variable/varsutil_test.go | 46 ++++++++ util/password-validation/BUILD.bazel | 2 +- .../password_validation.go | 39 ------- .../password_validation_test.go | 100 ++++++++++++++---- util/util.go | 20 ++++ 17 files changed, 176 insertions(+), 92 deletions(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index b440a3459bf1d..6111dd3fb4fdb 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -5071,7 +5071,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt } col := table.ToColumn(oldCol.Clone()) - // CleanPasswordDictionary the NoDefaultValueFlag value. + // Clean the NoDefaultValueFlag value. col.DelFlag(mysql.NoDefaultValueFlag) if len(specNewColumn.Options) == 0 { col.DefaultIsExpr = false diff --git a/ddl/schematracker/dm_tracker.go b/ddl/schematracker/dm_tracker.go index 53990bde6b0d6..75f8fa35b429d 100644 --- a/ddl/schematracker/dm_tracker.go +++ b/ddl/schematracker/dm_tracker.go @@ -648,7 +648,7 @@ func (d SchemaTracker) alterColumn(ctx sessionctx.Context, ident ast.Ident, spec return dbterror.ErrBadField.GenWithStackByArgs(colName, ident.Name) } - // CleanPasswordDictionary the NoDefaultValueFlag value. + // Clean the NoDefaultValueFlag value. oldCol.DelFlag(mysql.NoDefaultValueFlag) if len(specNewColumn.Options) == 0 { oldCol.DefaultIsExpr = false diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 1f65c981a4801..cf91360b17a60 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -407,7 +407,6 @@ go_test( "//util/memory", "//util/mock", "//util/paging", - "//util/password-validation", "//util/pdapi", "//util/plancodec", "//util/ranger", diff --git a/executor/adapter.go b/executor/adapter.go index 2a6e78443c414..db9fbbaa929e0 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -1381,7 +1381,7 @@ func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults boo } // Reset DurationParse due to the next statement may not need to be parsed (not a text protocol query). sessVars.DurationParse = 0 - // CleanPasswordDictionary the stale read flag when statement execution finish + // Clean the stale read flag when statement execution finish sessVars.StmtCtx.IsStaleness = false if sessVars.StmtCtx.ReadFromTableCache { diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go index 5dd6e09fde83a..50bb68a7b5235 100644 --- a/executor/partition_table_test.go +++ b/executor/partition_table_test.go @@ -3268,7 +3268,7 @@ func TestIssue26251(t *testing.T) { t.Fail() } - // CleanPasswordDictionary up + // Clean up <-ch tk2.MustExec("rollback") } diff --git a/executor/simple_test.go b/executor/simple_test.go index db7cec831d0c7..5cba352971214 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/server" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util" - pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" tikvutil "github.com/tikv/client-go/v2/util" ) @@ -134,7 +133,7 @@ func TestValidatePassword(t *testing.T) { require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} - dictFile, err := pwdValidator.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + dictFile, err := util.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) require.NoError(t, err) tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("SET GLOBAL validate_password.enable = 1") diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 471874910cb27..fc1752ef19e63 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -221,7 +221,6 @@ go_test( "//util/hack", "//util/mathutil", "//util/mock", - "//util/password-validation", "//util/printer", "//util/sem", "//util/sqlexec", diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index ad64bb9a86bd0..f657ab9a8ae0c 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -29,9 +29,9 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" - pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/stretchr/testify/require" ) @@ -637,7 +637,7 @@ func TestUncompressLength(t *testing.T) { func TestValidatePasswordStrength(t *testing.T) { ctx := createContext(t) ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} - tempDict, err := pwdValidator.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) + tempDict, err := util.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) require.NoError(t, err) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor diff --git a/expression/integration_test.go b/expression/integration_test.go index 15301a8bb0f7e..a2befd9a1e247 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -43,9 +43,9 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" - pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/versioninfo" @@ -1147,7 +1147,7 @@ func TestEncryptionBuiltin(t *testing.T) { result.Check(testkit.Rows("")) // for VALIDATE_PASSWORD_STRENGTH - tempDict, err := pwdValidator.CreateTmpDictWithContent("4.txt", []byte("password\n")) + tempDict, err := util.CreateTmpDictWithContent("4.txt", []byte("password\n")) require.NoError(t, err) tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file='%s'", tempDict)) tk.MustExec("SET GLOBAL validate_password.enable = 1") diff --git a/sessionctx/binloginfo/binloginfo_test.go b/sessionctx/binloginfo/binloginfo_test.go index 8b2b0746f33aa..3c777a9436234 100644 --- a/sessionctx/binloginfo/binloginfo_test.go +++ b/sessionctx/binloginfo/binloginfo_test.go @@ -479,7 +479,7 @@ func TestZIgnoreError(t *testing.T) { tk.MustExec("insert into t values (1)") tk.MustExec("insert into t values (1)") - // CleanPasswordDictionary up. + // Clean up. s.pump.mu.Lock() s.pump.mu.mockFail = false s.pump.mu.Unlock() diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index 857bbc8e07437..28a4e0bbd7eba 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -99,10 +99,12 @@ go_test( "//testkit", "//testkit/testsetup", "//types", + "//util", "//util/chunk", "//util/execdetails", "//util/memory", "//util/mock", + "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//util", diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 83226cc1a662c..d58368f1912fc 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -537,32 +537,28 @@ func collectAllowFuncName4ExpressionIndex() string { return strings.Join(str, ", ") } -// PasswordDictionaryImpl is the dictionary for validating password. -type PasswordDictionaryImpl struct { +type passwordDictionaryImpl struct { Cache map[string]struct{} m sync.RWMutex } -// MaxPwdValidationLength is the max length of word in dictionary. -const MaxPwdValidationLength int = 100 +const maxPwdValidationLength int = 100 -// MinPwdValidationLength is the min length of word in dictionary. -const MinPwdValidationLength int = 4 +const minPwdValidationLength int = 4 -// PasswordDictionary is the dictionary for validating password. -var PasswordDictionary = PasswordDictionaryImpl{Cache: make(map[string]struct{})} +var passwordDictionary = passwordDictionaryImpl{Cache: make(map[string]struct{})} // CleanPasswordDictionary removes all the words in the dictionary. func CleanPasswordDictionary() { - PasswordDictionary.m.Lock() - defer PasswordDictionary.m.Unlock() - PasswordDictionary.Cache = make(map[string]struct{}) + passwordDictionary.m.Lock() + defer passwordDictionary.m.Unlock() + passwordDictionary.Cache = make(map[string]struct{}) } // UpdatePasswordDictionary update the dictionary for validating password. func UpdatePasswordDictionary(filePath string) error { - PasswordDictionary.m.Lock() - defer PasswordDictionary.m.Unlock() + passwordDictionary.m.Lock() + defer passwordDictionary.m.Unlock() newDictionary := make(map[string]struct{}) file, err := os.Open(filepath.Clean(filePath)) if err != nil { @@ -576,29 +572,29 @@ func UpdatePasswordDictionary(filePath string) error { s := bufio.NewScanner(file) for s.Scan() { line := strings.ToLower(string(hack.String(s.Bytes()))) - if len(line) >= MinPwdValidationLength && len(line) <= MaxPwdValidationLength { + if len(line) >= minPwdValidationLength && len(line) <= maxPwdValidationLength { newDictionary[line] = struct{}{} } } if err := s.Err(); err != nil { return err } - PasswordDictionary.Cache = newDictionary + passwordDictionary.Cache = newDictionary return file.Close() } // ValidateDictionaryPassword checks if the password contains words in the dictionary. func ValidateDictionaryPassword(pwd string) bool { - PasswordDictionary.m.RLock() - defer PasswordDictionary.m.RUnlock() - if len(PasswordDictionary.Cache) == 0 { + passwordDictionary.m.RLock() + defer passwordDictionary.m.RUnlock() + if len(passwordDictionary.Cache) == 0 { return true } pwdLength := len(pwd) - for subStrLen := mathutil.Min(MaxPwdValidationLength, pwdLength); subStrLen >= MinPwdValidationLength; subStrLen-- { + for subStrLen := mathutil.Min(maxPwdValidationLength, pwdLength); subStrLen >= minPwdValidationLength; subStrLen-- { for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { subStr := pwd[subStrPos : subStrPos+subStrLen] - if _, ok := PasswordDictionary.Cache[subStr]; ok { + if _, ok := passwordDictionary.Cache[subStr]; ok { return false } } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 69c9caf294e5e..3c15be9a6b214 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -16,15 +16,19 @@ package variable import ( "context" + "os" + "path/filepath" "reflect" "strconv" "testing" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/util" "github.com/stretchr/testify/require" ) @@ -737,3 +741,45 @@ func TestAssertionLevel(t *testing.T) { require.Equal(t, AssertionLevelFast, tidbOptAssertionLevel(AssertionFastStr)) require.Equal(t, AssertionLevelOff, tidbOptAssertionLevel("bogus")) } + +func TestUpdateDictionaryFile(t *testing.T) { + tooLargeDict, err := func(filename string, size int) (string, error) { + filename = filepath.Join(os.TempDir(), filename) + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return "", err + } + if size > 0 { + n, err := file.Write(make([]byte, size)) + if err != nil { + return "", err + } else if n != size { + return "", errors.New("") + } + } + return filename, file.Close() + }("1.dict", 2*1024*1024) + require.NoError(t, err) + err = UpdatePasswordDictionary(tooLargeDict) + require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") + + dict, err := util.CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) + require.NoError(t, err) + require.NoError(t, UpdatePasswordDictionary(dict)) + _, ok := passwordDictionary.Cache["1234"] + require.True(t, ok) + _, ok = passwordDictionary.Cache["5678"] + require.True(t, ok) + _, ok = passwordDictionary.Cache["abc"] + require.False(t, ok) +} + +func TestValidateDictionaryPassword(t *testing.T) { + dict, err := util.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) + require.NoError(t, err) + require.NoError(t, UpdatePasswordDictionary(dict)) + require.True(t, ValidateDictionaryPassword("abcdefg")) + require.True(t, ValidateDictionaryPassword("abcd123efg")) + require.False(t, ValidateDictionaryPassword("abcd1234efg")) + require.False(t, ValidateDictionaryPassword("abcd12345efg")) +} diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel index b37c1ffddbe0f..c3649a3a15383 100644 --- a/util/password-validation/BUILD.bazel +++ b/util/password-validation/BUILD.bazel @@ -8,7 +8,6 @@ go_library( deps = [ "//sessionctx/variable", "//util/hack", - "@com_github_pingcap_errors//:errors", ], ) @@ -17,6 +16,7 @@ go_test( srcs = ["password_validation_test.go"], embed = [":password-validation"], deps = [ + "//parser/auth", "//sessionctx/variable", "@com_github_stretchr_testify//require", ], diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go index f3750417ba7dc..44fa1a23a2a15 100644 --- a/util/password-validation/password_validation.go +++ b/util/password-validation/password_validation.go @@ -17,52 +17,13 @@ package validator import ( "bytes" "fmt" - "os" - "path/filepath" "strconv" "unicode" - "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/hack" ) -// createTmpDictWithSize is only used for test. -func createTmpDictWithSize(filename string, size int) (string, error) { - filename = filepath.Join(os.TempDir(), filename) - file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) - if err != nil { - return "", err - } - if size > 0 { - n, err := file.Write(make([]byte, size)) - if err != nil { - return "", err - } else if n != size { - return "", errors.New("") - } - } - return filename, file.Close() -} - -// CreateTmpDictWithContent is only used for test. -func CreateTmpDictWithContent(filename string, content []byte) (string, error) { - filename = filepath.Join(os.TempDir(), filename) - file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) - if err != nil { - return "", err - } - if len(content) > 0 { - n, err := file.Write(content) - if err != nil { - return "", err - } else if n != len(content) { - return "", errors.New("") - } - } - return filename, file.Close() -} - // ValidateUserNameInPassword checks whether pwd exists in the dictionary. func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (string, error) { currentUser := sessionVars.User diff --git a/util/password-validation/password_validation_test.go b/util/password-validation/password_validation_test.go index 5161331c8dafd..b143fe159a520 100644 --- a/util/password-validation/password_validation_test.go +++ b/util/password-validation/password_validation_test.go @@ -15,35 +15,97 @@ package validator import ( + "context" "testing" + "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/sessionctx/variable" "github.com/stretchr/testify/require" ) -func TestUpdateDictionaryFile(t *testing.T) { - tooLargeDict, err := createTmpDictWithSize("1.dict", 2*1024*1024) +func TestValidateUserNameInPassword(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.User = &auth.UserIdentity{Username: "user", AuthUsername: "authuser"} + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + testcases := []struct { + pwd string + warn string + }{ + {"", ""}, + {"user", "Password Contains User Name"}, + {"authuser", "Password Contains User Name"}, + {"resu000", "Password Contains Reversed User Name"}, + {"resuhtua", "Password Contains Reversed User Name"}, + {"User", ""}, + {"authUser", ""}, + {"Resu", ""}, + {"Resuhtua", ""}, + } + // Enable check_user_name + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "ON") require.NoError(t, err) - err = variable.UpdatePasswordDictionary(tooLargeDict) - require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, testcase.warn, warn, testcase.pwd) + } - dict, err := CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) + // Disable check_user_name + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "OFF") require.NoError(t, err) - require.NoError(t, variable.UpdatePasswordDictionary(dict)) - _, ok := variable.PasswordDictionary.Cache["1234"] - require.True(t, ok) - _, ok = variable.PasswordDictionary.Cache["5678"] - require.True(t, ok) - _, ok = variable.PasswordDictionary.Cache["abc"] - require.False(t, ok) + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, "", warn, testcase.pwd) + } } -func TestValidateDictionaryPassword(t *testing.T) { - dict, err := CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) +func TestValidatePasswordLowPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "8") require.NoError(t, err) - require.NoError(t, variable.UpdatePasswordDictionary(dict)) - require.True(t, variable.ValidateDictionaryPassword("abcdefg")) - require.True(t, variable.ValidateDictionaryPassword("abcd123efg")) - require.False(t, variable.ValidateDictionaryPassword("abcd1234efg")) - require.False(t, variable.ValidateDictionaryPassword("abcd12345efg")) + + warn, err := ValidatePasswordLowPolicy("1234", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 8", warn) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) + + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "12") + require.NoError(t, err) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 12", warn) +} + +func TestValidatePasswordMediumPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordMixedCaseCount, "1") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordSpecialCharCount, "2") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordNumberCount, "3") + require.NoError(t, err) + + warn, err := ValidatePasswordMediumPolicy("!@A123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Lowercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@a123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Uppercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa12", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Digit Count: 3", warn) + warn, err = ValidatePasswordMediumPolicy("!Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Non-alphanumeric Count: 2", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) } diff --git a/util/util.go b/util/util.go index 8af2876240486..06b5c6f552ac9 100644 --- a/util/util.go +++ b/util/util.go @@ -19,6 +19,8 @@ import ( "fmt" "io/ioutil" "net/http" + "os" + "path/filepath" "strconv" "strings" "time" @@ -28,6 +30,24 @@ import ( "go.uber.org/zap" ) +// CreateTmpDictWithContent is only used for test. +func CreateTmpDictWithContent(filename string, content []byte) (string, error) { + filename = filepath.Join(os.TempDir(), filename) + file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if err != nil { + return "", err + } + if len(content) > 0 { + n, err := file.Write(content) + if err != nil { + return "", err + } else if n != len(content) { + return "", errors.New("") + } + } + return filename, file.Close() +} + // SliceToMap converts slice to map // nolint:unused func SliceToMap(slice []string) map[string]interface{} { From 7a563ede890c6143ee95e8db2dc32b8e2683617c Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 16 Nov 2022 23:25:00 +0800 Subject: [PATCH 16/26] Remove comment --- executor/simple_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/executor/simple_test.go b/executor/simple_test.go index 5cba352971214..32603290c5f63 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -124,10 +124,6 @@ func TestUserAttributes(t *testing.T) { } func TestValidatePassword(t *testing.T) { - // Some test cases come from mysql-server/mysql-test: - // t/validate_password_component.test - // t/validate_password_component_check_user.test - store, _ := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) From b35cd0e60ec31d32e944c3855f633812f30d1506 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 17 Nov 2022 16:13:27 +0800 Subject: [PATCH 17/26] create user must specify password --- executor/simple.go | 11 ++++++++--- executor/simple_test.go | 15 ++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index e083fd918bff3..d915356ac01ad 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -899,9 +899,14 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { authPlugin = spec.AuthOpt.AuthPlugin } - if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { - if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { - return err + if e.enableValidatePassword() { + if spec.AuthOpt == nil || !spec.AuthOpt.ByAuthString && spec.AuthOpt.HashString == "" { + return variable.ErrNotValidPassword.GenWithStackByArgs() + } + if e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + return err + } } } pwd, ok := spec.EncodedPassword() diff --git a/executor/simple_test.go b/executor/simple_test.go index 32603290c5f63..e0dfe846295b1 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -126,7 +126,13 @@ func TestUserAttributes(t *testing.T) { func TestValidatePassword(t *testing.T) { store, _ := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) - require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil)) + subtk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil) + require.NoError(t, err) + tk.MustExec("CREATE USER ''@'localhost'") + tk.MustExec("GRANT ALL PRIVILEGES ON mysql.* TO ''@'localhost';") + err = subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil) + require.NoError(t, err) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} dictFile, err := util.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) @@ -184,12 +190,11 @@ func TestValidatePassword(t *testing.T) { // "IDENTIFIED AS 'xxx'" is not affected by validation tk.MustExec(fmt.Sprintf("ALTER USER testuser IDENTIFIED WITH '%s' AS ''", authPlugin)) } + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password' AS ''", "Your password does not satisfy the current policy requirements") // if the username is '', all password can pass the check_user_name - tk.MustExec("CREATE USER ''@'localhost'") - tk.MustExec("GRANT ALL PRIVILEGES ON mysql.* TO ''@'localhost';") - subtk := testkit.NewTestKit(t, store) - require.NoError(t, subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil)) subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) From cd0c43799939f755a6b88f1cd89e66d47804c366 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 17 Nov 2022 17:33:57 +0800 Subject: [PATCH 18/26] dictionary_file -> dictionary --- executor/simple_test.go | 6 +- expression/builtin_encryption.go | 6 +- expression/builtin_encryption_test.go | 5 +- expression/integration_test.go | 5 +- sessionctx/variable/BUILD.bazel | 3 - sessionctx/variable/sysvar.go | 13 +--- sessionctx/variable/varsutil.go | 71 ------------------- sessionctx/variable/varsutil_test.go | 46 ------------ util/password-validation/BUILD.bazel | 1 + .../password_validation.go | 39 +++++++++- .../password_validation_test.go | 26 +++++++ util/util.go | 20 ------ 12 files changed, 75 insertions(+), 166 deletions(-) diff --git a/executor/simple_test.go b/executor/simple_test.go index e0dfe846295b1..31fb9f55b7801 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -135,8 +135,6 @@ func TestValidatePassword(t *testing.T) { require.NoError(t, err) authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} - dictFile, err := util.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) - require.NoError(t, err) tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) tk.MustExec("SET GLOBAL validate_password.enable = 1") tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) @@ -180,11 +178,11 @@ func TestValidatePassword(t *testing.T) { // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") - tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file = '%s'", dictFile)) + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary = '%s'", "1234;5678")) tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") - tk.MustExec("SET GLOBAL validate_password.dictionary_file = ''") + tk.MustExec("SET GLOBAL validate_password.dictionary = ''") tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") // "IDENTIFIED AS 'xxx'" is not affected by validation diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index 4ca541ee9a6f6..fb451f9714cd4 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -730,7 +730,6 @@ func (c *sm3FunctionClass) getFunction(ctx sessionctx.Context, args []Expression bf.tp.SetCollate(collate) bf.tp.SetFlen(40) sig := &builtinSM3Sig{bf} - //sig.setPbCode(tipb.ScalarFuncSig_SM3) // TODO return sig, nil } @@ -1021,7 +1020,6 @@ func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Conte } bf.tp.SetFlen(21) sig := &builtinValidatePasswordStrengthSig{bf} - //sig.setPbCode(tipb.ScalarFuncSig_ValidatePasswordStrength) return sig, nil } @@ -1069,7 +1067,9 @@ func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars } else if len(warn) > 0 { return 50, false, nil } - if ok := variable.ValidateDictionaryPassword(str); !ok { + if ok, err := pwdValidator.ValidateDictionaryPassword(str, globalVars); err != nil { + return 0, true, err + } else if !ok { return 75, false, nil } return 100, false, nil diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index f657ab9a8ae0c..087fb3f35e466 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -29,7 +29,6 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/stretchr/testify/require" @@ -637,11 +636,9 @@ func TestUncompressLength(t *testing.T) { func TestValidatePasswordStrength(t *testing.T) { ctx := createContext(t) ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} - tempDict, err := util.CreateTmpDictWithContent("tempDictionary.txt", []byte("1234\n")) - require.NoError(t, err) globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor - err = globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionaryFile, tempDict) + err := globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234") require.NoError(t, err) tests := []struct { diff --git a/expression/integration_test.go b/expression/integration_test.go index a2befd9a1e247..871e70fbfc306 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -43,7 +43,6 @@ import ( "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/sem" @@ -1147,9 +1146,7 @@ func TestEncryptionBuiltin(t *testing.T) { result.Check(testkit.Rows("")) // for VALIDATE_PASSWORD_STRENGTH - tempDict, err := util.CreateTmpDictWithContent("4.txt", []byte("password\n")) - require.NoError(t, err) - tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary_file='%s'", tempDict)) + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary='%s'", "password")) tk.MustExec("SET GLOBAL validate_password.enable = 1") tk.MustQuery("SELECT validate_password_strength('root')").Check(testkit.Rows("0")) tk.MustQuery("SELECT validate_password_strength('toor')").Check(testkit.Rows("0")) diff --git a/sessionctx/variable/BUILD.bazel b/sessionctx/variable/BUILD.bazel index 28a4e0bbd7eba..7c6bcd5330e25 100644 --- a/sessionctx/variable/BUILD.bazel +++ b/sessionctx/variable/BUILD.bazel @@ -43,7 +43,6 @@ go_library( "//util/disk", "//util/execdetails", "//util/gctuner", - "//util/hack", "//util/kvcache", "//util/logutil", "//util/mathutil", @@ -99,12 +98,10 @@ go_test( "//testkit", "//testkit/testsetup", "//types", - "//util", "//util/chunk", "//util/execdetails", "//util/memory", "//util/mock", - "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//util", diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 58d6c67b04d99..6cbe4b6a14f06 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -518,14 +518,7 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, - {Scope: ScopeGlobal, Name: ValidatePasswordDictionaryFile, Value: "", Type: TypeStr, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { - // Use 'SET @@global.validate_password.dictionary_file = ""' to clean the dictionary. - if len(val) == 0 { - CleanPasswordDictionary() - return nil - } - return UpdatePasswordDictionary(val) - }}, + {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, /* TiDB specific variables */ {Scope: ScopeGlobal, Name: TiDBTSOClientBatchMaxWaitTime, Value: strconv.FormatFloat(DefTiDBTSOClientBatchMaxWaitTime, 'f', -1, 64), Type: TypeFloat, MinValue: 0, MaxValue: 10, @@ -2356,6 +2349,6 @@ const ( ValidatePasswordNumberCount = "validate_password.number_count" // ValidatePasswordSpecialCharCount specified the minimum number of nonalphanumeric characters that validate_password requires ValidatePasswordSpecialCharCount = "validate_password.special_char_count" - // ValidatePasswordDictionaryFile specified the path name of the dictionary file that validate_password uses for checking passwords - ValidatePasswordDictionaryFile = "validate_password.dictionary_file" + // ValidatePasswordDictionary specified the dictionary that validate_password uses for checking passwords. Each word is seperated by semicolon (;). + ValidatePasswordDictionary = "validate_password.dictionary" ) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index d58368f1912fc..b64500b91d208 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -15,14 +15,10 @@ package variable import ( - "bufio" "fmt" "io" - "os" - "path/filepath" "strconv" "strings" - "sync" "sync/atomic" "time" @@ -32,8 +28,6 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" - "github.com/pingcap/tidb/util/hack" - "github.com/pingcap/tidb/util/mathutil" "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/timeutil" "github.com/tikv/client-go/v2/oracle" @@ -537,71 +531,6 @@ func collectAllowFuncName4ExpressionIndex() string { return strings.Join(str, ", ") } -type passwordDictionaryImpl struct { - Cache map[string]struct{} - m sync.RWMutex -} - -const maxPwdValidationLength int = 100 - -const minPwdValidationLength int = 4 - -var passwordDictionary = passwordDictionaryImpl{Cache: make(map[string]struct{})} - -// CleanPasswordDictionary removes all the words in the dictionary. -func CleanPasswordDictionary() { - passwordDictionary.m.Lock() - defer passwordDictionary.m.Unlock() - passwordDictionary.Cache = make(map[string]struct{}) -} - -// UpdatePasswordDictionary update the dictionary for validating password. -func UpdatePasswordDictionary(filePath string) error { - passwordDictionary.m.Lock() - defer passwordDictionary.m.Unlock() - newDictionary := make(map[string]struct{}) - file, err := os.Open(filepath.Clean(filePath)) - if err != nil { - return err - } - if fileInfo, err := file.Stat(); err != nil { - return err - } else if fileInfo.Size() > 1*1024*1024 { - return errors.New("Too Large Dictionary. The maximum permitted file size is 1MB") - } - s := bufio.NewScanner(file) - for s.Scan() { - line := strings.ToLower(string(hack.String(s.Bytes()))) - if len(line) >= minPwdValidationLength && len(line) <= maxPwdValidationLength { - newDictionary[line] = struct{}{} - } - } - if err := s.Err(); err != nil { - return err - } - passwordDictionary.Cache = newDictionary - return file.Close() -} - -// ValidateDictionaryPassword checks if the password contains words in the dictionary. -func ValidateDictionaryPassword(pwd string) bool { - passwordDictionary.m.RLock() - defer passwordDictionary.m.RUnlock() - if len(passwordDictionary.Cache) == 0 { - return true - } - pwdLength := len(pwd) - for subStrLen := mathutil.Min(maxPwdValidationLength, pwdLength); subStrLen >= minPwdValidationLength; subStrLen-- { - for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { - subStr := pwd[subStrPos : subStrPos+subStrLen] - if _, ok := passwordDictionary.Cache[subStr]; ok { - return false - } - } - } - return true -} - // GAFunction4ExpressionIndex stores functions GA for expression index. var GAFunction4ExpressionIndex = map[string]struct{}{ ast.Lower: {}, diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 3c15be9a6b214..69c9caf294e5e 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -16,19 +16,15 @@ package variable import ( "context" - "os" - "path/filepath" "reflect" "strconv" "testing" "time" - "github.com/pingcap/errors" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" - "github.com/pingcap/tidb/util" "github.com/stretchr/testify/require" ) @@ -741,45 +737,3 @@ func TestAssertionLevel(t *testing.T) { require.Equal(t, AssertionLevelFast, tidbOptAssertionLevel(AssertionFastStr)) require.Equal(t, AssertionLevelOff, tidbOptAssertionLevel("bogus")) } - -func TestUpdateDictionaryFile(t *testing.T) { - tooLargeDict, err := func(filename string, size int) (string, error) { - filename = filepath.Join(os.TempDir(), filename) - file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) - if err != nil { - return "", err - } - if size > 0 { - n, err := file.Write(make([]byte, size)) - if err != nil { - return "", err - } else if n != size { - return "", errors.New("") - } - } - return filename, file.Close() - }("1.dict", 2*1024*1024) - require.NoError(t, err) - err = UpdatePasswordDictionary(tooLargeDict) - require.ErrorContains(t, err, "Too Large Dictionary. The maximum permitted file size is 1MB") - - dict, err := util.CreateTmpDictWithContent("2.dict", []byte("abc\n1234\n5678")) - require.NoError(t, err) - require.NoError(t, UpdatePasswordDictionary(dict)) - _, ok := passwordDictionary.Cache["1234"] - require.True(t, ok) - _, ok = passwordDictionary.Cache["5678"] - require.True(t, ok) - _, ok = passwordDictionary.Cache["abc"] - require.False(t, ok) -} - -func TestValidateDictionaryPassword(t *testing.T) { - dict, err := util.CreateTmpDictWithContent("3.dict", []byte("1234\n5678")) - require.NoError(t, err) - require.NoError(t, UpdatePasswordDictionary(dict)) - require.True(t, ValidateDictionaryPassword("abcdefg")) - require.True(t, ValidateDictionaryPassword("abcd123efg")) - require.False(t, ValidateDictionaryPassword("abcd1234efg")) - require.False(t, ValidateDictionaryPassword("abcd12345efg")) -} diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel index c3649a3a15383..4a517c79c7e30 100644 --- a/util/password-validation/BUILD.bazel +++ b/util/password-validation/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//sessionctx/variable", "//util/hack", + "//util/mathutil", ], ) diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go index 44fa1a23a2a15..0b3954ed659eb 100644 --- a/util/password-validation/password_validation.go +++ b/util/password-validation/password_validation.go @@ -18,12 +18,47 @@ import ( "bytes" "fmt" "strconv" + "strings" "unicode" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/mathutil" ) +const maxPwdValidationLength int = 100 + +const minPwdValidationLength int = 4 + +// ValidateDictionaryPassword checks if the password contains words in the dictionary. +func ValidateDictionaryPassword(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { + dictionary, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordDictionary) + pwdLength := len(pwd) + if err != nil { + return false, err + } + words := strings.Split(dictionary, ";") + if len(words) == 0 { + return true, nil + } + cache := make(map[string]interface{}, len(words)) + for _, word := range words { + word = strings.ToLower(word) + if len(word) >= minPwdValidationLength && len(word) <= maxPwdValidationLength { + cache[word] = nil + } + } + for subStrLen := mathutil.Min(maxPwdValidationLength, pwdLength); subStrLen >= minPwdValidationLength; subStrLen-- { + for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { + subStr := strings.ToLower(pwd[subStrPos : subStrPos+subStrLen]) + if _, ok := cache[subStr]; ok { + return false, nil + } + } + } + return true, nil +} + // ValidateUserNameInPassword checks whether pwd exists in the dictionary. func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (string, error) { currentUser := sessionVars.User @@ -140,7 +175,9 @@ func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { } // STRONG - if !variable.ValidateDictionaryPassword(pwd) { + if ok, err := ValidateDictionaryPassword(pwd, &globalVars); err != nil { + return err + } else if !ok { return variable.ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") } return nil diff --git a/util/password-validation/password_validation_test.go b/util/password-validation/password_validation_test.go index b143fe159a520..323cba33ba409 100644 --- a/util/password-validation/password_validation_test.go +++ b/util/password-validation/password_validation_test.go @@ -23,6 +23,32 @@ import ( "github.com/stretchr/testify/require" ) +func TestValidateDictionaryPassword(t *testing.T) { + vars := variable.NewSessionVars(nil) + mock := variable.NewMockGlobalAccessor4Tests() + mock.SessionVars = vars + vars.GlobalVarsAccessor = mock + + err := mock.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234;5678;HIJK") + require.NoError(t, err) + testcases := []struct { + pwd string + result bool + }{ + {"abcdefg", true}, + {"abcd123efg", true}, + {"abcd1234efg", false}, + {"abcd12345efg", false}, + {"abcd123efghij", true}, + {"abcd123efghijk", false}, + } + for _, testcase := range testcases { + ok, err := ValidateDictionaryPassword(testcase.pwd, &vars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, testcase.result, ok, testcase.pwd) + } +} + func TestValidateUserNameInPassword(t *testing.T) { sessionVars := variable.NewSessionVars(nil) sessionVars.User = &auth.UserIdentity{Username: "user", AuthUsername: "authuser"} diff --git a/util/util.go b/util/util.go index 06b5c6f552ac9..8af2876240486 100644 --- a/util/util.go +++ b/util/util.go @@ -19,8 +19,6 @@ import ( "fmt" "io/ioutil" "net/http" - "os" - "path/filepath" "strconv" "strings" "time" @@ -30,24 +28,6 @@ import ( "go.uber.org/zap" ) -// CreateTmpDictWithContent is only used for test. -func CreateTmpDictWithContent(filename string, content []byte) (string, error) { - filename = filepath.Join(os.TempDir(), filename) - file, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY, os.ModePerm) - if err != nil { - return "", err - } - if len(content) > 0 { - n, err := file.Write(content) - if err != nil { - return "", err - } else if n != len(content) { - return "", errors.New("") - } - } - return filename, file.Close() -} - // SliceToMap converts slice to map // nolint:unused func SliceToMap(slice []string) map[string]interface{} { From 4321cbc534cae06b9812d361a5a4a205a67f04d6 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 17 Nov 2022 17:45:50 +0800 Subject: [PATCH 19/26] Fix --- sessionctx/variable/sysvar.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 6cbe4b6a14f06..9cdef86f68202 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -2349,6 +2349,6 @@ const ( ValidatePasswordNumberCount = "validate_password.number_count" // ValidatePasswordSpecialCharCount specified the minimum number of nonalphanumeric characters that validate_password requires ValidatePasswordSpecialCharCount = "validate_password.special_char_count" - // ValidatePasswordDictionary specified the dictionary that validate_password uses for checking passwords. Each word is seperated by semicolon (;). + // ValidatePasswordDictionary specified the dictionary that validate_password uses for checking passwords. Each word is separated by semicolon (;). ValidatePasswordDictionary = "validate_password.dictionary" ) From ca8fc6ef62974d27487efbe465acb51deecc253f Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 23 Nov 2022 14:48:12 +0800 Subject: [PATCH 20/26] Update --- executor/set_test.go | 8 ++++ executor/simple_test.go | 4 +- sessionctx/variable/sysvar.go | 68 +++++++++++++++++++++++++-------- sessionctx/variable/varsutil.go | 25 ++++++++++++ 4 files changed, 88 insertions(+), 17 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index 476fc3a4aebb6..a4a54a37a3595 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -853,6 +853,14 @@ func TestSetVar(t *testing.T) { tk.MustQuery("select @@global.tidb_opt_range_max_size").Check(testkit.Rows("1048576")) tk.MustExec("set session tidb_opt_range_max_size = 2097152") tk.MustQuery("select @@session.tidb_opt_range_max_size").Check(testkit.Rows("2097152")) + + // test for password validation + tk.MustQuery("SELECT @@GLOBAL.validate_password.enable").Check(testkit.Rows("0")) + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("8")) + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) + tk.MustExec("SET GLOBAL validate_password.mixed_case_count = 2") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("6")) } func TestGetSetNoopVars(t *testing.T) { diff --git a/executor/simple_test.go b/executor/simple_test.go index 4e9477e92eb9d..c2c0506e2ff15 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -163,6 +163,7 @@ func TestValidatePassword(t *testing.T) { tk.MustExec("SET GLOBAL validate_password.check_user_name = 1") // LOW: Length + tk.MustExec("SET GLOBAL validate_password.length = 8") tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") tk.MustExec("SET GLOBAL validate_password.length = 12") @@ -180,7 +181,8 @@ func TestValidatePassword(t *testing.T) { tk.MustExec("SET GLOBAL validate_password.special_char_count = 0") tk.MustExec("ALTER USER testuser IDENTIFIED BY 'Abc1234567'") tk.MustExec("SET GLOBAL validate_password.special_char_count = 1") - tk.MustContainErrMsg("SET GLOBAL validate_password.length = 3", "Variable 'validate_password.length' can't be set to the value of '3'") + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 250113dffd080..110c389b339a9 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -489,35 +489,71 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordEnable, Value: Off, Type: TypeBool}, {Scope: ScopeGlobal, Name: ValidatePasswordPolicy, Value: "MEDIUM", Type: TypeEnum, PossibleValues: []string{"LOW", "MEDIUM", "STRONG"}}, {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: On, Type: TypeBool}, - {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64, + {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - var numberCount, specialCharCount, mixedCaseCount int64 - if numberCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + _, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + if err != nil { return "", err - } else if numberCount, err = strconv.ParseInt(numberCountStr, 10, 64); err != nil { + } + if length, err := strconv.ParseInt(normalizedValue, 10, 64); err != nil { + return "", err + } else if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { + return strconv.FormatInt(minLength, 10), nil + } + return normalizedValue, nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + if err != nil { return "", err } - if specialCharCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + if mixedCaseCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { return "", err - } else if specialCharCount, err = strconv.ParseInt(specialCharCountStr, 10, 64); err != nil { + } + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { + err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) + if err != nil { + return "", err + } + } + return normalizedValue, nil + }}, + {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + if err != nil { return "", err } - if mixedCaseCountStr, err := vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + if numberCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { return "", err - } else if mixedCaseCount, err = strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + } + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { + err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) + if err != nil { + return "", err + } + } + return normalizedValue, nil + }}, + {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + if err != nil { return "", err } - if length, err := strconv.ParseInt(normalizedValue, 10, 64); err != nil { + if specialCharCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { return "", err - } else if length < numberCount+specialCharCount+2*mixedCaseCount { - return "", ErrWrongValueForVar.GenWithStackByArgs(ValidatePasswordLength, normalizedValue) + } + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { + err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) + if err != nil { + return "", err + } } return normalizedValue, nil - }, - }, - {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, - {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, - {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt64}, + }}, {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, /* TiDB specific variables */ diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index b64500b91d208..2527e0c539b58 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -531,6 +531,31 @@ func collectAllowFuncName4ExpressionIndex() string { return strings.Join(str, ", ") } +func getPasswordValidationLength(vars *SessionVars) (length, numberCount, specialCharCount, mixedCaseCount int64, err error) { + var lengthStr, numberCountStr, specialCharCountStr, mixedCaseCountStr string + if lengthStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordLength); err != nil { + return + } else if length, err = strconv.ParseInt(lengthStr, 10, 64); err != nil { + return + } + if numberCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { + return + } else if numberCount, err = strconv.ParseInt(numberCountStr, 10, 64); err != nil { + return + } + if specialCharCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordSpecialCharCount); err != nil { + return + } else if specialCharCount, err = strconv.ParseInt(specialCharCountStr, 10, 64); err != nil { + return + } + if mixedCaseCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordMixedCaseCount); err != nil { + return + } else if mixedCaseCount, err = strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return + } + return +} + // GAFunction4ExpressionIndex stores functions GA for expression index. var GAFunction4ExpressionIndex = map[string]struct{}{ ast.Lower: {}, From e1024de5d5e0c94dfa4966a99bb5f19c3f620992 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 23 Nov 2022 16:01:00 +0800 Subject: [PATCH 21/26] Fix --- sessionctx/variable/sysvar.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 110c389b339a9..0ca8bdea9f803 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -505,11 +505,12 @@ var defaultSysVars = []*SysVar{ }, {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + length, numberCount, specialCharCount, _, err := getPasswordValidationLength(vars) if err != nil { return "", err } - if mixedCaseCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { + mixedCaseCount, err := strconv.ParseInt(normalizedValue, 10, 64) + if err != nil { return "", err } if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { @@ -522,11 +523,12 @@ var defaultSysVars = []*SysVar{ }}, {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + length, _, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) if err != nil { return "", err } - if numberCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { + numberCount, err := strconv.ParseInt(normalizedValue, 10, 64) + if err != nil { return "", err } if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { @@ -539,11 +541,12 @@ var defaultSysVars = []*SysVar{ }}, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + length, numberCount, _, mixedCaseCount, err := getPasswordValidationLength(vars) if err != nil { return "", err } - if specialCharCount, err = strconv.ParseInt(normalizedValue, 10, 64); err != nil { + specialCharCount, err := strconv.ParseInt(normalizedValue, 10, 64) + if err != nil { return "", err } if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { From 4667f3a04f0aa79307d73ad2a2b5819e6d3b3dbd Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Wed, 23 Nov 2022 16:19:38 +0800 Subject: [PATCH 22/26] Fix --- executor/simple.go | 2 +- executor/simple_test.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/executor/simple.go b/executor/simple.go index a3b2194a24c99..1ae9df724e3b5 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -899,7 +899,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { authPlugin = spec.AuthOpt.AuthPlugin } - if e.enableValidatePassword() { + if e.enableValidatePassword() && !s.IsCreateRole { if spec.AuthOpt == nil || !spec.AuthOpt.ByAuthString && spec.AuthOpt.HashString == "" { return variable.ErrNotValidPassword.GenWithStackByArgs() } diff --git a/executor/simple_test.go b/executor/simple_test.go index c2c0506e2ff15..44c7691d805f8 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -208,5 +208,7 @@ func TestValidatePassword(t *testing.T) { subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY ''") subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY 'abcd'") + // CREATE ROLE is not affected by password validation tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustExec("CREATE ROLE role1") } From df65a25f08fbeeeb39fdc3651842aa43e6e663ce Mon Sep 17 00:00:00 2001 From: CbcWestwolf <1004626265@qq.com> Date: Thu, 24 Nov 2022 10:44:46 +0800 Subject: [PATCH 23/26] Update executor/simple.go Co-authored-by: xiongjiwei --- executor/simple.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index 1ae9df724e3b5..8e05a43501e28 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -788,12 +788,9 @@ func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin s if authOpt == nil || !authOpt.ByAuthString { return false } - if authPlugin == mysql.AuthNativePassword || + return authPlugin == mysql.AuthNativePassword || authPlugin == mysql.AuthTiDBSM3Password || - authPlugin == mysql.AuthCachingSha2Password { - return true - } - return false + authPlugin == mysql.AuthCachingSha2Password } func (e *SimpleExec) enableValidatePassword() bool { From 106f7c2288fb74bea329f8773a1d4c03326cd226 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 24 Nov 2022 10:52:18 +0800 Subject: [PATCH 24/26] Update --- executor/simple.go | 8 ++++---- util/password-validation/password_validation.go | 13 ++----------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/executor/simple.go b/executor/simple.go index 1ae9df724e3b5..b27d02cc164c0 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -796,7 +796,7 @@ func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin s return false } -func (e *SimpleExec) enableValidatePassword() bool { +func (e *SimpleExec) isValidatePasswordEnabled() bool { validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable) if err != nil { return false @@ -899,7 +899,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { authPlugin = spec.AuthOpt.AuthPlugin } - if e.enableValidatePassword() && !s.IsCreateRole { + if e.isValidatePasswordEnabled() && !s.IsCreateRole { if spec.AuthOpt == nil || !spec.AuthOpt.ByAuthString && spec.AuthOpt.HashString == "" { return variable.ErrNotValidPassword.GenWithStackByArgs() } @@ -1118,7 +1118,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } - if e.enableValidatePassword() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { + if e.isValidatePasswordEnabled() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { return err } @@ -1639,7 +1639,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } - if e.enableValidatePassword() { + if e.isValidatePasswordEnabled() { if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { return err } diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go index 0b3954ed659eb..edd0bd39ec38a 100644 --- a/util/password-validation/password_validation.go +++ b/util/password-validation/password_validation.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/hack" - "github.com/pingcap/tidb/util/mathutil" ) const maxPwdValidationLength int = 100 @@ -33,7 +32,6 @@ const minPwdValidationLength int = 4 // ValidateDictionaryPassword checks if the password contains words in the dictionary. func ValidateDictionaryPassword(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { dictionary, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordDictionary) - pwdLength := len(pwd) if err != nil { return false, err } @@ -41,17 +39,10 @@ func ValidateDictionaryPassword(pwd string, globalVars *variable.GlobalVarAccess if len(words) == 0 { return true, nil } - cache := make(map[string]interface{}, len(words)) + pwd = strings.ToLower(pwd) for _, word := range words { - word = strings.ToLower(word) if len(word) >= minPwdValidationLength && len(word) <= maxPwdValidationLength { - cache[word] = nil - } - } - for subStrLen := mathutil.Min(maxPwdValidationLength, pwdLength); subStrLen >= minPwdValidationLength; subStrLen-- { - for subStrPos := 0; subStrPos+subStrLen <= pwdLength; subStrPos++ { - subStr := strings.ToLower(pwd[subStrPos : subStrPos+subStrLen]) - if _, ok := cache[subStr]; ok { + if strings.Contains(pwd, strings.ToLower(word)) { return false, nil } } From e97d7d3c0aded92ce2e5acabc9f4bf74900b43e3 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 24 Nov 2022 11:58:56 +0800 Subject: [PATCH 25/26] Update --- sessionctx/variable/sysvar.go | 84 +++++++++++++++------------- sessionctx/variable/tidb_vars.go | 5 ++ sessionctx/variable/varsutil.go | 25 --------- util/password-validation/BUILD.bazel | 1 - 4 files changed, 51 insertions(+), 64 deletions(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 0ca8bdea9f803..6eff7b416d181 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -491,72 +491,80 @@ var defaultSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: On, Type: TypeBool}, {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - _, numberCount, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + numberCount, specialCharCount, mixedCaseCount := PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + length, err := strconv.ParseInt(normalizedValue, 10, 32) if err != nil { return "", err } - if length, err := strconv.ParseInt(normalizedValue, 10, 64); err != nil { - return "", err - } else if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { - return strconv.FormatInt(minLength, 10), nil + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; int32(length) < minLength { + return strconv.FormatInt(int64(minLength), 10), nil } return normalizedValue, nil }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationLength.Store(int32(TidbOptInt64(val, 8))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return fmt.Sprintf("%d", PasswordValidationLength.Load()), nil + }, }, {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, numberCount, specialCharCount, _, err := getPasswordValidationLength(vars) - if err != nil { - return "", err - } - mixedCaseCount, err := strconv.ParseInt(normalizedValue, 10, 64) + length, numberCount, specialCharCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load() + mixedCaseCount, err := strconv.ParseInt(normalizedValue, 10, 32) if err != nil { return "", err } - if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { - err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) - if err != nil { - return "", err - } + if minLength := numberCount + specialCharCount + 2*int32(mixedCaseCount); length < minLength { + PasswordValidationLength.Store(minLength) } return normalizedValue, nil - }}, + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationMixedCaseCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return fmt.Sprintf("%d", PasswordValidationMixedCaseCount.Load()), nil + }, + }, {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, _, specialCharCount, mixedCaseCount, err := getPasswordValidationLength(vars) + length, specialCharCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + numberCount, err := strconv.ParseInt(normalizedValue, 10, 32) if err != nil { return "", err } - numberCount, err := strconv.ParseInt(normalizedValue, 10, 64) - if err != nil { - return "", err - } - if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { - err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) - if err != nil { - return "", err - } + if minLength := int32(numberCount) + specialCharCount + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) } return normalizedValue, nil - }}, + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidtaionNumberCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return fmt.Sprintf("%d", PasswordValidtaionNumberCount.Load()), nil + }, + }, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { - length, numberCount, _, mixedCaseCount, err := getPasswordValidationLength(vars) - if err != nil { - return "", err - } - specialCharCount, err := strconv.ParseInt(normalizedValue, 10, 64) + length, numberCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationMixedCaseCount.Load() + specialCharCount, err := strconv.ParseInt(normalizedValue, 10, 32) if err != nil { return "", err } - if minLength := numberCount + specialCharCount + 2*mixedCaseCount; length < minLength { - err = vars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), ValidatePasswordLength, strconv.FormatInt(minLength, 10)) - if err != nil { - return "", err - } + if minLength := numberCount + int32(specialCharCount) + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) } return normalizedValue, nil - }}, + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationSpecialCharCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return fmt.Sprintf("%d", PasswordValidationSpecialCharCount.Load()), nil + }, + }, {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, /* TiDB specific variables */ diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 3511775de08f1..decb791927d61 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1150,6 +1150,11 @@ var ( // It should be a const and shouldn't be modified after tidb is started. DefTiDBServerMemoryLimit = serverMemoryLimitDefaultValue() GOGCTunerThreshold = atomic.NewFloat64(DefTiDBGOGCTunerThreshold) + + PasswordValidationLength = atomic.NewInt32(8) + PasswordValidationMixedCaseCount = atomic.NewInt32(1) + PasswordValidtaionNumberCount = atomic.NewInt32(1) + PasswordValidationSpecialCharCount = atomic.NewInt32(1) ) var ( diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 2527e0c539b58..b64500b91d208 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -531,31 +531,6 @@ func collectAllowFuncName4ExpressionIndex() string { return strings.Join(str, ", ") } -func getPasswordValidationLength(vars *SessionVars) (length, numberCount, specialCharCount, mixedCaseCount int64, err error) { - var lengthStr, numberCountStr, specialCharCountStr, mixedCaseCountStr string - if lengthStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordLength); err != nil { - return - } else if length, err = strconv.ParseInt(lengthStr, 10, 64); err != nil { - return - } - if numberCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordNumberCount); err != nil { - return - } else if numberCount, err = strconv.ParseInt(numberCountStr, 10, 64); err != nil { - return - } - if specialCharCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordSpecialCharCount); err != nil { - return - } else if specialCharCount, err = strconv.ParseInt(specialCharCountStr, 10, 64); err != nil { - return - } - if mixedCaseCountStr, err = vars.GlobalVarsAccessor.GetGlobalSysVar(ValidatePasswordMixedCaseCount); err != nil { - return - } else if mixedCaseCount, err = strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { - return - } - return -} - // GAFunction4ExpressionIndex stores functions GA for expression index. var GAFunction4ExpressionIndex = map[string]struct{}{ ast.Lower: {}, diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel index 4a517c79c7e30..c3649a3a15383 100644 --- a/util/password-validation/BUILD.bazel +++ b/util/password-validation/BUILD.bazel @@ -8,7 +8,6 @@ go_library( deps = [ "//sessionctx/variable", "//util/hack", - "//util/mathutil", ], ) From 9b7e323305e576e456f9c307246ed8b82bac6431 Mon Sep 17 00:00:00 2001 From: cbcwestwolf <1004626265@qq.com> Date: Thu, 24 Nov 2022 17:43:51 +0800 Subject: [PATCH 26/26] Update --- sessionctx/variable/noop.go | 1 + sessionctx/variable/sysvar.go | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sessionctx/variable/noop.go b/sessionctx/variable/noop.go index 5019ab90af115..5505fe65a3623 100644 --- a/sessionctx/variable/noop.go +++ b/sessionctx/variable/noop.go @@ -115,6 +115,7 @@ var noopSysVars = []*SysVar{ {Scope: ScopeNone, Name: "innodb_log_group_home_dir", Value: "./"}, {Scope: ScopeNone, Name: "performance_schema_events_statements_history_size", Value: "10"}, {Scope: ScopeGlobal, Name: GeneralLog, Value: Off, Type: TypeBool}, + {Scope: ScopeGlobal, Name: "validate_password_dictionary_file", Value: ""}, {Scope: ScopeGlobal, Name: BinlogOrderCommits, Value: On, Type: TypeBool}, {Scope: ScopeGlobal, Name: "key_cache_division_limit", Value: "100"}, {Scope: ScopeGlobal | ScopeSession, Name: "max_insert_delayed_threads", Value: "20"}, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 6eff7b416d181..060c542bddd77 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -505,7 +505,7 @@ var defaultSysVars = []*SysVar{ PasswordValidationLength.Store(int32(TidbOptInt64(val, 8))) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", PasswordValidationLength.Load()), nil + return strconv.FormatInt(int64(PasswordValidationLength.Load()), 10), nil }, }, {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, @@ -524,7 +524,7 @@ var defaultSysVars = []*SysVar{ PasswordValidationMixedCaseCount.Store(int32(TidbOptInt64(val, 1))) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", PasswordValidationMixedCaseCount.Load()), nil + return strconv.FormatInt(int64(PasswordValidationMixedCaseCount.Load()), 10), nil }, }, {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, @@ -543,7 +543,7 @@ var defaultSysVars = []*SysVar{ PasswordValidtaionNumberCount.Store(int32(TidbOptInt64(val, 1))) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", PasswordValidtaionNumberCount.Load()), nil + return strconv.FormatInt(int64(PasswordValidtaionNumberCount.Load()), 10), nil }, }, {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, @@ -562,7 +562,7 @@ var defaultSysVars = []*SysVar{ PasswordValidationSpecialCharCount.Store(int32(TidbOptInt64(val, 1))) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", PasswordValidationSpecialCharCount.Load()), nil + return strconv.FormatInt(int64(PasswordValidationSpecialCharCount.Load()), 10), nil }, }, {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, @@ -1127,7 +1127,7 @@ var defaultSysVars = []*SysVar{ MemoryUsageAlarmKeepRecordNum.Store(TidbOptInt64(val, DefMemoryUsageAlarmKeepRecordNum)) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", MemoryUsageAlarmKeepRecordNum.Load()), nil + return strconv.FormatInt(MemoryUsageAlarmKeepRecordNum.Load(), 10), nil }}, /* The system variables below have GLOBAL and SESSION scope */