diff --git a/errors.toml b/errors.toml index 31a56ef6b1d17..a50e484a5a29b 100644 --- a/errors.toml +++ b/errors.toml @@ -1451,6 +1451,11 @@ error = ''' SET PASSWORD has no significance for user '%-.48s'@'%-.255s' as authentication plugin does not support it. ''' +["executor:1819"] +error = ''' +Your password does not satisfy the current policy requirements +''' + ["executor:1827"] error = ''' The password hash doesn't have the expected format. Check if the correct password algorithm is being used with the PASSWORD() function. diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel index 6a300dbeaf654..cf91360b17a60 100644 --- a/executor/BUILD.bazel +++ b/executor/BUILD.bazel @@ -177,6 +177,7 @@ go_library( "//util/mathutil", "//util/memory", "//util/mvmap", + "//util/password-validation", "//util/pdapi", "//util/plancodec", "//util/printer", diff --git a/executor/set_test.go b/executor/set_test.go index 697209d64836a..a4a54a37a3595 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -853,6 +853,14 @@ func TestSetVar(t *testing.T) { tk.MustQuery("select @@global.tidb_opt_range_max_size").Check(testkit.Rows("1048576")) tk.MustExec("set session tidb_opt_range_max_size = 2097152") tk.MustQuery("select @@session.tidb_opt_range_max_size").Check(testkit.Rows("2097152")) + + // test for password validation + tk.MustQuery("SELECT @@GLOBAL.validate_password.enable").Check(testkit.Rows("0")) + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("8")) + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) + tk.MustExec("SET GLOBAL validate_password.mixed_case_count = 2") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("6")) } func TestGetSetNoopVars(t *testing.T) { @@ -1407,14 +1415,11 @@ func TestValidateSetVar(t *testing.T) { tk.MustExec("set @@innodb_lock_wait_timeout = 1073741825") tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect innodb_lock_wait_timeout value: '1073741825'")) - tk.MustExec("set @@global.validate_password_number_count=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_number_count value: '-1'")) - - tk.MustExec("set @@global.validate_password_length=-1") - tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password_length value: '-1'")) + tk.MustExec("set @@global.validate_password.number_count=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.number_count value: '-1'")) - tk.MustExec("set @@global.validate_password_length=8") - tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustExec("set @@global.validate_password.length=-1") + tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect validate_password.length value: '-1'")) err = tk.ExecToErr("set @@tx_isolation=''") require.True(t, terror.ErrorEqual(err, variable.ErrWrongValueForVar), fmt.Sprintf("err %v", err)) diff --git a/executor/simple.go b/executor/simple.go index 3670670977a20..5953a3fa687f2 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -51,6 +51,7 @@ import ( "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/timeutil" @@ -783,6 +784,23 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } +func (e *SimpleExec) authUsingCleartextPwd(authOpt *ast.AuthOption, authPlugin string) bool { + if authOpt == nil || !authOpt.ByAuthString { + return false + } + return authPlugin == mysql.AuthNativePassword || + authPlugin == mysql.AuthTiDBSM3Password || + authPlugin == mysql.AuthCachingSha2Password +} + +func (e *SimpleExec) isValidatePasswordEnabled() bool { + validatePwdEnable, err := e.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { + return false + } + return variable.TiDBOptOn(validatePwdEnable) +} + func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) // Check `CREATE USER` privilege. @@ -874,15 +892,25 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm e.ctx.GetSessionVars().StmtCtx.AppendNote(err) continue } + authPlugin := mysql.AuthNativePassword + if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { + authPlugin = spec.AuthOpt.AuthPlugin + } + if e.isValidatePasswordEnabled() && !s.IsCreateRole { + if spec.AuthOpt == nil || !spec.AuthOpt.ByAuthString && spec.AuthOpt.HashString == "" { + return variable.ErrNotValidPassword.GenWithStackByArgs() + } + if e.authUsingCleartextPwd(spec.AuthOpt, authPlugin) { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + return err + } + } + } pwd, ok := spec.EncodedPassword() if !ok { return errors.Trace(ErrPasswordFormat) } - authPlugin := mysql.AuthNativePassword - if spec.AuthOpt != nil && spec.AuthOpt.AuthPlugin != "" { - authPlugin = spec.AuthOpt.AuthPlugin - } switch authPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, mysql.AuthTiDBAuthToken: @@ -1071,11 +1099,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) var fields []alterField if spec.AuthOpt != nil { if spec.AuthOpt.AuthPlugin == "" { - authplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) + curAuthplugin, err := e.userAuthPlugin(spec.User.Username, spec.User.Hostname) if err != nil { return err } - spec.AuthOpt.AuthPlugin = authplugin + spec.AuthOpt.AuthPlugin = curAuthplugin } switch spec.AuthOpt.AuthPlugin { case mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password, mysql.AuthSocket, "": @@ -1087,6 +1115,11 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) default: return ErrPluginIsNotLoaded.GenWithStackByArgs(spec.AuthOpt.AuthPlugin) } + if e.isValidatePasswordEnabled() && e.authUsingCleartextPwd(spec.AuthOpt, spec.AuthOpt.AuthPlugin) { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), spec.AuthOpt.AuthString); err != nil { + return err + } + } pwd, ok := spec.EncodedPassword() if !ok { return errors.Trace(ErrPasswordFormat) @@ -1603,6 +1636,11 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error if err != nil { return err } + if e.isValidatePasswordEnabled() { + if err := pwdValidator.ValidatePassword(e.ctx.GetSessionVars(), s.Password); err != nil { + return err + } + } var pwd string switch authplugin { case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: diff --git a/executor/simple_test.go b/executor/simple_test.go index 13a439ad64d46..44c7691d805f8 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -129,3 +129,86 @@ func TestUserAttributes(t *testing.T) { rootTK.MustExec("alter user usr1 comment 'comment1'") rootTK.MustQuery("select user_attributes from mysql.user where user = 'usr1'").Check(testkit.Rows(`{"metadata": {"comment": "comment1"}}`)) } + +func TestValidatePassword(t *testing.T) { + store, _ := testkit.CreateMockStoreAndDomain(t) + tk := testkit.NewTestKit(t, store) + subtk := testkit.NewTestKit(t, store) + err := tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil) + require.NoError(t, err) + tk.MustExec("CREATE USER ''@'localhost'") + tk.MustExec("GRANT ALL PRIVILEGES ON mysql.* TO ''@'localhost';") + err = subtk.Session().Auth(&auth.UserIdentity{Hostname: "localhost"}, nil, nil) + require.NoError(t, err) + + authPlugins := []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("0")) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + + for _, authPlugin := range authPlugins { + tk.MustExec("DROP USER IF EXISTS testuser") + tk.MustExec(fmt.Sprintf("CREATE USER testuser IDENTIFIED WITH %s BY '!Abc12345678'", authPlugin)) + + tk.MustExec("SET GLOBAL validate_password.policy = 'LOW'") + // check user name + tk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'", "Password Contains User Name") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'", "Password Contains Reversed User Name") + tk.MustExec("SET PASSWORD FOR 'testuser' = 'testuser'") // password the same as the user name, but run by root + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'testuser'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdroot1234'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdtoor1234'") + tk.MustExec("SET GLOBAL validate_password.check_user_name = 1") + + // LOW: Length + tk.MustExec("SET GLOBAL validate_password.length = 8") + tk.MustQuery("SELECT @@global.validate_password.length").Check(testkit.Rows("8")) + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '1234567'", "Require Password Length: 8") + tk.MustExec("SET GLOBAL validate_password.length = 12") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abcdefg123'", "Require Password Length: 12") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abcdefg1234'") + tk.MustExec("SET GLOBAL validate_password.length = 8") + + // MEDIUM: Length; numeric, lowercase/uppercase, and special characters + tk.MustExec("SET GLOBAL validate_password.policy = 'MEDIUM'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABC1234567'", "Require Password Lowercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!abc1234567'", "Require Password Uppercase Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!ABCDabcd'", "Require Password Digit Count: 1") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY 'Abc1234567'", "Require Password Non-alphanumeric Count: 1") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 0") + tk.MustExec("ALTER USER testuser IDENTIFIED BY 'Abc1234567'") + tk.MustExec("SET GLOBAL validate_password.special_char_count = 1") + tk.MustExec("SET GLOBAL validate_password.length = 3") + tk.MustQuery("SELECT @@GLOBAL.validate_password.length").Check(testkit.Rows("4")) + + // STRONG: Length; numeric, lowercase/uppercase, and special characters; dictionary file + tk.MustExec("SET GLOBAL validate_password.policy = 'STRONG'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary = '%s'", "1234;5678")) + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc123567'") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc43218765'") + tk.MustContainErrMsg("ALTER USER testuser IDENTIFIED BY '!Abc1234567'", "Password contains word in the dictionary") + tk.MustExec("SET GLOBAL validate_password.dictionary = ''") + tk.MustExec("ALTER USER testuser IDENTIFIED BY '!Abc1234567'") + + // "IDENTIFIED AS 'xxx'" is not affected by validation + tk.MustExec(fmt.Sprintf("ALTER USER testuser IDENTIFIED WITH '%s' AS ''", authPlugin)) + } + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password'", "Your password does not satisfy the current policy requirements") + tk.MustContainErrMsg("CREATE USER 'testuser1'@'localhost' IDENTIFIED WITH 'caching_sha2_password' AS ''", "Your password does not satisfy the current policy requirements") + + // if the username is '', all password can pass the check_user_name + subtk.MustQuery("SELECT user(), current_user()").Check(testkit.Rows("@localhost @localhost")) + subtk.MustQuery("SELECT @@global.validate_password.check_user_name").Check(testkit.Rows("1")) + subtk.MustQuery("SELECT @@global.validate_password.enable").Check(testkit.Rows("1")) + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY ''") + subtk.MustExec("ALTER USER ''@'localhost' IDENTIFIED BY 'abcd'") + + // CREATE ROLE is not affected by password validation + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustExec("CREATE ROLE role1") +} diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 032c44054dba2..fc1752ef19e63 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -97,6 +97,7 @@ go_library( "//util/mathutil", "//util/mock", "//util/parser", + "//util/password-validation", "//util/plancodec", "//util/printer", "//util/sem", diff --git a/expression/builtin_encryption.go b/expression/builtin_encryption.go index a206a9d4970bb..fb451f9714cd4 100644 --- a/expression/builtin_encryption.go +++ b/expression/builtin_encryption.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" + pwdValidator "github.com/pingcap/tidb/util/password-validation" "github.com/pingcap/tipb/go-tipb" ) @@ -73,6 +74,7 @@ var ( _ builtinFunc = &builtinSHA2Sig{} _ builtinFunc = &builtinUncompressSig{} _ builtinFunc = &builtinUncompressedLengthSig{} + _ builtinFunc = &builtinValidatePasswordStrengthSig{} ) // aesModeAttr indicates that the key length and iv attribute for specific block_encryption_mode. @@ -728,7 +730,6 @@ func (c *sm3FunctionClass) getFunction(ctx sessionctx.Context, args []Expression bf.tp.SetCollate(collate) bf.tp.SetFlen(40) sig := &builtinSM3Sig{bf} - //sig.setPbCode(tipb.ScalarFuncSig_SM3) // TODO return sig, nil } @@ -1010,5 +1011,66 @@ type validatePasswordStrengthFunctionClass struct { } func (c *validatePasswordStrengthFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "VALIDATE_PASSWORD_STRENGTH") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString) + if err != nil { + return nil, err + } + bf.tp.SetFlen(21) + sig := &builtinValidatePasswordStrengthSig{bf} + return sig, nil +} + +type builtinValidatePasswordStrengthSig struct { + baseBuiltinFunc +} + +func (b *builtinValidatePasswordStrengthSig) Clone() builtinFunc { + newSig := &builtinValidatePasswordStrengthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals VALIDATE_PASSWORD_STRENGTH(str). +// See https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_validate-password-strength +func (b *builtinValidatePasswordStrengthSig) evalInt(row chunk.Row) (int64, bool, error) { + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + str, isNull, err := b.args[0].EvalString(b.ctx, row) + if err != nil || isNull { + return 0, true, err + } else if len([]rune(str)) < 4 { + return 0, false, nil + } + if validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable); err != nil { + return 0, true, err + } else if !variable.TiDBOptOn(validation) { + return 0, false, nil + } + return b.validateStr(str, &globalVars) +} + +func (b *builtinValidatePasswordStrengthSig) validateStr(str string, globalVars *variable.GlobalVarAccessor) (int64, bool, error) { + if warn, err := pwdValidator.ValidateUserNameInPassword(str, b.ctx.GetSessionVars()); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 0, false, nil + } + if warn, err := pwdValidator.ValidatePasswordLowPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 25, false, nil + } + if warn, err := pwdValidator.ValidatePasswordMediumPolicy(str, globalVars); err != nil { + return 0, true, err + } else if len(warn) > 0 { + return 50, false, nil + } + if ok, err := pwdValidator.ValidateDictionaryPassword(str, globalVars); err != nil { + return 0, true, err + } else if !ok { + return 75, false, nil + } + return 100, false, nil } diff --git a/expression/builtin_encryption_test.go b/expression/builtin_encryption_test.go index 0f74ab611aa48..087fb3f35e466 100644 --- a/expression/builtin_encryption_test.go +++ b/expression/builtin_encryption_test.go @@ -15,12 +15,14 @@ package expression import ( + "context" "encoding/hex" "fmt" "strings" "testing" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -631,6 +633,55 @@ func TestUncompressLength(t *testing.T) { } } +func TestValidatePasswordStrength(t *testing.T) { + ctx := createContext(t) + ctx.GetSessionVars().User = &auth.UserIdentity{Username: "testuser"} + globalVarsAccessor := variable.NewMockGlobalAccessor4Tests() + ctx.GetSessionVars().GlobalVarsAccessor = globalVarsAccessor + err := globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234") + require.NoError(t, err) + + tests := []struct { + in interface{} + expect interface{} + }{ + {nil, nil}, + {"123", 0}, + {"testuser123", 0}, + {"resutset123", 0}, + {"12345", 25}, + {"12345678", 50}, + {"!Abc12345678", 75}, + {"!Abc87654321", 100}, + } + + fc := funcs[ast.ValidatePasswordStrength] + // disable password validation + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + if test.expect == nil { + require.Equal(t, types.NewDatum(nil), out) + } else { + require.Equalf(t, types.NewDatum(0), out, "%v", test) + } + } + // enable password validation + err = globalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordEnable, "ON") + require.NoError(t, err) + for _, test := range tests { + arg := types.NewDatum(test.in) + f, err := fc.getFunction(ctx, datumsToConstants([]types.Datum{arg})) + require.NoErrorf(t, err, "%v", test) + out, err := evalBuiltinFunc(f, chunk.Row{}) + require.NoErrorf(t, err, "%v", test) + require.Equalf(t, types.NewDatum(test.expect), out, "%v", test) + } +} + func TestPassword(t *testing.T) { ctx := createContext(t) cases := []struct { diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index e9a1d45ae67be..ff71913f8d70b 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/encrypt" @@ -863,3 +864,45 @@ func (b *builtinUncompressedLengthSig) vecEvalInt(input *chunk.Chunk, result *ch } return nil } + +func (b *builtinValidatePasswordStrengthSig) vectorized() bool { + return true +} + +func (b *builtinValidatePasswordStrengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil { + return err + } + + result.ResizeInt64(n, false) + result.MergeNulls(buf) + i64s := result.Int64s() + globalVars := b.ctx.GetSessionVars().GlobalVarsAccessor + enableValidation := false + validation, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordEnable) + if err != nil { + return err + } + enableValidation = variable.TiDBOptOn(validation) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + if !enableValidation { + i64s[i] = 0 + } else if score, isNull, err := b.validateStr(buf.GetString(i), &globalVars); err != nil { + return err + } else if !isNull { + i64s[i] = score + } else { + result.SetNull(i, true) + } + } + return nil +} diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index c6caa1eb60d51..46395e51bcb6b 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -75,6 +75,9 @@ var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ ast.Decode: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}, geners: []dataGenerator{newRandLenStrGener(10, 20)}}, }, + ast.ValidatePasswordStrength: { + {retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ETString}}, + }, } func TestVectorizedBuiltinEncryptionFunc(t *testing.T) { diff --git a/expression/integration_test.go b/expression/integration_test.go index a0e2f93b3103f..bb5bddfa7d9a4 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -968,6 +968,7 @@ func TestEncryptionBuiltin(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root"} ctx := context.Background() // for password @@ -1143,6 +1144,25 @@ func TestEncryptionBuiltin(t *testing.T) { tk.MustQuery("SELECT RANDOM_BYTES(1024);") result = tk.MustQuery("SELECT RANDOM_BYTES(NULL);") result.Check(testkit.Rows("")) + + // for VALIDATE_PASSWORD_STRENGTH + tk.MustExec(fmt.Sprintf("SET GLOBAL validate_password.dictionary='%s'", "password")) + tk.MustExec("SET GLOBAL validate_password.enable = 1") + tk.MustQuery("SELECT validate_password_strength('root')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('toor')").Check(testkit.Rows("0")) + tk.MustQuery("SELECT validate_password_strength('ROOT')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('TOOR')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('fooHoHo%1')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength('pass')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT validate_password_strength('password')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password0000')").Check(testkit.Rows("50")) + tk.MustQuery("SELECT validate_password_strength('password1A#')").Check(testkit.Rows("75")) + tk.MustQuery("SELECT validate_password_strength('PA12wrd!#')").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(REPEAT(\"aA1#\", 26))").Check(testkit.Rows("100")) + tk.MustQuery("SELECT validate_password_strength(null)").Check(testkit.Rows("")) + tk.MustQuery("SELECT validate_password_strength('null')").Check(testkit.Rows("25")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH( 0x6E616E646F73617135234552 )").Check(testkit.Rows("100")) + tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(CAST(0xd2 AS BINARY(10)))").Check(testkit.Rows("50")) } func TestOpBuiltin(t *testing.T) { diff --git a/sessionctx/variable/error.go b/sessionctx/variable/error.go index 60928932f0f06..f760cba8bfcd5 100644 --- a/sessionctx/variable/error.go +++ b/sessionctx/variable/error.go @@ -39,6 +39,7 @@ var ( errLocalVariable = dbterror.ClassVariable.NewStd(mysql.ErrLocalVariable) errValueNotSupportedWhen = dbterror.ClassVariable.NewStdErr(mysql.ErrNotSupportedYet, pmysql.Message("%s = OFF is not supported when %s = ON", nil)) ErrStmtNotFound = dbterror.ClassOptimizer.NewStd(mysql.ErrPreparedStmtNotFound) + ErrNotValidPassword = dbterror.ClassExecutor.NewStd(mysql.ErrNotValidPassword) // ErrFunctionsNoopImpl is an error to say the behavior is protected by the tidb_enable_noop_functions sysvar. // This is copied from expression.ErrFunctionsNoopImpl to prevent circular dependencies. // It needs to be public for tests. diff --git a/sessionctx/variable/noop.go b/sessionctx/variable/noop.go index 398ea09f3ec92..5505fe65a3623 100644 --- a/sessionctx/variable/noop.go +++ b/sessionctx/variable/noop.go @@ -58,8 +58,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: BigTables, Value: Off, Type: TypeBool}, {Scope: ScopeNone, Name: "skip_external_locking", Value: "1"}, {Scope: ScopeNone, Name: "innodb_sync_array_size", Value: "1"}, - {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: Off, Type: TypeBool}, - {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeSession, Name: "gtid_next", Value: ""}, {Scope: ScopeGlobal, Name: "ndb_show_foreign_key_mock_tables", Value: ""}, {Scope: ScopeNone, Name: "multi_range_count", Value: "256"}, @@ -463,7 +461,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal | ScopeSession, Name: "eq_range_index_dive_limit", Value: "200", IsHintUpdatable: true}, {Scope: ScopeNone, Name: "performance_schema_events_stages_history_size", Value: "10"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndb_join_pushdown", Value: ""}, - {Scope: ScopeGlobal, Name: "validate_password_special_char_count", Value: "1"}, {Scope: ScopeNone, Name: "performance_schema_max_thread_instances", Value: "402"}, {Scope: ScopeGlobal | ScopeSession, Name: "ndbinfo_show_hidden", Value: ""}, {Scope: ScopeGlobal | ScopeSession, Name: "net_read_timeout", Value: "30"}, @@ -472,7 +469,6 @@ var noopSysVars = []*SysVar{ {Scope: ScopeGlobal, Name: "sync_relay_log_info", Value: "10000"}, {Scope: ScopeGlobal | ScopeSession, Name: "optimizer_trace_limit", Value: "1"}, {Scope: ScopeNone, Name: "innodb_ft_max_token_size", Value: "84"}, - {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64}, {Scope: ScopeGlobal, Name: "ndb_log_binlog_index", Value: ""}, {Scope: ScopeGlobal, Name: "innodb_api_bk_commit_interval", Value: "5"}, {Scope: ScopeNone, Name: "innodb_undo_directory", Value: "."}, diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index b3902d8f0e431..060c542bddd77 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -486,6 +486,86 @@ var defaultSysVars = []*SysVar{ } return normalizedValue, nil }}, + {Scope: ScopeGlobal, Name: ValidatePasswordEnable, Value: Off, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordPolicy, Value: "MEDIUM", Type: TypeEnum, PossibleValues: []string{"LOW", "MEDIUM", "STRONG"}}, + {Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: On, Type: TypeBool}, + {Scope: ScopeGlobal, Name: ValidatePasswordLength, Value: "8", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + numberCount, specialCharCount, mixedCaseCount := PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + length, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + specialCharCount + 2*mixedCaseCount; int32(length) < minLength { + return strconv.FormatInt(int64(minLength), 10), nil + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationLength.Store(int32(TidbOptInt64(val, 8))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationLength.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordMixedCaseCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, specialCharCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationSpecialCharCount.Load() + mixedCaseCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + specialCharCount + 2*int32(mixedCaseCount); length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationMixedCaseCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationMixedCaseCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, specialCharCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidationSpecialCharCount.Load(), PasswordValidationMixedCaseCount.Load() + numberCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := int32(numberCount) + specialCharCount + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidtaionNumberCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidtaionNumberCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordSpecialCharCount, Value: "1", Type: TypeInt, MinValue: 0, MaxValue: math.MaxInt32, + Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { + length, numberCount, mixedCaseCount := PasswordValidationLength.Load(), PasswordValidtaionNumberCount.Load(), PasswordValidationMixedCaseCount.Load() + specialCharCount, err := strconv.ParseInt(normalizedValue, 10, 32) + if err != nil { + return "", err + } + if minLength := numberCount + int32(specialCharCount) + 2*mixedCaseCount; length < minLength { + PasswordValidationLength.Store(minLength) + } + return normalizedValue, nil + }, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + PasswordValidationSpecialCharCount.Store(int32(TidbOptInt64(val, 1))) + return nil + }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { + return strconv.FormatInt(int64(PasswordValidationSpecialCharCount.Load()), 10), nil + }, + }, + {Scope: ScopeGlobal, Name: ValidatePasswordDictionary, Value: "", Type: TypeStr}, /* TiDB specific variables */ {Scope: ScopeGlobal, Name: TiDBTSOClientBatchMaxWaitTime, Value: strconv.FormatFloat(DefTiDBTSOClientBatchMaxWaitTime, 'f', -1, 64), Type: TypeFloat, MinValue: 0, MaxValue: 10, @@ -1047,7 +1127,7 @@ var defaultSysVars = []*SysVar{ MemoryUsageAlarmKeepRecordNum.Store(TidbOptInt64(val, DefMemoryUsageAlarmKeepRecordNum)) return nil }, GetGlobal: func(_ context.Context, s *SessionVars) (string, error) { - return fmt.Sprintf("%d", MemoryUsageAlarmKeepRecordNum.Load()), nil + return strconv.FormatInt(MemoryUsageAlarmKeepRecordNum.Load(), 10), nil }}, /* The system variables below have GLOBAL and SESSION scope */ @@ -2125,10 +2205,6 @@ const ( BlockEncryptionMode = "block_encryption_mode" // WaitTimeout is the name for 'wait_timeout' system variable. WaitTimeout = "wait_timeout" - // ValidatePasswordNumberCount is the name of 'validate_password_number_count' system variable. - ValidatePasswordNumberCount = "validate_password_number_count" - // ValidatePasswordLength is the name of 'validate_password_length' system variable. - ValidatePasswordLength = "validate_password_length" // Version is the name of 'version' system variable. Version = "version" // VersionComment is the name of 'version_comment' system variable. @@ -2151,8 +2227,6 @@ const ( BinlogOrderCommits = "binlog_order_commits" // MasterVerifyChecksum is the name for 'master_verify_checksum' system variable. MasterVerifyChecksum = "master_verify_checksum" - // ValidatePasswordCheckUserName is the name for 'validate_password_check_user_name' system variable. - ValidatePasswordCheckUserName = "validate_password_check_user_name" // SuperReadOnly is the name for 'super_read_only' system variable. SuperReadOnly = "super_read_only" // SQLNotes is the name for 'sql_notes' system variable. @@ -2319,4 +2393,21 @@ const ( RandSeed2 = "rand_seed2" // SQLRequirePrimaryKey is the name of `sql_require_primary_key` system variable. SQLRequirePrimaryKey = "sql_require_primary_key" + // ValidatePasswordEnable turns on/off the validation of password. + ValidatePasswordEnable = "validate_password.enable" + // ValidatePasswordPolicy specifies the password policy enforced by validate_password. + ValidatePasswordPolicy = "validate_password.policy" + // ValidatePasswordCheckUserName controls whether validate_password compares passwords to the user name part of + // the effective user account for the current session + ValidatePasswordCheckUserName = "validate_password.check_user_name" + // ValidatePasswordLength specified the minimum number of characters that validate_password requires passwords to have + ValidatePasswordLength = "validate_password.length" + // ValidatePasswordMixedCaseCount specified the minimum number of lowercase and uppercase characters that validate_password requires + ValidatePasswordMixedCaseCount = "validate_password.mixed_case_count" + // ValidatePasswordNumberCount specified the minimum number of numeric (digit) characters that validate_password requires + ValidatePasswordNumberCount = "validate_password.number_count" + // ValidatePasswordSpecialCharCount specified the minimum number of nonalphanumeric characters that validate_password requires + ValidatePasswordSpecialCharCount = "validate_password.special_char_count" + // ValidatePasswordDictionary specified the dictionary that validate_password uses for checking passwords. Each word is separated by semicolon (;). + ValidatePasswordDictionary = "validate_password.dictionary" ) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index e4fa0b671cebe..a9e278107d270 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1150,6 +1150,11 @@ var ( // It should be a const and shouldn't be modified after tidb is started. DefTiDBServerMemoryLimit = serverMemoryLimitDefaultValue() GOGCTunerThreshold = atomic.NewFloat64(DefTiDBGOGCTunerThreshold) + + PasswordValidationLength = atomic.NewInt32(8) + PasswordValidationMixedCaseCount = atomic.NewInt32(1) + PasswordValidtaionNumberCount = atomic.NewInt32(1) + PasswordValidationSpecialCharCount = atomic.NewInt32(1) ) var ( diff --git a/util/password-validation/BUILD.bazel b/util/password-validation/BUILD.bazel new file mode 100644 index 0000000000000..c3649a3a15383 --- /dev/null +++ b/util/password-validation/BUILD.bazel @@ -0,0 +1,23 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "password-validation", + srcs = ["password_validation.go"], + importpath = "github.com/pingcap/tidb/util/password-validation", + visibility = ["//visibility:public"], + deps = [ + "//sessionctx/variable", + "//util/hack", + ], +) + +go_test( + name = "password-validation_test", + srcs = ["password_validation_test.go"], + embed = [":password-validation"], + deps = [ + "//parser/auth", + "//sessionctx/variable", + "@com_github_stretchr_testify//require", + ], +) diff --git a/util/password-validation/password_validation.go b/util/password-validation/password_validation.go new file mode 100644 index 0000000000000..edd0bd39ec38a --- /dev/null +++ b/util/password-validation/password_validation.go @@ -0,0 +1,175 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "bytes" + "fmt" + "strconv" + "strings" + "unicode" + + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/hack" +) + +const maxPwdValidationLength int = 100 + +const minPwdValidationLength int = 4 + +// ValidateDictionaryPassword checks if the password contains words in the dictionary. +func ValidateDictionaryPassword(pwd string, globalVars *variable.GlobalVarAccessor) (bool, error) { + dictionary, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordDictionary) + if err != nil { + return false, err + } + words := strings.Split(dictionary, ";") + if len(words) == 0 { + return true, nil + } + pwd = strings.ToLower(pwd) + for _, word := range words { + if len(word) >= minPwdValidationLength && len(word) <= maxPwdValidationLength { + if strings.Contains(pwd, strings.ToLower(word)) { + return false, nil + } + } + } + return true, nil +} + +// ValidateUserNameInPassword checks whether pwd exists in the dictionary. +func ValidateUserNameInPassword(pwd string, sessionVars *variable.SessionVars) (string, error) { + currentUser := sessionVars.User + globalVars := sessionVars.GlobalVarsAccessor + pwdBytes := hack.Slice(pwd) + if checkUserName, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordCheckUserName); err != nil { + return "", err + } else if currentUser != nil && variable.TiDBOptOn(checkUserName) { + for _, username := range []string{currentUser.AuthUsername, currentUser.Username} { + usernameBytes := hack.Slice(username) + userNameLen := len(usernameBytes) + if userNameLen == 0 { + continue + } + if bytes.Contains(pwdBytes, usernameBytes) { + return "Password Contains User Name", nil + } + usernameReversedBytes := make([]byte, userNameLen) + for i := range usernameBytes { + usernameReversedBytes[i] = usernameBytes[userNameLen-1-i] + } + if bytes.Contains(pwdBytes, usernameReversedBytes) { + return "Password Contains Reversed User Name", nil + } + } + } + return "", nil +} + +// ValidatePasswordLowPolicy checks whether pwd satisfies the low policy of password validation. +func ValidatePasswordLowPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { + if validateLengthStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordLength); err != nil { + return "", err + } else if validateLength, err := strconv.ParseInt(validateLengthStr, 10, 64); err != nil { + return "", err + } else if (int64)(len([]rune(pwd))) < validateLength { + return fmt.Sprintf("Require Password Length: %d", validateLength), nil + } + return "", nil +} + +// ValidatePasswordMediumPolicy checks whether pwd satisfies the medium policy of password validation. +func ValidatePasswordMediumPolicy(pwd string, globalVars *variable.GlobalVarAccessor) (string, error) { + var lowerCaseCount, upperCaseCount, numberCount, specialCharCount int64 + runes := []rune(pwd) + for i := 0; i < len(runes); i++ { + if unicode.IsUpper(runes[i]) { + upperCaseCount++ + } else if unicode.IsLower(runes[i]) { + lowerCaseCount++ + } else if unicode.IsDigit(runes[i]) { + numberCount++ + } else { + specialCharCount++ + } + } + if mixedCaseCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordMixedCaseCount); err != nil { + return "", err + } else if mixedCaseCount, err := strconv.ParseInt(mixedCaseCountStr, 10, 64); err != nil { + return "", err + } else if lowerCaseCount < mixedCaseCount { + return fmt.Sprintf("Require Password Lowercase Count: %d", mixedCaseCount), nil + } else if upperCaseCount < mixedCaseCount { + return fmt.Sprintf("Require Password Uppercase Count: %d", mixedCaseCount), nil + } + if requireNumberCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordNumberCount); err != nil { + return "", err + } else if requireNumberCount, err := strconv.ParseInt(requireNumberCountStr, 10, 64); err != nil { + return "", err + } else if numberCount < requireNumberCount { + return fmt.Sprintf("Require Password Digit Count: %d", requireNumberCount), nil + } + if requireSpecialCharCountStr, err := (*globalVars).GetGlobalSysVar(variable.ValidatePasswordSpecialCharCount); err != nil { + return "", err + } else if requireSpecialCharCount, err := strconv.ParseInt(requireSpecialCharCountStr, 10, 64); err != nil { + return "", err + } else if specialCharCount < requireSpecialCharCount { + return fmt.Sprintf("Require Password Non-alphanumeric Count: %d", requireSpecialCharCount), nil + } + return "", nil +} + +// ValidatePassword checks whether the pwd can be used. +func ValidatePassword(sessionVars *variable.SessionVars, pwd string) error { + globalVars := sessionVars.GlobalVarsAccessor + + validatePolicy, err := globalVars.GetGlobalSysVar(variable.ValidatePasswordPolicy) + if err != nil { + return err + } + if warn, err := ValidateUserNameInPassword(pwd, sessionVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + if warn, err := ValidatePasswordLowPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + // LOW + if validatePolicy == "LOW" { + return nil + } + + // MEDIUM + if warn, err := ValidatePasswordMediumPolicy(pwd, &globalVars); err != nil { + return err + } else if len(warn) > 0 { + return variable.ErrNotValidPassword.GenWithStack(warn) + } + if validatePolicy == "MEDIUM" { + return nil + } + + // STRONG + if ok, err := ValidateDictionaryPassword(pwd, &globalVars); err != nil { + return err + } else if !ok { + return variable.ErrNotValidPassword.GenWithStack("Password contains word in the dictionary") + } + return nil +} diff --git a/util/password-validation/password_validation_test.go b/util/password-validation/password_validation_test.go new file mode 100644 index 0000000000000..323cba33ba409 --- /dev/null +++ b/util/password-validation/password_validation_test.go @@ -0,0 +1,137 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/parser/auth" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/stretchr/testify/require" +) + +func TestValidateDictionaryPassword(t *testing.T) { + vars := variable.NewSessionVars(nil) + mock := variable.NewMockGlobalAccessor4Tests() + mock.SessionVars = vars + vars.GlobalVarsAccessor = mock + + err := mock.SetGlobalSysVar(context.Background(), variable.ValidatePasswordDictionary, "1234;5678;HIJK") + require.NoError(t, err) + testcases := []struct { + pwd string + result bool + }{ + {"abcdefg", true}, + {"abcd123efg", true}, + {"abcd1234efg", false}, + {"abcd12345efg", false}, + {"abcd123efghij", true}, + {"abcd123efghijk", false}, + } + for _, testcase := range testcases { + ok, err := ValidateDictionaryPassword(testcase.pwd, &vars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, testcase.result, ok, testcase.pwd) + } +} + +func TestValidateUserNameInPassword(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.User = &auth.UserIdentity{Username: "user", AuthUsername: "authuser"} + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + testcases := []struct { + pwd string + warn string + }{ + {"", ""}, + {"user", "Password Contains User Name"}, + {"authuser", "Password Contains User Name"}, + {"resu000", "Password Contains Reversed User Name"}, + {"resuhtua", "Password Contains Reversed User Name"}, + {"User", ""}, + {"authUser", ""}, + {"Resu", ""}, + {"Resuhtua", ""}, + } + // Enable check_user_name + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "ON") + require.NoError(t, err) + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, testcase.warn, warn, testcase.pwd) + } + + // Disable check_user_name + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordCheckUserName, "OFF") + require.NoError(t, err) + for _, testcase := range testcases { + warn, err := ValidateUserNameInPassword(testcase.pwd, sessionVars) + require.NoError(t, err) + require.Equal(t, "", warn, testcase.pwd) + } +} + +func TestValidatePasswordLowPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "8") + require.NoError(t, err) + + warn, err := ValidatePasswordLowPolicy("1234", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 8", warn) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) + + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordLength, "12") + require.NoError(t, err) + warn, err = ValidatePasswordLowPolicy("12345678", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Length: 12", warn) +} + +func TestValidatePasswordMediumPolicy(t *testing.T) { + sessionVars := variable.NewSessionVars(nil) + sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor4Tests() + sessionVars.GlobalVarsAccessor.(*variable.MockGlobalAccessor).SessionVars = sessionVars + + err := sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordMixedCaseCount, "1") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordSpecialCharCount, "2") + require.NoError(t, err) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.ValidatePasswordNumberCount, "3") + require.NoError(t, err) + + warn, err := ValidatePasswordMediumPolicy("!@A123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Lowercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@a123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Uppercase Count: 1", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa12", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Digit Count: 3", warn) + warn, err = ValidatePasswordMediumPolicy("!Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "Require Password Non-alphanumeric Count: 2", warn) + warn, err = ValidatePasswordMediumPolicy("!@Aa123", &sessionVars.GlobalVarsAccessor) + require.NoError(t, err) + require.Equal(t, "", warn) +}